Skip to content

Commit

Permalink
feat: allow ratelimitting
Browse files Browse the repository at this point in the history
  • Loading branch information
ogzhanolguncu committed May 3, 2024
1 parent 375e3ae commit 8532294
Show file tree
Hide file tree
Showing 19 changed files with 140 additions and 50 deletions.
Binary file modified bun.lockb
Binary file not shown.
1 change: 1 addition & 0 deletions index.ts
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
export * from "./src/rag-chat";
export * from "./src/services/history";
export * from "./src/error";
4 changes: 2 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@upstash/rag-chat",
"version": "0.0.11-alpha",
"version": "0.0.14-alpha",
"main": "./dist/index.js",
"module": "./dist/index.mjs",
"types": "./dist/index.d.ts",
Expand Down Expand Up @@ -51,7 +51,7 @@
"@langchain/community": "^0.0.50",
"@langchain/core": "^0.1.58",
"@langchain/openai": "^0.0.28",
"@upstash/sdk": "0.0.25-alpha",
"@upstash/sdk": "0.0.26-alpha",
"ai": "^3.0.35"
}
}
3 changes: 2 additions & 1 deletion src/clients/redis/index.test.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
/* eslint-disable @typescript-eslint/no-non-null-assertion */
import { Upstash } from "@upstash/sdk";
import { describe, expect, test } from "bun:test";
import { DEFAULT_REDIS_CONFIG, DEFAULT_REDIS_DB_NAME, RedisClient } from ".";
import { DEFAULT_REDIS_CONFIG, RedisClient } from ".";
import { DEFAULT_REDIS_DB_NAME } from "../../constants";

const upstashSDK = new Upstash({
email: process.env.UPSTASH_EMAIL!,
Expand Down
3 changes: 1 addition & 2 deletions src/clients/redis/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ import type { CreateCommandPayload, Upstash } from "@upstash/sdk";

import { Redis } from "@upstash/sdk";
import type { PreferredRegions } from "../../types";

export const DEFAULT_REDIS_DB_NAME = "upstash-rag-chat-redis";
import { DEFAULT_REDIS_DB_NAME } from "../../constants";

export const DEFAULT_REDIS_CONFIG: CreateCommandPayload = {
name: DEFAULT_REDIS_DB_NAME,
Expand Down
3 changes: 2 additions & 1 deletion src/clients/vector/index.test.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
/* eslint-disable @typescript-eslint/no-non-null-assertion */
import { Upstash } from "@upstash/sdk";
import { describe, expect, test } from "bun:test";
import { DEFAULT_VECTOR_DB_NAME, VectorClient, DEFAULT_VECTOR_CONFIG } from ".";
import { VectorClient, DEFAULT_VECTOR_CONFIG } from ".";
import { DEFAULT_VECTOR_DB_NAME } from "../../constants";

const upstashSDK = new Upstash({
email: process.env.UPSTASH_EMAIL!,
Expand Down
3 changes: 1 addition & 2 deletions src/clients/vector/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ import type { CreateIndexPayload, Upstash } from "@upstash/sdk";
import { Index } from "@upstash/sdk";

import type { PreferredRegions } from "../../types";

export const DEFAULT_VECTOR_DB_NAME = "upstash-rag-chat-vector";
import { DEFAULT_VECTOR_DB_NAME } from "../../constants";

export const DEFAULT_VECTOR_CONFIG: CreateIndexPayload = {
name: DEFAULT_VECTOR_DB_NAME,
Expand Down
3 changes: 2 additions & 1 deletion src/config.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ import { PromptTemplate } from "@langchain/core/prompts";
import { ChatOpenAI } from "@langchain/openai";
import { Index, Redis } from "@upstash/sdk";
import { expect, test } from "bun:test";
import { Config, DEFAULT_REDIS_DB_NAME, DEFAULT_VECTOR_DB_NAME } from "./config";
import { Config } from "./config";
import { DEFAULT_VECTOR_DB_NAME, DEFAULT_REDIS_DB_NAME } from "./constants";

const mockRedis = new Redis({
token: "hey",
Expand Down
24 changes: 7 additions & 17 deletions src/config.ts
Original file line number Diff line number Diff line change
@@ -1,32 +1,20 @@
import type { BaseLanguageModelInterface } from "@langchain/core/language_models/base";
import type { PromptTemplate } from "@langchain/core/prompts";
import type { Ratelimit } from "@upstash/sdk";
import { Redis } from "@upstash/sdk";
import { Index } from "@upstash/sdk";
import type { PreferredRegions } from "./types";

type RAGChatConfigCommon = {
model?: BaseLanguageModelInterface;
template?: PromptTemplate;
region?: PreferredRegions;
};

const PREFERRED_REGION: PreferredRegions = "us-east-1";
export const DEFAULT_VECTOR_DB_NAME = "upstash-rag-chat-vector";
export const DEFAULT_REDIS_DB_NAME = "upstash-rag-chat-redis";

export type RAGChatConfig = {
vector?: string | Index;
redis?: string | Redis;
} & RAGChatConfigCommon;
import type { PreferredRegions, RAGChatConfig } from "./types";
import { DEFAULT_REDIS_DB_NAME, DEFAULT_VECTOR_DB_NAME, PREFERRED_REGION } from "./constants";

export class Config {
public readonly token: string;
public readonly email: string;

public readonly region: PreferredRegions;
public readonly vector?: string | Index;
public readonly redis?: string | Redis;
public readonly ratelimit?: Ratelimit;

public readonly region: PreferredRegions;
public readonly model?: BaseLanguageModelInterface;
public readonly template?: PromptTemplate;

Expand All @@ -45,6 +33,8 @@ export class Config {
? config.redis
: DEFAULT_REDIS_DB_NAME;

this.ratelimit = config?.ratelimit;

this.model = config?.model;
this.template = config?.template;
}
Expand Down
10 changes: 10 additions & 0 deletions src/constants.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import type { PreferredRegions } from "./types";

export const DEFAULT_CHAT_SESSION_ID = "upstash-rag-chat-session";
export const DEFAULT_CHAT_RATELIMIT_SESSION_ID = "upstash-rag-chat-ratelimit-session";

export const RATELIMIT_ERROR_MESSAGE = "ERR:USER_RATELIMITED";

export const DEFAULT_VECTOR_DB_NAME = "upstash-rag-chat-vector";
export const DEFAULT_REDIS_DB_NAME = "upstash-rag-chat-redis";
export const PREFERRED_REGION: PreferredRegions = "us-east-1";
1 change: 1 addition & 0 deletions src/error/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
export * from "./ratelimit";
2 changes: 1 addition & 1 deletion src/error/internal.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
export class InternalUpstashError extends Error {
constructor(message: string) {
super(message);
this.name = "InternalUpstashError";
this.name = "InternalError";
}
}
2 changes: 1 addition & 1 deletion src/error/model.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
export class UpstashModelError extends Error {
constructor(message: string) {
super(message);
this.name = "UpstashModelError";
this.name = "ModelError";
}
}
14 changes: 14 additions & 0 deletions src/error/ratelimit.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import type { RATELIMIT_ERROR_MESSAGE } from "../constants";

type RatelimitResponse = {
error: typeof RATELIMIT_ERROR_MESSAGE;
resetTime?: number;
};

export class RatelimitUpstashError extends Error {
constructor(message: string, cause: RatelimitResponse) {
super(message);
this.name = "RatelimitError";
this.cause = cause;
}
}
54 changes: 33 additions & 21 deletions src/rag-chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import type { BaseMessage } from "@langchain/core/messages";
import { RunnableSequence, RunnableWithMessageHistory } from "@langchain/core/runnables";
import { LangChainStream, StreamingTextResponse } from "ai";

import { formatChatHistory, sanitizeQuestion } from "./utils";
import { appendDefaultsIfNeeded, formatChatHistory, sanitizeQuestion } from "./utils";

import type { BaseLanguageModelInterface } from "@langchain/core/language_models/base";
import type { PromptTemplate } from "@langchain/core/prompts";
Expand All @@ -13,35 +13,29 @@ import { HistoryService } from "./services/history";
import { RetrievalService } from "./services/retrieval";
import { QA_TEMPLATE } from "./prompts";
import { UpstashModelError } from "./error/model";
import { RateLimitService } from "./services/ratelimit";
import type { ChatOptions, PrepareChatResult, RAGChatConfig } from "./types";
import { RatelimitUpstashError } from "./error/ratelimit";

type CustomInputValues = { chat_history?: BaseMessage[]; question: string; context: string };

type ChatOptions = {
stream: boolean;
sessionId: string;
includeHistory?: number;
similarityThreshold?: number;
};

type PrepareChatResult = {
question: string;
facts: string;
};

export class RAGChat {
private retrievalService: RetrievalService;
private historyService: HistoryService;
private ratelimitService: RateLimitService;

private model: BaseLanguageModelInterface;
private template: PromptTemplate;

constructor(
retrievalService: RetrievalService,
historyService: HistoryService,
ratelimitService: RateLimitService,
config: { model: BaseLanguageModelInterface; template: PromptTemplate }
) {
this.retrievalService = retrievalService;
this.historyService = historyService;
this.ratelimitService = ratelimitService;

this.model = config.model;
this.template = config.template;
Expand All @@ -56,26 +50,41 @@ export class RAGChat {
return { question, facts };
}

async chat(input: string, options: ChatOptions) {
async chat(
input: string,
options: ChatOptions
): Promise<StreamingTextResponse | Record<string, unknown>> {
const options_ = appendDefaultsIfNeeded(options);
const { success, resetTime } = await this.ratelimitService.checkLimit(
options_.ratelimitSessionId
);

if (!success) {
throw new RatelimitUpstashError("Couldn't process chat due to ratelimit.", {
error: "ERR:USER_RATELIMITED",
resetTime: resetTime,
});
}

const { question, facts } = await this.prepareChat(input, options.similarityThreshold);

return options.stream
? this.streamingChainCall(question, facts, options)
: this.chainCall(options, question, facts);
? this.streamingChainCall(options_, question, facts)
: this.chainCall(options_, question, facts);
}

private streamingChainCall = (
chatOptions: ChatOptions,
question: string,
facts: string,
chatOptions: ChatOptions
facts: string
): StreamingTextResponse => {
const { stream, handlers } = LangChainStream();
void this.chainCall(chatOptions, question, facts, [handlers]);
return new StreamingTextResponse(stream, {});
};

private chainCall(
chatOptions: { sessionId: string; includeHistory?: number },
chatOptions: ChatOptions,
question: string,
facts: string,
handlers?: Callbacks
Expand Down Expand Up @@ -113,7 +122,9 @@ export class RAGChat {
);
}

static async initialize(config: Config): Promise<RAGChat> {
static async initialize(
config: RAGChatConfig & { email: string; token: string }
): Promise<RAGChat> {
const clientFactory = new ClientFactory(
new Config(config.email, config.token, {
redis: config.redis,
Expand All @@ -125,12 +136,13 @@ export class RAGChat {

const historyService = new HistoryService(redis);
const retrievalService = new RetrievalService(index);
const ratelimitService = new RateLimitService(config.ratelimit);

if (!config.model) {
throw new UpstashModelError("Model can not be undefined!");
}

return new RAGChat(retrievalService, historyService, {
return new RAGChat(retrievalService, historyService, ratelimitService, {
model: config.model,
template: config.template ?? QA_TEMPLATE,
});
Expand Down
2 changes: 1 addition & 1 deletion src/services/history.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import type { Redis } from "@upstash/sdk";
import { CustomUpstashRedisChatMessageHistory } from "./redis-custom-history";
import type { RAGChatConfig } from "../config";
import { Config } from "../config";
import { ClientFactory } from "../client-factory";
import type { RAGChatConfig } from "../types";

const DAY_IN_SECONDS = 86_400;
const TOP_6 = 5;
Expand Down
22 changes: 22 additions & 0 deletions src/services/ratelimit.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import type { Ratelimit } from "@upstash/sdk";

export class RateLimitService {
private ratelimit?: Ratelimit;

constructor(ratelimit?: Ratelimit) {
this.ratelimit = ratelimit;
}

async checkLimit(sessionId: string): Promise<{ success: boolean; resetTime?: number }> {
if (!this.ratelimit) {
// If no ratelimit object is provided, always allow the operation.
return { success: true };
}

const result = await this.ratelimit.limit(sessionId);
if (!result.success) {
return { success: false, resetTime: result.reset };
}
return { success: true };
}
}
29 changes: 29 additions & 0 deletions src/types.ts
Original file line number Diff line number Diff line change
@@ -1 +1,30 @@
import type { BaseLanguageModelInterface } from "@langchain/core/language_models/base";
import type { PromptTemplate } from "@langchain/core/prompts";
import type { Index, Ratelimit, Redis } from "@upstash/sdk";

export type PreferredRegions = "eu-west-1" | "us-east-1";

export type ChatOptions = {
stream: boolean;
sessionId?: string;
includeHistory?: number;
similarityThreshold?: number;
ratelimitSessionId?: string;
};

export type PrepareChatResult = {
question: string;
facts: string;
};

type RAGChatConfigCommon = {
model?: BaseLanguageModelInterface;
template?: PromptTemplate;
region?: PreferredRegions;
ratelimit?: Ratelimit;
};

export type RAGChatConfig = {
vector?: string | Index;
redis?: string | Redis;
} & RAGChatConfigCommon;
10 changes: 10 additions & 0 deletions src/utils.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import type { BaseMessage } from "@langchain/core/messages";
import type { ChatOptions } from "./types";
import { DEFAULT_CHAT_SESSION_ID, DEFAULT_CHAT_RATELIMIT_SESSION_ID } from "./constants";

export const sanitizeQuestion = (question: string) => {
return question.trim().replaceAll("\n", " ");
Expand All @@ -17,3 +19,11 @@ export const formatChatHistory = (chatHistory: BaseMessage[]) => {

return formatFacts(formattedDialogueTurns);
};

export function appendDefaultsIfNeeded(options: ChatOptions) {
return {
...options,
sessionId: options.sessionId ?? DEFAULT_CHAT_SESSION_ID,
ratelimitSessionId: options.ratelimitSessionId ?? DEFAULT_CHAT_RATELIMIT_SESSION_ID,
} satisfies ChatOptions;
}

0 comments on commit 8532294

Please sign in to comment.