diff --git a/libs/langchain/langchain/llms/replicate.py b/libs/langchain/langchain/llms/replicate.py index 4ce4621d1658c..94d8c34fea999 100644 --- a/libs/langchain/langchain/llms/replicate.py +++ b/libs/langchain/langchain/llms/replicate.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging from typing import Any, Dict, List, Optional @@ -23,10 +25,14 @@ class Replicate(LLM): .. code-block:: python from langchain.llms import Replicate - replicate = Replicate(model="stability-ai/stable-diffusion: \ - 27b93a2413e7f36cd83da926f365628\ - 0b2931564ff050bf9575f1fdf9bcd7478", - input={"image_dimensions": "512x512"}) + + replicate = Replicate( + model=( + "stability-ai/stable-diffusion: " + "27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478", + ), + input={"image_dimensions": "512x512"} + ) """ model: str @@ -34,6 +40,11 @@ class Replicate(LLM): model_kwargs: Dict[str, Any] = Field(default_factory=dict) replicate_api_token: Optional[str] = None prompt_key: Optional[str] = None + version_obj: Any = Field(default=None, exclude=True) + """Optionally pass in the model version object during initialization to avoid + having to make an extra API call to retrieve it during streaming. NOTE: not + serializable, is excluded from serialization. + """ streaming: bool = Field(default=False) """Whether to stream the results.""" @@ -111,14 +122,15 @@ def _call( ) # get the model and version - model_str, version_str = self.model.split(":") - model = replicate_python.models.get(model_str) - version = model.versions.get(version_str) + if self.version_obj is None: + model_str, version_str = self.model.split(":") + model = replicate_python.models.get(model_str) + self.version_obj = model.versions.get(version_str) - if not self.prompt_key: + if self.prompt_key is None: # sort through the openapi schema to get the name of the first input input_properties = sorted( - version.openapi_schema["components"]["schemas"]["Input"][ + self.version_obj.openapi_schema["components"]["schemas"]["Input"][ "properties" ].items(), key=lambda item: item[1].get("x-order", 0), @@ -129,7 +141,7 @@ def _call( inputs: Dict = {self.prompt_key: prompt, **self.input} prediction = replicate_python.predictions.create( - version=version, input={**inputs, **kwargs} + version=self.version_obj, input={**inputs, **kwargs} ) current_completion: str = "" stop_condition_reached = False