Skip to content

Commit

Permalink
Add Redis cache implementation (langchain-ai#397)
Browse files Browse the repository at this point in the history
I'm using a hash function for the key just to make sure its length
doesn't get out of hand, otherwise the implementation is quite similar.
  • Loading branch information
sjwhitmore authored Dec 22, 2022
1 parent ff03242 commit 6bc8ae6
Show file tree
Hide file tree
Showing 4 changed files with 317 additions and 190 deletions.
106 changes: 96 additions & 10 deletions docs/examples/prompts/llm_caching.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 30 ms, sys: 10.8 ms, total: 40.8 ms\n",
"Wall time: 983 ms\n"
"CPU times: user 30.6 ms, sys: 9.95 ms, total: 40.5 ms\n",
"Wall time: 730 ms\n"
]
},
{
Expand Down Expand Up @@ -91,8 +91,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 65 µs, sys: 1 µs, total: 66 µs\n",
"Wall time: 70.1 µs\n"
"CPU times: user 71 µs, sys: 3 µs, total: 74 µs\n",
"Wall time: 78.9 µs\n"
]
},
{
Expand Down Expand Up @@ -142,8 +142,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 6.76 ms, sys: 2.6 ms, total: 9.36 ms\n",
"Wall time: 7.86 ms\n"
"CPU times: user 5.27 ms, sys: 2.36 ms, total: 7.63 ms\n",
"Wall time: 6.68 ms\n"
]
},
{
Expand All @@ -167,14 +167,16 @@
"cell_type": "code",
"execution_count": 8,
"id": "5bf2f6fd",
"metadata": {},
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 2.52 ms, sys: 1.47 ms, total: 3.99 ms\n",
"Wall time: 2.98 ms\n"
"CPU times: user 3.05 ms, sys: 1.1 ms, total: 4.16 ms\n",
"Wall time: 5.58 ms\n"
]
},
{
Expand All @@ -194,6 +196,90 @@
"llm(\"Tell me a joke\")"
]
},
{
"cell_type": "markdown",
"id": "278ad7ae",
"metadata": {},
"source": [
"### Redis Cache"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "39f6eb0b",
"metadata": {},
"outputs": [],
"source": [
"# We can do the same thing with a Redis cache\n",
"# (make sure your local Redis instance is running first before running this example)\n",
"from redis import Redis\n",
"from langchain.cache import RedisCache\n",
"langchain.llm_cache = RedisCache(redis_=Redis())"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "28920749",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 6.75 ms, sys: 3.14 ms, total: 9.89 ms\n",
"Wall time: 716 ms\n"
]
},
{
"data": {
"text/plain": [
"'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side!'"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"# The first time, it is not yet in cache, so it should take longer\n",
"llm(\"Tell me a joke\")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "94bf9415",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 1.66 ms, sys: 1.92 ms, total: 3.57 ms\n",
"Wall time: 7.56 ms\n"
]
},
{
"data": {
"text/plain": [
"'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side!'"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"# The second time it is, so it goes faster\n",
"llm(\"Tell me a joke\")"
]
},
{
"cell_type": "markdown",
"id": "934943dc",
Expand Down Expand Up @@ -459,7 +545,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
"version": "3.10.4"
}
},
"nbformat": 4,
Expand Down
42 changes: 41 additions & 1 deletion langchain/cache.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Beta Feature: base interface for cache."""
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

from sqlalchemy import Column, Integer, String, create_engine, select
from sqlalchemy.engine.base import Engine
Expand Down Expand Up @@ -94,3 +94,43 @@ def __init__(self, database_path: str = ".langchain.db"):
"""Initialize by creating the engine and all tables."""
engine = create_engine(f"sqlite:///{database_path}")
super().__init__(engine)


class RedisCache(BaseCache):
"""Cache that uses Redis as a backend."""

def __init__(self, redis_: Any):
"""Initialize by passing in Redis instance."""
try:
from redis import Redis
except ImportError:
raise ValueError(
"Could not import redis python package. "
"Please install it with `pip install redis`."
)
if not isinstance(redis_, Redis):
raise ValueError("Please pass in Redis object.")
self.redis = redis_

def _key(self, prompt: str, llm_string: str, idx: int) -> str:
"""Compute key from prompt, llm_string, and idx."""
return str(hash(prompt + llm_string)) + "_" + str(idx)

def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
idx = 0
generations = []
while self.redis.get(self._key(prompt, llm_string, idx)):
result = self.redis.get(self._key(prompt, llm_string, idx))
if not result:
break
elif isinstance(result, bytes):
result = result.decode()
generations.append(Generation(text=result))
idx += 1
return generations if generations else None

def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
for i, generation in enumerate(return_val):
self.redis.set(self._key(prompt, llm_string, i), generation.text)
Loading

0 comments on commit 6bc8ae6

Please sign in to comment.