Skip to content

Commit

Permalink
add replicate stream (langchain-ai#10518)
Browse files Browse the repository at this point in the history
support direct replicate streaming. cc @cbh123 @tjaffri
  • Loading branch information
baskaryan authored Sep 14, 2023
1 parent 7f3f609 commit 9dd4cac
Showing 1 changed file with 68 additions and 29 deletions.
97 changes: 68 additions & 29 deletions libs/langchain/langchain/llms/replicate.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from __future__ import annotations

import logging
from typing import Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional

from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.pydantic_v1 import Extra, Field, root_validator
from langchain.schema.output import GenerationChunk
from langchain.utils import get_from_dict_or_env

if TYPE_CHECKING:
from replicate.prediction import Prediction

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -46,10 +50,10 @@ class Replicate(LLM):
serializable, is excluded from serialization.
"""

streaming: bool = Field(default=False)
streaming: bool = False
"""Whether to stream the results."""

stop: Optional[List[str]] = Field(default=[])
stop: List[str] = Field(default_factory=list)
"""Stop sequences to early-terminate generation."""

class Config:
Expand Down Expand Up @@ -97,7 +101,7 @@ def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {
"model": self.model,
**{"model_kwargs": self.model_kwargs},
"model_kwargs": self.model_kwargs,
}

@property
Expand All @@ -113,6 +117,63 @@ def _call(
**kwargs: Any,
) -> str:
"""Call to replicate endpoint."""
if self.streaming:
completion: Optional[str] = None
for chunk in self._stream(
prompt, stop=stop, run_manager=run_manager, **kwargs
):
if completion is None:
completion = chunk.text
else:
completion += chunk.text
else:
prediction = self._create_prediction(prompt, **kwargs)
prediction.wait()
if prediction.status == "failed":
raise RuntimeError(prediction.error)
completion = prediction.output
assert completion is not None
stop_conditions = stop or self.stop
for s in stop_conditions:
if s in completion:
completion = completion[: completion.find(s)]
return completion

def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
prediction = self._create_prediction(prompt, **kwargs)
stop_conditions = stop or self.stop
stop_condition_reached = False
current_completion: str = ""
for output in prediction.output_iterator():
current_completion += output
# test for stop conditions, if specified
for s in stop_conditions:
if s in current_completion:
prediction.cancel()
stop_condition_reached = True
# Potentially some tokens that should still be yielded before ending
# stream.
stop_index = max(output.find(s), 0)
output = output[:stop_index]
if not output:
break
if output:
yield GenerationChunk(text=output)
if run_manager:
run_manager.on_llm_new_token(
output,
verbose=self.verbose,
)
if stop_condition_reached:
break

def _create_prediction(self, prompt: str, **kwargs: Any) -> Prediction:
try:
import replicate as replicate_python
except ImportError:
Expand All @@ -138,29 +199,7 @@ def _call(

self.prompt_key = input_properties[0][0]

inputs: Dict = {self.prompt_key: prompt, **self.input}

prediction = replicate_python.predictions.create(
version=self.version_obj, input={**inputs, **kwargs}
input_: Dict = {self.prompt_key: prompt, **self.input, **kwargs}
return replicate_python.predictions.create(
version=self.version_obj, input=input_
)
current_completion: str = ""
stop_condition_reached = False
for output in prediction.output_iterator():
current_completion += output

# test for stop conditions, if specified
if stop:
for s in stop:
if s in current_completion:
prediction.cancel()
stop_index = current_completion.find(s)
current_completion = current_completion[:stop_index]
stop_condition_reached = True
break

if stop_condition_reached:
break

if self.streaming and run_manager:
run_manager.on_llm_new_token(output)
return current_completion

0 comments on commit 9dd4cac

Please sign in to comment.