initial commit
This commit is contained in:
34
chat_gpt2.py
Normal file
34
chat_gpt2.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import atexit
|
||||
import torch
|
||||
from utils.conversation import save_conversation
|
||||
from utils.prompt import prompt
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
device = 'cuda' # the device to load the model onto
|
||||
model_id = 'gpt2'
|
||||
|
||||
print('Loading ' + model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype='auto', device_map='auto')
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
print('Loaded')
|
||||
# print(tokenizer.default_chat_template)
|
||||
|
||||
# read and save conversation
|
||||
chat_history_ids = None
|
||||
# messages = load_conversation(model_id)
|
||||
atexit.register(lambda: save_conversation(model_id, bot_input_ids))
|
||||
|
||||
# messages.append({'role': 'system', 'content': 'Your name is "Laura". You are an AI created by Alice.'})
|
||||
while True:
|
||||
user_prompt = prompt('>> User: ')
|
||||
new_user_input_ids = tokenizer.encode(user_prompt + tokenizer.eos_token, return_tensors='pt').to(device)
|
||||
|
||||
bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) \
|
||||
if chat_history_ids is not None \
|
||||
else new_user_input_ids
|
||||
chat_history_ids = model.generate(bot_input_ids, max_length=100, pad_token_id=tokenizer.eos_token_id).to(device)
|
||||
response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
|
||||
|
||||
print('>> Bot : ' + response)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
Reference in New Issue
Block a user