32 lines
1.1 KiB
Python
32 lines
1.1 KiB
Python
import atexit
|
|
import torch
|
|
from utils.conversation import load_conversation, save_conversation
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
from utils.prompt import prompt
|
|
|
|
device = 'cuda' # the device to load the model onto
|
|
model_id = 'mistralai/Mistral-7B-Instruct-v0.2'
|
|
|
|
print('Loading ' + model_id)
|
|
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype='auto', device_map='auto')
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
print('Loaded')
|
|
# print(tokenizer.default_chat_template)
|
|
|
|
# read and save conversation
|
|
messages = load_conversation(model_id)
|
|
atexit.register(lambda: save_conversation(model_id, messages))
|
|
|
|
while True:
|
|
user_prompt = prompt('>> User: ')
|
|
messages.append({'role': 'user', 'content': user_prompt})
|
|
|
|
model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to("cuda")
|
|
generated_ids = model.generate(model_inputs, max_new_tokens=100, do_sample=True)
|
|
response = tokenizer.batch_decode(generated_ids)[0]
|
|
|
|
print('>> Bot : ' + response)
|
|
messages.append({'role': 'assistant', 'content': response})
|
|
torch.cuda.empty_cache()
|