64 lines
2.2 KiB
Python
64 lines
2.2 KiB
Python
import http.server
|
|
import json
|
|
import socketserver
|
|
from functools import partial
|
|
|
|
from AiBase import AiBase
|
|
from QwenAi import QwenAi
|
|
|
|
|
|
class Server:
|
|
def __init__(self):
|
|
self._bots = {}
|
|
|
|
def get_bot(self, model_id: str) -> AiBase:
|
|
if model_id not in self._bots.keys():
|
|
self._bots[model_id] = QwenAi(model_id)
|
|
return self._bots[model_id]
|
|
|
|
class HTTPServer(socketserver.TCPServer):
|
|
# Avoid "address already used" error when frequently restarting the script
|
|
allow_reuse_address = True
|
|
|
|
class Handler(http.server.BaseHTTPRequestHandler):
|
|
def __init__(self, server, *args, **kwargs):
|
|
self._server = server
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def do_POST(self):
|
|
try:
|
|
content_len = int(self.headers.get('Content-Length'))
|
|
post_body = self.rfile.read(content_len)
|
|
json_body = json.loads(post_body)
|
|
|
|
if json_body['command'] == 'chat':
|
|
bot = self._server.get_bot(json_body['modelId'])
|
|
print(bot)
|
|
bot.ensure_loaded()
|
|
response = bot.generate(json_body['messages'],
|
|
token_count=json_body['tokenCount'],
|
|
temperature=json_body['temperature'])
|
|
elif json_body['command'] == 'shutdown':
|
|
bot = self._server.get_bot(json_body['modelId'])
|
|
bot.ensure_unloaded()
|
|
response = None
|
|
else:
|
|
self.send_response(400)
|
|
self.end_headers()
|
|
return
|
|
|
|
self.send_response(200)
|
|
self.end_headers()
|
|
self.wfile.write(json.dumps(response).encode("utf-8"))
|
|
except Exception as e:
|
|
print(e)
|
|
self.send_response(400)
|
|
self.end_headers()
|
|
|
|
def serve(self, port=None):
|
|
handler = partial(self.Handler, self)
|
|
port = port or 8900
|
|
with self.HTTPServer(("127.0.0.1", port), handler) as httpd:
|
|
print("serving at http://127.0.0.1:" + str(port))
|
|
httpd.serve_forever()
|