Skip to content

Commit

Permalink
Allow clearing cache and fix gptcache (langchain-ai#3493)
Browse files Browse the repository at this point in the history
This PR

* Adds `clear` method for `BaseCache` and implements it for various
caches
* Adds the default `init_func=None` and fixes gptcache integtest
* Since right now integtest is not running in CI, I've verified the
changes by running `docs/modules/models/llms/examples/llm_caching.ipynb`
(until proper e2e integtest is done in CI)
  • Loading branch information
ehsanmok authored Apr 27, 2023
1 parent 83e871f commit 4a246e2
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 59 deletions.
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -144,4 +144,8 @@ wandb/
/.ruff_cache/

*.pkl
*.bin
*.bin

# integration test artifacts
data_map*
\[('_type', 'fake'), ('stop', None)]
4 changes: 3 additions & 1 deletion docs/modules/models/llms/examples/llm_caching.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,9 @@
"id": "9df0dab8",
"metadata": {},
"outputs": [],
"source": []
"source": [
"!rm .langchain.db sqlite.db"
]
}
],
"metadata": {
Expand Down
79 changes: 61 additions & 18 deletions langchain/cache.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Beta Feature: base interface for cache."""
import json
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, cast

from sqlalchemy import Column, Integer, String, create_engine, select
from sqlalchemy.engine.base import Engine
Expand All @@ -28,6 +28,10 @@ def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""

@abstractmethod
def clear(self, **kwargs: Any) -> None:
"""Clear cache that can take additional keyword arguments."""


class InMemoryCache(BaseCache):
"""Cache that stores things in memory."""
Expand All @@ -44,6 +48,10 @@ def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> N
"""Update cache based on prompt and llm_string."""
self._cache[(prompt, llm_string)] = return_val

def clear(self, **kwargs: Any) -> None:
"""Clear cache."""
self._cache = {}


Base = declarative_base()

Expand All @@ -61,7 +69,7 @@ class FullLLMCache(Base): # type: ignore
class SQLAlchemyCache(BaseCache):
"""Cache that uses SQAlchemy as a backend."""

def __init__(self, engine: Engine, cache_schema: Any = FullLLMCache):
def __init__(self, engine: Engine, cache_schema: Type[FullLLMCache] = FullLLMCache):
"""Initialize by creating all tables."""
self.engine = engine
self.cache_schema = cache_schema
Expand All @@ -76,20 +84,26 @@ def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
.order_by(self.cache_schema.idx)
)
with Session(self.engine) as session:
generations = [Generation(text=row[0]) for row in session.execute(stmt)]
if len(generations) > 0:
return generations
rows = session.execute(stmt).fetchall()
if rows:
return [Generation(text=row[0]) for row in rows]
return None

def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Look up based on prompt and llm_string."""
for i, generation in enumerate(return_val):
item = self.cache_schema(
prompt=prompt, llm=llm_string, response=generation.text, idx=i
)
with Session(self.engine) as session, session.begin():
"""Update based on prompt and llm_string."""
items = [
self.cache_schema(prompt=prompt, llm=llm_string, response=gen.text, idx=i)
for i, gen in enumerate(return_val)
]
with Session(self.engine) as session, session.begin():
for item in items:
session.merge(item)

def clear(self, **kwargs: Any) -> None:
"""Clear cache."""
with Session(self.engine) as session:
session.execute(self.cache_schema.delete())


class SQLiteCache(SQLAlchemyCache):
"""Cache that uses SQLite as a backend."""
Expand Down Expand Up @@ -139,19 +153,26 @@ def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> N
for i, generation in enumerate(return_val):
self.redis.set(self._key(prompt, llm_string, i), generation.text)

def clear(self, **kwargs: Any) -> None:
"""Clear cache. If `asynchronous` is True, flush asynchronously."""
asynchronous = kwargs.get("asynchronous", False)
self.redis.flushdb(asynchronous=asynchronous, **kwargs)


class GPTCache(BaseCache):
"""Cache that uses GPTCache as a backend."""

def __init__(self, init_func: Callable[[Any], None]):
"""Initialize by passing in the `init` GPTCache func
def __init__(self, init_func: Optional[Callable[[Any], None]] = None):
"""Initialize by passing in init function (default: `None`).
Args:
init_func (Callable[[Any], None]): init `GPTCache` function
init_func (Optional[Callable[[Any], None]]): init `GPTCache` function
(default: `None`)
Example:
.. code-block:: python
# Initialize GPTCache with a custom init function
import gptcache
from gptcache.processor.pre import get_prompt
from gptcache.manager.factory import get_data_manager
Expand Down Expand Up @@ -180,7 +201,8 @@ def init_gptcache_map(cache_obj: gptcache.Cache):
"Could not import gptcache python package. "
"Please install it with `pip install gptcache`."
)
self.init_gptcache_func: Callable[[Any], None] = init_func

self.init_gptcache_func: Optional[Callable[[Any], None]] = init_func
self.gptcache_dict: Dict[str, Any] = {}

@staticmethod
Expand All @@ -205,11 +227,19 @@ def _get_gptcache(self, llm_string: str) -> Any:
When the corresponding llm model cache does not exist, it will be created."""
from gptcache import Cache
from gptcache.manager.factory import get_data_manager
from gptcache.processor.pre import get_prompt

_gptcache = self.gptcache_dict.get(llm_string, None)
if _gptcache is None:
_gptcache = Cache()
self.init_gptcache_func(_gptcache)
if self.init_gptcache_func is not None:
self.init_gptcache_func(_gptcache)
else:
_gptcache.init(
pre_embedding_func=get_prompt,
data_manager=get_data_manager(data_path=llm_string),
)
self.gptcache_dict[llm_string] = _gptcache
return _gptcache

Expand All @@ -220,7 +250,7 @@ def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""
from gptcache.adapter.adapter import adapt

_gptcache = self.gptcache_dict.get(llm_string)
_gptcache = self.gptcache_dict.get(llm_string, None)
if _gptcache is None:
return None
res = adapt(
Expand All @@ -234,7 +264,10 @@ def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:

@staticmethod
def _update_cache_callback(
llm_data: RETURN_VAL_TYPE, update_cache_func: Callable[[Any], None]
llm_data: RETURN_VAL_TYPE,
update_cache_func: Callable[[Any], None],
*args: Any,
**kwargs: Any,
) -> None:
"""Save the `llm_data` to cache storage"""
handled_data = json.dumps([generation.dict() for generation in llm_data])
Expand All @@ -260,3 +293,13 @@ def llm_handle(*_: Any, **__: Any) -> RETURN_VAL_TYPE:
cache_skip=True,
prompt=prompt,
)

def clear(self, **kwargs: Any) -> None:
"""Clear cache."""
from gptcache import Cache

for gptcache_instance in self.gptcache_dict.values():
gptcache_instance = cast(Cache, gptcache_instance)
gptcache_instance.flush()

self.gptcache_dict.clear()
1 change: 1 addition & 0 deletions langchain/memory/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,4 +235,5 @@ def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
def clear(self) -> None:
"""Clear memory contents."""
self.chat_memory.clear()
self.entity_cache.clear()
self.entity_store.clear()
65 changes: 26 additions & 39 deletions tests/integration_tests/cache/test_gptcache.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,48 @@
import os
from typing import Any, Callable, Optional

import pytest

import langchain
from langchain.cache import GPTCache
from langchain.schema import Generation, LLMResult
from langchain.schema import Generation
from tests.unit_tests.llms.fake_llm import FakeLLM

try:
import gptcache # noqa: F401
from gptcache import Cache # noqa: F401
from gptcache.manager.factory import get_data_manager
from gptcache.processor.pre import get_prompt

gptcache_installed = True
except ImportError:
gptcache_installed = False


@pytest.mark.skipif(not gptcache_installed, reason="gptcache not installed")
def test_gptcache_map_caching() -> None:
"""Test gptcache caching behavior."""

from gptcache import Cache
from gptcache.manager.factory import get_data_manager
from gptcache.processor.pre import get_prompt

i = 0
file_prefix = "data_map"

def init_gptcache_map(cache_obj: Cache) -> None:
nonlocal i
cache_path = f"{file_prefix}_{i}.txt"
if os.path.isfile(cache_path):
os.remove(cache_path)
cache_obj.init(
pre_embedding_func=get_prompt,
data_manager=get_data_manager(data_path=cache_path),
)
i += 1
def init_gptcache_map(cache_obj: Cache) -> None:
i = getattr(init_gptcache_map, "_i", 0)
cache_path = f"data_map_{i}.txt"
if os.path.isfile(cache_path):
os.remove(cache_path)
cache_obj.init(
pre_embedding_func=get_prompt,
data_manager=get_data_manager(data_path=cache_path),
)
init_gptcache_map._i = i + 1 # type: ignore

langchain.llm_cache = GPTCache(init_gptcache_map)

@pytest.mark.skipif(not gptcache_installed, reason="gptcache not installed")
@pytest.mark.parametrize("init_func", [None, init_gptcache_map])
def test_gptcache_caching(init_func: Optional[Callable[[Any], None]]) -> None:
"""Test gptcache default caching behavior."""
langchain.llm_cache = GPTCache(init_func)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["foo", "bar", "foo"])
expected_cache_output = [Generation(text="foo")]
cache_output = langchain.llm_cache.lookup("bar", llm_string)
assert cache_output == expected_cache_output
langchain.llm_cache = None
expected_generations = [
[Generation(text="fizz")],
[Generation(text="foo")],
[Generation(text="fizz")],
]
expected_output = LLMResult(
generations=expected_generations,
llm_output=None,
)
assert output == expected_output
_ = llm.generate(["foo", "bar", "foo"])
cache_output = langchain.llm_cache.lookup("foo", llm_string)
assert cache_output == [Generation(text="fizz")]

langchain.llm_cache.clear()
assert langchain.llm_cache.lookup("bar", llm_string) is None

0 comments on commit 4a246e2

Please sign in to comment.