added temperature and token count to ai config

This commit is contained in:
wea_ondara
2024-05-29 18:36:17 +02:00
parent 3a78660883
commit bc0712ba11
14 changed files with 197 additions and 15 deletions

View File

@@ -21,6 +21,8 @@ import {isAuthenticatedMiddleware} from '../../../middleware/auth';
import {UUID} from '../../../models/api/uuid';
import ApiBaseModelCreatedUpdated from '../../../models/api/ApiBaseModelCreatedUpdated';
import {UUID as nodeUUID} from 'node:crypto';
import AiService from '../../../services/AiService';
import {Inject} from 'typescript-ioc';
export interface AiConfigurationVmV1 extends ApiBaseModelCreatedUpdated {
/**
@@ -35,6 +37,17 @@ export interface AiConfigurationVmV1 extends ApiBaseModelCreatedUpdated {
* @maxLength 255
*/
discordToken: string;
/**
* @isInt
* @minimum 0
* @maximum 1000
*/
tokenCount: number;
/**
* @minimum 0
* @maximum 1
*/
temperature: number,
}
@Route('api/v1/ai/configurations')
@@ -42,6 +55,12 @@ export interface AiConfigurationVmV1 extends ApiBaseModelCreatedUpdated {
@Tags('aiConfiguration')
export class AiConfigurationController extends Controller {
private repo = new AiConfigurationRepository();
private readonly service: AiService;
constructor(@Inject service: AiService) {
super();
this.service = service;
}
@Get()
@SuccessResponse(200, 'Ok')
@@ -53,7 +72,7 @@ export class AiConfigurationController extends Controller {
@SuccessResponse(200, 'Ok')
@Response(404, 'Not Found')
async getById(@Path() id: UUID, @Request() req: Req): Promise<AiConfigurationVmV1> {
const aiConfiguration = await this.repo.getById( id as nodeUUID);
const aiConfiguration = await this.repo.getById(id as nodeUUID);
if (aiConfiguration !== undefined) {
return aiConfiguration;
} else {
@@ -83,7 +102,7 @@ export class AiConfigurationController extends Controller {
const aiConfiguration = AiConfiguration.fromJson(body);
//get from db
const dbaiConfiguration = await this.repo.getById(aiConfiguration.id, req.session.user!.id);
const dbaiConfiguration = await this.repo.getById(aiConfiguration.id);
if (!dbaiConfiguration) {
this.setStatus(404);
return undefined as any;
@@ -93,11 +112,14 @@ export class AiConfigurationController extends Controller {
dbaiConfiguration.name = aiConfiguration.name;
dbaiConfiguration.modelIdOrPath = aiConfiguration.modelIdOrPath;
dbaiConfiguration.discordToken = aiConfiguration.discordToken;
dbaiConfiguration.tokenCount = aiConfiguration.tokenCount;
dbaiConfiguration.temperature = aiConfiguration.temperature;
dbaiConfiguration.updatedBy = req.session.user!.id;
//save
const updated = await this.repo.update(dbaiConfiguration);
if (updated) {
this.service.updateConfiguration(dbaiConfiguration);
return dbaiConfiguration;
} else {
this.setStatus(404);
@@ -109,7 +131,7 @@ export class AiConfigurationController extends Controller {
@SuccessResponse(204, 'Deleted')
@Response(404, 'Not Found')
async remove(@Path() id: UUID, @Request() req: Req): Promise<void> {
const deleted = await this.repo.remove(id as nodeUUID, req.session.user!.id);
const deleted = await this.repo.remove(id as nodeUUID);
if (deleted) {
this.setStatus(204);
} else {

View File

@@ -0,0 +1,21 @@
import type {Knex} from 'knex';
export async function up(knex: Knex): Promise<void> {
console.log('Running migration 20240529151651_aiTokenCountAndTemperature');
await knex.transaction(async trx => {
await knex.schema.alterTable('aiConfigurations', table => {
table.integer('tokenCount').notNullable().defaultTo(100).after('discordToken');
table.double('temperature', 6).notNullable().defaultTo(0.1).after('tokenCount');
}).transacting(trx);
});
}
export async function down(knex: Knex): Promise<void> {
await knex.transaction(async trx => {
await knex.schema.alterTable('aiConfigurations', table => {
table.dropColumn('tokenCount');
table.dropColumn('temperature');
}).transacting(trx);
});
}

View File

@@ -2,6 +2,7 @@ import type {Knex} from 'knex';
import * as M20230114134301_user from './20230114134301_user';
import * as M20230610151046_userEmailAndDisplayName from './20230610151046_userEmailAndDisplayName';
import * as M20240511125408_aiConfigurations from './20240511125408_aiConfigurations';
import * as M20240529151651_aiTokenCountAndTemperature from './20240529151651_aiTokenCountAndTemperature';
export type Migration = {
name: string,
@@ -12,5 +13,6 @@ export const Migrations: Migration[] = [
{name: '20230114134301_user', migration: M20230114134301_user},
{name: '20230610151046_userEmailAndDisplayName', migration: M20230610151046_userEmailAndDisplayName},
{name: '20240511125408_aiConfigurations', migration: M20240511125408_aiConfigurations},
{name: '20240529151651_aiTokenCountAndTemperature', migration: M20240529151651_aiTokenCountAndTemperature},
];
//TODO use glob import

View File

@@ -29,7 +29,7 @@ export default class AiInstance {
await this._messagesSemaphore.acquire();
this.messages.push({'role': 'user', 'name': user, 'content': text});
getEmitter().emit('chatText', {aiInstance: this, ...this.messages[this.messages.length - 1]!});
this.messages = await this.aiPythonConnector.chat(this.configuration.modelIdOrPath, this.messages);
this.messages = await this.aiPythonConnector.chat(this.configuration.modelIdOrPath, this.messages, this.configuration.tokenCount, this.configuration.temperature);
getEmitter().emit('chatText', {aiInstance: this, ...this.messages[this.messages.length - 1]!});
return this.messages[this.messages.length - 1]!;
} finally {

View File

@@ -7,13 +7,17 @@ export default class AiConfiguration extends BaseModelCreatedUpdated {
name!: string; // max length 255
modelIdOrPath!: string; // max length 255
discordToken!: string; // max length 255
tokenCount!: number; // int
temperature!: number; // double 0-1
static new(id: UUID, name: string, modelIdOrPath: string, discordToken: string, createdBy: UUID): AiConfiguration {
static new(id: UUID, name: string, modelIdOrPath: string, discordToken: string, tokenCount: number, temperature: number, createdBy: UUID): AiConfiguration {
const ret = new AiConfiguration();
ret.id = id;
ret.name = name;
ret.modelIdOrPath = modelIdOrPath;
ret.discordToken = discordToken;
ret.tokenCount = tokenCount;
ret.temperature = temperature;
ret.createdBy = createdBy;
ret.updatedBy = createdBy;
return ret;
@@ -26,12 +30,14 @@ export default class AiConfiguration extends BaseModelCreatedUpdated {
static override get jsonSchemaWithReferences(): JSONSchema {
return mergeDeep({}, super.jsonSchemaWithReferences, {
$id: 'AiConfiguration',
required: ['date', 'name', 'lab', 'comment'],
required: ['name', 'modelIdOrPath', 'discordToken', 'tokenCount', 'temperature'],
properties: {
name: {type: 'string', maxLength: 255},
modelIdOrPath: {type: 'string', maxLength: 255},
discordToken: {type: 'string', maxLength: 255},
tokenCount: {type: 'integer', minimum: 0, maximum: 1000},
temperature: {type: 'number', minimum: 0, maximum: 1},
},
});
}

View File

@@ -15,14 +15,20 @@ export default class AiPythonConnector {
return !!this.process;
}
async chat(modelId: string, conversation: ChatMessage[]): Promise<ChatMessage[]> {
async chat(modelId: string, conversation: ChatMessage[], tokenCount: number = 300, temperature: number = 0.1): Promise<ChatMessage[]> {
await this.ensureStarted();
const port = await this.port();
const response = await fetch('http://localhost:' + port, {
method: 'POST',
headers: {'Content-Type': 'application/json'},
body: JSON.stringify({command: 'chat', modelId: modelId, messages: conversation}),
body: JSON.stringify({
command: 'chat',
modelId: modelId,
messages: conversation,
tokenCount: tokenCount,
temperature: temperature,
}),
});
if (response.status === 200) {
return await response.json();

View File

@@ -32,6 +32,13 @@ export default class AiService {
return this.instances.find(e => e.configuration.id === id);
}
updateConfiguration(configuration: AiConfiguration) {
const aiInstance = this.getInstanceById(configuration.id);
if (aiInstance) {
aiInstance.configuration = configuration;
}
}
private async ensureInstancesInited(): Promise<void> {
for (let config of await this.getConfigurations()) {
if (!this.getInstanceByName(config.name)) {

View File

@@ -129,6 +129,8 @@ const models: TsoaRoute.Models = {
"name": {"dataType":"string","required":true,"validators":{"maxLength":{"value":255}}},
"modelIdOrPath": {"dataType":"string","required":true,"validators":{"maxLength":{"value":255}}},
"discordToken": {"dataType":"string","required":true,"validators":{"maxLength":{"value":255}}},
"tokenCount": {"dataType":"integer","required":true,"validators":{"minimum":{"value":0},"maximum":{"value":1000}}},
"temperature": {"dataType":"double","required":true,"validators":{"minimum":{"value":0},"maximum":{"value":1}}},
},
"additionalProperties": false,
},