Skip to content

Commit

Permalink
Merge pull request #626 from PrefectHQ/fix-ai-fn-dict-return
Browse files Browse the repository at this point in the history
fix ai_fn type returns for generic aliases and dict
  • Loading branch information
zzstoatzz authored Oct 30, 2023
2 parents 70f7232 + 9c4e067 commit 5794ecd
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 8 deletions.
10 changes: 5 additions & 5 deletions src/marvin/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
BaseSettings,
PrivateAttr,
SecretStr,
ValidationError,
validate_arguments,
)

Expand All @@ -41,7 +40,6 @@
SecretStr,
validate_arguments,
validator as field_validator,
ValidationError,
PrivateAttr,
)

Expand Down Expand Up @@ -135,8 +133,10 @@ def cast_to_model(
"""
Casts a type or callable to a Pydantic model.
"""
origin = get_origin(function_or_type) or function_or_type

response = BaseModel
if get_origin(function_or_type) is Annotated:
if origin is Annotated:
metadata: Any = next(iter(function_or_type.__metadata__), None) # type: ignore
annotated_field_name: Optional[str] = field_name

Expand Down Expand Up @@ -167,11 +167,11 @@ def cast_to_model(
field_name=annotated_field_name, # type: ignore
)
response.__doc__ = annotated_field_description or ""
if isinstance(function_or_type, GenericAlias):
elif origin in {dict, list, tuple, set, frozenset}:
response = cast_type_or_alias_to_model(
function_or_type, name, description, field_name
)
elif isinstance(function_or_type, type):
elif isinstance(origin, type):
if issubclass(function_or_type, BaseModel):
response = create_model(
name or function_or_type.__name__,
Expand Down
6 changes: 4 additions & 2 deletions src/marvin/core/ChatCompletion/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,12 @@ def to_model(self, model_cls: Optional[type[T]] = None) -> T:
pairs = self.get_function_call()
try:
return model(**pairs[0][1])
except ValueError: # ValidationError is a subclass of ValueError
return model(output=pairs[0][1])
except TypeError:
pass
try:
return model.parse_raw(pairs[0][1]) # type: ignore
return model.parse_raw(pairs[0][1])
except TypeError:
pass
return model.construct(**pairs[0][1]) # type: ignore
return model.construct(**pairs[0][1])
65 changes: 64 additions & 1 deletion tests/test_components/test_ai_functions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
from typing import Dict, List

import pytest
from marvin import ai_fn
Expand Down Expand Up @@ -39,7 +40,7 @@ async def list_fruit(n: int) -> list[str]:

def test_list_fruit_with_generic_type_hints(self):
@ai_fn
def list_fruit(n: int) -> list[str]:
def list_fruit(n: int) -> List[str]:
"""Returns a list of `n` fruit"""

result = list_fruit(3)
Expand All @@ -66,6 +67,68 @@ def is_fruit(name: str) -> bool:

assert is_fruit(name) == expected

def test_plain_dict_return_type(self):
@ai_fn
def get_fruit(name: str) -> dict:
"""Returns a fruit with the provided name and color"""

fruit = get_fruit("banana")
assert fruit["name"].lower() == "banana"
assert fruit["color"].lower() == "yellow"

def test_annotated_dict_return_type(self):
@ai_fn
def get_fruit(name: str) -> dict[str, str]:
"""Returns a fruit with the provided name and color"""

fruit = get_fruit("banana")
assert fruit["name"].lower() == "banana"
assert fruit["color"].lower() == "yellow"

def test_generic_dict_return_type(self):
@ai_fn
def get_fruit(name: str) -> Dict[str, str]:
"""Returns a fruit with the provided name and color"""

fruit = get_fruit("banana")
assert fruit["name"].lower() == "banana"
assert fruit["color"].lower() == "yellow"

def test_int_return_type(self):
@ai_fn
def get_fruit(name: str) -> int:
"""Returns the number of letters in the provided fruit name"""

assert get_fruit("banana") == 6

def test_float_return_type(self):
@ai_fn
def get_fruit(name: str) -> float:
"""Returns the number of letters in the provided fruit name"""

assert get_fruit("banana") == 6.0

def test_tuple_return_type(self):
@ai_fn
def get_fruit(name: str) -> tuple:
"""Returns the number of letters in the provided fruit name"""

assert get_fruit("banana") == (6,)

def test_set_return_type(self):
@ai_fn
def get_fruit(name: str) -> set:
"""Returns the letters in the provided fruit name"""

assert get_fruit("banana") == {"a", "b", "n"}

def test_frozenset_return_type(self):
@ai_fn
def get_fruit(name: str) -> frozenset:
"""Returns the letters in the provided fruit name"""

assert get_fruit("banana") == frozenset({"a", "b", "n"})


@pytest_mark_class("llm")
class TestAIFunctionsMap:
Expand Down

0 comments on commit 5794ecd

Please sign in to comment.