From ecbb1ed8cb4b2ccaac3e01ad87dd74a318fa134c Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Thu, 14 Sep 2023 15:04:42 -0700 Subject: [PATCH] Replicate params fix (#10603) --- libs/langchain/langchain/llms/replicate.py | 4 ++-- libs/langchain/tests/integration_tests/llms/test_replicate.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/libs/langchain/langchain/llms/replicate.py b/libs/langchain/langchain/llms/replicate.py index 5d407c40b4142..7a146070ecbb9 100644 --- a/libs/langchain/langchain/llms/replicate.py +++ b/libs/langchain/langchain/llms/replicate.py @@ -79,7 +79,7 @@ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: logger.warning( "Init param `input` is deprecated, please use `model_kwargs` instead." ) - extra = {**values.get("model_kwargs", {}), **input} + extra = {**values.pop("model_kwargs", {}), **input} for field_name in list(values): if field_name not in all_required_field_names: if field_name in extra: @@ -96,7 +96,7 @@ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" replicate_api_token = get_from_dict_or_env( - values, "REPLICATE_API_TOKEN", "REPLICATE_API_TOKEN" + values, "replicate_api_token", "REPLICATE_API_TOKEN" ) values["replicate_api_token"] = replicate_api_token return values diff --git a/libs/langchain/tests/integration_tests/llms/test_replicate.py b/libs/langchain/tests/integration_tests/llms/test_replicate.py index 9bc183bb8b022..eaa09fc597b7c 100644 --- a/libs/langchain/tests/integration_tests/llms/test_replicate.py +++ b/libs/langchain/tests/integration_tests/llms/test_replicate.py @@ -37,6 +37,7 @@ def test_replicate_model_kwargs() -> None: ) short_output = llm("What is LangChain") assert len(short_output) < len(long_output) + assert llm.model_kwargs == {"max_length": 10, "temperature": 0.01} def test_replicate_input() -> None: