Inference
🧩 Syntax:
sample code bạn đó đưa:
```python
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, BitsAndBytesConfig
from tqdm import tqdm
from peft import PeftModelForSeq2SeqLM
import os
import torch
from datasets import load_from_disk as load
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
if os.path.exists('./luanhoilacvien_vi.txt'):
os.remove('./luanhoilacvien_vi.txt')
output_dir='novel_zh2vi'
model_to_load ='jetaudio/novel_zh2vi'
base_model = AutoModelForSeq2SeqLM.from_pretrained('google/madlad400-3bmt',quantization_config=bnb_config,device_map='auto')
model = PeftModelForSeq2SeqLM.from_pretrained(base_model,model_to_load)
tokenizer = AutoTokenizer.from_pretrained(model_to_load)
text = open('luanhoilacvien_cn.txt','r',encoding='utf8').read()
def trans(texts,temp=1,top_p=0.8):
encodings = tokenizer(texts,return_tensors='pt',padding='max_length', max_length=256).to('cuda')
gens = model.generate(input_ids=encodings.input_ids,
do_sample=True,
max_length=256,
temperature=temp,
top_p=top_p
)
return '\n'.join([tokenizer.decode(gen,skip_special_tokens=True) for gen in gens])
text = text.replace(' ','').replace('UU看书 www.uukanshu.net','').replace('\n\n','\n').split('\n')
batch_size = 32
texts = [text[i * batch_size:(i + 1) * batch_size] for i in range((len(text) + batch_size - 1) // batch_size)]
for sens in tqdm(texts):
t = trans(['<2vi>' + sen for sen in sens],temp=0.1,top_p=0.3)
with open('./luanhoilacvien_vi.txt','a', encoding='utf8') as fout:
fout.write(t + '\n')
```