import atexit import torch from utils.conversation import save_conversation from utils.prompt import prompt from transformers import AutoModelForCausalLM, AutoTokenizer device = 'cuda' # the device to load the model onto model_id = 'gpt2' print('Loading ' + model_id) model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype='auto', device_map='auto') tokenizer = AutoTokenizer.from_pretrained(model_id) print('Loaded') # print(tokenizer.default_chat_template) # read and save conversation chat_history_ids = None # messages = load_conversation(model_id) atexit.register(lambda: save_conversation(model_id, bot_input_ids)) # messages.append({'role': 'system', 'content': 'Your name is "Laura". You are an AI created by Alice.'}) while True: user_prompt = prompt('>> User: ') new_user_input_ids = tokenizer.encode(user_prompt + tokenizer.eos_token, return_tensors='pt').to(device) bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) \ if chat_history_ids is not None \ else new_user_input_ids chat_history_ids = model.generate(bot_input_ids, max_length=100, pad_token_id=tokenizer.eos_token_id).to(device) response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True) print('>> Bot : ' + response) torch.cuda.empty_cache()