diff --git a/libs/langchain/langchain/llms/replicate.py b/libs/langchain/langchain/llms/replicate.py index 94d8c34fea999..1a06f1c61876d 100644 --- a/libs/langchain/langchain/llms/replicate.py +++ b/libs/langchain/langchain/llms/replicate.py @@ -1,13 +1,17 @@ from __future__ import annotations import logging -from typing import Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.pydantic_v1 import Extra, Field, root_validator +from langchain.schema.output import GenerationChunk from langchain.utils import get_from_dict_or_env +if TYPE_CHECKING: + from replicate.prediction import Prediction + logger = logging.getLogger(__name__) @@ -46,10 +50,10 @@ class Replicate(LLM): serializable, is excluded from serialization. """ - streaming: bool = Field(default=False) + streaming: bool = False """Whether to stream the results.""" - stop: Optional[List[str]] = Field(default=[]) + stop: List[str] = Field(default_factory=list) """Stop sequences to early-terminate generation.""" class Config: @@ -97,7 +101,7 @@ def _identifying_params(self) -> Dict[str, Any]: """Get the identifying parameters.""" return { "model": self.model, - **{"model_kwargs": self.model_kwargs}, + "model_kwargs": self.model_kwargs, } @property @@ -113,6 +117,63 @@ def _call( **kwargs: Any, ) -> str: """Call to replicate endpoint.""" + if self.streaming: + completion: Optional[str] = None + for chunk in self._stream( + prompt, stop=stop, run_manager=run_manager, **kwargs + ): + if completion is None: + completion = chunk.text + else: + completion += chunk.text + else: + prediction = self._create_prediction(prompt, **kwargs) + prediction.wait() + if prediction.status == "failed": + raise RuntimeError(prediction.error) + completion = prediction.output + assert completion is not None + stop_conditions = stop or self.stop + for s in stop_conditions: + if s in completion: + completion = completion[: completion.find(s)] + return completion + + def _stream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[GenerationChunk]: + prediction = self._create_prediction(prompt, **kwargs) + stop_conditions = stop or self.stop + stop_condition_reached = False + current_completion: str = "" + for output in prediction.output_iterator(): + current_completion += output + # test for stop conditions, if specified + for s in stop_conditions: + if s in current_completion: + prediction.cancel() + stop_condition_reached = True + # Potentially some tokens that should still be yielded before ending + # stream. + stop_index = max(output.find(s), 0) + output = output[:stop_index] + if not output: + break + if output: + yield GenerationChunk(text=output) + if run_manager: + run_manager.on_llm_new_token( + output, + verbose=self.verbose, + ) + if stop_condition_reached: + break + + def _create_prediction(self, prompt: str, **kwargs: Any) -> Prediction: try: import replicate as replicate_python except ImportError: @@ -138,29 +199,7 @@ def _call( self.prompt_key = input_properties[0][0] - inputs: Dict = {self.prompt_key: prompt, **self.input} - - prediction = replicate_python.predictions.create( - version=self.version_obj, input={**inputs, **kwargs} + input_: Dict = {self.prompt_key: prompt, **self.input, **kwargs} + return replicate_python.predictions.create( + version=self.version_obj, input=input_ ) - current_completion: str = "" - stop_condition_reached = False - for output in prediction.output_iterator(): - current_completion += output - - # test for stop conditions, if specified - if stop: - for s in stop: - if s in current_completion: - prediction.cancel() - stop_index = current_completion.find(s) - current_completion = current_completion[:stop_index] - stop_condition_reached = True - break - - if stop_condition_reached: - break - - if self.streaming and run_manager: - run_manager.on_llm_new_token(output) - return current_completion