replace python django backend with nodejs backend

This commit is contained in:
wea_ondara
2024-05-27 18:59:58 +02:00
parent ebd0748894
commit 8b60d023e8
123 changed files with 15193 additions and 88 deletions

63
ai/AiBase.py Normal file
View File

@@ -0,0 +1,63 @@
import datetime
import json
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
class AiBase:
model_id_or_path: str = None
model = None
tokenizer = None
def __init__(self, model_id_or_path: str = None):
self.model_id_or_path = model_id_or_path
def ensure_loaded(self):
if self.model is not None:
return
print('Loading ' + self.model_id_or_path)
self.model = AutoModelForCausalLM.from_pretrained(self.model_id_or_path, torch_dtype='auto', device_map='auto')
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id_or_path)
# print(self.tokenizer.default_chat_template)
# print(type(self.model))
# print(type(self.tokenizer))
print('Loaded')
def ensure_unload(self):
if self.model is None:
del self
def generate(self, messages):
return []
def record_conversation(self, messages, response):
messages = messages + [response]
this_dir = os.path.dirname(os.path.abspath(__file__))
conversations_folder = this_dir + '/../../../conversations'
folder = conversations_folder + '/' + self.model_id_or_path.replace('/', '_')
self._mkdir(conversations_folder)
self._mkdir(folder)
timestamp = datetime.datetime.utcnow().strftime('%Y%m%d%H%M%S')
pickle_filename = folder + '/' + timestamp + '.json'
with open(pickle_filename, 'w') as file:
json.dump(messages, file)
def _mkdir(self, path):
if not os.path.isdir(path):
os.mkdir(path)
def __del__(self):
try:
raise Exception()
except Exception as e:
print(e)
del self.model
del self.tokenizer
import gc
gc.collect()
torch.cuda.empty_cache()

38
ai/QwenAi.py Normal file
View File

@@ -0,0 +1,38 @@
import torch
from AiBase import AiBase
class QwenAi(AiBase):
default_device = 'cuda' # the device to load the model onto
default_model_id = 'Qwen/Qwen1.5-1.8B-Chat'
default_instruction = {'role': 'system',
'name': 'system',
'content': 'Your name is "Laura". You are an AI created by Alice.'}
def __init__(self, model_id_or_path=default_model_id):
super().__init__(model_id_or_path)
def generate(self, messages):
try:
# prepare
messages = [m for m in messages if m['role'] != 'system']
input_messages = [self.default_instruction] + messages
# generate
text = self.tokenizer.apply_chat_template(input_messages, tokenize=False, add_generation_prompt=True)
model_inputs = self.tokenizer([text], return_tensors='pt').to(self.default_device)
generated_ids = self.model.generate(model_inputs.input_ids, max_new_tokens=300)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
# add response and save conversation
response_entry = {'role': 'assistant', 'name': 'assistant', 'content': response}
messages.append(response_entry)
self.record_conversation(input_messages, response_entry)
return messages
finally:
torch.cuda.empty_cache() # clear cache or the gpu mem will be used a lot

17
ai/main.py Normal file
View File

@@ -0,0 +1,17 @@
import os
import sys
def main(*args):
from server import Server
args = [a for a in args]
port = args.index('--port') if '--port' in args else -1
port = int(args[port + 1]) if port >= 0 and port + 1 < len(args) else None
Server().serve(port=port)
if __name__ == '__main__':
print(os.getcwd())
main(*sys.argv)

63
ai/server.py Normal file
View File

@@ -0,0 +1,63 @@
import http.server
import json
import socketserver
from functools import partial
from AiBase import AiBase
from QwenAi import QwenAi
class Server:
def __init__(self):
self._bots = {}
def get_bot(self, model_id: str) -> AiBase:
if model_id not in self._bots.keys():
self._bots[model_id] = QwenAi(model_id)
return self._bots[model_id]
class HTTPServer(socketserver.TCPServer):
# Avoid "address already used" error when frequently restarting the script
allow_reuse_address = True
class Handler(http.server.BaseHTTPRequestHandler):
def __init__(self, server, *args, **kwargs):
self._server = server
super().__init__(*args, **kwargs)
def do_POST(self):
try:
content_len = int(self.headers.get('Content-Length'))
post_body = self.rfile.read(content_len)
json_body = json.loads(post_body)
if json_body['command'] == 'chat':
bot = self._server.get_bot(json_body['modelId'])
print(bot)
bot.ensure_loaded()
response = bot.generate(json_body['messages'])
elif json_body['command'] == 'shutdown':
bot = self._server.get_bot(json_body['modelId'])
bot.ensure_unloaded()
response = None
else:
self.send_response(400)
self.end_headers()
return
self.send_response(200)
self.end_headers()
self.wfile.write(json.dumps(response).encode("utf-8"))
except Exception as e:
print(e)
self.send_response(400)
self.end_headers()
def serve(self, port=None):
handler = partial(self.Handler, self)
port = port or 8900
with self.HTTPServer(("127.0.0.1", port), handler) as httpd:
print("serving at http://127.0.0.1:" + str(port))
import sys
sys.stdout.flush()
httpd.serve_forever()