Skip to content

Commit

Permalink
Fix fine-tuned replicate models with faster cold boot (langchain-ai#1…
Browse files Browse the repository at this point in the history
…0512)

With the latest support for faster cold boot in replicate
https://replicate.com/blog/fine-tune-cold-boots it looks like the
replicate LLM support in langchain is broken since some internal
replicate inputs are being returned.

Screenshot below illustrates the problem:

<img width="1917" alt="image"
src="https://github.com/langchain-ai/langchain/assets/749277/d28c27cc-40fb-4258-8710-844c00d3c2b0">

As you can see, the new replicate_weights param is being sent down with
x-order = 0 (which is causing langchain to use that param instead of
prompt which is x-order = 1)

FYI @baskaryan this requires a fix otherwise replicate is broken for
these models. I have pinged replicate whether they want to fix it on
their end by changing the x-order returned by them.

Update: per suggestion I updated the PR to just allow manually setting
the prompt_key which can be set to "prompt" in this case by callers... I
think this is going to be faster anyway than trying to dynamically query
the model every time if you know the prompt key for your model.

---------

Co-authored-by: Taqi Jaffri <[email protected]>
  • Loading branch information
tjaffri and Taqi Jaffri authored Sep 12, 2023
1 parent 57e2de2 commit 21fbbe8
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions libs/langchain/langchain/llms/replicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class Replicate(LLM):
input: Dict[str, Any] = Field(default_factory=dict)
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
replicate_api_token: Optional[str] = None
prompt_key: Optional[str] = None

streaming: bool = Field(default=False)
"""Whether to stream the results."""
Expand Down Expand Up @@ -114,15 +115,18 @@ def _call(
model = replicate_python.models.get(model_str)
version = model.versions.get(version_str)

# sort through the openapi schema to get the name of the first input
input_properties = sorted(
version.openapi_schema["components"]["schemas"]["Input"][
"properties"
].items(),
key=lambda item: item[1].get("x-order", 0),
)
first_input_name = input_properties[0][0]
inputs = {first_input_name: prompt, **self.input}
if not self.prompt_key:
# sort through the openapi schema to get the name of the first input
input_properties = sorted(
version.openapi_schema["components"]["schemas"]["Input"][
"properties"
].items(),
key=lambda item: item[1].get("x-order", 0),
)

self.prompt_key = input_properties[0][0]

inputs = {self.prompt_key: prompt, **self.input}

prediction = replicate_python.predictions.create(
version=version, input={**inputs, **kwargs}
Expand Down

0 comments on commit 21fbbe8

Please sign in to comment.