Skip to content

Commit

Permalink
Allow replicate prompt key to be manually specified (langchain-ai#10516)
Browse files Browse the repository at this point in the history
Since inference logic doesn't work for all models

Co-authored-by: Taqi Jaffri <[email protected]>
Co-authored-by: Taqi Jaffri <[email protected]>
  • Loading branch information
3 people authored Sep 12, 2023
2 parents 57e2de2 + 7ecee78 commit eaf916f
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions libs/langchain/langchain/llms/replicate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Any, Dict, List, Mapping, Optional
from typing import Any, Dict, List, Optional

from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
Expand Down Expand Up @@ -33,6 +33,7 @@ class Replicate(LLM):
input: Dict[str, Any] = Field(default_factory=dict)
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
replicate_api_token: Optional[str] = None
prompt_key: Optional[str] = None

streaming: bool = Field(default=False)
"""Whether to stream the results."""
Expand Down Expand Up @@ -81,7 +82,7 @@ def validate_environment(cls, values: Dict) -> Dict:
return values

@property
def _identifying_params(self) -> Mapping[str, Any]:
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {
"model": self.model,
Expand Down Expand Up @@ -114,15 +115,18 @@ def _call(
model = replicate_python.models.get(model_str)
version = model.versions.get(version_str)

# sort through the openapi schema to get the name of the first input
input_properties = sorted(
version.openapi_schema["components"]["schemas"]["Input"][
"properties"
].items(),
key=lambda item: item[1].get("x-order", 0),
)
first_input_name = input_properties[0][0]
inputs = {first_input_name: prompt, **self.input}
if not self.prompt_key:
# sort through the openapi schema to get the name of the first input
input_properties = sorted(
version.openapi_schema["components"]["schemas"]["Input"][
"properties"
].items(),
key=lambda item: item[1].get("x-order", 0),
)

self.prompt_key = input_properties[0][0]

inputs: Dict = {self.prompt_key: prompt, **self.input}

prediction = replicate_python.predictions.create(
version=version, input={**inputs, **kwargs}
Expand Down

0 comments on commit eaf916f

Please sign in to comment.