Skip to content

Commit

Permalink
cache replicate version (langchain-ai#10517)
Browse files Browse the repository at this point in the history
In subsequent pr will update _call to use replicate.run directly when
not streaming, so version object isn't needed at all

cc @cbh123 @tjaffri
  • Loading branch information
baskaryan authored Sep 14, 2023
1 parent 49b65a1 commit ccf71e2
Showing 1 changed file with 22 additions and 10 deletions.
32 changes: 22 additions & 10 deletions libs/langchain/langchain/llms/replicate.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import logging
from typing import Any, Dict, List, Optional

Expand All @@ -23,17 +25,26 @@ class Replicate(LLM):
.. code-block:: python
from langchain.llms import Replicate
replicate = Replicate(model="stability-ai/stable-diffusion: \
27b93a2413e7f36cd83da926f365628\
0b2931564ff050bf9575f1fdf9bcd7478",
input={"image_dimensions": "512x512"})
replicate = Replicate(
model=(
"stability-ai/stable-diffusion: "
"27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478",
),
input={"image_dimensions": "512x512"}
)
"""

model: str
input: Dict[str, Any] = Field(default_factory=dict)
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
replicate_api_token: Optional[str] = None
prompt_key: Optional[str] = None
version_obj: Any = Field(default=None, exclude=True)
"""Optionally pass in the model version object during initialization to avoid
having to make an extra API call to retrieve it during streaming. NOTE: not
serializable, is excluded from serialization.
"""

streaming: bool = Field(default=False)
"""Whether to stream the results."""
Expand Down Expand Up @@ -111,14 +122,15 @@ def _call(
)

# get the model and version
model_str, version_str = self.model.split(":")
model = replicate_python.models.get(model_str)
version = model.versions.get(version_str)
if self.version_obj is None:
model_str, version_str = self.model.split(":")
model = replicate_python.models.get(model_str)
self.version_obj = model.versions.get(version_str)

if not self.prompt_key:
if self.prompt_key is None:
# sort through the openapi schema to get the name of the first input
input_properties = sorted(
version.openapi_schema["components"]["schemas"]["Input"][
self.version_obj.openapi_schema["components"]["schemas"]["Input"][
"properties"
].items(),
key=lambda item: item[1].get("x-order", 0),
Expand All @@ -129,7 +141,7 @@ def _call(
inputs: Dict = {self.prompt_key: prompt, **self.input}

prediction = replicate_python.predictions.create(
version=version, input={**inputs, **kwargs}
version=self.version_obj, input={**inputs, **kwargs}
)
current_completion: str = ""
stop_condition_reached = False
Expand Down

0 comments on commit ccf71e2

Please sign in to comment.