Skip to content

Commit

Permalink
Merge pull request #506 from PrefectHQ/response-model
Browse files Browse the repository at this point in the history
fix function_call params
  • Loading branch information
jlowin authored Jul 25, 2023
2 parents b18642c + 071e0db commit 1bdb7b3
Showing 1 changed file with 23 additions and 20 deletions.
43 changes: 23 additions & 20 deletions src/marvin/openai/ChatCompletion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
import warnings

from typing import Type, Optional, Union, Literal
from pydantic import Extra


class ChatCompletionConfig(BaseSettings):
class ChatCompletionConfig(BaseSettings, extra=Extra.allow):
model: str = "gpt-3.5-turbo"
temperature: float = 0
functions: list = Field(default_factory=list)
Expand All @@ -30,7 +31,7 @@ def merge(self, *args, **kwargs):
setattr(self, key, getattr(self, key, []) + value)
else:
setattr(self, key, value)
return {k: v for k, v in self.__dict__.items() if v != []}
return {k: v for k, v in self.__dict__.items() if v}


def process_list(lst):
Expand Down Expand Up @@ -68,18 +69,19 @@ def create(cls, *args, response_model: Optional[Type[BaseModel]] = None, **kwarg
}
payload = config.merge(**kwargs)
response = cls.observer(super(ChatCompletion, cls).create)(*args, **payload)
response.to_model = lambda: (
process_list(
list(
map(
lambda x: response_model.parse_raw(
x.message.function_call.arguments
),
response.choices,
if response_model is not None:
response.to_model = lambda: (
process_list(
list(
map(
lambda x: response_model.parse_raw(
x.message.function_call.arguments
),
response.choices,
)
)
)
)
)
return response

@classmethod
Expand All @@ -105,18 +107,19 @@ async def acreate(
response = await cls.observer(super(ChatCompletion, cls).acreate)(
*args, **payload
)
response.to_model = lambda: (
process_list(
list(
map(
lambda x: response_model.parse_raw(
x.message.function_call.arguments
),
response.choices,
if response_model is not None:
response.to_model = lambda: (
process_list(
list(
map(
lambda x: response_model.parse_raw(
x.message.function_call.arguments
),
response.choices,
)
)
)
)
)
return response

@staticmethod
Expand Down

0 comments on commit 1bdb7b3

Please sign in to comment.