import atexit import torch from utils.conversation import load_conversation, save_conversation from transformers import AutoModelForCausalLM, AutoTokenizer from utils.prompt import prompt device = 'cuda' # the device to load the model onto model_id = 'mistralai/Mistral-7B-Instruct-v0.2' print('Loading ' + model_id) model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype='auto', device_map='auto') tokenizer = AutoTokenizer.from_pretrained(model_id) print('Loaded') # print(tokenizer.default_chat_template) # read and save conversation messages = load_conversation(model_id) atexit.register(lambda: save_conversation(model_id, messages)) while True: user_prompt = prompt('>> User: ') messages.append({'role': 'user', 'content': user_prompt}) model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to("cuda") generated_ids = model.generate(model_inputs, max_new_tokens=100, do_sample=True) response = tokenizer.batch_decode(generated_ids)[0] print('>> Bot : ' + response) messages.append({'role': 'assistant', 'content': response}) torch.cuda.empty_cache()