Files
llm/chat_gpt2.py
2024-04-17 18:58:50 +02:00

35 lines
1.3 KiB
Python

import atexit
import torch
from utils.conversation import save_conversation
from utils.prompt import prompt
from transformers import AutoModelForCausalLM, AutoTokenizer
device = 'cuda' # the device to load the model onto
model_id = 'gpt2'
print('Loading ' + model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype='auto', device_map='auto')
tokenizer = AutoTokenizer.from_pretrained(model_id)
print('Loaded')
# print(tokenizer.default_chat_template)
# read and save conversation
chat_history_ids = None
# messages = load_conversation(model_id)
atexit.register(lambda: save_conversation(model_id, bot_input_ids))
# messages.append({'role': 'system', 'content': 'Your name is "Laura". You are an AI created by Alice.'})
while True:
user_prompt = prompt('>> User: ')
new_user_input_ids = tokenizer.encode(user_prompt + tokenizer.eos_token, return_tensors='pt').to(device)
bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) \
if chat_history_ids is not None \
else new_user_input_ids
chat_history_ids = model.generate(bot_input_ids, max_length=100, pad_token_id=tokenizer.eos_token_id).to(device)
response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
print('>> Bot : ' + response)
torch.cuda.empty_cache()