Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ai #1117

Open
wants to merge 22 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions apps/nestjs-backend/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,12 @@
"webpack": "5.91.0"
},
"dependencies": {
"@ai-sdk/anthropic": "1.0.6",
"@ai-sdk/azure": "1.0.13",
"@ai-sdk/cohere": "1.0.6",
"@ai-sdk/google": "1.0.12",
"@ai-sdk/mistral": "1.0.6",
"@ai-sdk/openai": "1.0.11",
"@aws-sdk/client-s3": "3.609.0",
"@aws-sdk/s3-request-presigner": "3.609.0",
"@keyv/redis": "2.8.4",
Expand Down Expand Up @@ -144,6 +150,7 @@
"@teable/openapi": "workspace:^",
"@teamwork/websocket-json-stream": "2.0.0",
"@types/papaparse": "5.3.14",
"ai": "3.4.33",
"ajv": "8.12.0",
"axios": "1.7.7",
"bcrypt": "5.1.1",
Expand Down
2 changes: 2 additions & 0 deletions apps/nestjs-backend/src/app.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import type { ICacheConfig } from './configs/cache.config';
import { ConfigModule } from './configs/config.module';
import { AccessTokenModule } from './features/access-token/access-token.module';
import { AggregationOpenApiModule } from './features/aggregation/open-api/aggregation-open-api.module';
import { AiModule } from './features/ai/ai.module';
import { AttachmentsModule } from './features/attachments/attachments.module';
import { AuthModule } from './features/auth/auth.module';
import { BaseModule } from './features/base/base.module';
Expand Down Expand Up @@ -68,6 +69,7 @@ export const appModules = {
DashboardModule,
CommentOpenApiModule,
OrganizationModule,
AiModule,
],
providers: [InitBootstrapProvider],
};
Expand Down
20 changes: 20 additions & 0 deletions apps/nestjs-backend/src/features/ai/ai.controller.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import { Body, Controller, Post, Res } from '@nestjs/common';
import { aiGenerateRoSchema, IAiGenerateRo } from '@teable/openapi';
import { Response } from 'express';
import { ZodValidationPipe } from '../../zod.validation.pipe';
import { TablePipe } from '../table/open-api/table.pipe';
import { AiService } from './ai.service';

@Controller('api/ai')
export class AiController {
constructor(private readonly aiService: AiService) {}

@Post('/generate-stream')
async generateStream(
@Body(new ZodValidationPipe(aiGenerateRoSchema), TablePipe) aiGenerateRo: IAiGenerateRo,
@Res() res: Response
) {
const result = await this.aiService.generateStream(aiGenerateRo);
result.pipeTextStreamToResponse(res);
}
}
12 changes: 12 additions & 0 deletions apps/nestjs-backend/src/features/ai/ai.module.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import { Module } from '@nestjs/common';
import { SettingModule } from '../setting/setting.module';
import { AiController } from './ai.controller';
import { AiService } from './ai.service';

@Module({
imports: [SettingModule],
controllers: [AiController],
providers: [AiService],
exports: [AiService],
})
export class AiModule {}
108 changes: 108 additions & 0 deletions apps/nestjs-backend/src/features/ai/ai.service.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import { createAnthropic } from '@ai-sdk/anthropic';
import { createAzure } from '@ai-sdk/azure';
import { createCohere } from '@ai-sdk/cohere';
import { createGoogleGenerativeAI } from '@ai-sdk/google';
import { createMistral } from '@ai-sdk/mistral';
import { createOpenAI } from '@ai-sdk/openai';
import { Injectable } from '@nestjs/common';
import type { IAiGenerateRo, LLMProvider } from '@teable/openapi';
import { LLMProviderType, Task } from '@teable/openapi';
import { streamText } from 'ai';
import { SettingService } from '../setting/setting.service';
import { TASK_MODEL_MAP } from './constant';

@Injectable()
export class AiService {
constructor(private readonly settingService: SettingService) {}

readonly modelProviders = {
[LLMProviderType.OPENAI]: createOpenAI,
[LLMProviderType.ANTHROPIC]: createAnthropic,
[LLMProviderType.GOOGLE]: createGoogleGenerativeAI,
[LLMProviderType.AZURE]: createAzure,
[LLMProviderType.COHERE]: createCohere,
[LLMProviderType.MISTRAL]: createMistral,
} as const;

public parseModelKey(modelKey: string) {
const [type, model, provider] = modelKey.split('@');
return { type, model, provider };
}

// modelKey-> type@model@provider
async getModelConfig(modelKey: string, llmProviders: LLMProvider[] = []) {
const { type, model, provider } = this.parseModelKey(modelKey);

const providerConfig = llmProviders.find(
(p) =>
p.name.toLowerCase() === provider.toLowerCase() &&
p.type.toLowerCase() === type.toLowerCase()
);

if (!providerConfig) {
throw new Error('AI provider configuration is not set');
}

const { baseUrl, apiKey } = providerConfig;

return {
type,
model,
baseUrl,
apiKey,
};
}

async getModelInstance(
modelKey: string,
llmProviders: LLMProvider[] = []
): Promise<
ReturnType<ReturnType<(typeof this.modelProviders)[keyof typeof this.modelProviders]>>
> {
const { type, model, baseUrl, apiKey } = await this.getModelConfig(modelKey, llmProviders);

if (!baseUrl || !apiKey) {
throw new Error('AI configuration is not set');
}

const provider = Object.entries(this.modelProviders).find(([key]) =>
type.toLowerCase().includes(key.toLowerCase())
)?.[1];

if (!provider) {
throw new Error(`Unsupported AI provider: ${type}`);
}

return provider({
baseURL: baseUrl,
apiKey,
})(model);
}

async getAIConfig() {
const { aiConfig } = await this.settingService.getSetting();

if (!aiConfig) {
throw new Error('AI configuration is not set');
}

if (!aiConfig.enable) {
throw new Error('AI is not enabled');
}

return aiConfig;
}

async generateStream(aiGenerateRo: IAiGenerateRo) {
const { prompt, task = Task.Coding } = aiGenerateRo;
const config = await this.getAIConfig();
const currentTaskModel = TASK_MODEL_MAP[task];
const modelKey = config[currentTaskModel as keyof typeof config] as string;
const modelInstance = await this.getModelInstance(modelKey, config.llmProviders);

return await streamText({
model: modelInstance,
prompt: prompt,
});
}
}
8 changes: 8 additions & 0 deletions apps/nestjs-backend/src/features/ai/constant.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
/* eslint-disable @typescript-eslint/naming-convention */
import { Task } from '@teable/openapi';

export const TASK_MODEL_MAP = {
[Task.Coding]: 'codingModel',
[Task.Embedding]: 'embeddingModel',
[Task.Translation]: 'translationModel',
};
35 changes: 32 additions & 3 deletions apps/nestjs-backend/src/features/setting/setting.controller.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { Body, Controller, Get, Patch } from '@nestjs/common';
import type { IPublicSettingVo, ISettingVo } from '@teable/openapi';
import { IUpdateSettingRo, updateSettingRoSchema } from '@teable/openapi';
import type { ISettingVo } from '@teable/openapi';
import { ZodValidationPipe } from '../../zod.validation.pipe';
import { Permissions } from '../auth/decorators/permissions.decorator';
import { Public } from '../auth/decorators/public.decorator';
Expand All @@ -10,18 +10,47 @@ import { SettingService } from './setting.service';
export class SettingController {
constructor(private readonly settingService: SettingService) {}

@Public()
/**
* Get the instance settings, now we have config for AI, there are some sensitive fields, we need check the permission before return.
*/
@Permissions('instance|read')
@Get()
async getSetting(): Promise<ISettingVo> {
return await this.settingService.getSetting();
}

/**
* Public endpoint for getting public settings without authentication
*/
@Public()
@Get('public')
async getPublicSetting(): Promise<IPublicSettingVo> {
const setting = await this.settingService.getSetting();
const { aiConfig, ...rest } = setting;
return {
...rest,
aiConfig: {
enable: aiConfig?.enable ?? false,
llmProviders:
aiConfig?.llmProviders?.map((provider) => ({
type: provider.type,
name: provider.name,
models: provider.models,
})) ?? [],
},
};
}

@Patch()
@Permissions('instance|update')
async updateSetting(
@Body(new ZodValidationPipe(updateSettingRoSchema))
updateSettingRo: IUpdateSettingRo
): Promise<ISettingVo> {
return await this.settingService.updateSetting(updateSettingRo);
const res = await this.settingService.updateSetting(updateSettingRo);
return {
...res,
aiConfig: res.aiConfig ? JSON.parse(res.aiConfig) : null,
};
}
}
10 changes: 9 additions & 1 deletion apps/nestjs-backend/src/features/setting/setting.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@ export class SettingService {
disallowSignUp: true,
disallowSpaceCreation: true,
disallowSpaceInvitation: true,
aiConfig: true,
},
})
.then((setting) => ({
...setting,
aiConfig: setting.aiConfig ? JSON.parse(setting.aiConfig as string) : null,
}))
.catch(() => {
throw new NotFoundException('Setting not found');
});
Expand All @@ -25,7 +30,10 @@ export class SettingService {
const setting = await this.getSetting();
return await this.prismaService.setting.update({
where: { instanceId: setting.instanceId },
data: updateSettingRo,
data: {
...updateSettingRo,
aiConfig: updateSettingRo.aiConfig ? JSON.stringify(updateSettingRo.aiConfig) : null,
},
});
}
}
6 changes: 6 additions & 0 deletions apps/nextjs-app/src/backend/api/rest/table.ssr.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import type {
IGroupPointsRo,
IGroupPointsVo,
ListSpaceCollaboratorRo,
IPublicSettingVo,
} from '@teable/openapi';
import {
ACCEPT_INVITATION_LINK,
Expand All @@ -27,6 +28,7 @@ import {
GET_DEFAULT_VIEW_ID,
GET_FIELD_LIST,
GET_GROUP_POINTS,
GET_PUBLIC_SETTING,
GET_RECORDS_URL,
GET_RECORD_URL,
GET_SETTING,
Expand Down Expand Up @@ -168,6 +170,10 @@ export class SsrApi {
return this.axios.get<ISettingVo>(GET_SETTING).then(({ data }) => data);
}

async getPublicSetting() {
return this.axios.get<IPublicSettingVo>(GET_PUBLIC_SETTING).then(({ data }) => data);
}

async getUserMe() {
return this.axios.get<IUserMeVo>(USER_ME).then(({ data }) => data);
}
Expand Down
Loading
Loading