Skip to content

Commit

Permalink
Harrison/fix and test caching (langchain-ai#538)
Browse files Browse the repository at this point in the history
  • Loading branch information
hwchase17 authored Jan 5, 2023
1 parent 73f7ebd commit 1631981
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
2 changes: 1 addition & 1 deletion langchain/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def generate(
new_results = self._generate(missing_prompts, stop=stop)
self.callback_manager.on_llm_end(new_results)
for i, result in enumerate(new_results.generations):
existing_prompts[i] = result
existing_prompts[missing_prompt_idxs[i]] = result
prompt = prompts[i]
langchain.llm_cache.update(prompt, llm_string, result)
generations = [existing_prompts[i] for i in range(len(prompts))]
Expand Down
27 changes: 27 additions & 0 deletions tests/unit_tests/llms/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Test base LLM functionality."""
import langchain
from langchain.cache import InMemoryCache
from langchain.schema import Generation, LLMResult
from tests.unit_tests.llms.fake_llm import FakeLLM


def test_caching() -> None:
"""Test caching behavior."""
langchain.llm_cache = InMemoryCache()
llm = FakeLLM()
params = llm._llm_dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["foo", "bar", "foo"])
langchain.llm_cache = None
expected_generations = [
[Generation(text="fizz")],
[Generation(text="foo")],
[Generation(text="fizz")],
]
expected_output = LLMResult(
expected_generations,
llm_output=None,
)
assert output == expected_output

0 comments on commit 1631981

Please sign in to comment.