added temperature and token count to ai config

This commit is contained in:
wea_ondara
2024-05-29 18:36:17 +02:00
parent 3a78660883
commit bc0712ba11
14 changed files with 197 additions and 15 deletions

View File

@@ -30,7 +30,7 @@ class AiBase:
if self.model is None:
del self
def generate(self, messages):
def generate(self, messages, token_count=100, temperature=0.1):
return []
def record_conversation(self, messages, response):

View File

@@ -13,7 +13,7 @@ class QwenAi(AiBase):
def __init__(self, model_id_or_path=default_model_id):
super().__init__(model_id_or_path)
def generate(self, messages):
def generate(self, messages, token_count=100, **kwargs):
try:
# prepare
messages = [m for m in messages if m['role'] != 'system']
@@ -22,7 +22,7 @@ class QwenAi(AiBase):
# generate
text = self.tokenizer.apply_chat_template(input_messages, tokenize=False, add_generation_prompt=True)
model_inputs = self.tokenizer([text], return_tensors='pt').to(self.default_device)
generated_ids = self.model.generate(model_inputs.input_ids, max_new_tokens=300)
generated_ids = self.model.generate(model_inputs.input_ids, max_new_tokens=token_count)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]

View File

@@ -35,7 +35,9 @@ class Server:
bot = self._server.get_bot(json_body['modelId'])
print(bot)
bot.ensure_loaded()
response = bot.generate(json_body['messages'])
response = bot.generate(json_body['messages'],
token_count=json_body['tokenCount'],
temperature=json_body['temperature'])
elif json_body['command'] == 'shutdown':
bot = self._server.get_bot(json_body['modelId'])
bot.ensure_unloaded()