diff --git a/src/controlflow/llm/rules.py b/src/controlflow/llm/rules.py index fed84d9c..e431e5be 100644 --- a/src/controlflow/llm/rules.py +++ b/src/controlflow/llm/rules.py @@ -1,8 +1,9 @@ import textwrap -from typing import Optional +from typing import Any, Optional, Union from langchain_anthropic import ChatAnthropic from langchain_openai import AzureChatOpenAI, ChatOpenAI +from pydantic import Field from controlflow.llm.models import BaseChatModel from controlflow.utilities.general import ControlFlowModel, unwrap @@ -17,7 +18,7 @@ class LLMRules(ControlFlowModel): necessary. """ - model: Optional[BaseChatModel] + model: Any # require at least one non-system message require_at_least_one_message: bool = False @@ -54,7 +55,7 @@ def model_instructions(self) -> Optional[list[str]]: class OpenAIRules(LLMRules): require_message_name_format: str = r"[^a-zA-Z0-9_-]" - model: ChatOpenAI + model: Any def model_instructions(self) -> list[str]: instructions = []