initial commit
This commit is contained in:
31
chat_mistral.py
Normal file
31
chat_mistral.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import atexit
|
||||
import torch
|
||||
from utils.conversation import load_conversation, save_conversation
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from utils.prompt import prompt
|
||||
|
||||
device = 'cuda' # the device to load the model onto
|
||||
model_id = 'mistralai/Mistral-7B-Instruct-v0.2'
|
||||
|
||||
print('Loading ' + model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype='auto', device_map='auto')
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
print('Loaded')
|
||||
# print(tokenizer.default_chat_template)
|
||||
|
||||
# read and save conversation
|
||||
messages = load_conversation(model_id)
|
||||
atexit.register(lambda: save_conversation(model_id, messages))
|
||||
|
||||
while True:
|
||||
user_prompt = prompt('>> User: ')
|
||||
messages.append({'role': 'user', 'content': user_prompt})
|
||||
|
||||
model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to("cuda")
|
||||
generated_ids = model.generate(model_inputs, max_new_tokens=100, do_sample=True)
|
||||
response = tokenizer.batch_decode(generated_ids)[0]
|
||||
|
||||
print('>> Bot : ' + response)
|
||||
messages.append({'role': 'assistant', 'content': response})
|
||||
torch.cuda.empty_cache()
|
||||
Reference in New Issue
Block a user