Skip to content

Commit

Permalink
Add context filter to VectorPayload (#94)
Browse files Browse the repository at this point in the history
* feat(database): add context filter to VectorPayload

* fix: add context length check to context filtering test

* fix: pass contextFilter to retrieve

---------

Co-authored-by: Ronaldo Lima <[email protected]>
  • Loading branch information
CahidArda and ronal2do authored Nov 19, 2024
1 parent af2e30c commit 0a78037
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/context-service/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ export class ContextService {
similarityThreshold: optionsWithDefault.similarityThreshold,
topK: optionsWithDefault.topK,
namespace: optionsWithDefault.namespace,
contextFilter: optionsWithDefault.contextFilter,
});

// Log the result, which will be captured by the outer traceable
Expand Down
3 changes: 3 additions & 0 deletions src/database.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ export type VectorPayload = {
similarityThreshold?: number;
topK?: number;
namespace?: string;
contextFilter?: string;
};

export type ResetOptions = {
Expand Down Expand Up @@ -106,6 +107,7 @@ export class Database {
similarityThreshold = DEFAULT_SIMILARITY_THRESHOLD,
topK = DEFAULT_TOP_K,
namespace,
contextFilter,
}: VectorPayload): Promise<{ data: string; id: string; metadata: TMetadata }[]> {
const index = this.index;
const result = await index.query<Record<string, string>>(
Expand All @@ -114,6 +116,7 @@ export class Database {
topK,
includeData: true,
includeMetadata: true,
...(typeof contextFilter === "string" && { filter: contextFilter }),
},
{ namespace }
);
Expand Down
54 changes: 54 additions & 0 deletions src/rag-chat.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1014,3 +1014,57 @@ describe("RAG Chat with non-embedding db", () => {
expect(called).toBeTrue();
});
});

describe("RAGChat - context filtering", () => {
const namespace = "context-filtering";
const vector = new Index({
token: process.env.UPSTASH_VECTOR_REST_TOKEN!,
url: process.env.UPSTASH_VECTOR_REST_URL!,
});

const ragChat = new RAGChat({
vector,
namespace,
streaming: true,
model: upstash("meta-llama/Meta-Llama-3-8B-Instruct"),
});

afterAll(async () => {
await vector.reset({ namespace });
await vector.deleteNamespace(namespace);
});

test(
"should return metadata",
async () => {
await ragChat.context.add({
type: "text",
data: "Tokyo is the Capital of Japan.",
options: { namespace, metadata: { unit: "Samurai" } },
});
await ragChat.context.add({
type: "text",
data: "Shakuhachi is a traditional wind instrument",
options: { namespace, metadata: { unit: "Shakuhachi" } },
});
await awaitUntilIndexed(vector);

const result = await ragChat.chat<{ unit: string }>("Where is the capital of Japan?", {
namespace,
topK: 5,
contextFilter: "unit = 'Samurai'",
onContextFetched(context) {
expect(context.length).toBe(1);
return context;
},
});

expect(result.metadata).toEqual([
{
unit: "Samurai",
},
]);
},
{ timeout: 30_000 }
);
});
1 change: 1 addition & 0 deletions src/rag-chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ export class RAGChat {
promptFn: isRagDisabledAndPromptFunctionMissing
? DEFAULT_PROMPT_WITHOUT_RAG
: (options?.promptFn ?? this.config.prompt),
contextFilter: options?.contextFilter ?? undefined,
};
}
}
7 changes: 7 additions & 0 deletions src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,13 @@ export type ChatOptions = {
* Must be provided if the Vector Database doesn't have default embeddings.
*/
embedding?: number[];

/**
* Allows filtering metadata from the vector database.
* @example "population >= 1000000 AND geography.continent = 'Asia'"
* https://upstash.com/docs/vector/features/filtering#metadata-filtering
*/
contextFilter?: string;
} & CommonChatAndRAGOptions;

export type PrepareChatResult = { data: string; id: string; metadata: unknown }[];
Expand Down

0 comments on commit 0a78037

Please sign in to comment.