diff --git a/.husky/pre-commit b/.husky/pre-commit index 1e0f6f3..4e52d65 100755 --- a/.husky/pre-commit +++ b/.husky/pre-commit @@ -1,4 +1,4 @@ #!/bin/sh . "$(dirname "$0")/_/husky.sh" -bun run lint && bun run fmt && bun test +bun run lint && bun run fmt && bun test \ No newline at end of file diff --git a/bun.lockb b/bun.lockb index 5fa88c7..38f9dd8 100755 Binary files a/bun.lockb and b/bun.lockb differ diff --git a/package.json b/package.json index 6795117..48fcf93 100644 --- a/package.json +++ b/package.json @@ -35,13 +35,12 @@ "devDependencies": { "@commitlint/cli": "^19.2.2", "@commitlint/config-conventional": "^19.2.2", - "@types/node": "^20.12.2", "@typescript-eslint/eslint-plugin": "^7.0.1", "@typescript-eslint/parser": "^7.0.1", - "bun-types": "latest", - "eslint": "^8", "eslint-plugin-unicorn": "^51.0.1", + "bun-types": "latest", "husky": "^9.0.10", + "eslint": "^8", "prettier": "^3.2.5", "tsup": "latest", "typescript": "^5.4.5", diff --git a/src/constants.ts b/src/constants.ts index cd9c7b9..7b2180e 100644 --- a/src/constants.ts +++ b/src/constants.ts @@ -8,3 +8,12 @@ export const RATELIMIT_ERROR_MESSAGE = "ERR:USER_RATELIMITED"; export const DEFAULT_VECTOR_DB_NAME = "upstash-rag-chat-vector"; export const DEFAULT_REDIS_DB_NAME = "upstash-rag-chat-redis"; export const PREFERRED_REGION: PreferredRegions = "us-east-1"; + +//Retrieval related default options +export const DEFAULT_SIMILARITY_THRESHOLD = 0.5; +export const DEFAULT_TOP_K = 5; +export const DEFAULT_METADATA_KEY = "text"; + +//History related default options +export const DEFAULT_HISTORY_TTL = 86_400; +export const DEFAULT_HISTORY_LENGTH = 5; diff --git a/src/rag-chat-base.ts b/src/rag-chat-base.ts index 7870114..0b83e9f 100644 --- a/src/rag-chat-base.ts +++ b/src/rag-chat-base.ts @@ -35,11 +35,13 @@ export class RAGChatBase { question: input, similarityThreshold, topK, + metadataKey, }: RetrievePayload): Promise<PrepareChatResult> { const question = sanitizeQuestion(input); const facts = await this.retrievalService.retrieveFromVectorDb({ question, similarityThreshold, + metadataKey, topK, }); return { question, facts }; @@ -76,7 +78,7 @@ export class RAGChatBase { getMessageHistory: (sessionId: string) => this.historyService.getMessageHistory({ sessionId, - length: chatOptions.includeHistory, + length: chatOptions.historyLength, }), inputMessagesKey: "question", historyMessagesKey: "chat_history", diff --git a/src/rag-chat.test.ts b/src/rag-chat.test.ts index f614530..e23cb3e 100644 --- a/src/rag-chat.test.ts +++ b/src/rag-chat.test.ts @@ -9,8 +9,13 @@ import type { StreamingTextResponse } from "ai"; import { DEFAULT_REDIS_DB_NAME, DEFAULT_VECTOR_DB_NAME } from "./constants"; import { RatelimitUpstashError } from "./error"; import { PromptTemplate } from "@langchain/core/prompts"; +import { sleep } from "bun"; describe("RAG Chat with advance configs and direct instances", async () => { + const vector = new Index({ + token: process.env.UPSTASH_VECTOR_REST_TOKEN!, + url: process.env.UPSTASH_VECTOR_REST_URL!, + }); const ragChat = await RAGChat.initialize({ email: process.env.UPSTASH_EMAIL!, token: process.env.UPSTASH_TOKEN!, @@ -21,10 +26,7 @@ describe("RAG Chat with advance configs and direct instances", async () => { temperature: 0, apiKey: process.env.OPENAI_API_KEY, }), - vector: new Index({ - token: process.env.UPSTASH_VECTOR_REST_TOKEN!, - url: process.env.UPSTASH_VECTOR_REST_URL!, - }), + vector, redis: new Redis({ token: process.env.UPSTASH_REDIS_REST_TOKEN!, url: process.env.UPSTASH_REDIS_REST_URL!, @@ -33,10 +35,15 @@ describe("RAG Chat with advance configs and direct instances", async () => { beforeAll(async () => { await ragChat.addContext( - "Paris, the capital of France, is renowned for its iconic landmark, the Eiffel Tower, which was completed in 1889 and stands at 330 meters tall." + "Paris, the capital of France, is renowned for its iconic landmark, the Eiffel Tower, which was completed in 1889 and stands at 330 meters tall.", + "text" ); + //eslint-disable-next-line @typescript-eslint/no-magic-numbers + await sleep(3000); }); + afterAll(async () => await vector.reset()); + test("should get result without streaming", async () => { const result = (await ragChat.chat( "What year was the construction of the Eiffel Tower completed, and what is its height?", @@ -104,6 +111,11 @@ describe("RAG Chat with ratelimit", async () => { token: process.env.UPSTASH_REDIS_REST_TOKEN!, url: process.env.UPSTASH_REDIS_REST_URL!, }); + const vector = new Index({ + token: process.env.UPSTASH_VECTOR_REST_TOKEN!, + url: process.env.UPSTASH_VECTOR_REST_URL!, + }); + const ragChat = await RAGChat.initialize({ email: process.env.UPSTASH_EMAIL!, token: process.env.UPSTASH_TOKEN!, @@ -114,10 +126,7 @@ describe("RAG Chat with ratelimit", async () => { temperature: 0, apiKey: process.env.OPENAI_API_KEY, }), - vector: new Index({ - token: process.env.UPSTASH_VECTOR_REST_TOKEN!, - url: process.env.UPSTASH_VECTOR_REST_URL!, - }), + vector, redis, ratelimit: new Ratelimit({ redis, @@ -128,20 +137,32 @@ describe("RAG Chat with ratelimit", async () => { afterAll(async () => { await redis.flushdb(); + await vector.reset(); }); - test("should throw ratelimit error", async () => { - await ragChat.chat( - "What year was the construction of the Eiffel Tower completed, and what is its height?", - { stream: false } - ); + test( + "should throw ratelimit error", + async () => { + await ragChat.addContext( + "Paris, the capital of France, is renowned for its iconic landmark, the Eiffel Tower, which was completed in 1889 and stands at 330 meters tall.", + "text" + ); + //eslint-disable-next-line @typescript-eslint/no-magic-numbers + await sleep(3000); - const throwable = async () => { - await ragChat.chat("You shall not pass", { stream: false }); - }; + await ragChat.chat( + "What year was the construction of the Eiffel Tower completed, and what is its height?", + { stream: false, metadataKey: "text" } + ); - expect(throwable).toThrowError(RatelimitUpstashError); - }); + const throwable = async () => { + await ragChat.chat("You shall not pass", { stream: false }); + }; + + expect(throwable).toThrowError(RatelimitUpstashError); + }, + { timeout: 10_000 } + ); }); describe("RAG Chat with instance names", async () => { diff --git a/src/rag-chat.ts b/src/rag-chat.ts index 893fc8a..74a64d9 100644 --- a/src/rag-chat.ts +++ b/src/rag-chat.ts @@ -5,6 +5,7 @@ import type { StreamingTextResponse } from "ai"; import { HistoryService } from "./services/history"; import { RateLimitService } from "./services/ratelimit"; +import type { AddContextPayload } from "./services/retrieval"; import { RetrievalService } from "./services/retrieval"; import { QA_TEMPLATE } from "./prompts"; @@ -50,8 +51,9 @@ export class RAGChat extends RAGChatBase { //Sanitizes the given input by stripping all the newline chars then queries vector db with sanitized question. const { question, facts } = await this.prepareChat({ question: input, - similarityThreshold: options.similarityThreshold, - topK: options.topK, + similarityThreshold: options_.similarityThreshold, + metadataKey: options_.metadataKey, + topK: options_.topK, }); return options.stream @@ -60,10 +62,12 @@ export class RAGChat extends RAGChatBase { } /** Context can be either plain text or embeddings */ - async addContext(context: string | number[]) { - const retrievalService = await this.retrievalService.addEmbeddingOrTextToVectorDb(context); - if (retrievalService === "Success") return "OK"; - return "NOT-OK"; + async addContext(context: AddContextPayload[] | string, metadataKey = "text") { + const retrievalServiceStatus = await this.retrievalService.addEmbeddingOrTextToVectorDb( + context, + metadataKey + ); + return retrievalServiceStatus === "Success" ? "OK" : "NOT-OK"; } /** diff --git a/src/services/history.ts b/src/services/history.ts index d876eab..ace3876 100644 --- a/src/services/history.ts +++ b/src/services/history.ts @@ -4,10 +4,7 @@ import { Config } from "../config"; import { ClientFactory } from "../client-factory"; import type { RAGChatConfig } from "../types"; -const DAY_IN_SECONDS = 86_400; -const TOP_6 = 5; - -type GetHistory = { sessionId: string; length?: number }; +type GetHistory = { sessionId: string; length?: number; sessionTTL?: number }; type HistoryInit = Omit<RAGChatConfig, "model" | "template" | "vector"> & { email: string; token: string; @@ -19,10 +16,10 @@ export class HistoryService { this.redis = redis; } - getMessageHistory({ length = TOP_6, sessionId }: GetHistory) { + getMessageHistory({ length, sessionId, sessionTTL }: GetHistory) { return new CustomUpstashRedisChatMessageHistory({ sessionId, - sessionTTL: DAY_IN_SECONDS, + sessionTTL, topLevelChatHistoryLength: length, client: this.redis, }); diff --git a/src/services/retrieval.ts b/src/services/retrieval.ts index d6b860e..4e03e88 100644 --- a/src/services/retrieval.ts +++ b/src/services/retrieval.ts @@ -4,9 +4,9 @@ import type { RAGChatConfig } from "../types"; import { ClientFactory } from "../client-factory"; import { Config } from "../config"; import { nanoid } from "nanoid"; +import { DEFAULT_METADATA_KEY, DEFAULT_SIMILARITY_THRESHOLD, DEFAULT_TOP_K } from "../constants"; -const SIMILARITY_THRESHOLD = 0.5; -const TOP_K = 5; +export type AddContextPayload = { input: string | number[]; id?: string; metadata?: string }; type RetrievalInit = Omit<RAGChatConfig, "model" | "template" | "vector"> & { email: string; @@ -15,8 +15,9 @@ type RetrievalInit = Omit<RAGChatConfig, "model" | "template" | "vector"> & { export type RetrievePayload = { question: string; - similarityThreshold?: number; - topK?: number; + similarityThreshold: number; + metadataKey: string; + topK: number; }; export class RetrievalService { @@ -27,36 +28,64 @@ export class RetrievalService { async retrieveFromVectorDb({ question, - similarityThreshold = SIMILARITY_THRESHOLD, - topK = TOP_K, + similarityThreshold = DEFAULT_SIMILARITY_THRESHOLD, + metadataKey = DEFAULT_METADATA_KEY, + topK = DEFAULT_TOP_K, }: RetrievePayload): Promise<string> { const index = this.index; - const result = await index.query<{ value: string }>({ + const result = await index.query<Record<string, string>>({ data: question, topK, includeMetadata: true, includeVectors: false, }); - const allValuesUndefined = result.every((embedding) => embedding.metadata?.value === undefined); + const allValuesUndefined = result.every( + (embedding) => embedding.metadata?.[metadataKey] === undefined + ); + if (allValuesUndefined) { throw new TypeError(` - Query to the vector store returned ${result.length} vectors but none had "value" field in their metadata. - Text of your vectors should be in the "value" field in the metadata for the RAG Chat. + Query to the vector store returned ${result.length} vectors but none had "${metadataKey}" field in their metadata. + Text of your vectors should be in the "${metadataKey}" field in the metadata for the RAG Chat. `); } const facts = result .filter((x) => x.score >= similarityThreshold) - .map((embedding, index) => `- Context Item ${index}: ${embedding.metadata?.value ?? ""}`); + .map( + (embedding, index) => `- Context Item ${index}: ${embedding.metadata?.[metadataKey] ?? ""}` + ); return formatFacts(facts); } - async addEmbeddingOrTextToVectorDb(input: string | number[]) { + async addEmbeddingOrTextToVectorDb( + input: AddContextPayload[] | string, + metadataKey = "text" + ): Promise<string> { if (typeof input === "string") { - return this.index.upsert({ data: input, id: nanoid(), metadata: { value: input } }); + return this.index.upsert({ + data: input, + id: nanoid(), + metadata: { [metadataKey]: input }, + }); } - return this.index.upsert({ vector: input, id: nanoid(), metadata: { value: input } }); + const items = input.map((context) => { + const isText = typeof context.input === "string"; + const metadata = context.metadata + ? { [metadataKey]: context.metadata } + : isText + ? { [metadataKey]: context.input } + : {}; + + return { + [isText ? "data" : "vector"]: context.input, + id: context.id ?? nanoid(), + metadata, + }; + }); + + return this.index.upsert(items as Parameters<Index["upsert"]>[number]); } public static async init(config: RetrievalInit) { diff --git a/src/types.ts b/src/types.ts index 58823b1..c152b5c 100644 --- a/src/types.ts +++ b/src/types.ts @@ -17,7 +17,12 @@ export type ChatOptions = { /** Length of the conversation history to include in your LLM query. Increasing this may lead to hallucinations. Retrieves the last N messages. * @default 5 */ - includeHistory?: number; + historyLength?: number; + + /** Configuration to retain chat history. After the specified time, the history will be automatically cleared. + * @default 86_400 // 1 day in seconds + */ + historyTTL?: number; /** Configuration to adjust the accuracy of results. * @default 0.5 @@ -33,6 +38,13 @@ export type ChatOptions = { * @default 5 */ topK?: number; + + /** Key of metadata that we use to store additional content . + * @default "text" + * @example {text: "Capital of France is Paris"} + * + */ + metadataKey?: string; }; export type PrepareChatResult = { diff --git a/src/utils.ts b/src/utils.ts index dd0e6b3..00ebb33 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -1,6 +1,14 @@ import type { BaseMessage } from "@langchain/core/messages"; import type { ChatOptions } from "./types"; -import { DEFAULT_CHAT_SESSION_ID, DEFAULT_CHAT_RATELIMIT_SESSION_ID } from "./constants"; +import { + DEFAULT_CHAT_SESSION_ID, + DEFAULT_CHAT_RATELIMIT_SESSION_ID, + DEFAULT_METADATA_KEY, + DEFAULT_SIMILARITY_THRESHOLD, + DEFAULT_TOP_K, + DEFAULT_HISTORY_LENGTH, + DEFAULT_HISTORY_TTL, +} from "./constants"; export const sanitizeQuestion = (question: string) => { return question.trim().replaceAll("\n", " "); @@ -24,8 +32,13 @@ export function appendDefaultsIfNeeded(options: ChatOptions) { return { ...options, sessionId: options.sessionId ?? DEFAULT_CHAT_SESSION_ID, + metadataKey: options.metadataKey ?? DEFAULT_METADATA_KEY, ratelimitSessionId: options.ratelimitSessionId ?? DEFAULT_CHAT_RATELIMIT_SESSION_ID, - } satisfies ChatOptions; + similarityThreshold: options.similarityThreshold ?? DEFAULT_SIMILARITY_THRESHOLD, + topK: options.topK ?? DEFAULT_TOP_K, + historyLength: options.historyLength ?? DEFAULT_HISTORY_LENGTH, + historyTTL: options.historyLength ?? DEFAULT_HISTORY_TTL, + }; } const DEFAULT_DELAY = 20_000;