Skip to content

Commit

Permalink
feat: allow custom seed (#6709)
Browse files Browse the repository at this point in the history
  • Loading branch information
darkskygit committed Apr 26, 2024
1 parent 5d114ea commit b639e52
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 79 deletions.
12 changes: 12 additions & 0 deletions packages/backend/server/src/plugins/copilot/controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,17 @@ export class CopilotController {
return controller.signal;
}

private parseNumber(value: string | string[] | undefined) {
if (!value) {
return undefined;
}
const num = Number.parseInt(Array.isArray(value) ? value[0] : value, 10);
if (Number.isNaN(num)) {
return undefined;
}
return num;
}

private handleError(err: any) {
if (err instanceof Error) {
const ret = {
Expand Down Expand Up @@ -256,6 +267,7 @@ export class CopilotController {

return from(
provider.generateImagesStream(session.finish(params), session.model, {
seed: this.parseNumber(params.seed),
signal: this.getSignal(req),
user: user.id,
})
Expand Down
13 changes: 4 additions & 9 deletions packages/backend/server/src/plugins/copilot/providers/fal.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import assert from 'node:assert';

import {
CopilotCapability,
CopilotImageOptions,
CopilotImageToImageProvider,
CopilotProviderType,
CopilotTextToImageProvider,
Expand Down Expand Up @@ -57,10 +58,7 @@ export class FalProvider
async generateImages(
messages: PromptMessage[],
model: string = this.availableModels[0],
options: {
signal?: AbortSignal;
user?: string;
} = {}
options: CopilotImageOptions = {}
): Promise<Array<string>> {
const { content, attachments } = messages.pop() || {};
if (!this.availableModels.includes(model)) {
Expand All @@ -82,7 +80,7 @@ export class FalProvider
image_url: attachments?.[0],
prompt: content,
sync_mode: true,
seed: 42,
seed: options.seed || 42,
enable_safety_checks: false,
}),
signal: options.signal,
Expand All @@ -100,10 +98,7 @@ export class FalProvider
async *generateImagesStream(
messages: PromptMessage[],
model: string = this.availableModels[0],
options: {
signal?: AbortSignal;
user?: string;
} = {}
options: CopilotImageOptions = {}
): AsyncIterable<string> {
const ret = await this.generateImages(messages, model, options);
for (const url of ret) {
Expand Down
33 changes: 8 additions & 25 deletions packages/backend/server/src/plugins/copilot/providers/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ import { ClientOptions, OpenAI } from 'openai';
import {
ChatMessageRole,
CopilotCapability,
CopilotChatOptions,
CopilotEmbeddingOptions,
CopilotImageOptions,
CopilotImageToTextProvider,
CopilotProviderType,
CopilotTextToEmbeddingProvider,
Expand Down Expand Up @@ -147,12 +150,7 @@ export class OpenAIProvider
async generateText(
messages: PromptMessage[],
model: string = 'gpt-3.5-turbo',
options: {
temperature?: number;
maxTokens?: number;
signal?: AbortSignal;
user?: string;
} = {}
options: CopilotChatOptions = {}
): Promise<string> {
this.checkParams({ messages, model });
const result = await this.instance.chat.completions.create(
Expand All @@ -175,12 +173,7 @@ export class OpenAIProvider
async *generateTextStream(
messages: PromptMessage[],
model: string = 'gpt-3.5-turbo',
options: {
temperature?: number;
maxTokens?: number;
signal?: AbortSignal;
user?: string;
} = {}
options: CopilotChatOptions = {}
): AsyncIterable<string> {
this.checkParams({ messages, model });
const result = await this.instance.chat.completions.create(
Expand Down Expand Up @@ -214,11 +207,7 @@ export class OpenAIProvider
async generateEmbedding(
messages: string | string[],
model: string,
options: {
dimensions: number;
signal?: AbortSignal;
user?: string;
} = { dimensions: DEFAULT_DIMENSIONS }
options: CopilotEmbeddingOptions = { dimensions: DEFAULT_DIMENSIONS }
): Promise<number[][]> {
messages = Array.isArray(messages) ? messages : [messages];
this.checkParams({ embeddings: messages, model });
Expand All @@ -236,10 +225,7 @@ export class OpenAIProvider
async generateImages(
messages: PromptMessage[],
model: string = 'dall-e-3',
options: {
signal?: AbortSignal;
user?: string;
} = {}
options: CopilotImageOptions = {}
): Promise<Array<string>> {
const { content: prompt } = messages.pop() || {};
if (!prompt) {
Expand All @@ -261,10 +247,7 @@ export class OpenAIProvider
async *generateImagesStream(
messages: PromptMessage[],
model: string = 'dall-e-3',
options: {
signal?: AbortSignal;
user?: string;
} = {}
options: CopilotImageOptions = {}
): AsyncIterable<string> {
const ret = await this.generateImages(messages, model, options);
for (const url of ret) {
Expand Down
80 changes: 35 additions & 45 deletions packages/backend/server/src/plugins/copilot/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,32 @@ export enum CopilotCapability {
ImageToText = 'image-to-text',
}

const CopilotProviderOptionsSchema = z.object({
signal: z.instanceof(AbortSignal).optional(),
user: z.string().optional(),
});

const CopilotChatOptionsSchema = CopilotProviderOptionsSchema.extend({
temperature: z.number().optional(),
maxTokens: z.number().optional(),
}).optional();

export type CopilotChatOptions = z.infer<typeof CopilotChatOptionsSchema>;

const CopilotEmbeddingOptionsSchema = CopilotProviderOptionsSchema.extend({
dimensions: z.number(),
}).optional();

export type CopilotEmbeddingOptions = z.infer<
typeof CopilotEmbeddingOptionsSchema
>;

const CopilotImageOptionsSchema = CopilotProviderOptionsSchema.extend({
seed: z.number().optional(),
}).optional();

export type CopilotImageOptions = z.infer<typeof CopilotImageOptionsSchema>;

export interface CopilotProvider {
readonly type: CopilotProviderType;
getCapabilities(): CopilotCapability[];
Expand All @@ -153,95 +179,59 @@ export interface CopilotTextToTextProvider extends CopilotProvider {
generateText(
messages: PromptMessage[],
model?: string,
options?: {
temperature?: number;
maxTokens?: number;
signal?: AbortSignal;
user?: string;
}
options?: CopilotChatOptions
): Promise<string>;
generateTextStream(
messages: PromptMessage[],
model?: string,
options?: {
temperature?: number;
maxTokens?: number;
signal?: AbortSignal;
user?: string;
}
options?: CopilotChatOptions
): AsyncIterable<string>;
}

export interface CopilotTextToEmbeddingProvider extends CopilotProvider {
generateEmbedding(
messages: string[] | string,
model: string,
options: {
dimensions: number;
signal?: AbortSignal;
user?: string;
}
options?: CopilotEmbeddingOptions
): Promise<number[][]>;
}

export interface CopilotTextToImageProvider extends CopilotProvider {
generateImages(
messages: PromptMessage[],
model: string,
options: {
signal?: AbortSignal;
user?: string;
}
options?: CopilotImageOptions
): Promise<Array<string>>;
generateImagesStream(
messages: PromptMessage[],
model?: string,
options?: {
signal?: AbortSignal;
user?: string;
}
options?: CopilotImageOptions
): AsyncIterable<string>;
}

export interface CopilotImageToTextProvider extends CopilotProvider {
generateText(
messages: PromptMessage[],
model: string,
options: {
temperature?: number;
maxTokens?: number;
signal?: AbortSignal;
user?: string;
}
options?: CopilotChatOptions
): Promise<string>;
generateTextStream(
messages: PromptMessage[],
model: string,
options: {
temperature?: number;
maxTokens?: number;
signal?: AbortSignal;
user?: string;
}
options?: CopilotChatOptions
): AsyncIterable<string>;
}

export interface CopilotImageToImageProvider extends CopilotProvider {
generateImages(
messages: PromptMessage[],
model: string,
options: {
signal?: AbortSignal;
user?: string;
}
options?: CopilotImageOptions
): Promise<Array<string>>;
generateImagesStream(
messages: PromptMessage[],
model?: string,
options?: {
signal?: AbortSignal;
user?: string;
}
options?: CopilotImageOptions
): AsyncIterable<string>;
}

Expand Down

0 comments on commit b639e52

Please sign in to comment.