added temperature and token count to ai config
This commit is contained in:
@@ -30,7 +30,7 @@ class AiBase:
|
||||
if self.model is None:
|
||||
del self
|
||||
|
||||
def generate(self, messages):
|
||||
def generate(self, messages, token_count=100, temperature=0.1):
|
||||
return []
|
||||
|
||||
def record_conversation(self, messages, response):
|
||||
|
||||
@@ -13,7 +13,7 @@ class QwenAi(AiBase):
|
||||
def __init__(self, model_id_or_path=default_model_id):
|
||||
super().__init__(model_id_or_path)
|
||||
|
||||
def generate(self, messages):
|
||||
def generate(self, messages, token_count=100, **kwargs):
|
||||
try:
|
||||
# prepare
|
||||
messages = [m for m in messages if m['role'] != 'system']
|
||||
@@ -22,7 +22,7 @@ class QwenAi(AiBase):
|
||||
# generate
|
||||
text = self.tokenizer.apply_chat_template(input_messages, tokenize=False, add_generation_prompt=True)
|
||||
model_inputs = self.tokenizer([text], return_tensors='pt').to(self.default_device)
|
||||
generated_ids = self.model.generate(model_inputs.input_ids, max_new_tokens=300)
|
||||
generated_ids = self.model.generate(model_inputs.input_ids, max_new_tokens=token_count)
|
||||
generated_ids = [
|
||||
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
||||
]
|
||||
|
||||
@@ -35,7 +35,9 @@ class Server:
|
||||
bot = self._server.get_bot(json_body['modelId'])
|
||||
print(bot)
|
||||
bot.ensure_loaded()
|
||||
response = bot.generate(json_body['messages'])
|
||||
response = bot.generate(json_body['messages'],
|
||||
token_count=json_body['tokenCount'],
|
||||
temperature=json_body['temperature'])
|
||||
elif json_body['command'] == 'shutdown':
|
||||
bot = self._server.get_bot(json_body['modelId'])
|
||||
bot.ensure_unloaded()
|
||||
|
||||
@@ -21,6 +21,8 @@ import {isAuthenticatedMiddleware} from '../../../middleware/auth';
|
||||
import {UUID} from '../../../models/api/uuid';
|
||||
import ApiBaseModelCreatedUpdated from '../../../models/api/ApiBaseModelCreatedUpdated';
|
||||
import {UUID as nodeUUID} from 'node:crypto';
|
||||
import AiService from '../../../services/AiService';
|
||||
import {Inject} from 'typescript-ioc';
|
||||
|
||||
export interface AiConfigurationVmV1 extends ApiBaseModelCreatedUpdated {
|
||||
/**
|
||||
@@ -35,6 +37,17 @@ export interface AiConfigurationVmV1 extends ApiBaseModelCreatedUpdated {
|
||||
* @maxLength 255
|
||||
*/
|
||||
discordToken: string;
|
||||
/**
|
||||
* @isInt
|
||||
* @minimum 0
|
||||
* @maximum 1000
|
||||
*/
|
||||
tokenCount: number;
|
||||
/**
|
||||
* @minimum 0
|
||||
* @maximum 1
|
||||
*/
|
||||
temperature: number,
|
||||
}
|
||||
|
||||
@Route('api/v1/ai/configurations')
|
||||
@@ -42,6 +55,12 @@ export interface AiConfigurationVmV1 extends ApiBaseModelCreatedUpdated {
|
||||
@Tags('aiConfiguration')
|
||||
export class AiConfigurationController extends Controller {
|
||||
private repo = new AiConfigurationRepository();
|
||||
private readonly service: AiService;
|
||||
|
||||
constructor(@Inject service: AiService) {
|
||||
super();
|
||||
this.service = service;
|
||||
}
|
||||
|
||||
@Get()
|
||||
@SuccessResponse(200, 'Ok')
|
||||
@@ -53,7 +72,7 @@ export class AiConfigurationController extends Controller {
|
||||
@SuccessResponse(200, 'Ok')
|
||||
@Response(404, 'Not Found')
|
||||
async getById(@Path() id: UUID, @Request() req: Req): Promise<AiConfigurationVmV1> {
|
||||
const aiConfiguration = await this.repo.getById( id as nodeUUID);
|
||||
const aiConfiguration = await this.repo.getById(id as nodeUUID);
|
||||
if (aiConfiguration !== undefined) {
|
||||
return aiConfiguration;
|
||||
} else {
|
||||
@@ -83,7 +102,7 @@ export class AiConfigurationController extends Controller {
|
||||
const aiConfiguration = AiConfiguration.fromJson(body);
|
||||
|
||||
//get from db
|
||||
const dbaiConfiguration = await this.repo.getById(aiConfiguration.id, req.session.user!.id);
|
||||
const dbaiConfiguration = await this.repo.getById(aiConfiguration.id);
|
||||
if (!dbaiConfiguration) {
|
||||
this.setStatus(404);
|
||||
return undefined as any;
|
||||
@@ -93,11 +112,14 @@ export class AiConfigurationController extends Controller {
|
||||
dbaiConfiguration.name = aiConfiguration.name;
|
||||
dbaiConfiguration.modelIdOrPath = aiConfiguration.modelIdOrPath;
|
||||
dbaiConfiguration.discordToken = aiConfiguration.discordToken;
|
||||
dbaiConfiguration.tokenCount = aiConfiguration.tokenCount;
|
||||
dbaiConfiguration.temperature = aiConfiguration.temperature;
|
||||
dbaiConfiguration.updatedBy = req.session.user!.id;
|
||||
|
||||
//save
|
||||
const updated = await this.repo.update(dbaiConfiguration);
|
||||
if (updated) {
|
||||
this.service.updateConfiguration(dbaiConfiguration);
|
||||
return dbaiConfiguration;
|
||||
} else {
|
||||
this.setStatus(404);
|
||||
@@ -109,7 +131,7 @@ export class AiConfigurationController extends Controller {
|
||||
@SuccessResponse(204, 'Deleted')
|
||||
@Response(404, 'Not Found')
|
||||
async remove(@Path() id: UUID, @Request() req: Req): Promise<void> {
|
||||
const deleted = await this.repo.remove(id as nodeUUID, req.session.user!.id);
|
||||
const deleted = await this.repo.remove(id as nodeUUID);
|
||||
if (deleted) {
|
||||
this.setStatus(204);
|
||||
} else {
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
import type {Knex} from 'knex';
|
||||
|
||||
export async function up(knex: Knex): Promise<void> {
|
||||
console.log('Running migration 20240529151651_aiTokenCountAndTemperature');
|
||||
|
||||
await knex.transaction(async trx => {
|
||||
await knex.schema.alterTable('aiConfigurations', table => {
|
||||
table.integer('tokenCount').notNullable().defaultTo(100).after('discordToken');
|
||||
table.double('temperature', 6).notNullable().defaultTo(0.1).after('tokenCount');
|
||||
}).transacting(trx);
|
||||
});
|
||||
}
|
||||
|
||||
export async function down(knex: Knex): Promise<void> {
|
||||
await knex.transaction(async trx => {
|
||||
await knex.schema.alterTable('aiConfigurations', table => {
|
||||
table.dropColumn('tokenCount');
|
||||
table.dropColumn('temperature');
|
||||
}).transacting(trx);
|
||||
});
|
||||
}
|
||||
@@ -2,6 +2,7 @@ import type {Knex} from 'knex';
|
||||
import * as M20230114134301_user from './20230114134301_user';
|
||||
import * as M20230610151046_userEmailAndDisplayName from './20230610151046_userEmailAndDisplayName';
|
||||
import * as M20240511125408_aiConfigurations from './20240511125408_aiConfigurations';
|
||||
import * as M20240529151651_aiTokenCountAndTemperature from './20240529151651_aiTokenCountAndTemperature';
|
||||
|
||||
export type Migration = {
|
||||
name: string,
|
||||
@@ -12,5 +13,6 @@ export const Migrations: Migration[] = [
|
||||
{name: '20230114134301_user', migration: M20230114134301_user},
|
||||
{name: '20230610151046_userEmailAndDisplayName', migration: M20230610151046_userEmailAndDisplayName},
|
||||
{name: '20240511125408_aiConfigurations', migration: M20240511125408_aiConfigurations},
|
||||
{name: '20240529151651_aiTokenCountAndTemperature', migration: M20240529151651_aiTokenCountAndTemperature},
|
||||
];
|
||||
//TODO use glob import
|
||||
@@ -29,7 +29,7 @@ export default class AiInstance {
|
||||
await this._messagesSemaphore.acquire();
|
||||
this.messages.push({'role': 'user', 'name': user, 'content': text});
|
||||
getEmitter().emit('chatText', {aiInstance: this, ...this.messages[this.messages.length - 1]!});
|
||||
this.messages = await this.aiPythonConnector.chat(this.configuration.modelIdOrPath, this.messages);
|
||||
this.messages = await this.aiPythonConnector.chat(this.configuration.modelIdOrPath, this.messages, this.configuration.tokenCount, this.configuration.temperature);
|
||||
getEmitter().emit('chatText', {aiInstance: this, ...this.messages[this.messages.length - 1]!});
|
||||
return this.messages[this.messages.length - 1]!;
|
||||
} finally {
|
||||
|
||||
@@ -7,13 +7,17 @@ export default class AiConfiguration extends BaseModelCreatedUpdated {
|
||||
name!: string; // max length 255
|
||||
modelIdOrPath!: string; // max length 255
|
||||
discordToken!: string; // max length 255
|
||||
tokenCount!: number; // int
|
||||
temperature!: number; // double 0-1
|
||||
|
||||
static new(id: UUID, name: string, modelIdOrPath: string, discordToken: string, createdBy: UUID): AiConfiguration {
|
||||
static new(id: UUID, name: string, modelIdOrPath: string, discordToken: string, tokenCount: number, temperature: number, createdBy: UUID): AiConfiguration {
|
||||
const ret = new AiConfiguration();
|
||||
ret.id = id;
|
||||
ret.name = name;
|
||||
ret.modelIdOrPath = modelIdOrPath;
|
||||
ret.discordToken = discordToken;
|
||||
ret.tokenCount = tokenCount;
|
||||
ret.temperature = temperature;
|
||||
ret.createdBy = createdBy;
|
||||
ret.updatedBy = createdBy;
|
||||
return ret;
|
||||
@@ -26,12 +30,14 @@ export default class AiConfiguration extends BaseModelCreatedUpdated {
|
||||
static override get jsonSchemaWithReferences(): JSONSchema {
|
||||
return mergeDeep({}, super.jsonSchemaWithReferences, {
|
||||
$id: 'AiConfiguration',
|
||||
required: ['date', 'name', 'lab', 'comment'],
|
||||
required: ['name', 'modelIdOrPath', 'discordToken', 'tokenCount', 'temperature'],
|
||||
|
||||
properties: {
|
||||
name: {type: 'string', maxLength: 255},
|
||||
modelIdOrPath: {type: 'string', maxLength: 255},
|
||||
discordToken: {type: 'string', maxLength: 255},
|
||||
tokenCount: {type: 'integer', minimum: 0, maximum: 1000},
|
||||
temperature: {type: 'number', minimum: 0, maximum: 1},
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
@@ -15,14 +15,20 @@ export default class AiPythonConnector {
|
||||
return !!this.process;
|
||||
}
|
||||
|
||||
async chat(modelId: string, conversation: ChatMessage[]): Promise<ChatMessage[]> {
|
||||
async chat(modelId: string, conversation: ChatMessage[], tokenCount: number = 300, temperature: number = 0.1): Promise<ChatMessage[]> {
|
||||
await this.ensureStarted();
|
||||
|
||||
const port = await this.port();
|
||||
const response = await fetch('http://localhost:' + port, {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json'},
|
||||
body: JSON.stringify({command: 'chat', modelId: modelId, messages: conversation}),
|
||||
body: JSON.stringify({
|
||||
command: 'chat',
|
||||
modelId: modelId,
|
||||
messages: conversation,
|
||||
tokenCount: tokenCount,
|
||||
temperature: temperature,
|
||||
}),
|
||||
});
|
||||
if (response.status === 200) {
|
||||
return await response.json();
|
||||
|
||||
@@ -32,6 +32,13 @@ export default class AiService {
|
||||
return this.instances.find(e => e.configuration.id === id);
|
||||
}
|
||||
|
||||
updateConfiguration(configuration: AiConfiguration) {
|
||||
const aiInstance = this.getInstanceById(configuration.id);
|
||||
if (aiInstance) {
|
||||
aiInstance.configuration = configuration;
|
||||
}
|
||||
}
|
||||
|
||||
private async ensureInstancesInited(): Promise<void> {
|
||||
for (let config of await this.getConfigurations()) {
|
||||
if (!this.getInstanceByName(config.name)) {
|
||||
|
||||
@@ -129,6 +129,8 @@ const models: TsoaRoute.Models = {
|
||||
"name": {"dataType":"string","required":true,"validators":{"maxLength":{"value":255}}},
|
||||
"modelIdOrPath": {"dataType":"string","required":true,"validators":{"maxLength":{"value":255}}},
|
||||
"discordToken": {"dataType":"string","required":true,"validators":{"maxLength":{"value":255}}},
|
||||
"tokenCount": {"dataType":"integer","required":true,"validators":{"minimum":{"value":0},"maximum":{"value":1000}}},
|
||||
"temperature": {"dataType":"double","required":true,"validators":{"minimum":{"value":0},"maximum":{"value":1}}},
|
||||
},
|
||||
"additionalProperties": false,
|
||||
},
|
||||
|
||||
@@ -4,10 +4,11 @@ import {Prop} from 'vue-property-decorator';
|
||||
import {AiInstanceVmV1} from 'ai-oas';
|
||||
import Discord from '@/components/dashboard/Discord.vue';
|
||||
import Chat from '@/components/dashboard/Chat.vue';
|
||||
import Settings from '@/components/dashboard/Settings.vue';
|
||||
|
||||
@Options({
|
||||
name: 'AiInstanceComponent',
|
||||
components: {Chat, Discord},
|
||||
components: {Chat, Discord, Settings},
|
||||
})
|
||||
export default class AiInstanceComponent extends Vue {
|
||||
@Prop({required: true})
|
||||
@@ -23,7 +24,8 @@ export default class AiInstanceComponent extends Vue {
|
||||
<Chat :ai-instance="aiInstance"/>
|
||||
</div>
|
||||
<div class="flex-grow-1" style="width: 33.33%">
|
||||
<Discord :ai-instance="aiInstance"/>
|
||||
<Discord :ai-instance="aiInstance" class="mb-2"/>
|
||||
<Settings :ai-instance="aiInstance"/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
98
frontend/src/components/dashboard/Settings.vue
Normal file
98
frontend/src/components/dashboard/Settings.vue
Normal file
@@ -0,0 +1,98 @@
|
||||
<template>
|
||||
<div class="card" style="min-width: 15em">
|
||||
<div class="card-header">
|
||||
<h5 class="card-title mb-0">Settings</h5>
|
||||
</div>
|
||||
<div class="card-body">
|
||||
<div class="form-group">
|
||||
<label class="form-label" :for="uid + '_temperature'">Temperature ({{ temperature }})</label>
|
||||
<input class="form-range" type="range" :id="uid + '_temperature'" v-model="temperature" min="0" max="1"
|
||||
step="0.001">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label class="form-label" :for="uid + '_tokenCount'">Token count ({{ tokenCount }})</label>
|
||||
<input class="form-range" type="range" :id="uid + '_tokenCount'" v-model="tokenCount" min="0" max="500">
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script lang="ts">
|
||||
import {Options, Vue} from 'vue-class-component';
|
||||
import {Prop} from 'vue-property-decorator';
|
||||
import {getCurrentInstance} from 'vue';
|
||||
import {toast} from 'vue3-toastify';
|
||||
import {AiInstanceVmV1} from 'ai-oas';
|
||||
import {ApiStore} from '@/stores/ApiStore';
|
||||
|
||||
@Options({
|
||||
name: 'Settings',
|
||||
methods: {getCurrentInstance},
|
||||
components: {},
|
||||
})
|
||||
export default class Settings extends Vue {
|
||||
@Prop({required: true})
|
||||
readonly aiInstance!: AiInstanceVmV1;
|
||||
readonly apiStore = new ApiStore();
|
||||
|
||||
private temperatureSaveTimeout: number | undefined = undefined;
|
||||
private tokenCountSaveTimeout: number | undefined = undefined;
|
||||
|
||||
private oldTemperature: number = 0;
|
||||
private oldTokenCount: number = 0;
|
||||
|
||||
get temperature(): number {
|
||||
return this.aiInstance.configuration.temperature;
|
||||
}
|
||||
|
||||
set temperature(val: number) {
|
||||
if (this.temperatureSaveTimeout) {
|
||||
clearTimeout(this.temperatureSaveTimeout);
|
||||
this.temperatureSaveTimeout = undefined;
|
||||
} else {
|
||||
this.oldTemperature = this.temperature;
|
||||
}
|
||||
this.aiInstance.configuration.temperature = val;
|
||||
|
||||
this.temperatureSaveTimeout = setTimeout(() => {
|
||||
this.aiInstance.configuration.temperature = val;
|
||||
this.apiStore.aiConfigurationApi
|
||||
.update(this.aiInstance.configuration)
|
||||
.then(_ => this.aiInstance.configuration.temperature = val)
|
||||
.catch(e => {
|
||||
toast.error('Error while setting temperature: ' + JSON.stringify(e));
|
||||
this.aiInstance.configuration.temperature = this.oldTemperature;
|
||||
});
|
||||
}, 500);
|
||||
}
|
||||
|
||||
get tokenCount(): number {
|
||||
return this.aiInstance.configuration.tokenCount;
|
||||
}
|
||||
|
||||
set tokenCount(val: number) {
|
||||
if (this.tokenCountSaveTimeout) {
|
||||
clearTimeout(this.tokenCountSaveTimeout);
|
||||
this.tokenCountSaveTimeout = undefined;
|
||||
} else {
|
||||
this.oldTokenCount = this.tokenCount;
|
||||
}
|
||||
this.aiInstance.configuration.tokenCount = val;
|
||||
|
||||
this.tokenCountSaveTimeout = setTimeout(() => {
|
||||
this.aiInstance.configuration.tokenCount = val;
|
||||
this.apiStore.aiConfigurationApi
|
||||
.update(this.aiInstance.configuration)
|
||||
.then(_ => this.aiInstance.configuration.tokenCount = val)
|
||||
.catch(e => {
|
||||
toast.error('Error while setting tokenCount: ' + JSON.stringify(e));
|
||||
this.aiInstance.configuration.tokenCount = this.oldTokenCount;
|
||||
});
|
||||
}, 500);
|
||||
}
|
||||
|
||||
get uid(): string {
|
||||
return '' + getCurrentInstance()?.uid!;
|
||||
}
|
||||
}
|
||||
</script>
|
||||
16
swagger.json
16
swagger.json
@@ -232,6 +232,18 @@
|
||||
"discordToken": {
|
||||
"type": "string",
|
||||
"maxLength": 255
|
||||
},
|
||||
"tokenCount": {
|
||||
"type": "integer",
|
||||
"format": "int32",
|
||||
"minimum": 0,
|
||||
"maximum": 1000
|
||||
},
|
||||
"temperature": {
|
||||
"type": "number",
|
||||
"format": "double",
|
||||
"minimum": 0,
|
||||
"maximum": 1
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
@@ -242,7 +254,9 @@
|
||||
"updatedBy",
|
||||
"name",
|
||||
"modelIdOrPath",
|
||||
"discordToken"
|
||||
"discordToken",
|
||||
"tokenCount",
|
||||
"temperature"
|
||||
],
|
||||
"type": "object",
|
||||
"additionalProperties": false
|
||||
|
||||
Reference in New Issue
Block a user