Files
llm/chat_mistral.py
2024-04-17 18:58:50 +02:00

32 lines
1.1 KiB
Python

import atexit
import torch
from utils.conversation import load_conversation, save_conversation
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils.prompt import prompt
device = 'cuda' # the device to load the model onto
model_id = 'mistralai/Mistral-7B-Instruct-v0.2'
print('Loading ' + model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype='auto', device_map='auto')
tokenizer = AutoTokenizer.from_pretrained(model_id)
print('Loaded')
# print(tokenizer.default_chat_template)
# read and save conversation
messages = load_conversation(model_id)
atexit.register(lambda: save_conversation(model_id, messages))
while True:
user_prompt = prompt('>> User: ')
messages.append({'role': 'user', 'content': user_prompt})
model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to("cuda")
generated_ids = model.generate(model_inputs, max_new_tokens=100, do_sample=True)
response = tokenizer.batch_decode(generated_ids)[0]
print('>> Bot : ' + response)
messages.append({'role': 'assistant', 'content': response})
torch.cuda.empty_cache()