19 lines
518 B
Python
19 lines
518 B
Python
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer
|
|
|
|
from conversation import mkdir
|
|
|
|
model_id = 'Qwen/Qwen1.5-0.5B-Chat'
|
|
# model_id = 'Qwen/Qwen1.5-1.8B-Chat'
|
|
|
|
print('Downloading ' + model_id)
|
|
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype='auto', device_map='auto')
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
print('Downloaded')
|
|
|
|
trainer = Trainer(
|
|
model=model,
|
|
tokenizer=tokenizer,
|
|
)
|
|
mkdir('models')
|
|
trainer.save_model('./models/' + model_id.replace('/', '_'))
|