Skip to content

Commit

Permalink
feat: allow custom metadata key
Browse files Browse the repository at this point in the history
  • Loading branch information
ogzhanolguncu committed May 13, 2024
1 parent dd58b7c commit 26e4cf8
Show file tree
Hide file tree
Showing 11 changed files with 139 additions and 53 deletions.
2 changes: 1 addition & 1 deletion .husky/pre-commit
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/bin/sh
. "$(dirname "$0")/_/husky.sh"

bun run lint && bun run fmt && bun test
bun run lint && bun run fmt && bun test
Binary file modified bun.lockb
Binary file not shown.
5 changes: 2 additions & 3 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,12 @@
"devDependencies": {
"@commitlint/cli": "^19.2.2",
"@commitlint/config-conventional": "^19.2.2",
"@types/node": "^20.12.2",
"@typescript-eslint/eslint-plugin": "^7.0.1",
"@typescript-eslint/parser": "^7.0.1",
"bun-types": "latest",
"eslint": "^8",
"eslint-plugin-unicorn": "^51.0.1",
"bun-types": "latest",
"husky": "^9.0.10",
"eslint": "^8",
"prettier": "^3.2.5",
"tsup": "latest",
"typescript": "^5.4.5",
Expand Down
9 changes: 9 additions & 0 deletions src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,12 @@ export const RATELIMIT_ERROR_MESSAGE = "ERR:USER_RATELIMITED";
export const DEFAULT_VECTOR_DB_NAME = "upstash-rag-chat-vector";
export const DEFAULT_REDIS_DB_NAME = "upstash-rag-chat-redis";
export const PREFERRED_REGION: PreferredRegions = "us-east-1";

//Retrieval related default options
export const DEFAULT_SIMILARITY_THRESHOLD = 0.5;
export const DEFAULT_TOP_K = 5;
export const DEFAULT_METADATA_KEY = "text";

//History related default options
export const DEFAULT_HISTORY_TTL = 86_400;
export const DEFAULT_HISTORY_LENGTH = 5;
4 changes: 3 additions & 1 deletion src/rag-chat-base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@ export class RAGChatBase {
question: input,
similarityThreshold,
topK,
metadataKey,
}: RetrievePayload): Promise<PrepareChatResult> {
const question = sanitizeQuestion(input);
const facts = await this.retrievalService.retrieveFromVectorDb({
question,
similarityThreshold,
metadataKey,
topK,
});
return { question, facts };
Expand Down Expand Up @@ -76,7 +78,7 @@ export class RAGChatBase {
getMessageHistory: (sessionId: string) =>
this.historyService.getMessageHistory({
sessionId,
length: chatOptions.includeHistory,
length: chatOptions.historyLength,
}),
inputMessagesKey: "question",
historyMessagesKey: "chat_history",
Expand Down
59 changes: 40 additions & 19 deletions src/rag-chat.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,13 @@ import type { StreamingTextResponse } from "ai";
import { DEFAULT_REDIS_DB_NAME, DEFAULT_VECTOR_DB_NAME } from "./constants";
import { RatelimitUpstashError } from "./error";
import { PromptTemplate } from "@langchain/core/prompts";
import { sleep } from "bun";

describe("RAG Chat with advance configs and direct instances", async () => {
const vector = new Index({
token: process.env.UPSTASH_VECTOR_REST_TOKEN!,
url: process.env.UPSTASH_VECTOR_REST_URL!,
});
const ragChat = await RAGChat.initialize({
email: process.env.UPSTASH_EMAIL!,
token: process.env.UPSTASH_TOKEN!,
Expand All @@ -21,10 +26,7 @@ describe("RAG Chat with advance configs and direct instances", async () => {
temperature: 0,
apiKey: process.env.OPENAI_API_KEY,
}),
vector: new Index({
token: process.env.UPSTASH_VECTOR_REST_TOKEN!,
url: process.env.UPSTASH_VECTOR_REST_URL!,
}),
vector,
redis: new Redis({
token: process.env.UPSTASH_REDIS_REST_TOKEN!,
url: process.env.UPSTASH_REDIS_REST_URL!,
Expand All @@ -33,10 +35,15 @@ describe("RAG Chat with advance configs and direct instances", async () => {

beforeAll(async () => {
await ragChat.addContext(
"Paris, the capital of France, is renowned for its iconic landmark, the Eiffel Tower, which was completed in 1889 and stands at 330 meters tall."
"Paris, the capital of France, is renowned for its iconic landmark, the Eiffel Tower, which was completed in 1889 and stands at 330 meters tall.",
"text"
);
//eslint-disable-next-line @typescript-eslint/no-magic-numbers
await sleep(3000);
});

afterAll(async () => await vector.reset());

test("should get result without streaming", async () => {
const result = (await ragChat.chat(
"What year was the construction of the Eiffel Tower completed, and what is its height?",
Expand Down Expand Up @@ -104,6 +111,11 @@ describe("RAG Chat with ratelimit", async () => {
token: process.env.UPSTASH_REDIS_REST_TOKEN!,
url: process.env.UPSTASH_REDIS_REST_URL!,
});
const vector = new Index({
token: process.env.UPSTASH_VECTOR_REST_TOKEN!,
url: process.env.UPSTASH_VECTOR_REST_URL!,
});

const ragChat = await RAGChat.initialize({
email: process.env.UPSTASH_EMAIL!,
token: process.env.UPSTASH_TOKEN!,
Expand All @@ -114,10 +126,7 @@ describe("RAG Chat with ratelimit", async () => {
temperature: 0,
apiKey: process.env.OPENAI_API_KEY,
}),
vector: new Index({
token: process.env.UPSTASH_VECTOR_REST_TOKEN!,
url: process.env.UPSTASH_VECTOR_REST_URL!,
}),
vector,
redis,
ratelimit: new Ratelimit({
redis,
Expand All @@ -128,20 +137,32 @@ describe("RAG Chat with ratelimit", async () => {

afterAll(async () => {
await redis.flushdb();
await vector.reset();
});

test("should throw ratelimit error", async () => {
await ragChat.chat(
"What year was the construction of the Eiffel Tower completed, and what is its height?",
{ stream: false }
);
test(
"should throw ratelimit error",
async () => {
await ragChat.addContext(
"Paris, the capital of France, is renowned for its iconic landmark, the Eiffel Tower, which was completed in 1889 and stands at 330 meters tall.",
"text"
);
//eslint-disable-next-line @typescript-eslint/no-magic-numbers
await sleep(3000);

const throwable = async () => {
await ragChat.chat("You shall not pass", { stream: false });
};
await ragChat.chat(
"What year was the construction of the Eiffel Tower completed, and what is its height?",
{ stream: false, metadataKey: "text" }
);

expect(throwable).toThrowError(RatelimitUpstashError);
});
const throwable = async () => {
await ragChat.chat("You shall not pass", { stream: false });
};

expect(throwable).toThrowError(RatelimitUpstashError);
},
{ timeout: 10_000 }
);
});

describe("RAG Chat with instance names", async () => {
Expand Down
16 changes: 10 additions & 6 deletions src/rag-chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import type { StreamingTextResponse } from "ai";

import { HistoryService } from "./services/history";
import { RateLimitService } from "./services/ratelimit";
import type { AddContextPayload } from "./services/retrieval";
import { RetrievalService } from "./services/retrieval";

import { QA_TEMPLATE } from "./prompts";
Expand Down Expand Up @@ -50,8 +51,9 @@ export class RAGChat extends RAGChatBase {
//Sanitizes the given input by stripping all the newline chars then queries vector db with sanitized question.
const { question, facts } = await this.prepareChat({
question: input,
similarityThreshold: options.similarityThreshold,
topK: options.topK,
similarityThreshold: options_.similarityThreshold,
metadataKey: options_.metadataKey,
topK: options_.topK,
});

return options.stream
Expand All @@ -60,10 +62,12 @@ export class RAGChat extends RAGChatBase {
}

/** Context can be either plain text or embeddings */
async addContext(context: string | number[]) {
const retrievalService = await this.retrievalService.addEmbeddingOrTextToVectorDb(context);
if (retrievalService === "Success") return "OK";
return "NOT-OK";
async addContext(context: AddContextPayload[] | string, metadataKey = "text") {
const retrievalServiceStatus = await this.retrievalService.addEmbeddingOrTextToVectorDb(
context,
metadataKey
);
return retrievalServiceStatus === "Success" ? "OK" : "NOT-OK";
}

/**
Expand Down
9 changes: 3 additions & 6 deletions src/services/history.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@ import { Config } from "../config";
import { ClientFactory } from "../client-factory";
import type { RAGChatConfig } from "../types";

const DAY_IN_SECONDS = 86_400;
const TOP_6 = 5;

type GetHistory = { sessionId: string; length?: number };
type GetHistory = { sessionId: string; length?: number; sessionTTL?: number };
type HistoryInit = Omit<RAGChatConfig, "model" | "template" | "vector"> & {
email: string;
token: string;
Expand All @@ -19,10 +16,10 @@ export class HistoryService {
this.redis = redis;
}

getMessageHistory({ length = TOP_6, sessionId }: GetHistory) {
getMessageHistory({ length, sessionId, sessionTTL }: GetHistory) {
return new CustomUpstashRedisChatMessageHistory({
sessionId,
sessionTTL: DAY_IN_SECONDS,
sessionTTL,
topLevelChatHistoryLength: length,
client: this.redis,
});
Expand Down
57 changes: 43 additions & 14 deletions src/services/retrieval.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ import type { RAGChatConfig } from "../types";
import { ClientFactory } from "../client-factory";
import { Config } from "../config";
import { nanoid } from "nanoid";
import { DEFAULT_METADATA_KEY, DEFAULT_SIMILARITY_THRESHOLD, DEFAULT_TOP_K } from "../constants";

const SIMILARITY_THRESHOLD = 0.5;
const TOP_K = 5;
export type AddContextPayload = { input: string | number[]; id?: string; metadata?: string };

type RetrievalInit = Omit<RAGChatConfig, "model" | "template" | "vector"> & {
email: string;
Expand All @@ -15,8 +15,9 @@ type RetrievalInit = Omit<RAGChatConfig, "model" | "template" | "vector"> & {

export type RetrievePayload = {
question: string;
similarityThreshold?: number;
topK?: number;
similarityThreshold: number;
metadataKey: string;
topK: number;
};

export class RetrievalService {
Expand All @@ -27,36 +28,64 @@ export class RetrievalService {

async retrieveFromVectorDb({
question,
similarityThreshold = SIMILARITY_THRESHOLD,
topK = TOP_K,
similarityThreshold = DEFAULT_SIMILARITY_THRESHOLD,
metadataKey = DEFAULT_METADATA_KEY,
topK = DEFAULT_TOP_K,
}: RetrievePayload): Promise<string> {
const index = this.index;
const result = await index.query<{ value: string }>({
const result = await index.query<Record<string, string>>({
data: question,
topK,
includeMetadata: true,
includeVectors: false,
});

const allValuesUndefined = result.every((embedding) => embedding.metadata?.value === undefined);
const allValuesUndefined = result.every(
(embedding) => embedding.metadata?.[metadataKey] === undefined
);

if (allValuesUndefined) {
throw new TypeError(`
Query to the vector store returned ${result.length} vectors but none had "value" field in their metadata.
Text of your vectors should be in the "value" field in the metadata for the RAG Chat.
Query to the vector store returned ${result.length} vectors but none had "${metadataKey}" field in their metadata.
Text of your vectors should be in the "${metadataKey}" field in the metadata for the RAG Chat.
`);
}

const facts = result
.filter((x) => x.score >= similarityThreshold)
.map((embedding, index) => `- Context Item ${index}: ${embedding.metadata?.value ?? ""}`);
.map(
(embedding, index) => `- Context Item ${index}: ${embedding.metadata?.[metadataKey] ?? ""}`
);
return formatFacts(facts);
}

async addEmbeddingOrTextToVectorDb(input: string | number[]) {
async addEmbeddingOrTextToVectorDb(
input: AddContextPayload[] | string,
metadataKey = "text"
): Promise<string> {
if (typeof input === "string") {
return this.index.upsert({ data: input, id: nanoid(), metadata: { value: input } });
return this.index.upsert({
data: input,
id: nanoid(),
metadata: { [metadataKey]: input },
});
}
return this.index.upsert({ vector: input, id: nanoid(), metadata: { value: input } });
const items = input.map((context) => {
const isText = typeof context.input === "string";
const metadata = context.metadata
? { [metadataKey]: context.metadata }
: isText
? { [metadataKey]: context.input }
: {};

return {
[isText ? "data" : "vector"]: context.input,
id: context.id ?? nanoid(),
metadata,
};
});

return this.index.upsert(items as Parameters<Index["upsert"]>[number]);
}

public static async init(config: RetrievalInit) {
Expand Down
14 changes: 13 additions & 1 deletion src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@ export type ChatOptions = {
/** Length of the conversation history to include in your LLM query. Increasing this may lead to hallucinations. Retrieves the last N messages.
* @default 5
*/
includeHistory?: number;
historyLength?: number;

/** Configuration to retain chat history. After the specified time, the history will be automatically cleared.
* @default 86_400 // 1 day in seconds
*/
historyTTL?: number;

/** Configuration to adjust the accuracy of results.
* @default 0.5
Expand All @@ -33,6 +38,13 @@ export type ChatOptions = {
* @default 5
*/
topK?: number;

/** Key of metadata that we use to store additional content .
* @default "text"
* @example {text: "Capital of France is Paris"}
*
*/
metadataKey?: string;
};

export type PrepareChatResult = {
Expand Down
17 changes: 15 additions & 2 deletions src/utils.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
import type { BaseMessage } from "@langchain/core/messages";
import type { ChatOptions } from "./types";
import { DEFAULT_CHAT_SESSION_ID, DEFAULT_CHAT_RATELIMIT_SESSION_ID } from "./constants";
import {
DEFAULT_CHAT_SESSION_ID,
DEFAULT_CHAT_RATELIMIT_SESSION_ID,
DEFAULT_METADATA_KEY,
DEFAULT_SIMILARITY_THRESHOLD,
DEFAULT_TOP_K,
DEFAULT_HISTORY_LENGTH,
DEFAULT_HISTORY_TTL,
} from "./constants";

export const sanitizeQuestion = (question: string) => {
return question.trim().replaceAll("\n", " ");
Expand All @@ -24,8 +32,13 @@ export function appendDefaultsIfNeeded(options: ChatOptions) {
return {
...options,
sessionId: options.sessionId ?? DEFAULT_CHAT_SESSION_ID,
metadataKey: options.metadataKey ?? DEFAULT_METADATA_KEY,
ratelimitSessionId: options.ratelimitSessionId ?? DEFAULT_CHAT_RATELIMIT_SESSION_ID,
} satisfies ChatOptions;
similarityThreshold: options.similarityThreshold ?? DEFAULT_SIMILARITY_THRESHOLD,
topK: options.topK ?? DEFAULT_TOP_K,
historyLength: options.historyLength ?? DEFAULT_HISTORY_LENGTH,
historyTTL: options.historyLength ?? DEFAULT_HISTORY_TTL,
};
}

const DEFAULT_DELAY = 20_000;
Expand Down

0 comments on commit 26e4cf8

Please sign in to comment.