Skip to content

Commit

Permalink
Replicate params fix (langchain-ai#10603)
Browse files Browse the repository at this point in the history
  • Loading branch information
baskaryan authored Sep 14, 2023
1 parent 50bb704 commit ecbb1ed
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
4 changes: 2 additions & 2 deletions libs/langchain/langchain/llms/replicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
logger.warning(
"Init param `input` is deprecated, please use `model_kwargs` instead."
)
extra = {**values.get("model_kwargs", {}), **input}
extra = {**values.pop("model_kwargs", {}), **input}
for field_name in list(values):
if field_name not in all_required_field_names:
if field_name in extra:
Expand All @@ -96,7 +96,7 @@ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
replicate_api_token = get_from_dict_or_env(
values, "REPLICATE_API_TOKEN", "REPLICATE_API_TOKEN"
values, "replicate_api_token", "REPLICATE_API_TOKEN"
)
values["replicate_api_token"] = replicate_api_token
return values
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def test_replicate_model_kwargs() -> None:
)
short_output = llm("What is LangChain")
assert len(short_output) < len(long_output)
assert llm.model_kwargs == {"max_length": 10, "temperature": 0.01}


def test_replicate_input() -> None:
Expand Down

0 comments on commit ecbb1ed

Please sign in to comment.