Files
llm/ai/AiBase.py
2024-05-27 18:59:58 +02:00

64 lines
1.8 KiB
Python

import datetime
import json
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
class AiBase:
model_id_or_path: str = None
model = None
tokenizer = None
def __init__(self, model_id_or_path: str = None):
self.model_id_or_path = model_id_or_path
def ensure_loaded(self):
if self.model is not None:
return
print('Loading ' + self.model_id_or_path)
self.model = AutoModelForCausalLM.from_pretrained(self.model_id_or_path, torch_dtype='auto', device_map='auto')
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id_or_path)
# print(self.tokenizer.default_chat_template)
# print(type(self.model))
# print(type(self.tokenizer))
print('Loaded')
def ensure_unload(self):
if self.model is None:
del self
def generate(self, messages):
return []
def record_conversation(self, messages, response):
messages = messages + [response]
this_dir = os.path.dirname(os.path.abspath(__file__))
conversations_folder = this_dir + '/../../../conversations'
folder = conversations_folder + '/' + self.model_id_or_path.replace('/', '_')
self._mkdir(conversations_folder)
self._mkdir(folder)
timestamp = datetime.datetime.utcnow().strftime('%Y%m%d%H%M%S')
pickle_filename = folder + '/' + timestamp + '.json'
with open(pickle_filename, 'w') as file:
json.dump(messages, file)
def _mkdir(self, path):
if not os.path.isdir(path):
os.mkdir(path)
def __del__(self):
try:
raise Exception()
except Exception as e:
print(e)
del self.model
del self.tokenizer
import gc
gc.collect()
torch.cuda.empty_cache()