diff --git a/src/index.ts b/src/index.ts index 8fb9838..f5b4cab 100644 --- a/src/index.ts +++ b/src/index.ts @@ -20,6 +20,7 @@ import { formatChatHistory, formatFacts, sanitizeQuestion } from "./utils"; const SIMILARITY_THRESHOLD = 0.5; type CustomInputValues = { chat_history?: BaseMessage[]; question: string; context: string }; +type ChatOptions = { stream: boolean; sessionId: string; includeHistory?: number }; type RAGChatConfigCommon = { model: BaseLanguageModelInterface; @@ -42,11 +43,15 @@ export type RAGChatConfig = ( ) & RAGChatConfigCommon; +type PrepareChatResult = { + question: string; + facts: string; +}; + export class RAGChat { private sdkClient: Upstash; private config?: RAGChatConfig; - //CLIENTS private vectorClient?: Index; private redisClient?: Redis; private ratelimiterClient?: Ratelimit; @@ -54,6 +59,10 @@ export class RAGChat { constructor(email: string, token: string, config?: RAGChatConfig) { this.sdkClient = new Upstash({ email, token, ...config?.umbrellaConfig }); this.config = config; + + this.initializeClients().catch((error: unknown) => { + console.error("Failed to initialize clients:", error); + }); } private async getFactsFromVector( @@ -61,7 +70,7 @@ export class RAGChat { similarityThreshold = SIMILARITY_THRESHOLD ): Promise { if (!this.vectorClient) - throw new InternalUpstashError("vectorClient is missing in getFactsFromVector"); + throw new InternalUpstashError("Vector client is missing in getFactsFromVector"); const index = this.vectorClient; const result = await index.query<{ value: string }>({ @@ -85,32 +94,23 @@ export class RAGChat { return formatFacts(facts); } - chat = async ( - input: string, - chatOptions: { stream: boolean; sessionId: string; includeHistory?: number } - ) => { - await this.initializeClients(); - + private async prepareChat(input: string): Promise { const question = sanitizeQuestion(input); const facts = await this.getFactsFromVector(question); + return { question, facts }; + } - const { stream, sessionId, includeHistory } = chatOptions; - - if (stream) { - return this.chainCallStreaming(question, facts, sessionId, includeHistory); - } + async chat(input: string, options: ChatOptions) { + const { question, facts } = await this.prepareChat(input); - return this.chainCall({ sessionId, includeHistory }, question, facts); - }; + return options.stream + ? this.streamingChainCall(question, facts, options) + : this.chainCall(options, question, facts); + } - private chainCallStreaming = ( - question: string, - facts: string, - sessionId: string, - includeHistory?: number - ) => { + private streamingChainCall = (question: string, facts: string, chatOptions: ChatOptions) => { const { stream, handlers } = LangChainStream(); - void this.chainCall({ sessionId, includeHistory }, question, facts, [handlers]); + void this.chainCall(chatOptions, question, facts, [handlers]); return new StreamingTextResponse(stream, {}); };