forked from langchain-ai/langchain
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Allow clearing cache and fix gptcache (langchain-ai#3493)
This PR * Adds `clear` method for `BaseCache` and implements it for various caches * Adds the default `init_func=None` and fixes gptcache integtest * Since right now integtest is not running in CI, I've verified the changes by running `docs/modules/models/llms/examples/llm_caching.ipynb` (until proper e2e integtest is done in CI)
- Loading branch information
Showing
5 changed files
with
96 additions
and
59 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,61 +1,48 @@ | ||
import os | ||
from typing import Any, Callable, Optional | ||
|
||
import pytest | ||
|
||
import langchain | ||
from langchain.cache import GPTCache | ||
from langchain.schema import Generation, LLMResult | ||
from langchain.schema import Generation | ||
from tests.unit_tests.llms.fake_llm import FakeLLM | ||
|
||
try: | ||
import gptcache # noqa: F401 | ||
from gptcache import Cache # noqa: F401 | ||
from gptcache.manager.factory import get_data_manager | ||
from gptcache.processor.pre import get_prompt | ||
|
||
gptcache_installed = True | ||
except ImportError: | ||
gptcache_installed = False | ||
|
||
|
||
@pytest.mark.skipif(not gptcache_installed, reason="gptcache not installed") | ||
def test_gptcache_map_caching() -> None: | ||
"""Test gptcache caching behavior.""" | ||
|
||
from gptcache import Cache | ||
from gptcache.manager.factory import get_data_manager | ||
from gptcache.processor.pre import get_prompt | ||
|
||
i = 0 | ||
file_prefix = "data_map" | ||
|
||
def init_gptcache_map(cache_obj: Cache) -> None: | ||
nonlocal i | ||
cache_path = f"{file_prefix}_{i}.txt" | ||
if os.path.isfile(cache_path): | ||
os.remove(cache_path) | ||
cache_obj.init( | ||
pre_embedding_func=get_prompt, | ||
data_manager=get_data_manager(data_path=cache_path), | ||
) | ||
i += 1 | ||
def init_gptcache_map(cache_obj: Cache) -> None: | ||
i = getattr(init_gptcache_map, "_i", 0) | ||
cache_path = f"data_map_{i}.txt" | ||
if os.path.isfile(cache_path): | ||
os.remove(cache_path) | ||
cache_obj.init( | ||
pre_embedding_func=get_prompt, | ||
data_manager=get_data_manager(data_path=cache_path), | ||
) | ||
init_gptcache_map._i = i + 1 # type: ignore | ||
|
||
langchain.llm_cache = GPTCache(init_gptcache_map) | ||
|
||
@pytest.mark.skipif(not gptcache_installed, reason="gptcache not installed") | ||
@pytest.mark.parametrize("init_func", [None, init_gptcache_map]) | ||
def test_gptcache_caching(init_func: Optional[Callable[[Any], None]]) -> None: | ||
"""Test gptcache default caching behavior.""" | ||
langchain.llm_cache = GPTCache(init_func) | ||
llm = FakeLLM() | ||
params = llm.dict() | ||
params["stop"] = None | ||
llm_string = str(sorted([(k, v) for k, v in params.items()])) | ||
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")]) | ||
output = llm.generate(["foo", "bar", "foo"]) | ||
expected_cache_output = [Generation(text="foo")] | ||
cache_output = langchain.llm_cache.lookup("bar", llm_string) | ||
assert cache_output == expected_cache_output | ||
langchain.llm_cache = None | ||
expected_generations = [ | ||
[Generation(text="fizz")], | ||
[Generation(text="foo")], | ||
[Generation(text="fizz")], | ||
] | ||
expected_output = LLMResult( | ||
generations=expected_generations, | ||
llm_output=None, | ||
) | ||
assert output == expected_output | ||
_ = llm.generate(["foo", "bar", "foo"]) | ||
cache_output = langchain.llm_cache.lookup("foo", llm_string) | ||
assert cache_output == [Generation(text="fizz")] | ||
|
||
langchain.llm_cache.clear() | ||
assert langchain.llm_cache.lookup("bar", llm_string) is None |