From d5ebe31f97fb5334caa14617ee953333cab945d2 Mon Sep 17 00:00:00 2001 From: vitreuz Date: Fri, 20 Dec 2024 22:28:37 -0800 Subject: [PATCH 1/2] Add process_response_format function - Added `process_response_format` function to convert OpenAI-style response format to Ollama API format. - Updated `OllamaConfig` and `OllamaChatConfig` to use the new `process_response_format` function. - Added unit test `test_ollama_structured_format` to validate the structured JSON schema format. --- litellm/llms/ollama/common_utils.py | 27 +++++++++++++ .../llms/ollama/completion/transformation.py | 5 +-- litellm/llms/ollama_chat.py | 5 ++- tests/local_testing/test_ollama.py | 38 ++++++++++++++++--- 4 files changed, 65 insertions(+), 10 deletions(-) diff --git a/litellm/llms/ollama/common_utils.py b/litellm/llms/ollama/common_utils.py index 5cf213950c16..5ca205a98696 100644 --- a/litellm/llms/ollama/common_utils.py +++ b/litellm/llms/ollama/common_utils.py @@ -2,6 +2,7 @@ import httpx +from litellm._logging import verbose_logger from litellm.llms.base_llm.chat.transformation import BaseLLMException @@ -43,3 +44,29 @@ def _convert_image(image): image_data.convert("RGB").save(jpeg_image, "JPEG") jpeg_image.seek(0) return base64.b64encode(jpeg_image.getvalue()).decode("utf-8") + + +def process_response_format(response_format: dict) -> str: + """ + Process OpenAI-style response format specification into Ollama API format + string + + Args: + response_format (dict): OpenAI-style response format specification + + Returns: + str: Format string for Ollama API + """ + format_type = response_format.get("type") + if format_type == "json_object": + return "json" + elif format_type == "json_schema": + schema = response_format.get("json_schema", {}).get("schema") + if not schema: + raise ValueError("Invalid JSON schema format") + return schema + else: + verbose_logger.warning( + f"Unsupported response format type: {format_type}, falling back to 'json'" + ) + return "json" diff --git a/litellm/llms/ollama/completion/transformation.py b/litellm/llms/ollama/completion/transformation.py index 52198893219f..be4048f8953e 100644 --- a/litellm/llms/ollama/completion/transformation.py +++ b/litellm/llms/ollama/completion/transformation.py @@ -22,7 +22,7 @@ ProviderField, ) -from ..common_utils import OllamaError, _convert_image +from ..common_utils import OllamaError, _convert_image, process_response_format if TYPE_CHECKING: from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj @@ -172,8 +172,7 @@ def map_openai_params( if param == "stop": optional_params["stop"] = value if param == "response_format" and isinstance(value, dict): - if value["type"] == "json_object": - optional_params["format"] = "json" + optional_params["format"] = process_response_format(value) return optional_params diff --git a/litellm/llms/ollama_chat.py b/litellm/llms/ollama_chat.py index 5aa26ced46dd..2f0d68f03380 100644 --- a/litellm/llms/ollama_chat.py +++ b/litellm/llms/ollama_chat.py @@ -10,6 +10,7 @@ import litellm from litellm import verbose_logger from litellm.llms.custom_httpx.http_handler import get_async_httpx_client +from litellm.llms.ollama.common_utils import process_response_format from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig from litellm.types.llms.ollama import OllamaToolCall, OllamaToolCallFunction from litellm.types.llms.openai import ChatCompletionAssistantToolCall @@ -152,8 +153,8 @@ def map_openai_params( optional_params["repeat_penalty"] = value if param == "stop": optional_params["stop"] = value - if param == "response_format" and value["type"] == "json_object": - optional_params["format"] = "json" + if param == "response_format" and isinstance(value, dict): + optional_params["format"] = process_response_format(value) ### FUNCTION CALLING LOGIC ### if param == "tools": # ollama actually supports json output diff --git a/tests/local_testing/test_ollama.py b/tests/local_testing/test_ollama.py index 2066859091e8..06124061ddbf 100644 --- a/tests/local_testing/test_ollama.py +++ b/tests/local_testing/test_ollama.py @@ -1,13 +1,12 @@ import asyncio +import json import os import sys -import traceback from dotenv import load_dotenv +from pydantic import BaseModel load_dotenv() -import io -import os sys.path.insert( 0, os.path.abspath("../..") @@ -19,7 +18,12 @@ import litellm ## for ollama we can't test making the completion call -from litellm.utils import EmbeddingResponse, get_llm_provider, get_optional_params +from litellm.utils import ( + EmbeddingResponse, + get_llm_provider, + get_optional_params, + type_to_response_format_param, +) def test_get_ollama_params(): @@ -76,6 +80,31 @@ def test_ollama_json_mode(): # test_ollama_json_mode() +def test_ollama_structured_format(): + # assert that format receives a structred json schema + class Country(BaseModel): + name: str + code: str + languages: list[str] + + model_schema = type_to_response_format_param(Country) or {} + model_schema = model_schema.get("json_schema", {}).get("schema", {}) + try: + converted_params = get_optional_params( + custom_llm_provider="ollama", + model="llama2", + temperature=0.5, + response_format=Country, + ) + assert converted_params == { + "temperature": 0.5, + "format": model_schema, + "stream": False, + }, f"{converted_params} != {{'temperature': 0.5, 'format': {model_schema}, 'stream': False}}" + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + mock_ollama_embedding_response = EmbeddingResponse(model="ollama/nomic-embed-text") @@ -135,7 +164,6 @@ def test_ollama_aembeddings(mock_aembeddings): @pytest.mark.skip(reason="local only test") def test_ollama_chat_function_calling(): - import json tools = [ { From 624022801e2060c84dce37e82e3fd79f7fcb5519 Mon Sep 17 00:00:00 2001 From: vitreuz Date: Fri, 20 Dec 2024 23:05:03 -0800 Subject: [PATCH 2/2] Handle unsupported response format types - Updated `process_response_format` to return `Optional[str]` - Removed fallback to 'json' for unsupported response formats - Adjusted `OllamaConfig` and `OllamaChatConfig` to handle `None` format --- litellm/llms/ollama/common_utils.py | 9 +++------ litellm/llms/ollama/completion/transformation.py | 4 +++- litellm/llms/ollama_chat.py | 4 +++- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/litellm/llms/ollama/common_utils.py b/litellm/llms/ollama/common_utils.py index 5ca205a98696..2450eeccf946 100644 --- a/litellm/llms/ollama/common_utils.py +++ b/litellm/llms/ollama/common_utils.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Optional, Union import httpx @@ -46,7 +46,7 @@ def _convert_image(image): return base64.b64encode(jpeg_image.getvalue()).decode("utf-8") -def process_response_format(response_format: dict) -> str: +def process_response_format(response_format: dict) -> Optional[str]: """ Process OpenAI-style response format specification into Ollama API format string @@ -66,7 +66,4 @@ def process_response_format(response_format: dict) -> str: raise ValueError("Invalid JSON schema format") return schema else: - verbose_logger.warning( - f"Unsupported response format type: {format_type}, falling back to 'json'" - ) - return "json" + verbose_logger.warning(f"Unsupported response format type: {format_type}") diff --git a/litellm/llms/ollama/completion/transformation.py b/litellm/llms/ollama/completion/transformation.py index be4048f8953e..33b009bf30d1 100644 --- a/litellm/llms/ollama/completion/transformation.py +++ b/litellm/llms/ollama/completion/transformation.py @@ -172,7 +172,9 @@ def map_openai_params( if param == "stop": optional_params["stop"] = value if param == "response_format" and isinstance(value, dict): - optional_params["format"] = process_response_format(value) + format = process_response_format(value) + if format is not None: + optional_params["format"] = format return optional_params diff --git a/litellm/llms/ollama_chat.py b/litellm/llms/ollama_chat.py index 2f0d68f03380..3ce69aa805b6 100644 --- a/litellm/llms/ollama_chat.py +++ b/litellm/llms/ollama_chat.py @@ -154,7 +154,9 @@ def map_openai_params( if param == "stop": optional_params["stop"] = value if param == "response_format" and isinstance(value, dict): - optional_params["format"] = process_response_format(value) + format = process_response_format(value) + if format is not None: + optional_params["format"] = format ### FUNCTION CALLING LOGIC ### if param == "tools": # ollama actually supports json output