import datetime import json import os import torch from transformers import AutoModelForCausalLM, AutoTokenizer class AiBase: model_id_or_path: str = None model = None tokenizer = None def __init__(self, model_id_or_path: str = None): self.model_id_or_path = model_id_or_path def ensure_loaded(self): if self.model is not None: return print('Loading ' + self.model_id_or_path) self.model = AutoModelForCausalLM.from_pretrained(self.model_id_or_path, torch_dtype='auto', device_map='auto') self.tokenizer = AutoTokenizer.from_pretrained(self.model_id_or_path) # print(self.tokenizer.default_chat_template) # print(type(self.model)) # print(type(self.tokenizer)) print('Loaded') def ensure_unload(self): if self.model is None: del self def generate(self, messages): return [] def record_conversation(self, messages, response): messages = messages + [response] this_dir = os.path.dirname(os.path.abspath(__file__)) conversations_folder = this_dir + '/../../../conversations' folder = conversations_folder + '/' + self.model_id_or_path.replace('/', '_') self._mkdir(conversations_folder) self._mkdir(folder) timestamp = datetime.datetime.utcnow().strftime('%Y%m%d%H%M%S') pickle_filename = folder + '/' + timestamp + '.json' with open(pickle_filename, 'w') as file: json.dump(messages, file) def _mkdir(self, path): if not os.path.isdir(path): os.mkdir(path) def __del__(self): try: raise Exception() except Exception as e: print(e) del self.model del self.tokenizer import gc gc.collect() torch.cuda.empty_cache()