diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 95efeda..cee5643 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -10,6 +10,8 @@ on: env: UPSTASH_VECTOR_REST_URL: ${{ secrets.UPSTASH_VECTOR_REST_URL }} UPSTASH_VECTOR_REST_TOKEN: ${{ secrets.UPSTASH_VECTOR_REST_TOKEN }} + NON_EMBEDDING_UPSTASH_VECTOR_REST_TOKEN: ${{ secrets.NON_EMBEDDING_UPSTASH_VECTOR_REST_TOKEN }} + NON_EMBEDDING_UPSTASH_VECTOR_REST_URL: ${{ secrets.NON_EMBEDDING_UPSTASH_VECTOR_REST_URL }} UPSTASH_REDIS_REST_URL: ${{ secrets.UPSTASH_REDIS_REST_URL }} UPSTASH_REDIS_REST_TOKEN: ${{ secrets.UPSTASH_REDIS_REST_TOKEN }} OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} diff --git a/src/context-service/index.ts b/src/context-service/index.ts index e0a59bd..3d2e673 100644 --- a/src/context-service/index.ts +++ b/src/context-service/index.ts @@ -20,13 +20,13 @@ export class ContextService { * @example * ```typescript * await addDataToVectorDb({ - * dataType: "pdf", + * type: "pdf", * fileSource: "./data/the_wonderful_wizard_of_oz.pdf", * opts: { chunkSize: 500, chunkOverlap: 50 }, * }); * // OR * await addDataToVectorDb({ - * dataType: "text", + * type: "text", * data: "Paris, the capital of France, is renowned for its iconic landmark, the Eiffel Tower, which was completed in 1889 and stands at 330 meters tall.", * }); * ``` @@ -44,12 +44,11 @@ export class ContextService { return await this.#vectorService.save(args); } - // eslint-disable-next-line @typescript-eslint/require-await async addMany(args: AddContextPayload[] | string[]) { - return args.map((data) => this.add(data)); + return Promise.all(args.map((data) => this.add(data))); } - async deleteEntireContext(options?: ResetOptions ) { + async deleteEntireContext(options?: ResetOptions) { await this.#vectorService.reset( options?.namespace ? { namespace: options.namespace } : undefined ); @@ -62,13 +61,14 @@ export class ContextService { /** This is internal usage only. */ _getContext( optionsWithDefault: ModifiedChatOptions, - input: string, + input: string | number[], debug?: ChatLogger ) { return traceable( async (sessionId: string) => { // Log the input, which will be captured by the outer traceable - await debug?.logSendPrompt(input); + // eslint-disable-next-line @typescript-eslint/no-magic-numbers + await debug?.logSendPrompt(typeof input === "string" ? input : `${input.slice(0, 3)}...`); debug?.startRetrieveContext(); if (optionsWithDefault.disableRAG) return { formattedContext: "", metadata: [] }; diff --git a/src/database.ts b/src/database.ts index 4901735..9af4df8 100644 --- a/src/database.ts +++ b/src/database.ts @@ -59,11 +59,17 @@ export type DatasWithFileSource = export type AddContextPayload = | { type: "text"; data: string; options?: AddContextOptions; id?: string | number } - | { type: "embedding"; data: number[]; options?: AddContextOptions; id?: string | number } + | { + type: "embedding"; + data: number[]; + text?: string; + options?: AddContextOptions; + id?: string | number; + } | DatasWithFileSource; export type VectorPayload = { - question: string; + question: string | number[]; similarityThreshold?: number; topK?: number; namespace?: string; @@ -104,7 +110,7 @@ export class Database { const index = this.index; const result = await index.query>( { - data: question, + ...(typeof question === "string" ? { data: question } : { vector: question }), topK, includeData: true, includeMetadata: true, @@ -162,6 +168,7 @@ export class Database { const vectorId = await this.index.upsert( { vector: input.data, + data: input.text, id: input.id ?? nanoid(), metadata: input.options?.metadata, }, diff --git a/src/rag-chat.test.ts b/src/rag-chat.test.ts index f6cec6e..f7485c1 100644 --- a/src/rag-chat.test.ts +++ b/src/rag-chat.test.ts @@ -20,6 +20,7 @@ import { RatelimitUpstashError } from "./error"; import { custom, upstash, openai as upstashOpenai } from "./models"; import { RAGChat } from "./rag-chat"; import { awaitUntilIndexed } from "./test-utils"; +import type { PrepareChatResult } from "./types"; async function checkStream( stream: ReadableStream, @@ -933,3 +934,83 @@ describe("RAG Chat with disableHistory option", () => { expect(getMessagesSpy).toHaveBeenCalled(); }); }); + +describe("RAG Chat with non-embedding db", () => { + const namespace = "non-embedding"; + const vector = new Index({ + token: process.env.NON_EMBEDDING_UPSTASH_VECTOR_REST_TOKEN!, + url: process.env.NON_EMBEDDING_UPSTASH_VECTOR_REST_URL!, + }); + + const redis = new Redis({ + token: process.env.UPSTASH_REDIS_REST_TOKEN!, + url: process.env.UPSTASH_REDIS_REST_URL!, + }); + + const ragChat = new RAGChat({ + model: new ChatOpenAI({ + modelName: "gpt-3.5-turbo", + streaming: false, + verbose: false, + temperature: 0, + apiKey: process.env.OPENAI_API_KEY, + configuration: { + organization: process.env.OPENAI_ORGANIZATION, + }, + }), + vector, + redis, + namespace, + }); + + beforeAll(async () => { + await vector.reset({ namespace }); + // eslint-disable-next-line @typescript-eslint/no-magic-numbers + await new Promise((r) => setTimeout(r, 1000)); + }); + + test("should upsert embedding and query it", async () => { + await ragChat.context.addMany([ + { + id: 1, + type: "embedding", + data: [1, 1, 0], + text: "first embedding", + options: { namespace }, + }, + { + id: 2, + type: "embedding", + data: [1, 0, 1], + text: "second embedding", + options: { namespace }, + }, + ]); + + await awaitUntilIndexed(vector); + + let called = false; + const onContextFetched = (context: PrepareChatResult) => { + // eslint-disable-next-line @typescript-eslint/no-magic-numbers + expect(context.length).toBe(2); + + expect(context[0].data).toBe("second embedding"); + expect(context[0].id).toBe("2"); + + expect(context[1].data).toBe("first embedding"); + expect(context[1].id).toBe("1"); + + called = true; + return context; + }; + + await ragChat.chat("hello world!", { + // eslint-disable-next-line @typescript-eslint/no-magic-numbers + embedding: [0, 0, 0.5], + onContextFetched, + namespace, + }); + + expect(called).toBeTrue(); + }); +}); diff --git a/src/rag-chat.ts b/src/rag-chat.ts index ab31674..f50721f 100644 --- a/src/rag-chat.ts +++ b/src/rag-chat.ts @@ -114,7 +114,7 @@ export class RAGChat { rawContext, } = await this.context._getContext( optionsWithDefault, - input, + options?.embedding ?? input, this.debug )(optionsWithDefault.sessionId); diff --git a/src/types.ts b/src/types.ts index af57b33..4b1bbf5 100644 --- a/src/types.ts +++ b/src/types.ts @@ -78,6 +78,13 @@ export type ChatOptions = { * @default false */ disableHistory?: boolean; + + /** + * Embedding to use when fetching context. + * + * Must be provided if the Vector Database doesn't have default embeddings. + */ + embedding?: number[]; } & CommonChatAndRAGOptions; export type PrepareChatResult = { data: string; id: string; metadata: unknown }[];