64 lines
1.8 KiB
Python
64 lines
1.8 KiB
Python
import datetime
|
|
import json
|
|
import os
|
|
|
|
import torch
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
|
|
class AiBase:
|
|
model_id_or_path: str = None
|
|
model = None
|
|
tokenizer = None
|
|
|
|
def __init__(self, model_id_or_path: str = None):
|
|
self.model_id_or_path = model_id_or_path
|
|
|
|
def ensure_loaded(self):
|
|
if self.model is not None:
|
|
return
|
|
|
|
print('Loading ' + self.model_id_or_path)
|
|
self.model = AutoModelForCausalLM.from_pretrained(self.model_id_or_path, torch_dtype='auto', device_map='auto')
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id_or_path)
|
|
# print(self.tokenizer.default_chat_template)
|
|
# print(type(self.model))
|
|
# print(type(self.tokenizer))
|
|
print('Loaded')
|
|
|
|
def ensure_unload(self):
|
|
if self.model is None:
|
|
del self
|
|
|
|
def generate(self, messages):
|
|
return []
|
|
|
|
def record_conversation(self, messages, response):
|
|
messages = messages + [response]
|
|
|
|
this_dir = os.path.dirname(os.path.abspath(__file__))
|
|
conversations_folder = this_dir + '/../../../conversations'
|
|
folder = conversations_folder + '/' + self.model_id_or_path.replace('/', '_')
|
|
self._mkdir(conversations_folder)
|
|
self._mkdir(folder)
|
|
timestamp = datetime.datetime.utcnow().strftime('%Y%m%d%H%M%S')
|
|
pickle_filename = folder + '/' + timestamp + '.json'
|
|
with open(pickle_filename, 'w') as file:
|
|
json.dump(messages, file)
|
|
|
|
def _mkdir(self, path):
|
|
if not os.path.isdir(path):
|
|
os.mkdir(path)
|
|
|
|
def __del__(self):
|
|
try:
|
|
raise Exception()
|
|
except Exception as e:
|
|
print(e)
|
|
del self.model
|
|
del self.tokenizer
|
|
|
|
import gc
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|