added temperature and token count to ai config

This commit is contained in:
wea_ondara
2024-05-29 18:36:17 +02:00
parent 3a78660883
commit bc0712ba11
14 changed files with 197 additions and 15 deletions

View File

@@ -30,7 +30,7 @@ class AiBase:
if self.model is None:
del self
def generate(self, messages):
def generate(self, messages, token_count=100, temperature=0.1):
return []
def record_conversation(self, messages, response):

View File

@@ -13,7 +13,7 @@ class QwenAi(AiBase):
def __init__(self, model_id_or_path=default_model_id):
super().__init__(model_id_or_path)
def generate(self, messages):
def generate(self, messages, token_count=100, **kwargs):
try:
# prepare
messages = [m for m in messages if m['role'] != 'system']
@@ -22,7 +22,7 @@ class QwenAi(AiBase):
# generate
text = self.tokenizer.apply_chat_template(input_messages, tokenize=False, add_generation_prompt=True)
model_inputs = self.tokenizer([text], return_tensors='pt').to(self.default_device)
generated_ids = self.model.generate(model_inputs.input_ids, max_new_tokens=300)
generated_ids = self.model.generate(model_inputs.input_ids, max_new_tokens=token_count)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]

View File

@@ -35,7 +35,9 @@ class Server:
bot = self._server.get_bot(json_body['modelId'])
print(bot)
bot.ensure_loaded()
response = bot.generate(json_body['messages'])
response = bot.generate(json_body['messages'],
token_count=json_body['tokenCount'],
temperature=json_body['temperature'])
elif json_body['command'] == 'shutdown':
bot = self._server.get_bot(json_body['modelId'])
bot.ensure_unloaded()

View File

@@ -21,6 +21,8 @@ import {isAuthenticatedMiddleware} from '../../../middleware/auth';
import {UUID} from '../../../models/api/uuid';
import ApiBaseModelCreatedUpdated from '../../../models/api/ApiBaseModelCreatedUpdated';
import {UUID as nodeUUID} from 'node:crypto';
import AiService from '../../../services/AiService';
import {Inject} from 'typescript-ioc';
export interface AiConfigurationVmV1 extends ApiBaseModelCreatedUpdated {
/**
@@ -35,6 +37,17 @@ export interface AiConfigurationVmV1 extends ApiBaseModelCreatedUpdated {
* @maxLength 255
*/
discordToken: string;
/**
* @isInt
* @minimum 0
* @maximum 1000
*/
tokenCount: number;
/**
* @minimum 0
* @maximum 1
*/
temperature: number,
}
@Route('api/v1/ai/configurations')
@@ -42,6 +55,12 @@ export interface AiConfigurationVmV1 extends ApiBaseModelCreatedUpdated {
@Tags('aiConfiguration')
export class AiConfigurationController extends Controller {
private repo = new AiConfigurationRepository();
private readonly service: AiService;
constructor(@Inject service: AiService) {
super();
this.service = service;
}
@Get()
@SuccessResponse(200, 'Ok')
@@ -83,7 +102,7 @@ export class AiConfigurationController extends Controller {
const aiConfiguration = AiConfiguration.fromJson(body);
//get from db
const dbaiConfiguration = await this.repo.getById(aiConfiguration.id, req.session.user!.id);
const dbaiConfiguration = await this.repo.getById(aiConfiguration.id);
if (!dbaiConfiguration) {
this.setStatus(404);
return undefined as any;
@@ -93,11 +112,14 @@ export class AiConfigurationController extends Controller {
dbaiConfiguration.name = aiConfiguration.name;
dbaiConfiguration.modelIdOrPath = aiConfiguration.modelIdOrPath;
dbaiConfiguration.discordToken = aiConfiguration.discordToken;
dbaiConfiguration.tokenCount = aiConfiguration.tokenCount;
dbaiConfiguration.temperature = aiConfiguration.temperature;
dbaiConfiguration.updatedBy = req.session.user!.id;
//save
const updated = await this.repo.update(dbaiConfiguration);
if (updated) {
this.service.updateConfiguration(dbaiConfiguration);
return dbaiConfiguration;
} else {
this.setStatus(404);
@@ -109,7 +131,7 @@ export class AiConfigurationController extends Controller {
@SuccessResponse(204, 'Deleted')
@Response(404, 'Not Found')
async remove(@Path() id: UUID, @Request() req: Req): Promise<void> {
const deleted = await this.repo.remove(id as nodeUUID, req.session.user!.id);
const deleted = await this.repo.remove(id as nodeUUID);
if (deleted) {
this.setStatus(204);
} else {

View File

@@ -0,0 +1,21 @@
import type {Knex} from 'knex';
export async function up(knex: Knex): Promise<void> {
console.log('Running migration 20240529151651_aiTokenCountAndTemperature');
await knex.transaction(async trx => {
await knex.schema.alterTable('aiConfigurations', table => {
table.integer('tokenCount').notNullable().defaultTo(100).after('discordToken');
table.double('temperature', 6).notNullable().defaultTo(0.1).after('tokenCount');
}).transacting(trx);
});
}
export async function down(knex: Knex): Promise<void> {
await knex.transaction(async trx => {
await knex.schema.alterTable('aiConfigurations', table => {
table.dropColumn('tokenCount');
table.dropColumn('temperature');
}).transacting(trx);
});
}

View File

@@ -2,6 +2,7 @@ import type {Knex} from 'knex';
import * as M20230114134301_user from './20230114134301_user';
import * as M20230610151046_userEmailAndDisplayName from './20230610151046_userEmailAndDisplayName';
import * as M20240511125408_aiConfigurations from './20240511125408_aiConfigurations';
import * as M20240529151651_aiTokenCountAndTemperature from './20240529151651_aiTokenCountAndTemperature';
export type Migration = {
name: string,
@@ -12,5 +13,6 @@ export const Migrations: Migration[] = [
{name: '20230114134301_user', migration: M20230114134301_user},
{name: '20230610151046_userEmailAndDisplayName', migration: M20230610151046_userEmailAndDisplayName},
{name: '20240511125408_aiConfigurations', migration: M20240511125408_aiConfigurations},
{name: '20240529151651_aiTokenCountAndTemperature', migration: M20240529151651_aiTokenCountAndTemperature},
];
//TODO use glob import

View File

@@ -29,7 +29,7 @@ export default class AiInstance {
await this._messagesSemaphore.acquire();
this.messages.push({'role': 'user', 'name': user, 'content': text});
getEmitter().emit('chatText', {aiInstance: this, ...this.messages[this.messages.length - 1]!});
this.messages = await this.aiPythonConnector.chat(this.configuration.modelIdOrPath, this.messages);
this.messages = await this.aiPythonConnector.chat(this.configuration.modelIdOrPath, this.messages, this.configuration.tokenCount, this.configuration.temperature);
getEmitter().emit('chatText', {aiInstance: this, ...this.messages[this.messages.length - 1]!});
return this.messages[this.messages.length - 1]!;
} finally {

View File

@@ -7,13 +7,17 @@ export default class AiConfiguration extends BaseModelCreatedUpdated {
name!: string; // max length 255
modelIdOrPath!: string; // max length 255
discordToken!: string; // max length 255
tokenCount!: number; // int
temperature!: number; // double 0-1
static new(id: UUID, name: string, modelIdOrPath: string, discordToken: string, createdBy: UUID): AiConfiguration {
static new(id: UUID, name: string, modelIdOrPath: string, discordToken: string, tokenCount: number, temperature: number, createdBy: UUID): AiConfiguration {
const ret = new AiConfiguration();
ret.id = id;
ret.name = name;
ret.modelIdOrPath = modelIdOrPath;
ret.discordToken = discordToken;
ret.tokenCount = tokenCount;
ret.temperature = temperature;
ret.createdBy = createdBy;
ret.updatedBy = createdBy;
return ret;
@@ -26,12 +30,14 @@ export default class AiConfiguration extends BaseModelCreatedUpdated {
static override get jsonSchemaWithReferences(): JSONSchema {
return mergeDeep({}, super.jsonSchemaWithReferences, {
$id: 'AiConfiguration',
required: ['date', 'name', 'lab', 'comment'],
required: ['name', 'modelIdOrPath', 'discordToken', 'tokenCount', 'temperature'],
properties: {
name: {type: 'string', maxLength: 255},
modelIdOrPath: {type: 'string', maxLength: 255},
discordToken: {type: 'string', maxLength: 255},
tokenCount: {type: 'integer', minimum: 0, maximum: 1000},
temperature: {type: 'number', minimum: 0, maximum: 1},
},
});
}

View File

@@ -15,14 +15,20 @@ export default class AiPythonConnector {
return !!this.process;
}
async chat(modelId: string, conversation: ChatMessage[]): Promise<ChatMessage[]> {
async chat(modelId: string, conversation: ChatMessage[], tokenCount: number = 300, temperature: number = 0.1): Promise<ChatMessage[]> {
await this.ensureStarted();
const port = await this.port();
const response = await fetch('http://localhost:' + port, {
method: 'POST',
headers: {'Content-Type': 'application/json'},
body: JSON.stringify({command: 'chat', modelId: modelId, messages: conversation}),
body: JSON.stringify({
command: 'chat',
modelId: modelId,
messages: conversation,
tokenCount: tokenCount,
temperature: temperature,
}),
});
if (response.status === 200) {
return await response.json();

View File

@@ -32,6 +32,13 @@ export default class AiService {
return this.instances.find(e => e.configuration.id === id);
}
updateConfiguration(configuration: AiConfiguration) {
const aiInstance = this.getInstanceById(configuration.id);
if (aiInstance) {
aiInstance.configuration = configuration;
}
}
private async ensureInstancesInited(): Promise<void> {
for (let config of await this.getConfigurations()) {
if (!this.getInstanceByName(config.name)) {

View File

@@ -129,6 +129,8 @@ const models: TsoaRoute.Models = {
"name": {"dataType":"string","required":true,"validators":{"maxLength":{"value":255}}},
"modelIdOrPath": {"dataType":"string","required":true,"validators":{"maxLength":{"value":255}}},
"discordToken": {"dataType":"string","required":true,"validators":{"maxLength":{"value":255}}},
"tokenCount": {"dataType":"integer","required":true,"validators":{"minimum":{"value":0},"maximum":{"value":1000}}},
"temperature": {"dataType":"double","required":true,"validators":{"minimum":{"value":0},"maximum":{"value":1}}},
},
"additionalProperties": false,
},

View File

@@ -4,10 +4,11 @@ import {Prop} from 'vue-property-decorator';
import {AiInstanceVmV1} from 'ai-oas';
import Discord from '@/components/dashboard/Discord.vue';
import Chat from '@/components/dashboard/Chat.vue';
import Settings from '@/components/dashboard/Settings.vue';
@Options({
name: 'AiInstanceComponent',
components: {Chat, Discord},
components: {Chat, Discord, Settings},
})
export default class AiInstanceComponent extends Vue {
@Prop({required: true})
@@ -23,7 +24,8 @@ export default class AiInstanceComponent extends Vue {
<Chat :ai-instance="aiInstance"/>
</div>
<div class="flex-grow-1" style="width: 33.33%">
<Discord :ai-instance="aiInstance"/>
<Discord :ai-instance="aiInstance" class="mb-2"/>
<Settings :ai-instance="aiInstance"/>
</div>
</div>
</div>

View File

@@ -0,0 +1,98 @@
<template>
<div class="card" style="min-width: 15em">
<div class="card-header">
<h5 class="card-title mb-0">Settings</h5>
</div>
<div class="card-body">
<div class="form-group">
<label class="form-label" :for="uid + '_temperature'">Temperature ({{ temperature }})</label>
<input class="form-range" type="range" :id="uid + '_temperature'" v-model="temperature" min="0" max="1"
step="0.001">
</div>
<div class="form-group">
<label class="form-label" :for="uid + '_tokenCount'">Token count ({{ tokenCount }})</label>
<input class="form-range" type="range" :id="uid + '_tokenCount'" v-model="tokenCount" min="0" max="500">
</div>
</div>
</div>
</template>
<script lang="ts">
import {Options, Vue} from 'vue-class-component';
import {Prop} from 'vue-property-decorator';
import {getCurrentInstance} from 'vue';
import {toast} from 'vue3-toastify';
import {AiInstanceVmV1} from 'ai-oas';
import {ApiStore} from '@/stores/ApiStore';
@Options({
name: 'Settings',
methods: {getCurrentInstance},
components: {},
})
export default class Settings extends Vue {
@Prop({required: true})
readonly aiInstance!: AiInstanceVmV1;
readonly apiStore = new ApiStore();
private temperatureSaveTimeout: number | undefined = undefined;
private tokenCountSaveTimeout: number | undefined = undefined;
private oldTemperature: number = 0;
private oldTokenCount: number = 0;
get temperature(): number {
return this.aiInstance.configuration.temperature;
}
set temperature(val: number) {
if (this.temperatureSaveTimeout) {
clearTimeout(this.temperatureSaveTimeout);
this.temperatureSaveTimeout = undefined;
} else {
this.oldTemperature = this.temperature;
}
this.aiInstance.configuration.temperature = val;
this.temperatureSaveTimeout = setTimeout(() => {
this.aiInstance.configuration.temperature = val;
this.apiStore.aiConfigurationApi
.update(this.aiInstance.configuration)
.then(_ => this.aiInstance.configuration.temperature = val)
.catch(e => {
toast.error('Error while setting temperature: ' + JSON.stringify(e));
this.aiInstance.configuration.temperature = this.oldTemperature;
});
}, 500);
}
get tokenCount(): number {
return this.aiInstance.configuration.tokenCount;
}
set tokenCount(val: number) {
if (this.tokenCountSaveTimeout) {
clearTimeout(this.tokenCountSaveTimeout);
this.tokenCountSaveTimeout = undefined;
} else {
this.oldTokenCount = this.tokenCount;
}
this.aiInstance.configuration.tokenCount = val;
this.tokenCountSaveTimeout = setTimeout(() => {
this.aiInstance.configuration.tokenCount = val;
this.apiStore.aiConfigurationApi
.update(this.aiInstance.configuration)
.then(_ => this.aiInstance.configuration.tokenCount = val)
.catch(e => {
toast.error('Error while setting tokenCount: ' + JSON.stringify(e));
this.aiInstance.configuration.tokenCount = this.oldTokenCount;
});
}, 500);
}
get uid(): string {
return '' + getCurrentInstance()?.uid!;
}
}
</script>

View File

@@ -232,6 +232,18 @@
"discordToken": {
"type": "string",
"maxLength": 255
},
"tokenCount": {
"type": "integer",
"format": "int32",
"minimum": 0,
"maximum": 1000
},
"temperature": {
"type": "number",
"format": "double",
"minimum": 0,
"maximum": 1
}
},
"required": [
@@ -242,7 +254,9 @@
"updatedBy",
"name",
"modelIdOrPath",
"discordToken"
"discordToken",
"tokenCount",
"temperature"
],
"type": "object",
"additionalProperties": false