From bc0712ba112daad0073bdc5e183753b57feae762 Mon Sep 17 00:00:00 2001 From: wea_ondara Date: Wed, 29 May 2024 18:36:17 +0200 Subject: [PATCH] added temperature and token count to ai config --- ai/AiBase.py | 2 +- ai/QwenAi.py | 4 +- ai/server.py | 4 +- .../api/v1/AiConfigurationController.ts | 28 +++++- ...240529151651_aiTokenCountAndTemperature.ts | 21 ++++ backend/src/migrations/_migrations.ts | 2 + backend/src/models/business/AiInstance.ts | 2 +- backend/src/models/db/AiConfiguration.ts | 10 +- backend/src/services/AiPythonConnector.ts | 10 +- backend/src/services/AiService.ts | 7 ++ backend/src/tsoa.gen/routes.ts | 2 + .../dashboard/AiInstanceComponent.vue | 6 +- .../src/components/dashboard/Settings.vue | 98 +++++++++++++++++++ swagger.json | 16 ++- 14 files changed, 197 insertions(+), 15 deletions(-) create mode 100644 backend/src/migrations/20240529151651_aiTokenCountAndTemperature.ts create mode 100644 frontend/src/components/dashboard/Settings.vue diff --git a/ai/AiBase.py b/ai/AiBase.py index 532653a..30061f2 100644 --- a/ai/AiBase.py +++ b/ai/AiBase.py @@ -30,7 +30,7 @@ class AiBase: if self.model is None: del self - def generate(self, messages): + def generate(self, messages, token_count=100, temperature=0.1): return [] def record_conversation(self, messages, response): diff --git a/ai/QwenAi.py b/ai/QwenAi.py index 58af596..4273f01 100644 --- a/ai/QwenAi.py +++ b/ai/QwenAi.py @@ -13,7 +13,7 @@ class QwenAi(AiBase): def __init__(self, model_id_or_path=default_model_id): super().__init__(model_id_or_path) - def generate(self, messages): + def generate(self, messages, token_count=100, **kwargs): try: # prepare messages = [m for m in messages if m['role'] != 'system'] @@ -22,7 +22,7 @@ class QwenAi(AiBase): # generate text = self.tokenizer.apply_chat_template(input_messages, tokenize=False, add_generation_prompt=True) model_inputs = self.tokenizer([text], return_tensors='pt').to(self.default_device) - generated_ids = self.model.generate(model_inputs.input_ids, max_new_tokens=300) + generated_ids = self.model.generate(model_inputs.input_ids, max_new_tokens=token_count) generated_ids = [ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) ] diff --git a/ai/server.py b/ai/server.py index 3a2750d..9d26c8b 100644 --- a/ai/server.py +++ b/ai/server.py @@ -35,7 +35,9 @@ class Server: bot = self._server.get_bot(json_body['modelId']) print(bot) bot.ensure_loaded() - response = bot.generate(json_body['messages']) + response = bot.generate(json_body['messages'], + token_count=json_body['tokenCount'], + temperature=json_body['temperature']) elif json_body['command'] == 'shutdown': bot = self._server.get_bot(json_body['modelId']) bot.ensure_unloaded() diff --git a/backend/src/controllers/api/v1/AiConfigurationController.ts b/backend/src/controllers/api/v1/AiConfigurationController.ts index 94b5f80..0e39537 100644 --- a/backend/src/controllers/api/v1/AiConfigurationController.ts +++ b/backend/src/controllers/api/v1/AiConfigurationController.ts @@ -21,6 +21,8 @@ import {isAuthenticatedMiddleware} from '../../../middleware/auth'; import {UUID} from '../../../models/api/uuid'; import ApiBaseModelCreatedUpdated from '../../../models/api/ApiBaseModelCreatedUpdated'; import {UUID as nodeUUID} from 'node:crypto'; +import AiService from '../../../services/AiService'; +import {Inject} from 'typescript-ioc'; export interface AiConfigurationVmV1 extends ApiBaseModelCreatedUpdated { /** @@ -35,6 +37,17 @@ export interface AiConfigurationVmV1 extends ApiBaseModelCreatedUpdated { * @maxLength 255 */ discordToken: string; + /** + * @isInt + * @minimum 0 + * @maximum 1000 + */ + tokenCount: number; + /** + * @minimum 0 + * @maximum 1 + */ + temperature: number, } @Route('api/v1/ai/configurations') @@ -42,6 +55,12 @@ export interface AiConfigurationVmV1 extends ApiBaseModelCreatedUpdated { @Tags('aiConfiguration') export class AiConfigurationController extends Controller { private repo = new AiConfigurationRepository(); + private readonly service: AiService; + + constructor(@Inject service: AiService) { + super(); + this.service = service; + } @Get() @SuccessResponse(200, 'Ok') @@ -53,7 +72,7 @@ export class AiConfigurationController extends Controller { @SuccessResponse(200, 'Ok') @Response(404, 'Not Found') async getById(@Path() id: UUID, @Request() req: Req): Promise { - const aiConfiguration = await this.repo.getById( id as nodeUUID); + const aiConfiguration = await this.repo.getById(id as nodeUUID); if (aiConfiguration !== undefined) { return aiConfiguration; } else { @@ -83,7 +102,7 @@ export class AiConfigurationController extends Controller { const aiConfiguration = AiConfiguration.fromJson(body); //get from db - const dbaiConfiguration = await this.repo.getById(aiConfiguration.id, req.session.user!.id); + const dbaiConfiguration = await this.repo.getById(aiConfiguration.id); if (!dbaiConfiguration) { this.setStatus(404); return undefined as any; @@ -93,11 +112,14 @@ export class AiConfigurationController extends Controller { dbaiConfiguration.name = aiConfiguration.name; dbaiConfiguration.modelIdOrPath = aiConfiguration.modelIdOrPath; dbaiConfiguration.discordToken = aiConfiguration.discordToken; + dbaiConfiguration.tokenCount = aiConfiguration.tokenCount; + dbaiConfiguration.temperature = aiConfiguration.temperature; dbaiConfiguration.updatedBy = req.session.user!.id; //save const updated = await this.repo.update(dbaiConfiguration); if (updated) { + this.service.updateConfiguration(dbaiConfiguration); return dbaiConfiguration; } else { this.setStatus(404); @@ -109,7 +131,7 @@ export class AiConfigurationController extends Controller { @SuccessResponse(204, 'Deleted') @Response(404, 'Not Found') async remove(@Path() id: UUID, @Request() req: Req): Promise { - const deleted = await this.repo.remove(id as nodeUUID, req.session.user!.id); + const deleted = await this.repo.remove(id as nodeUUID); if (deleted) { this.setStatus(204); } else { diff --git a/backend/src/migrations/20240529151651_aiTokenCountAndTemperature.ts b/backend/src/migrations/20240529151651_aiTokenCountAndTemperature.ts new file mode 100644 index 0000000..78f3559 --- /dev/null +++ b/backend/src/migrations/20240529151651_aiTokenCountAndTemperature.ts @@ -0,0 +1,21 @@ +import type {Knex} from 'knex'; + +export async function up(knex: Knex): Promise { + console.log('Running migration 20240529151651_aiTokenCountAndTemperature'); + + await knex.transaction(async trx => { + await knex.schema.alterTable('aiConfigurations', table => { + table.integer('tokenCount').notNullable().defaultTo(100).after('discordToken'); + table.double('temperature', 6).notNullable().defaultTo(0.1).after('tokenCount'); + }).transacting(trx); + }); +} + +export async function down(knex: Knex): Promise { + await knex.transaction(async trx => { + await knex.schema.alterTable('aiConfigurations', table => { + table.dropColumn('tokenCount'); + table.dropColumn('temperature'); + }).transacting(trx); + }); +} diff --git a/backend/src/migrations/_migrations.ts b/backend/src/migrations/_migrations.ts index f644f96..a5928ea 100644 --- a/backend/src/migrations/_migrations.ts +++ b/backend/src/migrations/_migrations.ts @@ -2,6 +2,7 @@ import type {Knex} from 'knex'; import * as M20230114134301_user from './20230114134301_user'; import * as M20230610151046_userEmailAndDisplayName from './20230610151046_userEmailAndDisplayName'; import * as M20240511125408_aiConfigurations from './20240511125408_aiConfigurations'; +import * as M20240529151651_aiTokenCountAndTemperature from './20240529151651_aiTokenCountAndTemperature'; export type Migration = { name: string, @@ -12,5 +13,6 @@ export const Migrations: Migration[] = [ {name: '20230114134301_user', migration: M20230114134301_user}, {name: '20230610151046_userEmailAndDisplayName', migration: M20230610151046_userEmailAndDisplayName}, {name: '20240511125408_aiConfigurations', migration: M20240511125408_aiConfigurations}, + {name: '20240529151651_aiTokenCountAndTemperature', migration: M20240529151651_aiTokenCountAndTemperature}, ]; //TODO use glob import \ No newline at end of file diff --git a/backend/src/models/business/AiInstance.ts b/backend/src/models/business/AiInstance.ts index 742d897..c039b28 100644 --- a/backend/src/models/business/AiInstance.ts +++ b/backend/src/models/business/AiInstance.ts @@ -29,7 +29,7 @@ export default class AiInstance { await this._messagesSemaphore.acquire(); this.messages.push({'role': 'user', 'name': user, 'content': text}); getEmitter().emit('chatText', {aiInstance: this, ...this.messages[this.messages.length - 1]!}); - this.messages = await this.aiPythonConnector.chat(this.configuration.modelIdOrPath, this.messages); + this.messages = await this.aiPythonConnector.chat(this.configuration.modelIdOrPath, this.messages, this.configuration.tokenCount, this.configuration.temperature); getEmitter().emit('chatText', {aiInstance: this, ...this.messages[this.messages.length - 1]!}); return this.messages[this.messages.length - 1]!; } finally { diff --git a/backend/src/models/db/AiConfiguration.ts b/backend/src/models/db/AiConfiguration.ts index eac03ab..728a5c7 100644 --- a/backend/src/models/db/AiConfiguration.ts +++ b/backend/src/models/db/AiConfiguration.ts @@ -7,13 +7,17 @@ export default class AiConfiguration extends BaseModelCreatedUpdated { name!: string; // max length 255 modelIdOrPath!: string; // max length 255 discordToken!: string; // max length 255 + tokenCount!: number; // int + temperature!: number; // double 0-1 - static new(id: UUID, name: string, modelIdOrPath: string, discordToken: string, createdBy: UUID): AiConfiguration { + static new(id: UUID, name: string, modelIdOrPath: string, discordToken: string, tokenCount: number, temperature: number, createdBy: UUID): AiConfiguration { const ret = new AiConfiguration(); ret.id = id; ret.name = name; ret.modelIdOrPath = modelIdOrPath; ret.discordToken = discordToken; + ret.tokenCount = tokenCount; + ret.temperature = temperature; ret.createdBy = createdBy; ret.updatedBy = createdBy; return ret; @@ -26,12 +30,14 @@ export default class AiConfiguration extends BaseModelCreatedUpdated { static override get jsonSchemaWithReferences(): JSONSchema { return mergeDeep({}, super.jsonSchemaWithReferences, { $id: 'AiConfiguration', - required: ['date', 'name', 'lab', 'comment'], + required: ['name', 'modelIdOrPath', 'discordToken', 'tokenCount', 'temperature'], properties: { name: {type: 'string', maxLength: 255}, modelIdOrPath: {type: 'string', maxLength: 255}, discordToken: {type: 'string', maxLength: 255}, + tokenCount: {type: 'integer', minimum: 0, maximum: 1000}, + temperature: {type: 'number', minimum: 0, maximum: 1}, }, }); } diff --git a/backend/src/services/AiPythonConnector.ts b/backend/src/services/AiPythonConnector.ts index 7422c57..4b496ed 100644 --- a/backend/src/services/AiPythonConnector.ts +++ b/backend/src/services/AiPythonConnector.ts @@ -15,14 +15,20 @@ export default class AiPythonConnector { return !!this.process; } - async chat(modelId: string, conversation: ChatMessage[]): Promise { + async chat(modelId: string, conversation: ChatMessage[], tokenCount: number = 300, temperature: number = 0.1): Promise { await this.ensureStarted(); const port = await this.port(); const response = await fetch('http://localhost:' + port, { method: 'POST', headers: {'Content-Type': 'application/json'}, - body: JSON.stringify({command: 'chat', modelId: modelId, messages: conversation}), + body: JSON.stringify({ + command: 'chat', + modelId: modelId, + messages: conversation, + tokenCount: tokenCount, + temperature: temperature, + }), }); if (response.status === 200) { return await response.json(); diff --git a/backend/src/services/AiService.ts b/backend/src/services/AiService.ts index 49bd95e..4e019cd 100644 --- a/backend/src/services/AiService.ts +++ b/backend/src/services/AiService.ts @@ -32,6 +32,13 @@ export default class AiService { return this.instances.find(e => e.configuration.id === id); } + updateConfiguration(configuration: AiConfiguration) { + const aiInstance = this.getInstanceById(configuration.id); + if (aiInstance) { + aiInstance.configuration = configuration; + } + } + private async ensureInstancesInited(): Promise { for (let config of await this.getConfigurations()) { if (!this.getInstanceByName(config.name)) { diff --git a/backend/src/tsoa.gen/routes.ts b/backend/src/tsoa.gen/routes.ts index 1c7d85d..70003d0 100644 --- a/backend/src/tsoa.gen/routes.ts +++ b/backend/src/tsoa.gen/routes.ts @@ -129,6 +129,8 @@ const models: TsoaRoute.Models = { "name": {"dataType":"string","required":true,"validators":{"maxLength":{"value":255}}}, "modelIdOrPath": {"dataType":"string","required":true,"validators":{"maxLength":{"value":255}}}, "discordToken": {"dataType":"string","required":true,"validators":{"maxLength":{"value":255}}}, + "tokenCount": {"dataType":"integer","required":true,"validators":{"minimum":{"value":0},"maximum":{"value":1000}}}, + "temperature": {"dataType":"double","required":true,"validators":{"minimum":{"value":0},"maximum":{"value":1}}}, }, "additionalProperties": false, }, diff --git a/frontend/src/components/dashboard/AiInstanceComponent.vue b/frontend/src/components/dashboard/AiInstanceComponent.vue index f484866..ccd6e7a 100644 --- a/frontend/src/components/dashboard/AiInstanceComponent.vue +++ b/frontend/src/components/dashboard/AiInstanceComponent.vue @@ -4,10 +4,11 @@ import {Prop} from 'vue-property-decorator'; import {AiInstanceVmV1} from 'ai-oas'; import Discord from '@/components/dashboard/Discord.vue'; import Chat from '@/components/dashboard/Chat.vue'; +import Settings from '@/components/dashboard/Settings.vue'; @Options({ name: 'AiInstanceComponent', - components: {Chat, Discord}, + components: {Chat, Discord, Settings}, }) export default class AiInstanceComponent extends Vue { @Prop({required: true}) @@ -23,7 +24,8 @@ export default class AiInstanceComponent extends Vue {
- + +
diff --git a/frontend/src/components/dashboard/Settings.vue b/frontend/src/components/dashboard/Settings.vue new file mode 100644 index 0000000..2e0c7b9 --- /dev/null +++ b/frontend/src/components/dashboard/Settings.vue @@ -0,0 +1,98 @@ + + + diff --git a/swagger.json b/swagger.json index 2c0edbd..7697667 100644 --- a/swagger.json +++ b/swagger.json @@ -232,6 +232,18 @@ "discordToken": { "type": "string", "maxLength": 255 + }, + "tokenCount": { + "type": "integer", + "format": "int32", + "minimum": 0, + "maximum": 1000 + }, + "temperature": { + "type": "number", + "format": "double", + "minimum": 0, + "maximum": 1 } }, "required": [ @@ -242,7 +254,9 @@ "updatedBy", "name", "modelIdOrPath", - "discordToken" + "discordToken", + "tokenCount", + "temperature" ], "type": "object", "additionalProperties": false