@@ -231,16 +259,25 @@ In addition, you can find:
## Related Papers
-[AutoGen](https://arxiv.org/abs/2308.08155)
+[AutoGen Studio](https://www.microsoft.com/en-us/research/publication/autogen-studio-a-no-code-developer-tool-for-building-and-debugging-multi-agent-systems/)
+
+```
+@inproceedings{dibia2024studio,
+ title={AutoGen Studio: A No-Code Developer Tool for Building and Debugging Multi-Agent Systems},
+ author={Victor Dibia and Jingya Chen and Gagan Bansal and Suff Syed and Adam Fourney and Erkang (Eric) Zhu and Chi Wang and Saleema Amershi},
+ year={2024},
+ booktitle={Pre-Print}
+}
+```
+
+[AutoGen](https://aka.ms/autogen-pdf)
```
@inproceedings{wu2023autogen,
title={AutoGen: Enabling Next-Gen LLM Applications via Multi-Agent Conversation Framework},
author={Qingyun Wu and Gagan Bansal and Jieyu Zhang and Yiran Wu and Beibin Li and Erkang Zhu and Li Jiang and Xiaoyun Zhang and Shaokun Zhang and Jiale Liu and Ahmed Hassan Awadallah and Ryen W White and Doug Burger and Chi Wang},
- year={2023},
- eprint={2308.08155},
- archivePrefix={arXiv},
- primaryClass={cs.AI}
+ year={2024},
+ booktitle={COLM},
}
```
@@ -266,6 +303,27 @@ In addition, you can find:
}
```
+[AgentOptimizer](https://arxiv.org/pdf/2402.11359)
+
+```
+@article{zhang2024training,
+ title={Training Language Model Agents without Modifying Language Models},
+ author={Zhang, Shaokun and Zhang, Jieyu and Liu, Jiale and Song, Linxin and Wang, Chi and Krishna, Ranjay and Wu, Qingyun},
+ journal={ICML'24},
+ year={2024}
+}
+```
+
+[StateFlow](https://arxiv.org/abs/2403.11322)
+```
+@article{wu2024stateflow,
+ title={StateFlow: Enhancing LLM Task-Solving through State-Driven Workflows},
+ author={Wu, Yiran and Yue, Tianwei and Zhang, Shaokun and Wang, Chi and Wu, Qingyun},
+ journal={arXiv preprint arXiv:2403.11322},
+ year={2024}
+}
+```
+
↑ Back to Top ↑
@@ -317,7 +375,7 @@ may be either trademarks or registered trademarks of Microsoft in the United Sta
The licenses for this project do not grant you rights to use any Microsoft names, logos, or trademarks.
Microsoft's general trademark guidelines can be found at http://go.microsoft.com/fwlink/?LinkID=254653.
-Privacy information can be found at https://privacy.microsoft.com/en-us/
+Privacy information can be found at https://go.microsoft.com/fwlink/?LinkId=521839
Microsoft and any contributors reserve all other rights, whether under their respective copyrights, patents,
or trademarks, whether by implication, estoppel, or otherwise.
diff --git a/TRANSPARENCY_FAQS.md b/TRANSPARENCY_FAQS.md
index 206af084748..addf29d8b8d 100644
--- a/TRANSPARENCY_FAQS.md
+++ b/TRANSPARENCY_FAQS.md
@@ -31,6 +31,8 @@ While AutoGen automates LLM workflows, decisions about how to use specific LLM o
- Current version of AutoGen was evaluated on six applications to illustrate its potential in simplifying the development of high-performance multi-agent applications. These applications are selected based on their real-world relevance, problem difficulty and problem solving capabilities enabled by AutoGen, and innovative potential.
- These applications involve using AutoGen to solve math problems, question answering, decision making in text world environments, supply chain optimization, etc. For each of these domains AutoGen was evaluated on various success based metrics (i.e., how often the AutoGen based implementation solved the task). And, in some cases, AutoGen based approach was also evaluated on implementation efficiency (e.g., to track reductions in developer effort to build). More details can be found at: https://aka.ms/AutoGen/TechReport
- The team has conducted tests where a “red” agent attempts to get the default AutoGen assistant to break from its alignment and guardrails. The team has observed that out of 70 attempts to break guardrails, only 1 was successful in producing text that would have been flagged as problematic by Azure OpenAI filters. The team has not observed any evidence that AutoGen (or GPT models as hosted by OpenAI or Azure) can produce novel code exploits or jailbreak prompts, since direct prompts to “be a hacker”, “write exploits”, or “produce a phishing email” are refused by existing filters.
+- We also evaluated [a team of AutoGen agents](https://github.com/microsoft/autogen/tree/gaia_multiagent_v01_march_1st/samples/tools/autogenbench/scenarios/GAIA/Templates/Orchestrator) on the [GAIA benchmarks](https://arxiv.org/abs/2311.12983), and got [SOTA results](https://huggingface.co/spaces/gaia-benchmark/leaderboard) as of
+ March 1, 2024.
## What are the limitations of AutoGen? How can users minimize the impact of AutoGen’s limitations when using the system?
AutoGen relies on existing LLMs. Experimenting with AutoGen would retain common limitations of large language models; including:
diff --git a/autogen/_pydantic.py b/autogen/_pydantic.py
index 9a37208c406..c463dbb3875 100644
--- a/autogen/_pydantic.py
+++ b/autogen/_pydantic.py
@@ -64,27 +64,27 @@ def type2schema(t: Any) -> JsonSchemaValue:
Returns:
JsonSchemaValue: The JSON schema
"""
- if PYDANTIC_V1:
- if t is None:
- return {"type": "null"}
- elif get_origin(t) is Union:
- return {"anyOf": [type2schema(tt) for tt in get_args(t)]}
- elif get_origin(t) in [Tuple, tuple]:
- prefixItems = [type2schema(tt) for tt in get_args(t)]
- return {
- "maxItems": len(prefixItems),
- "minItems": len(prefixItems),
- "prefixItems": prefixItems,
- "type": "array",
- }
-
- d = schema_of(t)
- if "title" in d:
- d.pop("title")
- if "description" in d:
- d.pop("description")
-
- return d
+
+ if t is None:
+ return {"type": "null"}
+ elif get_origin(t) is Union:
+ return {"anyOf": [type2schema(tt) for tt in get_args(t)]}
+ elif get_origin(t) in [Tuple, tuple]:
+ prefixItems = [type2schema(tt) for tt in get_args(t)]
+ return {
+ "maxItems": len(prefixItems),
+ "minItems": len(prefixItems),
+ "prefixItems": prefixItems,
+ "type": "array",
+ }
+ else:
+ d = schema_of(t)
+ if "title" in d:
+ d.pop("title")
+ if "description" in d:
+ d.pop("description")
+
+ return d
def model_dump(model: BaseModel) -> Dict[str, Any]:
"""Convert a pydantic model to a dict
diff --git a/autogen/agentchat/assistant_agent.py b/autogen/agentchat/assistant_agent.py
index b5ec7de90c7..c1601ea9ba8 100644
--- a/autogen/agentchat/assistant_agent.py
+++ b/autogen/agentchat/assistant_agent.py
@@ -38,7 +38,7 @@ def __init__(
llm_config: Optional[Union[Dict, Literal[False]]] = None,
is_termination_msg: Optional[Callable[[Dict], bool]] = None,
max_consecutive_auto_reply: Optional[int] = None,
- human_input_mode: Optional[str] = "NEVER",
+ human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "NEVER",
description: Optional[str] = None,
**kwargs,
):
diff --git a/autogen/agentchat/chat.py b/autogen/agentchat/chat.py
index a07f3302ae9..d07b4d15cb6 100644
--- a/autogen/agentchat/chat.py
+++ b/autogen/agentchat/chat.py
@@ -21,14 +21,16 @@ class ChatResult:
chat_id: int = None
"""chat id"""
- chat_history: List[Dict[str, any]] = None
+ chat_history: List[Dict[str, Any]] = None
"""The chat history."""
summary: str = None
"""A summary obtained from the chat."""
- cost: tuple = None # (dict, dict) - (total_cost, actual_cost_with_cache)
- """The cost of the chat. a tuple of (total_cost, total_actual_cost), where total_cost is a
- dictionary of cost information, and total_actual_cost is a dictionary of information on
- the actual incurred cost with cache."""
+ cost: Dict[str, dict] = None # keys: "usage_including_cached_inference", "usage_excluding_cached_inference"
+ """The cost of the chat.
+ The value for each usage type is a dictionary containing cost information for that specific type.
+ - "usage_including_cached_inference": Cost information on the total usage, including the tokens in cached inference.
+ - "usage_excluding_cached_inference": Cost information on the usage of tokens, excluding the tokens in cache. No larger than "usage_including_cached_inference".
+ """
human_input: List[str] = None
"""A list of human input solicited during the chat."""
@@ -105,6 +107,15 @@ def __find_async_chat_order(chat_ids: Set[int], prerequisites: List[Prerequisite
return chat_order
+def _post_process_carryover_item(carryover_item):
+ if isinstance(carryover_item, str):
+ return carryover_item
+ elif isinstance(carryover_item, dict) and "content" in carryover_item:
+ return str(carryover_item["content"])
+ else:
+ return str(carryover_item)
+
+
def __post_carryover_processing(chat_info: Dict[str, Any]) -> None:
iostream = IOStream.get_default()
@@ -114,7 +125,7 @@ def __post_carryover_processing(chat_info: Dict[str, Any]) -> None:
UserWarning,
)
print_carryover = (
- ("\n").join([t for t in chat_info["carryover"]])
+ ("\n").join([_post_process_carryover_item(t) for t in chat_info["carryover"]])
if isinstance(chat_info["carryover"], list)
else chat_info["carryover"]
)
@@ -151,7 +162,7 @@ def initiate_chats(chat_queue: List[Dict[str, Any]]) -> List[ChatResult]:
For example:
- `"sender"` - the sender agent.
- `"recipient"` - the recipient agent.
- - `"clear_history" (bool) - whether to clear the chat history with the agent.
+ - `"clear_history"` (bool) - whether to clear the chat history with the agent.
Default is True.
- `"silent"` (bool or None) - (Experimental) whether to print the messages in this
conversation. Default is False.
@@ -169,6 +180,9 @@ def initiate_chats(chat_queue: List[Dict[str, Any]]) -> List[ChatResult]:
- `"carryover"` - It can be used to specify the carryover information to be passed
to this chat. If provided, we will combine this carryover with the "message" content when
generating the initial chat message in `generate_init_message`.
+ - `"finished_chat_indexes_to_exclude_from_carryover"` - It can be used by specifying a list of indexes of the finished_chats list,
+ from which to exclude the summaries for carryover. If 'finished_chat_indexes_to_exclude_from_carryover' is not provided or an empty list,
+ then summary from all the finished chats will be taken.
Returns:
(list): a list of ChatResult objects corresponding to the finished chats in the chat_queue.
"""
@@ -180,10 +194,19 @@ def initiate_chats(chat_queue: List[Dict[str, Any]]) -> List[ChatResult]:
while current_chat_queue:
chat_info = current_chat_queue.pop(0)
_chat_carryover = chat_info.get("carryover", [])
+ finished_chat_indexes_to_exclude_from_carryover = chat_info.get(
+ "finished_chat_indexes_to_exclude_from_carryover", []
+ )
+
if isinstance(_chat_carryover, str):
_chat_carryover = [_chat_carryover]
- chat_info["carryover"] = _chat_carryover + [r.summary for r in finished_chats]
- __post_carryover_processing(chat_info)
+ chat_info["carryover"] = _chat_carryover + [
+ r.summary for i, r in enumerate(finished_chats) if i not in finished_chat_indexes_to_exclude_from_carryover
+ ]
+
+ if not chat_info.get("silent", False):
+ __post_carryover_processing(chat_info)
+
sender = chat_info["sender"]
chat_res = sender.initiate_chat(**chat_info)
finished_chats.append(chat_res)
@@ -212,6 +235,9 @@ async def _dependent_chat_future(
"""
logger.debug(f"Create Task for chat {chat_id}." + __system_now_str())
_chat_carryover = chat_info.get("carryover", [])
+ finished_chat_indexes_to_exclude_from_carryover = chat_info.get(
+ "finished_chat_indexes_to_exclude_from_carryover", []
+ )
finished_chats = dict()
for chat in prerequisite_chat_futures:
chat_future = prerequisite_chat_futures[chat]
@@ -223,8 +249,15 @@ async def _dependent_chat_future(
if isinstance(_chat_carryover, str):
_chat_carryover = [_chat_carryover]
- chat_info["carryover"] = _chat_carryover + [finished_chats[pre_id].summary for pre_id in finished_chats]
- __post_carryover_processing(chat_info)
+ data = [
+ chat_result.summary
+ for chat_id, chat_result in finished_chats.items()
+ if chat_id not in finished_chat_indexes_to_exclude_from_carryover
+ ]
+ chat_info["carryover"] = _chat_carryover + data
+ if not chat_info.get("silent", False):
+ __post_carryover_processing(chat_info)
+
sender = chat_info["sender"]
chat_res_future = asyncio.create_task(sender.a_initiate_chat(**chat_info))
call_back_with_args = partial(_on_chat_future_done, chat_id=chat_id)
diff --git a/autogen/agentchat/contrib/agent_builder.py b/autogen/agentchat/contrib/agent_builder.py
index a257a6dcf61..c9a2d79607d 100644
--- a/autogen/agentchat/contrib/agent_builder.py
+++ b/autogen/agentchat/contrib/agent_builder.py
@@ -1,12 +1,20 @@
import hashlib
+import importlib
import json
+import logging
+import re
import socket
import subprocess as sp
import time
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, List, Optional, Tuple, Union
+
+import requests
+from termcolor import colored
import autogen
+logger = logging.getLogger(__name__)
+
def _config_check(config: Dict):
# check config loading
@@ -16,113 +24,162 @@ def _config_check(config: Dict):
for agent_config in config["agent_configs"]:
assert agent_config.get("name", None) is not None, 'Missing agent "name" in your agent_configs.'
- assert agent_config.get("model", None) is not None, 'Missing agent "model" in your agent_configs.'
assert (
agent_config.get("system_message", None) is not None
), 'Missing agent "system_message" in your agent_configs.'
assert agent_config.get("description", None) is not None, 'Missing agent "description" in your agent_configs.'
+def _retrieve_json(text):
+ match = re.findall(autogen.code_utils.CODE_BLOCK_PATTERN, text, flags=re.DOTALL)
+ if not match:
+ return text
+ code_blocks = []
+ for _, code in match:
+ code_blocks.append(code)
+ return code_blocks[0]
+
+
class AgentBuilder:
"""
AgentBuilder can help user build an automatic task solving process powered by multi-agent system.
Specifically, our building pipeline includes initialize and build.
- In build(), we prompt a LLM to create multiple participant agents, and specify whether this task need programming to solve.
- User can save the built agents' config by calling save(), and load the saved configs by load(), which can skip the
- building process.
"""
online_server_name = "online"
+ DEFAULT_PROXY_AUTO_REPLY = 'There is no code from the last 1 message for me to execute. Group chat manager should let other participants to continue the conversation. If the group chat manager want to end the conversation, you should let other participant reply me only with "TERMINATE"'
+
+ GROUP_CHAT_DESCRIPTION = """ # Group chat instruction
+You are now working in a group chat with different expert and a group chat manager.
+You should refer to the previous message from other participant members or yourself, follow their topic and reply to them.
+
+**Your role is**: {name}
+Group chat members: {members}{user_proxy_desc}
+
+When the task is complete and the result has been carefully verified, after obtaining agreement from the other members, you can end the conversation by replying only with "TERMINATE".
+
+# Your profile
+{sys_msg}
+"""
+
+ DEFAULT_DESCRIPTION = """## Your role
+[Complete this part with expert's name and skill description]
+
+## Task and skill instructions
+- [Complete this part with task description]
+- [Complete this part with skill description]
+- [(Optional) Complete this part with other information]
+"""
+
+ CODING_AND_TASK_SKILL_INSTRUCTION = """## Useful instructions for task-solving
+- Solve the task step by step if you need to.
+- When you find an answer, verify the answer carefully. Include verifiable evidence with possible test case in your response if possible.
+- All your reply should be based on the provided facts.
+
+## How to verify?
+**You have to keep believing that everyone else's answers are wrong until they provide clear enough evidence.**
+- Verifying with step-by-step backward reasoning.
+- Write test cases according to the general task.
+
+## How to use code?
+- Suggest python code (in a python coding block) or shell script (in a sh coding block) for the Computer_terminal to execute.
+- If missing python packages, you can install the package by suggesting a `pip install` code in the ```sh ... ``` block.
+- When using code, you must indicate the script type in the coding block.
+- Do not the coding block which requires users to modify.
+- Do not suggest a coding block if it's not intended to be executed by the Computer_terminal.
+- The Computer_terminal cannot modify your code.
+- **Use 'print' function for the output when relevant**.
+- Check the execution result returned by the Computer_terminal.
+- Do not ask Computer_terminal to copy and paste the result.
+- If the result indicates there is an error, fix the error and output the code again. """
+
CODING_PROMPT = """Does the following task need programming (i.e., access external API or tool by coding) to solve,
- or coding may help the following task become easier?
+or coding may help the following task become easier?
- TASK: {task}
+TASK: {task}
- Hint:
- # Answer only YES or NO.
- """
+Answer only YES or NO.
+"""
- AGENT_NAME_PROMPT = """To complete the following task, what positions/jobs should be set to maximize efficiency?
-
- TASK: {task}
-
- Hint:
- # Considering the effort, the position in this task should be no more than {max_agents}; less is better.
- # These positions' name should include enough information that can help a group chat manager know when to let this position speak.
- # The position name should be as specific as possible. For example, use "python_programmer" instead of "programmer".
- # Do not use ambiguous position name, such as "domain expert" with no specific description of domain or "technical writer" with no description of what it should write.
- # Each position should have a unique function and the position name should reflect this.
- # The positions should relate to the task and significantly different in function.
- # Add ONLY ONE programming related position if the task needs coding.
- # Generated agent's name should follow the format of ^[a-zA-Z0-9_-]{{1,64}}$, use "_" to split words.
- # Answer the names of those positions/jobs, separated names by commas.
- # Only return the list of positions.
- """
+ AGENT_NAME_PROMPT = """# Your task
+Suggest no more then {max_agents} experts with their name according to the following user requirement.
- AGENT_SYS_MSG_PROMPT = """Considering the following position and task:
+## User requirement
+{task}
- TASK: {task}
- POSITION: {position}
+# Task requirement
+- Expert's name should follow the format: [skill]_Expert.
+- Only reply the names of the experts, separated by ",".
+For example: Python_Expert, Math_Expert, ... """
- Modify the following position requirement, making it more suitable for the above task and position:
+ AGENT_SYS_MSG_PROMPT = """# Your goal
+- According to the task and expert name, write a high-quality description for the expert by filling the given template.
+- Ensure that your description are clear and unambiguous, and include all necessary information.
- REQUIREMENT: {default_sys_msg}
+# Task
+{task}
- Hint:
- # Your answer should be natural, starting from "You are now in a group chat. You need to complete a task with other participants. As a ...".
- # [IMPORTANT] You should let them reply "TERMINATE" when they think the task is completed (the user's need has actually been satisfied).
- # The modified requirement should not contain the code interpreter skill.
- # You should remove the related skill description when the position is not a programmer or developer.
- # Coding skill is limited to Python.
- # Your answer should omit the word "REQUIREMENT".
- # People with the above position can doubt previous messages or code in the group chat (for example, if there is no
-output after executing the code) and provide a corrected answer or code.
- # People in the above position should ask for help from the group chat manager when confused and let the manager select another participant.
- """
+# Expert name
+{position}
- AGENT_DESCRIPTION_PROMPT = """Considering the following position:
+# Template
+{default_sys_msg}
+"""
- POSITION: {position}
+ AGENT_DESCRIPTION_PROMPT = """# Your goal
+Summarize the following expert's description in a sentence.
- What requirements should this position be satisfied?
+# Expert name
+{position}
- Hint:
- # This description should include enough information that can help a group chat manager know when to let this position speak.
- # People with the above position can doubt previous messages or code in the group chat (for example, if there is no
-output after executing the code) and provide a corrected answer or code.
- # Your answer should be in at most three sentences.
- # Your answer should be natural, starting from "[POSITION's name] is a ...".
- # Your answer should include the skills that this position should have.
- # Your answer should not contain coding-related skills when the position is not a programmer or developer.
- # Coding skills should be limited to Python.
- """
+# Expert's description
+{sys_msg}
+"""
- AGENT_SEARCHING_PROMPT = """Considering the following task:
+ AGENT_SEARCHING_PROMPT = """# Your goal
+Considering the following task, what experts should be involved to the task?
- TASK: {task}
+# TASK
+{task}
- What following agents should be involved to the task?
+# EXPERT LIST
+{agent_list}
- AGENT LIST:
- {agent_list}
+# Requirement
+- You should consider if the experts' name and profile match the task.
+- Considering the effort, you should select less then {max_agents} experts; less is better.
+- Separate expert names by commas and use "_" instead of space. For example, Product_manager,Programmer
+- Only return the list of expert names.
+"""
- Hint:
- # You should consider if the agent's name and profile match the task.
- # Considering the effort, you should select less then {max_agents} agents; less is better.
- # Separate agent names by commas and use "_" instead of space. For example, Product_manager,Programmer
- # Only return the list of agent names.
- """
+ AGENT_SELECTION_PROMPT = """# Your goal
+Match roles in the role set to each expert in expert set.
+
+# Skill set
+{skills}
+
+# Expert pool (formatting with name: description)
+{expert_pool}
+
+# Answer format
+```json
+{{
+ "skill_1 description": "expert_name: expert_description", // if there exists an expert that suitable for skill_1
+ "skill_2 description": "None", // if there is no experts that suitable for skill_2
+ ...
+}}
+```
+"""
def __init__(
self,
config_file_or_env: Optional[str] = "OAI_CONFIG_LIST",
config_file_location: Optional[str] = "",
- builder_model: Optional[str] = "gpt-4",
- agent_model: Optional[str] = "gpt-4",
- host: Optional[str] = "localhost",
- endpoint_building_timeout: Optional[int] = 600,
- max_tokens: Optional[int] = 945,
+ builder_model: Optional[Union[str, list]] = [],
+ agent_model: Optional[Union[str, list]] = [],
+ builder_model_tags: Optional[list] = [],
+ agent_model_tags: Optional[list] = [],
max_agents: Optional[int] = 5,
):
"""
@@ -131,17 +188,27 @@ def __init__(
config_file_or_env: path or environment of the OpenAI api configs.
builder_model: specify a model as the backbone of build manager.
agent_model: specify a model as the backbone of participant agents.
- host: endpoint host.
endpoint_building_timeout: timeout for building up an endpoint server.
- max_tokens: max tokens for each agent.
max_agents: max agents for each task.
"""
- self.host = host
- self.builder_model = builder_model
- self.agent_model = agent_model
+ builder_model = builder_model if isinstance(builder_model, list) else [builder_model]
+ builder_filter_dict = {}
+ if len(builder_model) != 0:
+ builder_filter_dict.update({"model": builder_model})
+ if len(builder_model_tags) != 0:
+ builder_filter_dict.update({"tags": builder_model_tags})
+ builder_config_list = autogen.config_list_from_json(config_file_or_env, filter_dict=builder_filter_dict)
+ if len(builder_config_list) == 0:
+ raise RuntimeError(
+ f"Fail to initialize build manager: {builder_model}{builder_model_tags} does not exist in {config_file_or_env}. "
+ f'If you want to change this model, please specify the "builder_model" in the constructor.'
+ )
+ self.builder_model = autogen.OpenAIWrapper(config_list=builder_config_list)
+
+ self.agent_model = agent_model if isinstance(agent_model, list) else [agent_model]
+ self.agent_model_tags = agent_model_tags
self.config_file_or_env = config_file_or_env
self.config_file_location = config_file_location
- self.endpoint_building_timeout = endpoint_building_timeout
self.building_task: str = None
self.agent_configs: List[Dict] = []
@@ -150,40 +217,20 @@ def __init__(
self.agent_procs_assign: Dict[str, Tuple[autogen.ConversableAgent, str]] = {}
self.cached_configs: Dict = {}
- self.max_tokens = max_tokens
self.max_agents = max_agents
- for port in range(8000, 65535):
- if self._is_port_open(host, port):
- self.open_ports.append(str(port))
-
def set_builder_model(self, model: str):
self.builder_model = model
def set_agent_model(self, model: str):
self.agent_model = model
- @staticmethod
- def _is_port_open(host, port):
- """Check if a tcp port is open."""
- try:
- s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- s.settimeout(10)
- s.bind((host, int(port)))
- s.close()
- return True
- except OSError:
- return False
-
def _create_agent(
self,
- agent_name: str,
- model_name_or_hf_repo: str,
+ agent_config: Dict,
+ member_name: List[str],
llm_config: dict,
- system_message: Optional[str] = autogen.AssistantAgent.DEFAULT_SYSTEM_MESSAGE,
- description: Optional[str] = autogen.AssistantAgent.DEFAULT_DESCRIPTION,
use_oai_assistant: Optional[bool] = False,
- world_size: Optional[int] = 1,
) -> autogen.AssistantAgent:
"""
Create a group chat participant agent.
@@ -192,100 +239,46 @@ def _create_agent(
The API address of that endpoint will be "localhost:{free port}".
Args:
- agent_name: the name that identify the function of the agent (e.g., Coder, Product Manager,...)
- model_name_or_hf_repo: the name of the model or the huggingface repo.
+ agent_config: agent's config. It should include the following information:
+ 1. model_name: backbone model of an agent, e.g., gpt-4-1106-preview, meta/Llama-2-70b-chat
+ 2. agent_name: use to identify an agent in the group chat.
+ 3. system_message: including persona, task solving instruction, etc.
+ 4. description: brief description of an agent that help group chat manager to pick the speaker.
llm_config: specific configs for LLM (e.g., config_list, seed, temperature, ...).
- system_message: system prompt use to format an agent's behavior.
- description: a brief description of the agent. This will improve the group chat performance.
use_oai_assistant: use OpenAI assistant api instead of self-constructed agent.
world_size: the max size of parallel tensors (in most of the cases, this is identical to the amount of GPUs).
Returns:
agent: a set-up agent.
"""
- from huggingface_hub import HfApi
- from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
-
+ model_name_or_hf_repo = agent_config.get("model", [])
+ model_name_or_hf_repo = (
+ model_name_or_hf_repo if isinstance(model_name_or_hf_repo, list) else [model_name_or_hf_repo]
+ )
+ model_tags = agent_config.get("tags", [])
+ agent_name = agent_config["name"]
+ system_message = agent_config["system_message"]
+ description = agent_config["description"]
+
+ # Path to the customize **ConversableAgent** class.
+ model_path = agent_config.get("model_path", None)
+ filter_dict = {}
+ if len(model_name_or_hf_repo) > 0:
+ filter_dict.update({"model": model_name_or_hf_repo})
+ if len(model_tags) > 0:
+ filter_dict.update({"tags": model_tags})
config_list = autogen.config_list_from_json(
- self.config_file_or_env,
- file_location=self.config_file_location,
- filter_dict={"model": [model_name_or_hf_repo]},
+ self.config_file_or_env, file_location=self.config_file_location, filter_dict=filter_dict
)
if len(config_list) == 0:
raise RuntimeError(
- f"Fail to initialize agent {agent_name}: {model_name_or_hf_repo} does not exist in {self.config_file_or_env}.\n"
+ f"Fail to initialize agent {agent_name}: {model_name_or_hf_repo}{model_tags} does not exist in {self.config_file_or_env}.\n"
f'If you would like to change this model, please specify the "agent_model" in the constructor.\n'
f"If you load configs from json, make sure the model in agent_configs is in the {self.config_file_or_env}."
)
- try:
- hf_api = HfApi()
- hf_api.model_info(model_name_or_hf_repo)
- model_name = model_name_or_hf_repo.split("/")[-1]
- server_id = f"{model_name}_{self.host}"
- except GatedRepoError as e:
- raise e
- except RepositoryNotFoundError:
- server_id = self.online_server_name
-
- if server_id != self.online_server_name:
- # The code in this block is uncovered by tests because online environment does not support gpu use.
- if self.agent_procs.get(server_id, None) is None:
- while True:
- port = self.open_ports.pop()
- if self._is_port_open(self.host, port):
- break
-
- # Use vLLM to set up a server with OpenAI API support.
- agent_proc = sp.Popen(
- [
- "python",
- "-m",
- "vllm.entrypoints.openai.api_server",
- "--host",
- f"{self.host}",
- "--port",
- f"{port}",
- "--model",
- f"{model_name_or_hf_repo}",
- "--tensor-parallel-size",
- f"{world_size}",
- ],
- stdout=sp.PIPE,
- stderr=sp.STDOUT,
- )
- timeout_start = time.time()
-
- while True:
- server_stdout = agent_proc.stdout.readline()
- if server_stdout != b"":
- print(server_stdout)
- timeout_end = time.time()
- if b"running" in server_stdout:
- print(
- f"Running {model_name_or_hf_repo} on http://{self.host}:{port} "
- f"with tensor parallel size {world_size}."
- )
- break
- elif b"address already in use" in server_stdout:
- raise RuntimeError(
- f"{self.host}:{port} already in use. Fail to set up the endpoint for "
- f"{model_name_or_hf_repo} on {self.host}:{port}."
- )
- elif timeout_end - timeout_start > self.endpoint_building_timeout:
- raise RuntimeError(
- f"Timeout exceed. Fail to set up the endpoint for "
- f"{model_name_or_hf_repo} on {self.host}:{port}."
- )
- self.agent_procs[server_id] = (agent_proc, port)
- else:
- port = self.agent_procs[server_id][1]
-
- config_list[0]["base_url"] = f"http://{self.host}:{port}/v1"
-
+ server_id = self.online_server_name
current_config = llm_config.copy()
- current_config.update(
- {"config_list": config_list, "model": model_name_or_hf_repo, "max_tokens": self.max_tokens}
- )
+ current_config.update({"config_list": config_list})
if use_oai_assistant:
from autogen.agentchat.contrib.gpt_assistant_agent import GPTAssistantAgent
@@ -296,12 +289,38 @@ def _create_agent(
overwrite_instructions=False,
)
else:
- agent = autogen.AssistantAgent(
- name=agent_name,
- llm_config=current_config.copy(),
- system_message=system_message,
- description=description,
+ user_proxy_desc = ""
+ if self.cached_configs["coding"] is True:
+ user_proxy_desc = (
+ "\nThe group also include a Computer_terminal to help you run the python and shell code."
+ )
+
+ model_class = autogen.AssistantAgent
+ if model_path:
+ module_path, model_class_name = model_path.replace("/", ".").rsplit(".", 1)
+ module = importlib.import_module(module_path)
+ model_class = getattr(module, model_class_name)
+ if not issubclass(model_class, autogen.ConversableAgent):
+ logger.error(f"{model_class} is not a ConversableAgent. Use AssistantAgent as default")
+ model_class = autogen.AssistantAgent
+
+ additional_config = {
+ k: v
+ for k, v in agent_config.items()
+ if k not in ["model", "name", "system_message", "description", "model_path", "tags"]
+ }
+ agent = model_class(
+ name=agent_name, llm_config=current_config.copy(), description=description, **additional_config
)
+ if system_message == "":
+ system_message = agent.system_message
+ else:
+ system_message = f"{system_message}\n\n{self.CODING_AND_TASK_SKILL_INSTRUCTION}"
+
+ enhanced_sys_msg = self.GROUP_CHAT_DESCRIPTION.format(
+ name=agent_name, members=member_name, user_proxy_desc=user_proxy_desc, sys_msg=system_message
+ )
+ agent.update_system_message(enhanced_sys_msg)
self.agent_procs_assign[agent_name] = (agent, server_id)
return agent
@@ -325,7 +344,7 @@ def clear_agent(self, agent_name: str, recycle_endpoint: Optional[bool] = True):
return
self.agent_procs[server_id][0].terminate()
self.open_ports.append(server_id.split("_")[-1])
- print(f"Agent {agent_name} has been cleared.")
+ print(colored(f"Agent {agent_name} has been cleared.", "yellow"), flush=True)
def clear_all_agents(self, recycle_endpoint: Optional[bool] = True):
"""
@@ -333,7 +352,7 @@ def clear_all_agents(self, recycle_endpoint: Optional[bool] = True):
"""
for agent_name in [agent_name for agent_name in self.agent_procs_assign.keys()]:
self.clear_agent(agent_name, recycle_endpoint)
- print("All agents have been cleared.")
+ print(colored("All agents have been cleared.", "yellow"), flush=True)
def build(
self,
@@ -342,6 +361,8 @@ def build(
coding: Optional[bool] = None,
code_execution_config: Optional[Dict] = None,
use_oai_assistant: Optional[bool] = False,
+ user_proxy: Optional[autogen.ConversableAgent] = None,
+ max_agents: Optional[int] = None,
**kwargs,
) -> Tuple[List[autogen.ConversableAgent], Dict]:
"""
@@ -353,6 +374,7 @@ def build(
code_execution_config: specific configs for user proxy (e.g., last_n_messages, work_dir, ...).
default_llm_config: specific configs for LLM (e.g., config_list, seed, temperature, ...).
use_oai_assistant: use OpenAI assistant api instead of self-constructed agent.
+ user_proxy: user proxy's class that can be used to replace the default user proxy.
Returns:
agent_list: a list of agents.
@@ -360,34 +382,25 @@ def build(
"""
if code_execution_config is None:
code_execution_config = {
- "last_n_messages": 2,
+ "last_n_messages": 1,
"work_dir": "groupchat",
"use_docker": False,
- "timeout": 60,
+ "timeout": 10,
}
+ if max_agents is None:
+ max_agents = self.max_agents
+
agent_configs = []
self.building_task = building_task
- config_list = autogen.config_list_from_json(
- self.config_file_or_env,
- file_location=self.config_file_location,
- filter_dict={"model": [self.builder_model]},
- )
- if len(config_list) == 0:
- raise RuntimeError(
- f"Fail to initialize build manager: {self.builder_model} does not exist in {self.config_file_or_env}. "
- f'If you want to change this model, please specify the "builder_model" in the constructor.'
- )
- build_manager = autogen.OpenAIWrapper(config_list=config_list)
-
- print("==> Generating agents...")
+ print(colored("==> Generating agents...", "green"), flush=True)
resp_agent_name = (
- build_manager.create(
+ self.builder_model.create(
messages=[
{
"role": "user",
- "content": self.AGENT_NAME_PROMPT.format(task=building_task, max_agents=self.max_agents),
+ "content": self.AGENT_NAME_PROMPT.format(task=building_task, max_agents=max_agents),
}
]
)
@@ -395,21 +408,21 @@ def build(
.message.content
)
agent_name_list = [agent_name.strip().replace(" ", "_") for agent_name in resp_agent_name.split(",")]
- print(f"{agent_name_list} are generated.")
+ print(f"{agent_name_list} are generated.", flush=True)
- print("==> Generating system message...")
+ print(colored("==> Generating system message...", "green"), flush=True)
agent_sys_msg_list = []
for name in agent_name_list:
- print(f"Preparing system message for {name}")
+ print(f"Preparing system message for {name}", flush=True)
resp_agent_sys_msg = (
- build_manager.create(
+ self.builder_model.create(
messages=[
{
"role": "user",
"content": self.AGENT_SYS_MSG_PROMPT.format(
task=building_task,
position=name,
- default_sys_msg=autogen.AssistantAgent.DEFAULT_SYSTEM_MESSAGE,
+ default_sys_msg=self.DEFAULT_DESCRIPTION,
),
}
]
@@ -419,16 +432,16 @@ def build(
)
agent_sys_msg_list.append(resp_agent_sys_msg)
- print("==> Generating description...")
+ print(colored("==> Generating description...", "green"), flush=True)
agent_description_list = []
- for name in agent_name_list:
- print(f"Preparing description for {name}")
+ for name, sys_msg in list(zip(agent_name_list, agent_sys_msg_list)):
+ print(f"Preparing description for {name}", flush=True)
resp_agent_description = (
- build_manager.create(
+ self.builder_model.create(
messages=[
{
"role": "user",
- "content": self.AGENT_DESCRIPTION_PROMPT.format(position=name),
+ "content": self.AGENT_DESCRIPTION_PROMPT.format(position=name, sys_msg=sys_msg),
}
]
)
@@ -439,12 +452,18 @@ def build(
for name, sys_msg, description in list(zip(agent_name_list, agent_sys_msg_list, agent_description_list)):
agent_configs.append(
- {"name": name, "model": self.agent_model, "system_message": sys_msg, "description": description}
+ {
+ "name": name,
+ "model": self.agent_model,
+ "tags": self.agent_model_tags,
+ "system_message": sys_msg,
+ "description": description,
+ }
)
if coding is None:
resp = (
- build_manager.create(
+ self.builder_model.create(
messages=[{"role": "user", "content": self.CODING_PROMPT.format(task=building_task)}]
)
.choices[0]
@@ -461,18 +480,20 @@ def build(
"code_execution_config": code_execution_config,
}
)
-
- return self._build_agents(use_oai_assistant, **kwargs)
+ _config_check(self.cached_configs)
+ return self._build_agents(use_oai_assistant, user_proxy=user_proxy, **kwargs)
def build_from_library(
self,
building_task: str,
library_path_or_json: str,
default_llm_config: Dict,
- coding: Optional[bool] = True,
+ top_k: int = 3,
+ coding: Optional[bool] = None,
code_execution_config: Optional[Dict] = None,
use_oai_assistant: Optional[bool] = False,
- embedding_model: Optional[str] = None,
+ embedding_model: Optional[str] = "all-mpnet-base-v2",
+ user_proxy: Optional[autogen.ConversableAgent] = None,
**kwargs,
) -> Tuple[List[autogen.ConversableAgent], Dict]:
"""
@@ -488,81 +509,83 @@ def build_from_library(
code_execution_config: specific configs for user proxy (e.g., last_n_messages, work_dir, ...).
use_oai_assistant: use OpenAI assistant api instead of self-constructed agent.
embedding_model: a Sentence-Transformers model use for embedding similarity to select agents from library.
- if None, an openai model will be prompted to select agents. As reference, chromadb use "all-mpnet-base-
- v2" as default.
+ As reference, chromadb use "all-mpnet-base-v2" as default.
+ user_proxy: user proxy's class that can be used to replace the default user proxy.
Returns:
agent_list: a list of agents.
cached_configs: cached configs.
"""
+ import sqlite3
+
+ # Some system will have an unexcepted sqlite3 version.
+ # Check if the user has installed pysqlite3.
+ if int(sqlite3.version.split(".")[0]) < 3:
+ try:
+ __import__("pysqlite3")
+ import sys
+
+ sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
+ except Exception as e:
+ raise e
import chromadb
from chromadb.utils import embedding_functions
if code_execution_config is None:
code_execution_config = {
- "last_n_messages": 2,
+ "last_n_messages": 1,
"work_dir": "groupchat",
"use_docker": False,
- "timeout": 60,
+ "timeout": 120,
}
- agent_configs = []
-
- config_list = autogen.config_list_from_json(
- self.config_file_or_env,
- file_location=self.config_file_location,
- filter_dict={"model": [self.builder_model]},
- )
- if len(config_list) == 0:
- raise RuntimeError(
- f"Fail to initialize build manager: {self.builder_model} does not exist in {self.config_file_or_env}. "
- f'If you want to change this model, please specify the "builder_model" in the constructor.'
- )
- build_manager = autogen.OpenAIWrapper(config_list=config_list)
-
try:
agent_library = json.loads(library_path_or_json)
except json.decoder.JSONDecodeError:
with open(library_path_or_json, "r") as f:
agent_library = json.load(f)
+ except Exception as e:
+ raise e
- print("==> Looking for suitable agents in library...")
- if embedding_model is not None:
- chroma_client = chromadb.Client()
- collection = chroma_client.create_collection(
- name="agent_list",
- embedding_function=embedding_functions.SentenceTransformerEmbeddingFunction(model_name=embedding_model),
- )
- collection.add(
- documents=[agent["profile"] for agent in agent_library],
- metadatas=[{"source": "agent_profile"} for _ in range(len(agent_library))],
- ids=[f"agent_{i}" for i in range(len(agent_library))],
- )
- agent_profile_list = collection.query(query_texts=[building_task], n_results=self.max_agents)["documents"][
- 0
- ]
-
- # search name from library
- agent_name_list = []
- for profile in agent_profile_list:
- for agent in agent_library:
- if agent["profile"] == profile:
- agent_name_list.append(agent["name"])
- break
- chroma_client.delete_collection(collection.name)
- print(f"{agent_name_list} are selected.")
- else:
- agent_profiles = [
- f"No.{i + 1} AGENT's NAME: {agent['name']}\nNo.{i + 1} AGENT's PROFILE: {agent['profile']}\n\n"
- for i, agent in enumerate(agent_library)
- ]
- resp_agent_name = (
- build_manager.create(
+ print(colored("==> Looking for suitable agents in the library...", "green"), flush=True)
+ skills = building_task.replace(":", " ").split("\n")
+ # skills = [line.split("-", 1)[1].strip() if line.startswith("-") else line for line in lines]
+ if len(skills) == 0:
+ skills = [building_task]
+
+ chroma_client = chromadb.Client()
+ collection = chroma_client.create_collection(
+ name="agent_list",
+ embedding_function=embedding_functions.SentenceTransformerEmbeddingFunction(model_name=embedding_model),
+ )
+ collection.add(
+ documents=[agent["description"] for agent in agent_library],
+ metadatas=[{"source": "agent_profile"} for _ in range(len(agent_library))],
+ ids=[f"agent_{i}" for i in range(len(agent_library))],
+ )
+ agent_desc_list = set()
+ for skill in skills:
+ recall = set(collection.query(query_texts=[skill], n_results=top_k)["documents"][0])
+ agent_desc_list = agent_desc_list.union(recall)
+
+ agent_config_list = []
+ for description in list(agent_desc_list):
+ for agent in agent_library:
+ if agent["description"] == description:
+ agent_config_list.append(agent.copy())
+ break
+ chroma_client.delete_collection(collection.name)
+
+ # double recall from the searching result
+ expert_pool = [f"{agent['name']}: {agent['description']}" for agent in agent_config_list]
+ while True:
+ skill_agent_pair_json = (
+ self.builder_model.create(
messages=[
{
"role": "user",
- "content": self.AGENT_SEARCHING_PROMPT.format(
- task=building_task, agent_list="".join(agent_profiles), max_agents=self.max_agents
+ "content": self.AGENT_SELECTION_PROMPT.format(
+ skills=building_task, expert_pool=expert_pool, max_agents=self.max_agents
),
}
]
@@ -570,48 +593,45 @@ def build_from_library(
.choices[0]
.message.content
)
- agent_name_list = [agent_name.strip().replace(" ", "_") for agent_name in resp_agent_name.split(",")]
-
- # search profile from library
- agent_profile_list = []
- for name in agent_name_list:
- for agent in agent_library:
- if agent["name"] == name:
- agent_profile_list.append(agent["profile"])
- break
- print(f"{agent_name_list} are selected.")
-
- print("==> Generating system message...")
- # generate system message from profile
- agent_sys_msg_list = []
- for name, profile in list(zip(agent_name_list, agent_profile_list)):
- print(f"Preparing system message for {name}...")
- resp_agent_sys_msg = (
- build_manager.create(
- messages=[
- {
- "role": "user",
- "content": self.AGENT_SYS_MSG_PROMPT.format(
- task=building_task,
- position=f"{name}\nPOSITION PROFILE: {profile}",
- default_sys_msg=autogen.AssistantAgent.DEFAULT_SYSTEM_MESSAGE,
- ),
- }
- ]
+ try:
+ skill_agent_pair_json = _retrieve_json(skill_agent_pair_json)
+ skill_agent_pair = json.loads(skill_agent_pair_json)
+ break
+ except Exception as e:
+ print(e, flush=True)
+ time.sleep(5)
+ continue
+
+ recalled_agent_config_list = []
+ recalled_name_desc = []
+ for skill, agent_profile in skill_agent_pair.items():
+ # If no suitable agent, generate an agent
+ if agent_profile == "None":
+ _, agent_config_temp = self.build(
+ building_task=skill,
+ default_llm_config=default_llm_config.copy(),
+ coding=False,
+ use_oai_assistant=use_oai_assistant,
+ max_agents=1,
)
- .choices[0]
- .message.content
- )
- agent_sys_msg_list.append(resp_agent_sys_msg)
-
- for name, sys_msg, description in list(zip(agent_name_list, agent_sys_msg_list, agent_profile_list)):
- agent_configs.append(
- {"name": name, "model": self.agent_model, "system_message": sys_msg, "description": description}
- )
+ self.clear_agent(agent_config_temp["agent_configs"][0]["name"])
+ recalled_agent_config_list.append(agent_config_temp["agent_configs"][0])
+ else:
+ if agent_profile in recalled_name_desc:
+ # prevent identical agents
+ continue
+ recalled_name_desc.append(agent_profile)
+ name = agent_profile.split(":")[0].strip()
+ desc = agent_profile.split(":")[1].strip()
+ for agent in agent_config_list:
+ if name == agent["name"] and desc == agent["description"]:
+ recalled_agent_config_list.append(agent.copy())
+
+ print(f"{[agent['name'] for agent in recalled_agent_config_list]} are selected.", flush=True)
if coding is None:
resp = (
- build_manager.create(
+ self.builder_model.create(
messages=[{"role": "user", "content": self.CODING_PROMPT.format(task=building_task)}]
)
.choices[0]
@@ -622,23 +642,25 @@ def build_from_library(
self.cached_configs.update(
{
"building_task": building_task,
- "agent_configs": agent_configs,
+ "agent_configs": recalled_agent_config_list,
"coding": coding,
"default_llm_config": default_llm_config,
"code_execution_config": code_execution_config,
}
)
+ _config_check(self.cached_configs)
- return self._build_agents(use_oai_assistant, **kwargs)
+ return self._build_agents(use_oai_assistant, user_proxy=user_proxy, **kwargs)
def _build_agents(
- self, use_oai_assistant: Optional[bool] = False, **kwargs
+ self, use_oai_assistant: Optional[bool] = False, user_proxy: Optional[autogen.ConversableAgent] = None, **kwargs
) -> Tuple[List[autogen.ConversableAgent], Dict]:
"""
Build agents with generated configs.
Args:
use_oai_assistant: use OpenAI assistant api instead of self-constructed agent.
+ user_proxy: user proxy's class that can be used to replace the default user proxy.
Returns:
agent_list: a list of agents.
@@ -649,37 +671,29 @@ def _build_agents(
coding = self.cached_configs["coding"]
code_execution_config = self.cached_configs["code_execution_config"]
- print("==> Creating agents...")
+ print(colored("==> Creating agents...", "green"), flush=True)
for config in agent_configs:
- print(f"Creating agent {config['name']} with backbone {config['model']}...")
+ print(f"Creating agent {config['name']}...", flush=True)
self._create_agent(
- config["name"],
- config["model"],
- default_llm_config,
- system_message=config["system_message"],
- description=config["description"],
+ agent_config=config.copy(),
+ member_name=[agent["name"] for agent in agent_configs],
+ llm_config=default_llm_config,
use_oai_assistant=use_oai_assistant,
**kwargs,
)
agent_list = [agent_config[0] for agent_config in self.agent_procs_assign.values()]
if coding is True:
- print("Adding user console proxy...")
- agent_list = (
- [
- autogen.UserProxyAgent(
- name="User_console_and_code_interpreter",
- is_termination_msg=lambda x: "TERMINATE" in x.get("content"),
- system_message="User console with a python code interpreter interface.",
- description="""A user console with a code interpreter interface.
-It can provide the code execution results. Select this player when other players provide some code that needs to be executed.
-DO NOT SELECT THIS PLAYER WHEN NO CODE TO EXECUTE; IT WILL NOT ANSWER ANYTHING.""",
- code_execution_config=code_execution_config,
- human_input_mode="NEVER",
- )
- ]
- + agent_list
- )
+ print("Adding user console proxy...", flush=True)
+ if user_proxy is None:
+ user_proxy = autogen.UserProxyAgent(
+ name="Computer_terminal",
+ is_termination_msg=lambda x: x == "TERMINATE" or x == "TERMINATE.",
+ code_execution_config=code_execution_config,
+ human_input_mode="NEVER",
+ default_auto_reply=self.DEFAULT_PROXY_AUTO_REPLY,
+ )
+ agent_list = agent_list + [user_proxy]
return agent_list, self.cached_configs.copy()
@@ -698,7 +712,7 @@ def save(self, filepath: Optional[str] = None) -> str:
filepath = f'./save_config_{hashlib.md5(self.building_task.encode("utf-8")).hexdigest()}.json'
with open(filepath, "w") as save_file:
json.dump(self.cached_configs, save_file, indent=4)
- print(f"Building config saved to {filepath}")
+ print(colored(f"Building config saved to {filepath}", "green"), flush=True)
return filepath
@@ -723,12 +737,12 @@ def load(
"""
# load json string.
if config_json is not None:
- print("Loading config from JSON...")
+ print(colored("Loading config from JSON...", "green"), flush=True)
cached_configs = json.loads(config_json)
# load from path.
if filepath is not None:
- print(f"Loading config from {filepath}")
+ print(colored(f"Loading config from {filepath}", "green"), flush=True)
with open(filepath) as f:
cached_configs = json.load(f)
diff --git a/autogen/agentchat/contrib/agent_eval/README.md b/autogen/agentchat/contrib/agent_eval/README.md
new file mode 100644
index 00000000000..478f28fd74e
--- /dev/null
+++ b/autogen/agentchat/contrib/agent_eval/README.md
@@ -0,0 +1,9 @@
+Agents for running the [AgentEval](https://microsoft.github.io/autogen/blog/2023/11/20/AgentEval/) pipeline.
+
+AgentEval is a process for evaluating a LLM-based system's performance on a given task.
+
+When given a task to evaluate and a few example runs, the critic and subcritic agents create evaluation criteria for evaluating a system's solution. Once the criteria has been created, the quantifier agent can evaluate subsequent task solutions based on the generated criteria.
+
+For more information see: [AgentEval Integration Roadmap](https://github.com/microsoft/autogen/issues/2162)
+
+See our [blog post](https://microsoft.github.io/autogen/blog/2024/06/21/AgentEval) for usage examples and general explanations.
diff --git a/autogen/agentchat/contrib/agent_eval/agent_eval.py b/autogen/agentchat/contrib/agent_eval/agent_eval.py
new file mode 100644
index 00000000000..b48c65a66d2
--- /dev/null
+++ b/autogen/agentchat/contrib/agent_eval/agent_eval.py
@@ -0,0 +1,101 @@
+from typing import Dict, List, Literal, Optional, Union
+
+import autogen
+from autogen.agentchat.contrib.agent_eval.criterion import Criterion
+from autogen.agentchat.contrib.agent_eval.critic_agent import CriticAgent
+from autogen.agentchat.contrib.agent_eval.quantifier_agent import QuantifierAgent
+from autogen.agentchat.contrib.agent_eval.subcritic_agent import SubCriticAgent
+from autogen.agentchat.contrib.agent_eval.task import Task
+
+
+def generate_criteria(
+ llm_config: Optional[Union[Dict, Literal[False]]] = None,
+ task: Task = None,
+ additional_instructions: str = "",
+ max_round=2,
+ use_subcritic: bool = False,
+):
+ """
+ Creates a list of criteria for evaluating the utility of a given task.
+ Args:
+ llm_config (dict or bool): llm inference configuration.
+ task (Task): The task to evaluate.
+ additional_instructions (str): Additional instructions for the criteria agent.
+ max_round (int): The maximum number of rounds to run the conversation.
+ use_subcritic (bool): Whether to use the subcritic agent to generate subcriteria.
+ Returns:
+ list: A list of Criterion objects for evaluating the utility of the given task.
+ """
+ critic = CriticAgent(
+ system_message=CriticAgent.DEFAULT_SYSTEM_MESSAGE + "\n" + additional_instructions,
+ llm_config=llm_config,
+ )
+
+ critic_user = autogen.UserProxyAgent(
+ name="critic_user",
+ max_consecutive_auto_reply=0, # terminate without auto-reply
+ human_input_mode="NEVER",
+ code_execution_config={"use_docker": False},
+ )
+
+ agents = [critic_user, critic]
+
+ if use_subcritic:
+ subcritic = SubCriticAgent(
+ llm_config=llm_config,
+ )
+ agents.append(subcritic)
+
+ groupchat = autogen.GroupChat(
+ agents=agents, messages=[], max_round=max_round, speaker_selection_method="round_robin"
+ )
+ critic_manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=llm_config)
+
+ critic_user.initiate_chat(critic_manager, message=task.get_sys_message())
+ criteria = critic_user.last_message()
+ content = criteria["content"]
+ # need to strip out any extra code around the returned json
+ content = content[content.find("[") : content.rfind("]") + 1]
+ criteria = Criterion.parse_json_str(content)
+ return criteria
+
+
+def quantify_criteria(
+ llm_config: Optional[Union[Dict, Literal[False]]] = None,
+ criteria: List[Criterion] = None,
+ task: Task = None,
+ test_case: str = "",
+ ground_truth: str = "",
+):
+ """
+ Quantifies the performance of a system using the provided criteria.
+ Args:
+ llm_config (dict or bool): llm inference configuration.
+ criteria ([Criterion]): A list of criteria for evaluating the utility of a given task.
+ task (Task): The task to evaluate.
+ test_case (str): The test case to evaluate.
+ ground_truth (str): The ground truth for the test case.
+ Returns:
+ dict: A dictionary where the keys are the criteria and the values are the assessed performance based on accepted values for each criteria.
+ """
+ quantifier = QuantifierAgent(
+ llm_config=llm_config,
+ )
+
+ quantifier_user = autogen.UserProxyAgent(
+ name="quantifier_user",
+ max_consecutive_auto_reply=0, # terminate without auto-reply
+ human_input_mode="NEVER",
+ code_execution_config={"use_docker": False},
+ )
+
+ quantifier_user.initiate_chat( # noqa: F841
+ quantifier,
+ message=task.get_sys_message()
+ + "Evaluation dictionary: "
+ + Criterion.write_json(criteria)
+ + "actual test case to evaluate: "
+ + test_case,
+ )
+ quantified_results = quantifier_user.last_message()
+ return {"actual_success": ground_truth, "estimated_performance": quantified_results["content"]}
diff --git a/autogen/agentchat/contrib/agent_eval/criterion.py b/autogen/agentchat/contrib/agent_eval/criterion.py
new file mode 100644
index 00000000000..5efd121ec07
--- /dev/null
+++ b/autogen/agentchat/contrib/agent_eval/criterion.py
@@ -0,0 +1,41 @@
+from __future__ import annotations
+
+import json
+from typing import List
+
+import pydantic_core
+from pydantic import BaseModel
+from pydantic.json import pydantic_encoder
+
+
+class Criterion(BaseModel):
+ """
+ A class that represents a criterion for agent evaluation.
+ """
+
+ name: str
+ description: str
+ accepted_values: List[str]
+ sub_criteria: List[Criterion] = list()
+
+ @staticmethod
+ def parse_json_str(criteria: str):
+ """
+ Create a list of Criterion objects from a json string.
+ Args:
+ criteria (str): Json string that represents the criteria
+ returns:
+ [Criterion]: A list of Criterion objects that represents the json criteria information.
+ """
+ return [Criterion(**crit) for crit in json.loads(criteria)]
+
+ @staticmethod
+ def write_json(criteria):
+ """
+ Create a json string from a list of Criterion objects.
+ Args:
+ criteria ([Criterion]): A list of Criterion objects.
+ Returns:
+ str: A json string that represents the list of Criterion objects.
+ """
+ return json.dumps([crit.model_dump() for crit in criteria], indent=2)
diff --git a/autogen/agentchat/contrib/agent_eval/critic_agent.py b/autogen/agentchat/contrib/agent_eval/critic_agent.py
new file mode 100644
index 00000000000..2f5e5598ba6
--- /dev/null
+++ b/autogen/agentchat/contrib/agent_eval/critic_agent.py
@@ -0,0 +1,41 @@
+from typing import Optional
+
+from autogen.agentchat.conversable_agent import ConversableAgent
+
+
+class CriticAgent(ConversableAgent):
+ """
+ An agent for creating list of criteria for evaluating the utility of a given task.
+ """
+
+ DEFAULT_SYSTEM_MESSAGE = """You are a helpful assistant. You suggest criteria for evaluating different tasks. They should be distinguishable, quantifiable and not redundant.
+ Convert the evaluation criteria into a list where each item is a criteria which consists of the following dictionary as follows
+ {"name": name of the criterion, "description": criteria description , "accepted_values": possible accepted inputs for this key}
+ Make sure "accepted_values" include the acceptable inputs for each key that are fine-grained and preferably multi-graded levels and "description" includes the criterion description.
+ Output just the criteria string you have created, no code.
+ """
+
+ DEFAULT_DESCRIPTION = "An AI agent for creating list criteria for evaluating the utility of a given task."
+
+ def __init__(
+ self,
+ name="critic",
+ system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE,
+ description: Optional[str] = DEFAULT_DESCRIPTION,
+ **kwargs,
+ ):
+ """
+ Args:
+ name (str): agent name.
+ system_message (str): system message for the ChatCompletion inference.
+ Please override this attribute if you want to reprogram the agent.
+ description (str): The description of the agent.
+ **kwargs (dict): Please refer to other kwargs in
+ [ConversableAgent](../../conversable_agent#__init__).
+ """
+ super().__init__(
+ name=name,
+ system_message=system_message,
+ description=description,
+ **kwargs,
+ )
diff --git a/autogen/agentchat/contrib/agent_eval/quantifier_agent.py b/autogen/agentchat/contrib/agent_eval/quantifier_agent.py
new file mode 100644
index 00000000000..02a8f650fab
--- /dev/null
+++ b/autogen/agentchat/contrib/agent_eval/quantifier_agent.py
@@ -0,0 +1,36 @@
+from typing import Optional
+
+from autogen.agentchat.conversable_agent import ConversableAgent
+
+
+class QuantifierAgent(ConversableAgent):
+ """
+ An agent for quantifying the performance of a system using the provided criteria.
+ """
+
+ DEFAULT_SYSTEM_MESSAGE = """"You are a helpful assistant. You quantify the output of different tasks based on the given criteria.
+ The criterion is given in a json list format where each element is a distinct criteria.
+ The each element is a dictionary as follows {"name": name of the criterion, "description": criteria description , "accepted_values": possible accepted inputs for this key}
+ You are going to quantify each of the crieria for a given task based on the task description.
+ Return a dictionary where the keys are the criteria and the values are the assessed performance based on accepted values for each criteria.
+ Return only the dictionary, no code."""
+
+ DEFAULT_DESCRIPTION = "An AI agent for quantifing the performance of a system using the provided criteria."
+
+ def __init__(
+ self,
+ name="quantifier",
+ system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE,
+ description: Optional[str] = DEFAULT_DESCRIPTION,
+ **kwargs,
+ ):
+ """
+ Args:
+ name (str): agent name.
+ system_message (str): system message for the ChatCompletion inference.
+ Please override this attribute if you want to reprogram the agent.
+ description (str): The description of the agent.
+ **kwargs (dict): Please refer to other kwargs in
+ [ConversableAgent](../../conversable_agent#__init__).
+ """
+ super().__init__(name=name, system_message=system_message, description=description, **kwargs)
diff --git a/autogen/agentchat/contrib/agent_eval/subcritic_agent.py b/autogen/agentchat/contrib/agent_eval/subcritic_agent.py
new file mode 100755
index 00000000000..fa994ee7bda
--- /dev/null
+++ b/autogen/agentchat/contrib/agent_eval/subcritic_agent.py
@@ -0,0 +1,42 @@
+from typing import Optional
+
+from autogen.agentchat.conversable_agent import ConversableAgent
+
+
+class SubCriticAgent(ConversableAgent):
+ """
+ An agent for creating subcriteria from a given list of criteria for evaluating the utility of a given task.
+ """
+
+ DEFAULT_SYSTEM_MESSAGE = """You are a helpful assistant to the critic agent. You suggest sub criteria for evaluating different tasks based on the criteria provided by the critic agent (if you feel it is needed).
+ They should be distinguishable, quantifiable, and related to the overall theme of the critic's provided criteria.
+ You operate by taking in the description of the criteria. You then create a new key called sub criteria where you provide the sub criteria for the given criteria.
+ The value of the sub_criteria is a dictionary where the keys are the subcriteria and each value is as follows {"description": sub criteria description , "accepted_values": possible accepted inputs for this key}
+ Do this for each criteria provided by the critic (removing the criteria's accepted values). "accepted_values" include the acceptable inputs for each key that are fine-grained and preferably multi-graded levels. "description" includes the criterion description.
+ Once you have created the sub criteria for the given criteria, you return the json (make sure to include the contents of the critic's dictionary in the final dictionary as well).
+ Make sure to return a valid json and no code"""
+
+ DEFAULT_DESCRIPTION = "An AI agent for creating subcriteria from a given list of criteria."
+
+ def __init__(
+ self,
+ name="subcritic",
+ system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE,
+ description: Optional[str] = DEFAULT_DESCRIPTION,
+ **kwargs,
+ ):
+ """
+ Args:
+ name (str): agent name.
+ system_message (str): system message for the ChatCompletion inference.
+ Please override this attribute if you want to reprogram the agent.
+ description (str): The description of the agent.
+ **kwargs (dict): Please refer to other kwargs in
+ [ConversableAgent](../../conversable_agent#__init__).
+ """
+ super().__init__(
+ name=name,
+ system_message=system_message,
+ description=description,
+ **kwargs,
+ )
diff --git a/autogen/agentchat/contrib/agent_eval/task.py b/autogen/agentchat/contrib/agent_eval/task.py
new file mode 100644
index 00000000000..9f96fbf79e2
--- /dev/null
+++ b/autogen/agentchat/contrib/agent_eval/task.py
@@ -0,0 +1,37 @@
+import json
+
+from pydantic import BaseModel
+
+
+class Task(BaseModel):
+ """
+ Class representing a task for agent completion, includes example agent execution for criteria generation.
+ """
+
+ name: str
+ description: str
+ successful_response: str
+ failed_response: str
+
+ def get_sys_message(self):
+ return f"""Task: {self.name}.
+ Task description: {self.description}
+ Task successful example: {self.successful_response}
+ Task failed example: {self.failed_response}
+ """
+
+ @staticmethod
+ def parse_json_str(task: str):
+ """
+ Create a Task object from a json object.
+ Args:
+ json_data (dict): A dictionary that represents the task.
+ Returns:
+ Task: A Task object that represents the json task information.
+ """
+ json_data = json.loads(task)
+ name = json_data.get("name")
+ description = json_data.get("description")
+ successful_response = json_data.get("successful_response")
+ failed_response = json_data.get("failed_response")
+ return Task(name, description, successful_response, failed_response)
diff --git a/autogen/agentchat/contrib/capabilities/context_handling.py b/autogen/agentchat/contrib/capabilities/context_handling.py
deleted file mode 100644
index 173811842eb..00000000000
--- a/autogen/agentchat/contrib/capabilities/context_handling.py
+++ /dev/null
@@ -1,138 +0,0 @@
-import sys
-from typing import Dict, List, Optional
-from warnings import warn
-
-import tiktoken
-from termcolor import colored
-
-from autogen import ConversableAgent, token_count_utils
-
-warn(
- "Context handling with TransformChatHistory is deprecated. "
- "Please use TransformMessages from autogen/agentchat/contrib/capabilities/transform_messages.py instead.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class TransformChatHistory:
- """
- An agent's chat history with other agents is a common context that it uses to generate a reply.
- This capability allows the agent to transform its chat history prior to using it to generate a reply.
- It does not permanently modify the chat history, but rather processes it on every invocation.
-
- This capability class enables various strategies to transform chat history, such as:
- - Truncate messages: Truncate each message to first maximum number of tokens.
- - Limit number of messages: Truncate the chat history to a maximum number of (recent) messages.
- - Limit number of tokens: Truncate the chat history to number of recent N messages that fit in
- maximum number of tokens.
- Note that the system message, because of its special significance, is always kept as is.
-
- The three strategies can be combined. For example, when each of these parameters are specified
- they are used in the following order:
- 1. First truncate messages to a maximum number of tokens
- 2. Second, it limits the number of message to keep
- 3. Third, it limits the total number of tokens in the chat history
-
- When adding this capability to an agent, the following are modified:
- - A hook is added to the hookable method `process_all_messages_before_reply` to transform the
- received messages for possible truncation.
- Not modifying the stored message history.
- """
-
- def __init__(
- self,
- *,
- max_tokens_per_message: Optional[int] = None,
- max_messages: Optional[int] = None,
- max_tokens: Optional[int] = None,
- ):
- """
- Args:
- max_tokens_per_message (Optional[int]): Maximum number of tokens to keep in each message.
- max_messages (Optional[int]): Maximum number of messages to keep in the context.
- max_tokens (Optional[int]): Maximum number of tokens to keep in the context.
- """
- self.max_tokens_per_message = max_tokens_per_message if max_tokens_per_message else sys.maxsize
- self.max_messages = max_messages if max_messages else sys.maxsize
- self.max_tokens = max_tokens if max_tokens else sys.maxsize
-
- def add_to_agent(self, agent: ConversableAgent):
- """
- Adds TransformChatHistory capability to the given agent.
- """
- agent.register_hook(hookable_method="process_all_messages_before_reply", hook=self._transform_messages)
-
- def _transform_messages(self, messages: List[Dict]) -> List[Dict]:
- """
- Args:
- messages: List of messages to process.
-
- Returns:
- List of messages with the first system message and the last max_messages messages,
- ensuring each message does not exceed max_tokens_per_message.
- """
- temp_messages = messages.copy()
- processed_messages = []
- system_message = None
- processed_messages_tokens = 0
-
- if messages[0]["role"] == "system":
- system_message = messages[0].copy()
- temp_messages.pop(0)
-
- total_tokens = sum(
- token_count_utils.count_token(msg["content"]) for msg in temp_messages
- ) # Calculate tokens for all messages
-
- # Truncate each message's content to a maximum token limit of each message
-
- # Process recent messages first
- for msg in reversed(temp_messages[-self.max_messages :]):
- msg["content"] = truncate_str_to_tokens(msg["content"], self.max_tokens_per_message)
- msg_tokens = token_count_utils.count_token(msg["content"])
- if processed_messages_tokens + msg_tokens > self.max_tokens:
- break
- # append the message to the beginning of the list to preserve order
- processed_messages = [msg] + processed_messages
- processed_messages_tokens += msg_tokens
- if system_message:
- processed_messages.insert(0, system_message)
- # Optionally, log the number of truncated messages and tokens if needed
- num_truncated = len(messages) - len(processed_messages)
-
- if num_truncated > 0 or total_tokens > processed_messages_tokens:
- print(
- colored(
- f"Truncated {num_truncated} messages. Reduced from {len(messages)} to {len(processed_messages)}.",
- "yellow",
- )
- )
- print(
- colored(
- f"Truncated {total_tokens - processed_messages_tokens} tokens. Tokens reduced from {total_tokens} to {processed_messages_tokens}",
- "yellow",
- )
- )
- return processed_messages
-
-
-def truncate_str_to_tokens(text: str, max_tokens: int, model: str = "gpt-3.5-turbo-0613") -> str:
- """Truncate a string so that the number of tokens is less than or equal to max_tokens using tiktoken.
-
- Args:
- text: The string to truncate.
- max_tokens: The maximum number of tokens to keep.
- model: The target OpenAI model for tokenization alignment.
-
- Returns:
- The truncated string.
- """
-
- encoding = tiktoken.encoding_for_model(model) # Get the appropriate tokenizer
-
- encoded_tokens = encoding.encode(text)
- truncated_tokens = encoded_tokens[:max_tokens]
- truncated_text = encoding.decode(truncated_tokens) # Decode back to text
-
- return truncated_text
diff --git a/autogen/agentchat/contrib/capabilities/teachability.py b/autogen/agentchat/contrib/capabilities/teachability.py
index 3a64f061963..596e449ce34 100644
--- a/autogen/agentchat/contrib/capabilities/teachability.py
+++ b/autogen/agentchat/contrib/capabilities/teachability.py
@@ -86,7 +86,7 @@ def prepopulate_db(self):
"""Adds a few arbitrary memos to the DB."""
self.memo_store.prepopulate()
- def process_last_received_message(self, text):
+ def process_last_received_message(self, text: Union[Dict, str]):
"""
Appends any relevant memos to the message text, and stores any apparent teachings in new memos.
Uses TextAnalyzerAgent to make decisions about memo storage and retrieval.
@@ -103,7 +103,7 @@ def process_last_received_message(self, text):
# Return the (possibly) expanded message text.
return expanded_text
- def _consider_memo_storage(self, comment):
+ def _consider_memo_storage(self, comment: Union[Dict, str]):
"""Decides whether to store something from one user comment in the DB."""
memo_added = False
@@ -161,7 +161,7 @@ def _consider_memo_storage(self, comment):
# Yes. Save them to disk.
self.memo_store._save_memos()
- def _consider_memo_retrieval(self, comment):
+ def _consider_memo_retrieval(self, comment: Union[Dict, str]):
"""Decides whether to retrieve memos from the DB, and add them to the chat context."""
# First, use the comment directly as the lookup key.
@@ -195,7 +195,7 @@ def _consider_memo_retrieval(self, comment):
# Append the memos to the text of the last message.
return comment + self._concatenate_memo_texts(memo_list)
- def _retrieve_relevant_memos(self, input_text):
+ def _retrieve_relevant_memos(self, input_text: str) -> list:
"""Returns semantically related memos from the DB."""
memo_list = self.memo_store.get_related_memos(
input_text, n_results=self.max_num_retrievals, threshold=self.recall_threshold
@@ -213,7 +213,7 @@ def _retrieve_relevant_memos(self, input_text):
memo_list = [memo[1] for memo in memo_list]
return memo_list
- def _concatenate_memo_texts(self, memo_list):
+ def _concatenate_memo_texts(self, memo_list: list) -> str:
"""Concatenates the memo texts into a single string for inclusion in the chat context."""
memo_texts = ""
if len(memo_list) > 0:
@@ -225,7 +225,7 @@ def _concatenate_memo_texts(self, memo_list):
memo_texts = memo_texts + "\n" + info
return memo_texts
- def _analyze(self, text_to_analyze, analysis_instructions):
+ def _analyze(self, text_to_analyze: Union[Dict, str], analysis_instructions: Union[Dict, str]):
"""Asks TextAnalyzerAgent to analyze the given text according to specific instructions."""
self.analyzer.reset() # Clear the analyzer's list of messages.
self.teachable_agent.send(
@@ -246,10 +246,16 @@ class MemoStore:
Vector embeddings are currently supplied by Chroma's default Sentence Transformers.
"""
- def __init__(self, verbosity, reset, path_to_db_dir):
+ def __init__(
+ self,
+ verbosity: Optional[int] = 0,
+ reset: Optional[bool] = False,
+ path_to_db_dir: Optional[str] = "./tmp/teachable_agent_db",
+ ):
"""
Args:
- verbosity (Optional, int): 1 to print memory operations, 0 to omit them. 3+ to print memo lists.
+ - reset (Optional, bool): True to clear the DB before starting. Default False.
- path_to_db_dir (Optional, str): path to the directory where the DB is stored.
"""
self.verbosity = verbosity
@@ -304,7 +310,7 @@ def reset_db(self):
self.uid_text_dict = {}
self._save_memos()
- def add_input_output_pair(self, input_text, output_text):
+ def add_input_output_pair(self, input_text: str, output_text: str):
"""Adds an input-output pair to the vector DB."""
self.last_memo_id += 1
self.vec_db.add(documents=[input_text], ids=[str(self.last_memo_id)])
@@ -321,7 +327,7 @@ def add_input_output_pair(self, input_text, output_text):
if self.verbosity >= 3:
self.list_memos()
- def get_nearest_memo(self, query_text):
+ def get_nearest_memo(self, query_text: str):
"""Retrieves the nearest memo to the given query text."""
results = self.vec_db.query(query_texts=[query_text], n_results=1)
uid, input_text, distance = results["ids"][0][0], results["documents"][0][0], results["distances"][0][0]
@@ -338,7 +344,7 @@ def get_nearest_memo(self, query_text):
)
return input_text, output_text, distance
- def get_related_memos(self, query_text, n_results, threshold):
+ def get_related_memos(self, query_text: str, n_results: int, threshold: Union[int, float]):
"""Retrieves memos that are related to the given query text within the specified distance threshold."""
if n_results > len(self.uid_text_dict):
n_results = len(self.uid_text_dict)
diff --git a/autogen/agentchat/contrib/capabilities/text_compressors.py b/autogen/agentchat/contrib/capabilities/text_compressors.py
new file mode 100644
index 00000000000..78554bdc935
--- /dev/null
+++ b/autogen/agentchat/contrib/capabilities/text_compressors.py
@@ -0,0 +1,68 @@
+from typing import Any, Dict, Optional, Protocol
+
+IMPORT_ERROR: Optional[Exception] = None
+try:
+ import llmlingua
+except ImportError:
+ IMPORT_ERROR = ImportError(
+ "LLMLingua is not installed. Please install it with `pip install pyautogen[long-context]`"
+ )
+ PromptCompressor = object
+else:
+ from llmlingua import PromptCompressor
+
+
+class TextCompressor(Protocol):
+ """Defines a protocol for text compression to optimize agent interactions."""
+
+ def compress_text(self, text: str, **compression_params) -> Dict[str, Any]:
+ """This method takes a string as input and returns a dictionary containing the compressed text and other
+ relevant information. The compressed text should be stored under the 'compressed_text' key in the dictionary.
+ To calculate the number of saved tokens, the dictionary should include 'origin_tokens' and 'compressed_tokens' keys.
+ """
+ ...
+
+
+class LLMLingua:
+ """Compresses text messages using LLMLingua for improved efficiency in processing and response generation.
+
+ NOTE: The effectiveness of compression and the resultant token savings can vary based on the content of the messages
+ and the specific configurations used for the PromptCompressor.
+ """
+
+ def __init__(
+ self,
+ prompt_compressor_kwargs: Dict = dict(
+ model_name="microsoft/llmlingua-2-bert-base-multilingual-cased-meetingbank",
+ use_llmlingua2=True,
+ device_map="cpu",
+ ),
+ structured_compression: bool = False,
+ ) -> None:
+ """
+ Args:
+ prompt_compressor_kwargs (dict): A dictionary of keyword arguments for the PromptCompressor. Defaults to a
+ dictionary with model_name set to "microsoft/llmlingua-2-bert-base-multilingual-cased-meetingbank",
+ use_llmlingua2 set to True, and device_map set to "cpu".
+ structured_compression (bool): A flag indicating whether to use structured compression. If True, the
+ structured_compress_prompt method of the PromptCompressor is used. Otherwise, the compress_prompt method
+ is used. Defaults to False.
+ dictionary.
+
+ Raises:
+ ImportError: If the llmlingua library is not installed.
+ """
+ if IMPORT_ERROR:
+ raise IMPORT_ERROR
+
+ self._prompt_compressor = PromptCompressor(**prompt_compressor_kwargs)
+
+ assert isinstance(self._prompt_compressor, llmlingua.PromptCompressor)
+ self._compression_method = (
+ self._prompt_compressor.structured_compress_prompt
+ if structured_compression
+ else self._prompt_compressor.compress_prompt
+ )
+
+ def compress_text(self, text: str, **compression_params) -> Dict[str, Any]:
+ return self._compression_method([text], **compression_params)
diff --git a/autogen/agentchat/contrib/capabilities/transform_messages.py b/autogen/agentchat/contrib/capabilities/transform_messages.py
index 46c8d4e0a4d..1ce219bdadf 100644
--- a/autogen/agentchat/contrib/capabilities/transform_messages.py
+++ b/autogen/agentchat/contrib/capabilities/transform_messages.py
@@ -1,10 +1,8 @@
import copy
from typing import Dict, List
-from termcolor import colored
-
-from autogen import ConversableAgent
-
+from ....formatting_utils import colored
+from ...conversable_agent import ConversableAgent
from .transforms import MessageTransform
@@ -43,12 +41,14 @@ class TransformMessages:
```
"""
- def __init__(self, *, transforms: List[MessageTransform] = []):
+ def __init__(self, *, transforms: List[MessageTransform] = [], verbose: bool = True):
"""
Args:
transforms: A list of message transformations to apply.
+ verbose: Whether to print logs of each transformation or not.
"""
self._transforms = transforms
+ self._verbose = verbose
def add_to_agent(self, agent: ConversableAgent):
"""Adds the message transformations capability to the specified ConversableAgent.
@@ -61,31 +61,26 @@ def add_to_agent(self, agent: ConversableAgent):
agent.register_hook(hookable_method="process_all_messages_before_reply", hook=self._transform_messages)
def _transform_messages(self, messages: List[Dict]) -> List[Dict]:
- temp_messages = copy.deepcopy(messages)
+ post_transform_messages = copy.deepcopy(messages)
system_message = None
if messages[0]["role"] == "system":
system_message = copy.deepcopy(messages[0])
- temp_messages.pop(0)
+ post_transform_messages.pop(0)
for transform in self._transforms:
- temp_messages = transform.apply_transform(temp_messages)
-
- if system_message:
- temp_messages.insert(0, system_message)
-
- self._print_stats(messages, temp_messages)
+ # deepcopy in case pre_transform_messages will later be used for logs printing
+ pre_transform_messages = (
+ copy.deepcopy(post_transform_messages) if self._verbose else post_transform_messages
+ )
+ post_transform_messages = transform.apply_transform(pre_transform_messages)
- return temp_messages
+ if self._verbose:
+ logs_str, had_effect = transform.get_logs(pre_transform_messages, post_transform_messages)
+ if had_effect:
+ print(colored(logs_str, "yellow"))
- def _print_stats(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]):
- pre_transform_messages_len = len(pre_transform_messages)
- post_transform_messages_len = len(post_transform_messages)
+ if system_message:
+ post_transform_messages.insert(0, system_message)
- if pre_transform_messages_len < post_transform_messages_len:
- print(
- colored(
- f"Number of messages reduced from {pre_transform_messages_len} to {post_transform_messages_len}.",
- "yellow",
- )
- )
+ return post_transform_messages
diff --git a/autogen/agentchat/contrib/capabilities/transforms.py b/autogen/agentchat/contrib/capabilities/transforms.py
index cc4faace3f1..d9ad365b91b 100644
--- a/autogen/agentchat/contrib/capabilities/transforms.py
+++ b/autogen/agentchat/contrib/capabilities/transforms.py
@@ -1,10 +1,16 @@
+import copy
import sys
-from typing import Any, Dict, List, Optional, Protocol, Union
+from typing import Any, Dict, List, Optional, Protocol, Tuple, Union
import tiktoken
from termcolor import colored
from autogen import token_count_utils
+from autogen.cache import AbstractCache, Cache
+from autogen.types import MessageContentType
+
+from . import transforms_util
+from .text_compressors import LLMLingua, TextCompressor
class MessageTransform(Protocol):
@@ -25,6 +31,20 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
"""
...
+ def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]:
+ """Creates the string including the logs of the transformation
+
+ Alongside the string, it returns a boolean indicating whether the transformation had an effect or not.
+
+ Args:
+ pre_transform_messages: A list of dictionaries representing messages before the transformation.
+ post_transform_messages: A list of dictionaries representig messages after the transformation.
+
+ Returns:
+ A tuple with a string with the logs and a flag indicating whether the transformation had an effect or not.
+ """
+ ...
+
class MessageHistoryLimiter:
"""Limits the number of messages considered by an agent for response generation.
@@ -33,14 +53,16 @@ class MessageHistoryLimiter:
It trims the conversation history by removing older messages, retaining only the most recent messages.
"""
- def __init__(self, max_messages: Optional[int] = None):
+ def __init__(self, max_messages: Optional[int] = None, keep_first_message: bool = False):
"""
Args:
- max_messages (None or int): Maximum number of messages to keep in the context.
- Must be greater than 0 if not None.
+ max_messages Optional[int]: Maximum number of messages to keep in the context. Must be greater than 0 if not None.
+ keep_first_message bool: Whether to keep the original first message in the conversation history.
+ Defaults to False.
"""
self._validate_max_messages(max_messages)
self._max_messages = max_messages
+ self._keep_first_message = keep_first_message
def apply_transform(self, messages: List[Dict]) -> List[Dict]:
"""Truncates the conversation history to the specified maximum number of messages.
@@ -55,10 +77,44 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
Returns:
List[Dict]: A new list containing the most recent messages up to the specified maximum.
"""
- if self._max_messages is None:
+
+ if self._max_messages is None or len(messages) <= self._max_messages:
return messages
- return messages[-self._max_messages :]
+ truncated_messages = []
+ remaining_count = self._max_messages
+
+ # Start with the first message if we need to keep it
+ if self._keep_first_message:
+ truncated_messages = [messages[0]]
+ remaining_count -= 1
+
+ # Loop through messages in reverse
+ for i in range(len(messages) - 1, 0, -1):
+ if remaining_count > 1:
+ truncated_messages.insert(1 if self._keep_first_message else 0, messages[i])
+ if remaining_count == 1:
+ # If there's only 1 slot left and it's a 'tools' message, ignore it.
+ if messages[i].get("role") != "tool":
+ truncated_messages.insert(1, messages[i])
+
+ remaining_count -= 1
+ if remaining_count == 0:
+ break
+
+ return truncated_messages
+
+ def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]:
+ pre_transform_messages_len = len(pre_transform_messages)
+ post_transform_messages_len = len(post_transform_messages)
+
+ if post_transform_messages_len < pre_transform_messages_len:
+ logs_str = (
+ f"Removed {pre_transform_messages_len - post_transform_messages_len} messages. "
+ f"Number of messages reduced from {pre_transform_messages_len} to {post_transform_messages_len}."
+ )
+ return logs_str, True
+ return "No messages were removed.", False
def _validate_max_messages(self, max_messages: Optional[int]):
if max_messages is not None and max_messages < 1:
@@ -81,13 +137,15 @@ class MessageTokenLimiter:
The truncation process follows these steps in order:
- 1. Messages are processed in reverse order (newest to oldest).
- 2. Individual messages are truncated based on max_tokens_per_message. For multimodal messages containing both text
+ 1. The minimum tokens threshold (`min_tokens`) is checked (0 by default). If the total number of tokens in messages
+ are less than this threshold, then the messages are returned as is. In other case, the following process is applied.
+ 2. Messages are processed in reverse order (newest to oldest).
+ 3. Individual messages are truncated based on max_tokens_per_message. For multimodal messages containing both text
and other types of content, only the text content is truncated.
- 3. The overall conversation history is truncated based on the max_tokens limit. Once the accumulated token count
+ 4. The overall conversation history is truncated based on the max_tokens limit. Once the accumulated token count
exceeds this limit, the current message being processed get truncated to meet the total token count and any
remaining messages get discarded.
- 4. The truncated conversation history is reconstructed by prepending the messages to a new list to preserve the
+ 5. The truncated conversation history is reconstructed by prepending the messages to a new list to preserve the
original message order.
"""
@@ -95,7 +153,10 @@ def __init__(
self,
max_tokens_per_message: Optional[int] = None,
max_tokens: Optional[int] = None,
+ min_tokens: Optional[int] = None,
model: str = "gpt-3.5-turbo-0613",
+ filter_dict: Optional[Dict] = None,
+ exclude_filter: bool = True,
):
"""
Args:
@@ -103,11 +164,20 @@ def __init__(
Must be greater than or equal to 0 if not None.
max_tokens (Optional[int]): Maximum number of tokens to keep in the chat history.
Must be greater than or equal to 0 if not None.
+ min_tokens (Optional[int]): Minimum number of tokens in messages to apply the transformation.
+ Must be greater than or equal to 0 if not None.
model (str): The target OpenAI model for tokenization alignment.
+ filter_dict (None or dict): A dictionary to filter out messages that you want/don't want to compress.
+ If None, no filters will be applied.
+ exclude_filter (bool): If exclude filter is True (the default value), messages that match the filter will be
+ excluded from token truncation. If False, messages that match the filter will be truncated.
"""
self._model = model
self._max_tokens_per_message = self._validate_max_tokens(max_tokens_per_message)
self._max_tokens = self._validate_max_tokens(max_tokens)
+ self._min_tokens = self._validate_min_tokens(min_tokens, max_tokens)
+ self._filter_dict = filter_dict
+ self._exclude_filter = exclude_filter
def apply_transform(self, messages: List[Dict]) -> List[Dict]:
"""Applies token truncation to the conversation history.
@@ -120,20 +190,25 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
"""
assert self._max_tokens_per_message is not None
assert self._max_tokens is not None
+ assert self._min_tokens is not None
+
+ # if the total number of tokens in the messages is less than the min_tokens, return the messages as is
+ if not transforms_util.min_tokens_reached(messages, self._min_tokens):
+ return messages
- temp_messages = messages.copy()
+ temp_messages = copy.deepcopy(messages)
processed_messages = []
processed_messages_tokens = 0
- # calculate tokens for all messages
- total_tokens = sum(
- _count_tokens(msg["content"]) for msg in temp_messages if isinstance(msg.get("content"), (str, list))
- )
-
for msg in reversed(temp_messages):
# Some messages may not have content.
- if not isinstance(msg.get("content"), (str, list)):
+ if not transforms_util.is_content_right_type(msg.get("content")):
+ processed_messages.insert(0, msg)
+ continue
+
+ if not transforms_util.should_transform_message(msg, self._filter_dict, self._exclude_filter):
processed_messages.insert(0, msg)
+ processed_messages_tokens += transforms_util.count_text_tokens(msg["content"])
continue
expected_tokens_remained = self._max_tokens - processed_messages_tokens - self._max_tokens_per_message
@@ -148,22 +223,30 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
break
msg["content"] = self._truncate_str_to_tokens(msg["content"], self._max_tokens_per_message)
- msg_tokens = _count_tokens(msg["content"])
+ msg_tokens = transforms_util.count_text_tokens(msg["content"])
# prepend the message to the list to preserve order
processed_messages_tokens += msg_tokens
processed_messages.insert(0, msg)
- if total_tokens > processed_messages_tokens:
- print(
- colored(
- f"Truncated {total_tokens - processed_messages_tokens} tokens. Tokens reduced from {total_tokens} to {processed_messages_tokens}",
- "yellow",
- )
- )
-
return processed_messages
+ def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]:
+ pre_transform_messages_tokens = sum(
+ transforms_util.count_text_tokens(msg["content"]) for msg in pre_transform_messages if "content" in msg
+ )
+ post_transform_messages_tokens = sum(
+ transforms_util.count_text_tokens(msg["content"]) for msg in post_transform_messages if "content" in msg
+ )
+
+ if post_transform_messages_tokens < pre_transform_messages_tokens:
+ logs_str = (
+ f"Truncated {pre_transform_messages_tokens - post_transform_messages_tokens} tokens. "
+ f"Number of tokens reduced from {pre_transform_messages_tokens} to {post_transform_messages_tokens}"
+ )
+ return logs_str, True
+ return "No tokens were truncated.", False
+
def _truncate_str_to_tokens(self, contents: Union[str, List], n_tokens: int) -> Union[str, List]:
if isinstance(contents, str):
return self._truncate_tokens(contents, n_tokens)
@@ -214,12 +297,243 @@ def _validate_max_tokens(self, max_tokens: Optional[int] = None) -> Optional[int
return max_tokens if max_tokens is not None else sys.maxsize
+ def _validate_min_tokens(self, min_tokens: Optional[int], max_tokens: Optional[int]) -> int:
+ if min_tokens is None:
+ return 0
+ if min_tokens < 0:
+ raise ValueError("min_tokens must be None or greater than or equal to 0.")
+ if max_tokens is not None and min_tokens > max_tokens:
+ raise ValueError("min_tokens must not be more than max_tokens.")
+ return min_tokens
+
+
+class TextMessageCompressor:
+ """A transform for compressing text messages in a conversation history.
+
+ It uses a specified text compression method to reduce the token count of messages, which can lead to more efficient
+ processing and response generation by downstream models.
+ """
+
+ def __init__(
+ self,
+ text_compressor: Optional[TextCompressor] = None,
+ min_tokens: Optional[int] = None,
+ compression_params: Dict = dict(),
+ cache: Optional[AbstractCache] = Cache.disk(),
+ filter_dict: Optional[Dict] = None,
+ exclude_filter: bool = True,
+ ):
+ """
+ Args:
+ text_compressor (TextCompressor or None): An instance of a class that implements the TextCompressor
+ protocol. If None, it defaults to LLMLingua.
+ min_tokens (int or None): Minimum number of tokens in messages to apply the transformation. Must be greater
+ than or equal to 0 if not None. If None, no threshold-based compression is applied.
+ compression_args (dict): A dictionary of arguments for the compression method. Defaults to an empty
+ dictionary.
+ cache (None or AbstractCache): The cache client to use to store and retrieve previously compressed messages.
+ If None, no caching will be used.
+ filter_dict (None or dict): A dictionary to filter out messages that you want/don't want to compress.
+ If None, no filters will be applied.
+ exclude_filter (bool): If exclude filter is True (the default value), messages that match the filter will be
+ excluded from compression. If False, messages that match the filter will be compressed.
+ """
+
+ if text_compressor is None:
+ text_compressor = LLMLingua()
+
+ self._validate_min_tokens(min_tokens)
-def _count_tokens(content: Union[str, List[Dict[str, Any]]]) -> int:
- token_count = 0
- if isinstance(content, str):
- token_count = token_count_utils.count_token(content)
- elif isinstance(content, list):
+ self._text_compressor = text_compressor
+ self._min_tokens = min_tokens
+ self._compression_args = compression_params
+ self._filter_dict = filter_dict
+ self._exclude_filter = exclude_filter
+ self._cache = cache
+
+ # Optimizing savings calculations to optimize log generation
+ self._recent_tokens_savings = 0
+
+ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
+ """Applies compression to messages in a conversation history based on the specified configuration.
+
+ The function processes each message according to the `compression_args` and `min_tokens` settings, applying
+ the specified compression configuration and returning a new list of messages with reduced token counts
+ where possible.
+
+ Args:
+ messages (List[Dict]): A list of message dictionaries to be compressed.
+
+ Returns:
+ List[Dict]: A list of dictionaries with the message content compressed according to the configured
+ method and scope.
+ """
+ # Make sure there is at least one message
+ if not messages:
+ return messages
+
+ # if the total number of tokens in the messages is less than the min_tokens, return the messages as is
+ if not transforms_util.min_tokens_reached(messages, self._min_tokens):
+ return messages
+
+ total_savings = 0
+ processed_messages = messages.copy()
+ for message in processed_messages:
+ # Some messages may not have content.
+ if not transforms_util.is_content_right_type(message.get("content")):
+ continue
+
+ if not transforms_util.should_transform_message(message, self._filter_dict, self._exclude_filter):
+ continue
+
+ if transforms_util.is_content_text_empty(message["content"]):
+ continue
+
+ cache_key = transforms_util.cache_key(message["content"], self._min_tokens)
+ cached_content = transforms_util.cache_content_get(self._cache, cache_key)
+ if cached_content is not None:
+ message["content"], savings = cached_content
+ else:
+ message["content"], savings = self._compress(message["content"])
+
+ transforms_util.cache_content_set(self._cache, cache_key, message["content"], savings)
+
+ assert isinstance(savings, int)
+ total_savings += savings
+
+ self._recent_tokens_savings = total_savings
+ return processed_messages
+
+ def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]:
+ if self._recent_tokens_savings > 0:
+ return f"{self._recent_tokens_savings} tokens saved with text compression.", True
+ else:
+ return "No tokens saved with text compression.", False
+
+ def _compress(self, content: MessageContentType) -> Tuple[MessageContentType, int]:
+ """Compresses the given text or multimodal content using the specified compression method."""
+ if isinstance(content, str):
+ return self._compress_text(content)
+ elif isinstance(content, list):
+ return self._compress_multimodal(content)
+ else:
+ return content, 0
+
+ def _compress_multimodal(self, content: MessageContentType) -> Tuple[MessageContentType, int]:
+ tokens_saved = 0
for item in content:
- token_count += _count_tokens(item.get("text", ""))
- return token_count
+ if isinstance(item, dict) and "text" in item:
+ item["text"], savings = self._compress_text(item["text"])
+ tokens_saved += savings
+
+ elif isinstance(item, str):
+ item, savings = self._compress_text(item)
+ tokens_saved += savings
+
+ return content, tokens_saved
+
+ def _compress_text(self, text: str) -> Tuple[str, int]:
+ """Compresses the given text using the specified compression method."""
+ compressed_text = self._text_compressor.compress_text(text, **self._compression_args)
+
+ savings = 0
+ if "origin_tokens" in compressed_text and "compressed_tokens" in compressed_text:
+ savings = compressed_text["origin_tokens"] - compressed_text["compressed_tokens"]
+
+ return compressed_text["compressed_prompt"], savings
+
+ def _validate_min_tokens(self, min_tokens: Optional[int]):
+ if min_tokens is not None and min_tokens <= 0:
+ raise ValueError("min_tokens must be greater than 0 or None")
+
+
+class TextMessageContentName:
+ """A transform for including the agent's name in the content of a message."""
+
+ def __init__(
+ self,
+ position: str = "start",
+ format_string: str = "{name}:\n",
+ deduplicate: bool = True,
+ filter_dict: Optional[Dict] = None,
+ exclude_filter: bool = True,
+ ):
+ """
+ Args:
+ position (str): The position to add the name to the content. The possible options are 'start' or 'end'. Defaults to 'start'.
+ format_string (str): The f-string to format the message name with. Use '{name}' as a placeholder for the agent's name. Defaults to '{name}:\n' and must contain '{name}'.
+ deduplicate (bool): Whether to deduplicate the formatted string so it doesn't appear twice (sometimes the LLM will add it to new messages itself). Defaults to True.
+ filter_dict (None or dict): A dictionary to filter out messages that you want/don't want to compress.
+ If None, no filters will be applied.
+ exclude_filter (bool): If exclude filter is True (the default value), messages that match the filter will be
+ excluded from compression. If False, messages that match the filter will be compressed.
+ """
+
+ assert isinstance(position, str) and position is not None
+ assert position in ["start", "end"]
+ assert isinstance(format_string, str) and format_string is not None
+ assert "{name}" in format_string
+ assert isinstance(deduplicate, bool) and deduplicate is not None
+
+ self._position = position
+ self._format_string = format_string
+ self._deduplicate = deduplicate
+ self._filter_dict = filter_dict
+ self._exclude_filter = exclude_filter
+
+ # Track the number of messages changed for logging
+ self._messages_changed = 0
+
+ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
+ """Applies the name change to the message based on the position and format string.
+
+ Args:
+ messages (List[Dict]): A list of message dictionaries.
+
+ Returns:
+ List[Dict]: A list of dictionaries with the message content updated with names.
+ """
+ # Make sure there is at least one message
+ if not messages:
+ return messages
+
+ messages_changed = 0
+ processed_messages = copy.deepcopy(messages)
+ for message in processed_messages:
+ # Some messages may not have content.
+ if not transforms_util.is_content_right_type(
+ message.get("content")
+ ) or not transforms_util.is_content_right_type(message.get("name")):
+ continue
+
+ if not transforms_util.should_transform_message(message, self._filter_dict, self._exclude_filter):
+ continue
+
+ if transforms_util.is_content_text_empty(message["content"]) or transforms_util.is_content_text_empty(
+ message["name"]
+ ):
+ continue
+
+ # Get and format the name in the content
+ content = message["content"]
+ formatted_name = self._format_string.format(name=message["name"])
+
+ if self._position == "start":
+ if not self._deduplicate or not content.startswith(formatted_name):
+ message["content"] = f"{formatted_name}{content}"
+
+ messages_changed += 1
+ else:
+ if not self._deduplicate or not content.endswith(formatted_name):
+ message["content"] = f"{content}{formatted_name}"
+
+ messages_changed += 1
+
+ self._messages_changed = messages_changed
+ return processed_messages
+
+ def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]:
+ if self._messages_changed > 0:
+ return f"{self._messages_changed} message(s) changed to incorporate name.", True
+ else:
+ return "No messages changed to incorporate name.", False
diff --git a/autogen/agentchat/contrib/capabilities/transforms_util.py b/autogen/agentchat/contrib/capabilities/transforms_util.py
new file mode 100644
index 00000000000..8678dec654c
--- /dev/null
+++ b/autogen/agentchat/contrib/capabilities/transforms_util.py
@@ -0,0 +1,114 @@
+from typing import Any, Dict, Hashable, List, Optional, Tuple
+
+from autogen import token_count_utils
+from autogen.cache.abstract_cache_base import AbstractCache
+from autogen.oai.openai_utils import filter_config
+from autogen.types import MessageContentType
+
+
+def cache_key(content: MessageContentType, *args: Hashable) -> str:
+ """Calculates the cache key for the given message content and any other hashable args.
+
+ Args:
+ content (MessageContentType): The message content to calculate the cache key for.
+ *args: Any additional hashable args to include in the cache key.
+ """
+ str_keys = [str(key) for key in (content, *args)]
+ return "".join(str_keys)
+
+
+def cache_content_get(cache: Optional[AbstractCache], key: str) -> Optional[Tuple[MessageContentType, ...]]:
+ """Retrieves cachedd content from the cache.
+
+ Args:
+ cache (None or AbstractCache): The cache to retrieve the content from. If None, the cache is ignored.
+ key (str): The key to retrieve the content from.
+ """
+ if cache:
+ cached_value = cache.get(key)
+ if cached_value:
+ return cached_value
+
+
+def cache_content_set(cache: Optional[AbstractCache], key: str, content: MessageContentType, *extra_values):
+ """Sets content into the cache.
+
+ Args:
+ cache (None or AbstractCache): The cache to set the content into. If None, the cache is ignored.
+ key (str): The key to set the content into.
+ content (MessageContentType): The message content to set into the cache.
+ *extra_values: Additional values to be passed to the cache.
+ """
+ if cache:
+ cache_value = (content, *extra_values)
+ cache.set(key, cache_value)
+
+
+def min_tokens_reached(messages: List[Dict], min_tokens: Optional[int]) -> bool:
+ """Returns True if the total number of tokens in the messages is greater than or equal to the specified value.
+
+ Args:
+ messages (List[Dict]): A list of messages to check.
+ """
+ if not min_tokens:
+ return True
+
+ messages_tokens = sum(count_text_tokens(msg["content"]) for msg in messages if "content" in msg)
+ return messages_tokens >= min_tokens
+
+
+def count_text_tokens(content: MessageContentType) -> int:
+ """Calculates the number of text tokens in the given message content.
+
+ Args:
+ content (MessageContentType): The message content to calculate the number of text tokens for.
+ """
+ token_count = 0
+ if isinstance(content, str):
+ token_count = token_count_utils.count_token(content)
+ elif isinstance(content, list):
+ for item in content:
+ if isinstance(item, str):
+ token_count += token_count_utils.count_token(item)
+ else:
+ token_count += count_text_tokens(item.get("text", ""))
+ return token_count
+
+
+def is_content_right_type(content: Any) -> bool:
+ """A helper function to check if the passed in content is of the right type."""
+ return isinstance(content, (str, list))
+
+
+def is_content_text_empty(content: MessageContentType) -> bool:
+ """Checks if the content of the message does not contain any text.
+
+ Args:
+ content (MessageContentType): The message content to check.
+ """
+ if isinstance(content, str):
+ return content == ""
+ elif isinstance(content, list):
+ texts = []
+ for item in content:
+ if isinstance(item, str):
+ texts.append(item)
+ elif isinstance(item, dict):
+ texts.append(item.get("text", ""))
+ return not any(texts)
+ else:
+ return True
+
+
+def should_transform_message(message: Dict[str, Any], filter_dict: Optional[Dict[str, Any]], exclude: bool) -> bool:
+ """Validates whether the transform should be applied according to the filter dictionary.
+
+ Args:
+ message (Dict[str, Any]): The message to validate.
+ filter_dict (None or Dict[str, Any]): The filter dictionary to validate against. If None, the transform is always applied.
+ exclude (bool): Whether to exclude messages that match the filter dictionary.
+ """
+ if not filter_dict:
+ return True
+
+ return len(filter_config([message], filter_dict, exclude)) > 0
diff --git a/autogen/agentchat/contrib/compressible_agent.py b/autogen/agentchat/contrib/compressible_agent.py
deleted file mode 100644
index 9c4e78af852..00000000000
--- a/autogen/agentchat/contrib/compressible_agent.py
+++ /dev/null
@@ -1,437 +0,0 @@
-import asyncio
-import copy
-import inspect
-import logging
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union
-from warnings import warn
-
-from autogen import Agent, ConversableAgent, OpenAIWrapper
-from autogen.token_count_utils import count_token, get_max_token_limit, num_tokens_from_functions
-
-from ...formatting_utils import colored
-
-logger = logging.getLogger(__name__)
-
-warn(
- "Context handling with CompressibleAgent is deprecated. "
- "Please use `TransformMessages`, documentation can be found at https://microsoft.github.io/autogen/docs/reference/agentchat/contrib/capabilities/transform_messages",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class CompressibleAgent(ConversableAgent):
- """CompressibleAgent agent. While this agent retains all the default functionalities of the `AssistantAgent`,
- it also provides the added feature of compression when activated through the `compress_config` setting.
-
- `compress_config` is set to False by default, making this agent equivalent to the `AssistantAgent`.
- This agent does not work well in a GroupChat: The compressed messages will not be sent to all the agents in the group.
- The default system message is the same as AssistantAgent.
- `human_input_mode` is default to "NEVER"
- and `code_execution_config` is default to False.
- This agent doesn't execute code or function call by default.
- """
-
- DEFAULT_SYSTEM_MESSAGE = """You are a helpful AI assistant.
-Solve tasks using your coding and language skills.
-In the following cases, suggest python code (in a python coding block) or shell script (in a sh coding block) for the user to execute.
- 1. When you need to collect info, use the code to output the info you need, for example, browse or search the web, download/read a file, print the content of a webpage or a file, get the current date/time, check the operating system. After sufficient info is printed and the task is ready to be solved based on your language skill, you can solve the task by yourself.
- 2. When you need to perform some task with code, use the code to perform the task and output the result. Finish the task smartly.
-Solve the task step by step if you need to. If a plan is not provided, explain your plan first. Be clear which step uses code, and which step uses your language skill.
-When using code, you must indicate the script type in the code block. The user cannot provide any other feedback or perform any other action beyond executing the code you suggest. The user can't modify your code. So do not suggest incomplete code which requires users to modify. Don't use a code block if it's not intended to be executed by the user.
-If you want the user to save the code in a file before executing it, put # filename: inside the code block as the first line. Don't include multiple code blocks in one response. Do not ask users to copy and paste the result. Instead, use 'print' function for the output when relevant. Check the execution result returned by the user.
-If the result indicates there is an error, fix the error and output the code again. Suggest the full code instead of partial code or code changes. If the error can't be fixed or if the task is not solved even after the code is executed successfully, analyze the problem, revisit your assumption, collect additional info you need, and think of a different approach to try.
-When you find an answer, verify the answer carefully. Include verifiable evidence in your response if possible.
-Reply "TERMINATE" in the end when everything is done.
- """
- DEFAULT_COMPRESS_CONFIG = {
- "mode": "TERMINATE",
- "compress_function": None,
- "trigger_count": 0.7,
- "async": False,
- "broadcast": True,
- "verbose": False,
- "leave_last_n": 2,
- }
-
- def __init__(
- self,
- name: str,
- system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE,
- is_termination_msg: Optional[Callable[[Dict], bool]] = None,
- max_consecutive_auto_reply: Optional[int] = None,
- human_input_mode: Optional[str] = "NEVER",
- function_map: Optional[Dict[str, Callable]] = None,
- code_execution_config: Optional[Union[Dict, bool]] = False,
- llm_config: Optional[Union[Dict, bool]] = None,
- default_auto_reply: Optional[Union[str, Dict, None]] = "",
- compress_config: Optional[Dict] = False,
- description: Optional[str] = None,
- **kwargs,
- ):
- """
- Args:
- name (str): agent name.
- system_message (str): system message for the ChatCompletion inference.
- Please override this attribute if you want to reprogram the agent.
- llm_config (dict): llm inference configuration.
- Note: you must set `model` in llm_config. It will be used to compute the token count.
- Please refer to [OpenAIWrapper.create](/docs/reference/oai/client#create)
- for available options.
- is_termination_msg (function): a function that takes a message in the form of a dictionary
- and returns a boolean value indicating if this received message is a termination message.
- The dict can contain the following keys: "content", "role", "name", "function_call".
- max_consecutive_auto_reply (int): the maximum number of consecutive auto replies.
- default to None (no limit provided, class attribute MAX_CONSECUTIVE_AUTO_REPLY will be used as the limit in this case).
- The limit only plays a role when human_input_mode is not "ALWAYS".
- compress_config (dict or True/False): config for compression before oai_reply. Default to False.
- You should contain the following keys:
- - "mode" (Optional, str, default to "TERMINATE"): Choose from ["COMPRESS", "TERMINATE", "CUSTOMIZED"].
- 1. `TERMINATE`: terminate the conversation ONLY when token count exceeds the max limit of current model. `trigger_count` is NOT used in this mode.
- 2. `COMPRESS`: compress the messages when the token count exceeds the limit.
- 3. `CUSTOMIZED`: pass in a customized function to compress the messages.
- - "compress_function" (Optional, callable, default to None): Must be provided when mode is "CUSTOMIZED".
- The function should takes a list of messages and returns a tuple of (is_compress_success: bool, compressed_messages: List[Dict]).
- - "trigger_count" (Optional, float, int, default to 0.7): the threshold to trigger compression.
- If a float between (0, 1], it is the percentage of token used. if a int, it is the number of tokens used.
- - "async" (Optional, bool, default to False): whether to compress asynchronously.
- - "broadcast" (Optional, bool, default to True): whether to update the compressed message history to sender.
- - "verbose" (Optional, bool, default to False): Whether to print the content before and after compression. Used when mode="COMPRESS".
- - "leave_last_n" (Optional, int, default to 0): If provided, the last n messages will not be compressed. Used when mode="COMPRESS".
- description (str): a short description of the agent. This description is used by other agents
- (e.g. the GroupChatManager) to decide when to call upon this agent. (Default: system_message)
- **kwargs (dict): Please refer to other kwargs in
- [ConversableAgent](../conversable_agent#__init__).
- """
- super().__init__(
- name=name,
- system_message=system_message,
- is_termination_msg=is_termination_msg,
- max_consecutive_auto_reply=max_consecutive_auto_reply,
- human_input_mode=human_input_mode,
- function_map=function_map,
- code_execution_config=code_execution_config,
- llm_config=llm_config,
- default_auto_reply=default_auto_reply,
- description=description,
- **kwargs,
- )
-
- self._set_compress_config(compress_config)
-
- # create a separate client for compression.
- if llm_config is False:
- self.llm_compress_config = False
- self.compress_client = None
- else:
- if "model" not in llm_config:
- raise ValueError("llm_config must contain the 'model' field.")
- self.llm_compress_config = self.llm_config.copy()
- # remove functions
- if "functions" in self.llm_compress_config:
- del self.llm_compress_config["functions"]
- self.compress_client = OpenAIWrapper(**self.llm_compress_config)
-
- self._reply_func_list.clear()
- self.register_reply([Agent, None], ConversableAgent.generate_oai_reply)
- self.register_reply([Agent], CompressibleAgent.on_oai_token_limit) # check token limit
- self.register_reply([Agent, None], ConversableAgent.generate_code_execution_reply)
- self.register_reply([Agent, None], ConversableAgent.generate_function_call_reply)
- self.register_reply([Agent, None], ConversableAgent.check_termination_and_human_reply)
-
- def _set_compress_config(self, compress_config: Optional[Dict] = False):
- if compress_config:
- if compress_config is True:
- compress_config = {}
- if not isinstance(compress_config, dict):
- raise ValueError("compress_config must be a dict or True/False.")
-
- allowed_modes = ["COMPRESS", "TERMINATE", "CUSTOMIZED"]
- if compress_config.get("mode", "TERMINATE") not in allowed_modes:
- raise ValueError(f"Invalid compression mode. Allowed values are: {', '.join(allowed_modes)}")
-
- self.compress_config = self.DEFAULT_COMPRESS_CONFIG.copy()
- self.compress_config.update(compress_config)
-
- if not isinstance(self.compress_config["leave_last_n"], int) or self.compress_config["leave_last_n"] < 0:
- raise ValueError("leave_last_n must be a non-negative integer.")
-
- # convert trigger_count to int, default to 0.7
- trigger_count = self.compress_config["trigger_count"]
- if not (isinstance(trigger_count, int) or isinstance(trigger_count, float)) or trigger_count <= 0:
- raise ValueError("trigger_count must be a positive number.")
- if isinstance(trigger_count, float) and 0 < trigger_count <= 1:
- self.compress_config["trigger_count"] = int(
- trigger_count * get_max_token_limit(self.llm_config["model"])
- )
- trigger_count = self.compress_config["trigger_count"]
- init_count = self._compute_init_token_count()
- if trigger_count < init_count:
- print(
- f"Warning: trigger_count {trigger_count} is less than the initial token count {init_count} (system message + function description if passed), compression will be disabled. Please increase trigger_count if you want to enable compression."
- )
- self.compress_config = False
-
- if self.compress_config["mode"] == "CUSTOMIZED" and self.compress_config["compress_function"] is None:
- raise ValueError("compress_function must be provided when mode is CUSTOMIZED.")
- if self.compress_config["mode"] != "CUSTOMIZED" and self.compress_config["compress_function"] is not None:
- print("Warning: compress_function is provided but mode is not 'CUSTOMIZED'.")
-
- else:
- self.compress_config = False
-
- def generate_reply(
- self,
- messages: Optional[List[Dict]] = None,
- sender: Optional[Agent] = None,
- exclude: Optional[List[Callable]] = None,
- ) -> Union[str, Dict, None]:
- """
-
- Adding to line 202:
- ```
- if messages is not None and messages != self._oai_messages[sender]:
- messages = self._oai_messages[sender]
- ```
- """
- if all((messages is None, sender is None)):
- error_msg = f"Either {messages=} or {sender=} must be provided."
- logger.error(error_msg)
- raise AssertionError(error_msg)
-
- if messages is None:
- messages = self._oai_messages[sender]
-
- for reply_func_tuple in self._reply_func_list:
- reply_func = reply_func_tuple["reply_func"]
- if exclude and reply_func in exclude:
- continue
- if inspect.iscoroutinefunction(reply_func):
- continue
- if self._match_trigger(reply_func_tuple["trigger"], sender):
- final, reply = reply_func(self, messages=messages, sender=sender, config=reply_func_tuple["config"])
- if messages is not None and sender is not None and messages != self._oai_messages[sender]:
- messages = self._oai_messages[sender]
- if final:
- return reply
- return self._default_auto_reply
-
- def _compute_init_token_count(self):
- """Check if the agent is LLM-based and compute the initial token count."""
- if self.llm_config is False:
- return 0
-
- func_count = 0
- if "functions" in self.llm_config:
- func_count = num_tokens_from_functions(self.llm_config["functions"], self.llm_config["model"])
-
- return func_count + count_token(self._oai_system_message, self.llm_config["model"])
-
- def _manage_history_on_token_limit(self, messages, token_used, max_token_allowed, model):
- """Manage the message history with different modes when token limit is reached.
- Return:
- final (bool): whether to terminate the agent.
- compressed_messages (List[Dict]): the compressed messages. None if no compression or compression failed.
- """
- # 1. mode = "TERMINATE", terminate the agent if no token left.
- if self.compress_config["mode"] == "TERMINATE":
- if max_token_allowed - token_used <= 0:
- # Terminate if no token left.
- print(
- colored(
- f'Warning: Terminate Agent "{self.name}" due to no token left for oai reply. max token for {model}: {max_token_allowed}, existing token count: {token_used}',
- "yellow",
- ),
- flush=True,
- )
- return True, None
- return False, None
-
- # if token_used is less than trigger_count, no compression will be used.
- if token_used < self.compress_config["trigger_count"]:
- return False, None
-
- # 2. mode = "COMPRESS" or mode = "CUSTOMIZED", compress the messages
- copied_messages = copy.deepcopy(messages)
- if self.compress_config["mode"] == "COMPRESS":
- _, compress_messages = self.compress_messages(copied_messages)
- elif self.compress_config["mode"] == "CUSTOMIZED":
- _, compress_messages = self.compress_config["compress_function"](copied_messages)
- else:
- raise ValueError(f"Unknown compression mode: {self.compress_config['mode']}")
-
- if compress_messages is not None:
- for i in range(len(compress_messages)):
- compress_messages[i] = self._get_valid_oai_message(compress_messages[i])
- return False, compress_messages
-
- def _get_valid_oai_message(self, message):
- """Convert a message into a valid OpenAI ChatCompletion message."""
- oai_message = {k: message[k] for k in ("content", "function_call", "name", "context", "role") if k in message}
- if "content" not in oai_message:
- if "function_call" in oai_message:
- oai_message["content"] = None # if only function_call is provided, content will be set to None.
- else:
- raise ValueError(
- "Message can't be converted into a valid ChatCompletion message. Either content or function_call must be provided."
- )
- if "function_call" in oai_message:
- oai_message["role"] = "assistant" # only messages with role 'assistant' can have a function call.
- oai_message["function_call"] = dict(oai_message["function_call"])
- return oai_message
-
- def _print_compress_info(self, init_token_count, token_used, token_after_compression):
- to_print = "Token Count (including {} tokens from system msg and function descriptions). Before compression : {} | After: {}".format(
- init_token_count,
- token_used,
- token_after_compression,
- )
- print(colored(to_print, "magenta"), flush=True)
- print("-" * 80, flush=True)
-
- def on_oai_token_limit(
- self,
- messages: Optional[List[Dict]] = None,
- sender: Optional[Agent] = None,
- config: Optional[Any] = None,
- ) -> Tuple[bool, Union[str, Dict, None]]:
- """(Experimental) Compress previous messages when a threshold of tokens is reached.
-
- TODO: async compress
- TODO: maintain a list for old oai messages (messages before compression)
- """
- llm_config = self.llm_config if config is None else config
- if self.compress_config is False:
- return False, None
- if messages is None:
- messages = self._oai_messages[sender]
-
- model = llm_config["model"]
- init_token_count = self._compute_init_token_count()
- token_used = init_token_count + count_token(messages, model)
- final, compressed_messages = self._manage_history_on_token_limit(
- messages, token_used, get_max_token_limit(model), model
- )
-
- # update message history with compressed messages
- if compressed_messages is not None:
- self._print_compress_info(
- init_token_count, token_used, count_token(compressed_messages, model) + init_token_count
- )
- self._oai_messages[sender] = compressed_messages
- if self.compress_config["broadcast"]:
- # update the compressed message history to sender
- sender._oai_messages[self] = copy.deepcopy(compressed_messages)
- # switching the role of the messages for the sender
- for i in range(len(sender._oai_messages[self])):
- cmsg = sender._oai_messages[self][i]
- if "function_call" in cmsg or cmsg["role"] == "user":
- cmsg["role"] = "assistant"
- elif cmsg["role"] == "assistant":
- cmsg["role"] = "user"
- sender._oai_messages[self][i] = cmsg
-
- # successfully compressed, return False, None for generate_oai_reply to be called with the updated messages
- return False, None
- return final, None
-
- def compress_messages(
- self,
- messages: Optional[List[Dict]] = None,
- config: Optional[Any] = None,
- ) -> Tuple[bool, Union[str, Dict, None, List]]:
- """Compress a list of messages into one message.
-
- The first message (the initial prompt) will not be compressed.
- The rest of the messages will be compressed into one message, the model is asked to distinguish the role of each message: USER, ASSISTANT, FUNCTION_CALL, FUNCTION_RETURN.
- Check out the compress_sys_msg.
-
- TODO: model used in compression agent is different from assistant agent: For example, if original model used by is gpt-4; we start compressing at 70% of usage, 70% of 8092 = 5664; and we use gpt 3.5 here max_toke = 4096, it will raise error. choosinng model automatically?
- """
- # 1. use the compression client
- client = self.compress_client if config is None else config
-
- # 2. stop if there is only one message in the list
- leave_last_n = self.compress_config.get("leave_last_n", 0)
- if leave_last_n + 1 >= len(messages):
- logger.warning(
- f"Warning: Compression skipped at trigger count threshold. The first msg and last {leave_last_n} msgs will not be compressed. current msg count: {len(messages)}. Consider raising trigger_count."
- )
- return False, None
-
- # 3. put all history into one, except the first one
- if self.compress_config["verbose"]:
- print(colored("*" * 30 + "Start compressing the following content:" + "*" * 30, "magenta"), flush=True)
-
- compressed_prompt = "Below is the compressed content from the previous conversation, evaluate the process and continue if necessary:\n"
- chat_to_compress = "To be compressed:\n"
-
- for m in messages[1 : len(messages) - leave_last_n]: # 0, 1, 2, 3, 4
- # Handle function role
- if m.get("role") == "function":
- chat_to_compress += f"##FUNCTION_RETURN## (from function \"{m['name']}\"): \n{m['content']}\n"
-
- # If name exists in the message
- elif "name" in m:
- chat_to_compress += f"##{m['name']}({m['role'].upper()})## {m['content']}\n"
-
- # Handle case where content is not None and name is absent
- elif m.get("content"): # This condition will also handle None and empty string
- if compressed_prompt in m["content"]:
- chat_to_compress += m["content"].replace(compressed_prompt, "") + "\n"
- else:
- chat_to_compress += f"##{m['role'].upper()}## {m['content']}\n"
-
- # Handle function_call in the message
- if "function_call" in m:
- function_name = m["function_call"].get("name")
- function_args = m["function_call"].get("arguments")
-
- if not function_name or not function_args:
- chat_to_compress += f"##FUNCTION_CALL## {m['function_call']}\n"
- else:
- chat_to_compress += f"##FUNCTION_CALL## \nName: {function_name}\nArgs: {function_args}\n"
-
- chat_to_compress = [{"role": "user", "content": chat_to_compress}]
-
- if self.compress_config["verbose"]:
- print(chat_to_compress[0]["content"])
-
- # 4. use LLM to compress
- compress_sys_msg = """You are a helpful assistant that will summarize and compress conversation history.
-Rules:
-1. Please summarize each of the message and reserve the exact titles: ##USER##, ##ASSISTANT##, ##FUNCTION_CALL##, ##FUNCTION_RETURN##, ##SYSTEM##, ##()## (e.g. ##Bob(ASSISTANT)##).
-2. Try to compress the content but reserve important information (a link, a specific number, etc.).
-3. Use words to summarize the code blocks or functions calls (##FUNCTION_CALL##) and their goals. For code blocks, please use ##CODE## to mark it.
-4. For returns from functions (##FUNCTION_RETURN##) or returns from code execution: summarize the content and indicate the status of the return (e.g. success, error, etc.).
-"""
- try:
- response = client.create(
- context=None,
- messages=[{"role": "system", "content": compress_sys_msg}] + chat_to_compress,
- )
- except Exception as e:
- print(colored(f"Failed to compress the content due to {e}", "red"), flush=True)
- return False, None
-
- compressed_message = self.client.extract_text_or_completion_object(response)[0]
- assert isinstance(compressed_message, str), f"compressed_message should be a string: {compressed_message}"
- if self.compress_config["verbose"]:
- print(
- colored("*" * 30 + "Content after compressing:" + "*" * 30, "magenta"),
- flush=True,
- )
- print(compressed_message, colored("\n" + "*" * 80, "magenta"))
-
- # 5. add compressed message to the first message and return
- return (
- True,
- [
- messages[0],
- {
- "content": compressed_prompt + compressed_message,
- "role": "system",
- },
- ]
- + messages[len(messages) - leave_last_n :],
- )
diff --git a/autogen/agentchat/contrib/gpt_assistant_agent.py b/autogen/agentchat/contrib/gpt_assistant_agent.py
index 253d4d18e2e..0dcad27b16d 100644
--- a/autogen/agentchat/contrib/gpt_assistant_agent.py
+++ b/autogen/agentchat/contrib/gpt_assistant_agent.py
@@ -5,12 +5,11 @@
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple, Union
-import openai
-
from autogen import OpenAIWrapper
from autogen.agentchat.agent import Agent
from autogen.agentchat.assistant_agent import AssistantAgent, ConversableAgent
-from autogen.oai.openai_utils import retrieve_assistants_by_name
+from autogen.oai.openai_utils import create_gpt_assistant, retrieve_assistants_by_name, update_gpt_assistant
+from autogen.runtime_logging import log_new_agent, logging_enabled
logger = logging.getLogger(__name__)
@@ -50,7 +49,8 @@ def __init__(
- check_every_ms: check thread run status interval
- tools: Give Assistants access to OpenAI-hosted tools like Code Interpreter and Knowledge Retrieval,
or build your own tools using Function calling. ref https://platform.openai.com/docs/assistants/tools
- - file_ids: files used by retrieval in run
+ - file_ids: (Deprecated) files used by retrieval in run. It is Deprecated, use tool_resources instead. https://platform.openai.com/docs/assistants/migration/what-has-changed.
+ - tool_resources: A set of resources that are used by the assistant's tools. The resources are specific to the type of tool.
overwrite_instructions (bool): whether to overwrite the instructions of an existing assistant. This parameter is in effect only when assistant_id is specified in llm_config.
overwrite_tools (bool): whether to overwrite the tools of an existing assistant. This parameter is in effect only when assistant_id is specified in llm_config.
kwargs (dict): Additional configuration options for the agent.
@@ -64,6 +64,8 @@ def __init__(
super().__init__(
name=name, system_message=instructions, human_input_mode="NEVER", llm_config=openai_client_cfg, **kwargs
)
+ if logging_enabled():
+ log_new_agent(self, locals())
# GPTAssistantAgent's azure_deployment param may cause NotFoundError (404) in client.beta.assistants.list()
# See: https://github.com/microsoft/autogen/pull/1721
@@ -90,7 +92,6 @@ def __init__(
candidate_assistants,
instructions,
openai_assistant_cfg.get("tools", []),
- openai_assistant_cfg.get("file_ids", []),
)
if len(candidate_assistants) == 0:
@@ -101,12 +102,12 @@ def __init__(
"No instructions were provided for new assistant. Using default instructions from AssistantAgent.DEFAULT_SYSTEM_MESSAGE."
)
instructions = AssistantAgent.DEFAULT_SYSTEM_MESSAGE
- self._openai_assistant = self._openai_client.beta.assistants.create(
+ self._openai_assistant = create_gpt_assistant(
+ self._openai_client,
name=name,
instructions=instructions,
- tools=openai_assistant_cfg.get("tools", []),
model=model_name,
- file_ids=openai_assistant_cfg.get("file_ids", []),
+ assistant_config=openai_assistant_cfg,
)
else:
logger.warning(
@@ -127,9 +128,12 @@ def __init__(
logger.warning(
"overwrite_instructions is True. Provided instructions will be used and will modify the assistant in the API"
)
- self._openai_assistant = self._openai_client.beta.assistants.update(
+ self._openai_assistant = update_gpt_assistant(
+ self._openai_client,
assistant_id=openai_assistant_id,
- instructions=instructions,
+ assistant_config={
+ "instructions": instructions,
+ },
)
else:
logger.warning(
@@ -154,18 +158,23 @@ def __init__(
logger.warning(
"overwrite_tools is True. Provided tools will be used and will modify the assistant in the API"
)
- self._openai_assistant = self._openai_client.beta.assistants.update(
+ self._openai_assistant = update_gpt_assistant(
+ self._openai_client,
assistant_id=openai_assistant_id,
- tools=openai_assistant_cfg.get("tools", []),
+ assistant_config={
+ "tools": specified_tools,
+ "tool_resources": openai_assistant_cfg.get("tool_resources", None),
+ },
)
else:
# Tools are specified but overwrite_tools is False; do not update the assistant's tools
logger.warning("overwrite_tools is False. Using existing tools from assistant API.")
+ self.update_system_message(self._openai_assistant.instructions)
# lazily create threads
self._openai_threads = {}
self._unread_index = defaultdict(int)
- self.register_reply(Agent, GPTAssistantAgent._invoke_assistant, position=2)
+ self.register_reply([Agent, None], GPTAssistantAgent._invoke_assistant, position=2)
def _invoke_assistant(
self,
@@ -198,6 +207,8 @@ def _invoke_assistant(
assistant_thread = self._openai_threads[sender]
# Process each unread message
for message in pending_messages:
+ if message["content"].strip() == "":
+ continue
self._openai_client.beta.threads.messages.create(
thread_id=assistant_thread.id,
content=message["content"],
@@ -426,22 +437,23 @@ def delete_assistant(self):
logger.warning("Permanently deleting assistant...")
self._openai_client.beta.assistants.delete(self.assistant_id)
- def find_matching_assistant(self, candidate_assistants, instructions, tools, file_ids):
+ def find_matching_assistant(self, candidate_assistants, instructions, tools):
"""
Find the matching assistant from a list of candidate assistants.
- Filter out candidates with the same name but different instructions, file IDs, and function names.
- TODO: implement accurate match based on assistant metadata fields.
+ Filter out candidates with the same name but different instructions, and function names.
"""
matching_assistants = []
# Preprocess the required tools for faster comparison
- required_tool_types = set(tool.get("type") for tool in tools)
+ required_tool_types = set(
+ "file_search" if tool.get("type") in ["retrieval", "file_search"] else tool.get("type") for tool in tools
+ )
+
required_function_names = set(
tool.get("function", {}).get("name")
for tool in tools
- if tool.get("type") not in ["code_interpreter", "retrieval"]
+ if tool.get("type") not in ["code_interpreter", "retrieval", "file_search"]
)
- required_file_ids = set(file_ids) # Convert file_ids to a set for unordered comparison
for assistant in candidate_assistants:
# Check if instructions are similar
@@ -454,11 +466,12 @@ def find_matching_assistant(self, candidate_assistants, instructions, tools, fil
continue
# Preprocess the assistant's tools
- assistant_tool_types = set(tool.type for tool in assistant.tools)
+ assistant_tool_types = set(
+ "file_search" if tool.type in ["retrieval", "file_search"] else tool.type for tool in assistant.tools
+ )
assistant_function_names = set(tool.function.name for tool in assistant.tools if hasattr(tool, "function"))
- assistant_file_ids = set(getattr(assistant, "file_ids", [])) # Convert to set for comparison
- # Check if the tool types, function names, and file IDs match
+ # Check if the tool types, function names match
if required_tool_types != assistant_tool_types or required_function_names != assistant_function_names:
logger.warning(
"tools not match, skip assistant(%s): tools %s, functions %s",
@@ -467,9 +480,6 @@ def find_matching_assistant(self, candidate_assistants, instructions, tools, fil
assistant_function_names,
)
continue
- if required_file_ids != assistant_file_ids:
- logger.warning("file_ids not match, skip assistant(%s): %s", assistant.id, assistant_file_ids)
- continue
# Append assistant to matching list if all conditions are met
matching_assistants.append(assistant)
@@ -496,7 +506,7 @@ def _process_assistant_config(self, llm_config, assistant_config):
# Move the assistant related configurations to assistant_config
# It's important to keep forward compatibility
- assistant_config_items = ["assistant_id", "tools", "file_ids", "check_every_ms"]
+ assistant_config_items = ["assistant_id", "tools", "file_ids", "tool_resources", "check_every_ms"]
for item in assistant_config_items:
if openai_client_cfg.get(item) is not None and openai_assistant_cfg.get(item) is None:
openai_assistant_cfg[item] = openai_client_cfg[item]
diff --git a/autogen/agentchat/contrib/llamaindex_conversable_agent.py b/autogen/agentchat/contrib/llamaindex_conversable_agent.py
new file mode 100644
index 00000000000..dbf6f274ae8
--- /dev/null
+++ b/autogen/agentchat/contrib/llamaindex_conversable_agent.py
@@ -0,0 +1,108 @@
+from typing import Dict, List, Optional, Tuple, Union
+
+from autogen import OpenAIWrapper
+from autogen.agentchat import Agent, ConversableAgent
+from autogen.agentchat.contrib.vectordb.utils import get_logger
+
+logger = get_logger(__name__)
+
+try:
+ from llama_index.core.agent.runner.base import AgentRunner
+ from llama_index.core.base.llms.types import ChatMessage
+ from llama_index.core.chat_engine.types import AgentChatResponse
+except ImportError as e:
+ logger.fatal("Failed to import llama-index. Try running 'pip install llama-index'")
+ raise e
+
+
+class LLamaIndexConversableAgent(ConversableAgent):
+ def __init__(
+ self,
+ name: str,
+ llama_index_agent: AgentRunner,
+ description: Optional[str] = None,
+ **kwargs,
+ ):
+ """
+ Args:
+ name (str): agent name.
+ llama_index_agent (AgentRunner): llama index agent.
+ Please override this attribute if you want to reprogram the agent.
+ description (str): a short description of the agent. This description is used by other agents
+ (e.g. the GroupChatManager) to decide when to call upon this agent.
+ **kwargs (dict): Please refer to other kwargs in
+ [ConversableAgent](../conversable_agent#__init__).
+ """
+
+ if llama_index_agent is None:
+ raise ValueError("llama_index_agent must be provided")
+
+ if description is None or description.isspace():
+ raise ValueError("description must be provided")
+
+ super().__init__(
+ name,
+ description=description,
+ **kwargs,
+ )
+
+ self._llama_index_agent = llama_index_agent
+
+ # Override the `generate_oai_reply`
+ self.replace_reply_func(ConversableAgent.generate_oai_reply, LLamaIndexConversableAgent._generate_oai_reply)
+
+ self.replace_reply_func(ConversableAgent.a_generate_oai_reply, LLamaIndexConversableAgent._a_generate_oai_reply)
+
+ def _generate_oai_reply(
+ self,
+ messages: Optional[List[Dict]] = None,
+ sender: Optional[Agent] = None,
+ config: Optional[OpenAIWrapper] = None,
+ ) -> Tuple[bool, Union[str, Dict, None]]:
+ """Generate a reply using autogen.oai."""
+ user_message, history = self._extract_message_and_history(messages=messages, sender=sender)
+
+ chatResponse: AgentChatResponse = self._llama_index_agent.chat(message=user_message, chat_history=history)
+
+ extracted_response = chatResponse.response
+
+ return (True, extracted_response)
+
+ async def _a_generate_oai_reply(
+ self,
+ messages: Optional[List[Dict]] = None,
+ sender: Optional[Agent] = None,
+ config: Optional[OpenAIWrapper] = None,
+ ) -> Tuple[bool, Union[str, Dict, None]]:
+ """Generate a reply using autogen.oai."""
+ user_message, history = self._extract_message_and_history(messages=messages, sender=sender)
+
+ chatResponse: AgentChatResponse = await self._llama_index_agent.achat(
+ message=user_message, chat_history=history
+ )
+
+ extracted_response = chatResponse.response
+
+ return (True, extracted_response)
+
+ def _extract_message_and_history(
+ self, messages: Optional[List[Dict]] = None, sender: Optional[Agent] = None
+ ) -> Tuple[str, List[ChatMessage]]:
+ """Extract the message and history from the messages."""
+ if not messages:
+ messages = self._oai_messages[sender]
+
+ if not messages:
+ return "", []
+
+ message = messages[-1].get("content", "")
+
+ history = messages[:-1]
+ history_messages: List[ChatMessage] = []
+ for history_message in history:
+ content = history_message.get("content", "")
+ role = history_message.get("role", "user")
+ if role:
+ if role == "user" or role == "assistant":
+ history_messages.append(ChatMessage(content=content, role=role, additional_kwargs={}))
+ return message, history_messages
diff --git a/autogen/agentchat/contrib/math_user_proxy_agent.py b/autogen/agentchat/contrib/math_user_proxy_agent.py
index d2b6b7cde00..699caeb85b3 100644
--- a/autogen/agentchat/contrib/math_user_proxy_agent.py
+++ b/autogen/agentchat/contrib/math_user_proxy_agent.py
@@ -1,7 +1,7 @@
import os
import re
from time import sleep
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from pydantic import BaseModel, Extra, root_validator
@@ -136,7 +136,7 @@ def __init__(
is_termination_msg: Optional[
Callable[[Dict], bool]
] = _is_termination_msg_mathchat, # terminate if \boxed{} in message
- human_input_mode: Optional[str] = "NEVER", # Fully automated
+ human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "NEVER", # Fully automated
default_auto_reply: Optional[Union[str, Dict, None]] = DEFAULT_REPLY,
max_invalid_q_per_step=3, # a parameter needed in MathChat
**kwargs,
diff --git a/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py
index c68ce809d8d..f1cc6947d50 100644
--- a/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py
+++ b/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py
@@ -1,17 +1,22 @@
-import logging
-from typing import Callable, Dict, List, Optional
+import warnings
+from typing import Callable, Dict, List, Literal, Optional
from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent
+from autogen.agentchat.contrib.vectordb.utils import (
+ chroma_results_to_query_results,
+ filter_results_by_distance,
+ get_logger,
+)
from autogen.retrieve_utils import TEXT_FORMATS, get_files_from_dir, split_files_to_chunks
-logger = logging.getLogger(__name__)
+logger = get_logger(__name__)
try:
import fastembed
from qdrant_client import QdrantClient, models
from qdrant_client.fastembed_common import QueryResponse
except ImportError as e:
- logging.fatal("Failed to import qdrant_client with fastembed. Try running 'pip install qdrant_client[fastembed]'")
+ logger.fatal("Failed to import qdrant_client with fastembed. Try running 'pip install qdrant_client[fastembed]'")
raise e
@@ -19,7 +24,7 @@ class QdrantRetrieveUserProxyAgent(RetrieveUserProxyAgent):
def __init__(
self,
name="RetrieveChatAgent", # default set to RetrieveChatAgent
- human_input_mode: Optional[str] = "ALWAYS",
+ human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "ALWAYS",
is_termination_msg: Optional[Callable[[Dict], bool]] = None,
retrieve_config: Optional[Dict] = None, # config for the retrieve agent
**kwargs,
@@ -89,6 +94,11 @@ def __init__(
**kwargs (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__).
"""
+ warnings.warn(
+ "The QdrantRetrieveUserProxyAgent is deprecated. Please use the RetrieveUserProxyAgent instead, set `vector_db` to `qdrant`.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
super().__init__(name, human_input_mode, is_termination_msg, retrieve_config, **kwargs)
self._client = self._retrieve_config.get("client", QdrantClient(":memory:"))
self._embedding_model = self._retrieve_config.get("embedding_model", "BAAI/bge-small-en-v1.5")
@@ -136,6 +146,11 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
collection_name=self._collection_name,
embedding_model=self._embedding_model,
)
+ results["contents"] = results.pop("documents")
+ results = chroma_results_to_query_results(results, "distances")
+ results = filter_results_by_distance(results, self._distance_threshold)
+
+ self._search_string = search_string
self._results = results
@@ -298,6 +313,7 @@ class QueryResponse(BaseModel, extra="forbid"): # type: ignore
data = {
"ids": [[result.id for result in sublist] for sublist in results],
"documents": [[result.document for result in sublist] for sublist in results],
+ "distances": [[result.score for result in sublist] for sublist in results],
"metadatas": [[result.metadata for result in sublist] for sublist in results],
}
return data
diff --git a/autogen/agentchat/contrib/retrieve_assistant_agent.py b/autogen/agentchat/contrib/retrieve_assistant_agent.py
index 9b5ace200dc..173bc4432e7 100644
--- a/autogen/agentchat/contrib/retrieve_assistant_agent.py
+++ b/autogen/agentchat/contrib/retrieve_assistant_agent.py
@@ -1,3 +1,4 @@
+import warnings
from typing import Any, Dict, List, Optional, Tuple, Union
from autogen.agentchat.agent import Agent
@@ -16,6 +17,11 @@ class RetrieveAssistantAgent(AssistantAgent):
"""
def __init__(self, *args, **kwargs):
+ warnings.warn(
+ "The RetrieveAssistantAgent is deprecated. Please use the AssistantAgent instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
super().__init__(*args, **kwargs)
self.register_reply(Agent, RetrieveAssistantAgent._generate_retrieve_assistant_reply)
diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py
index 34dbe28d098..10b70e0e972 100644
--- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py
+++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py
@@ -1,21 +1,37 @@
+import hashlib
+import os
import re
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+import uuid
+from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from IPython import get_ipython
try:
import chromadb
-except ImportError:
- raise ImportError("Please install dependencies first. `pip install pyautogen[retrievechat]`")
-from autogen import logger
+except ImportError as e:
+ raise ImportError(f"{e}. You can try `pip install pyautogen[retrievechat]`, or install `chromadb` manually.")
from autogen.agentchat import UserProxyAgent
from autogen.agentchat.agent import Agent
+from autogen.agentchat.contrib.vectordb.base import Document, QueryResults, VectorDB, VectorDBFactory
+from autogen.agentchat.contrib.vectordb.utils import (
+ chroma_results_to_query_results,
+ filter_results_by_distance,
+ get_logger,
+)
from autogen.code_utils import extract_code
-from autogen.retrieve_utils import TEXT_FORMATS, create_vector_db_from_dir, query_vector_db
+from autogen.retrieve_utils import (
+ TEXT_FORMATS,
+ create_vector_db_from_dir,
+ get_files_from_dir,
+ query_vector_db,
+ split_files_to_chunks,
+)
from autogen.token_count_utils import count_token
from ...formatting_utils import colored
+logger = get_logger(__name__)
+
PROMPT_DEFAULT = """You're a retrieve augmented chatbot. You answer user's questions based on your own knowledge and the
context provided by the user. You should follow the following steps to answer a question:
Step 1, you estimate the user's intent based on the question and context. The intent can be a code generation task or
@@ -65,6 +81,9 @@
Context is: {input_context}
"""
+HASH_LENGTH = int(os.environ.get("HASH_LENGTH", 8))
+UPDATE_CONTEXT_IN_PROMPT = "you should reply exactly `UPDATE CONTEXT`"
+
class RetrieveUserProxyAgent(UserProxyAgent):
"""(In preview) The Retrieval-Augmented User Proxy retrieves document chunks based on the embedding
@@ -74,7 +93,7 @@ class RetrieveUserProxyAgent(UserProxyAgent):
def __init__(
self,
name="RetrieveChatAgent", # default set to RetrieveChatAgent
- human_input_mode: Optional[str] = "ALWAYS",
+ human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "ALWAYS",
is_termination_msg: Optional[Callable[[Dict], bool]] = None,
retrieve_config: Optional[Dict] = None, # config for the retrieve agent
**kwargs,
@@ -107,9 +126,17 @@ def __init__(
"code", "qa" and "default". System prompt will be different for different tasks.
The default value is `default`, which supports both code and qa, and provides
source information in the end of the response.
+ - `vector_db` (Optional, Union[str, VectorDB]) - the vector db for the retrieve chat.
+ If it's a string, it should be the type of the vector db, such as "chroma"; otherwise,
+ it should be an instance of the VectorDB protocol. Default is "chroma".
+ Set `None` to use the deprecated `client`.
+ - `db_config` (Optional, Dict) - the config for the vector db. Default is `{}`. Please make
+ sure you understand the config for the vector db you are using, otherwise, leave it as `{}`.
+ Only valid when `vector_db` is a string.
- `client` (Optional, chromadb.Client) - the chromadb client. If key not provided, a
default client `chromadb.Client()` will be used. If you want to use other
vector db, extend this class and override the `retrieve_docs` function.
+ *[Deprecated]* use `vector_db` instead.
- `docs_path` (Optional, Union[str, List[str]]) - the path to the docs directory. It
can also be the path to a single file, the url to a single file or a list
of directories, files and urls. Default is None, which works only if the
@@ -123,8 +150,11 @@ def __init__(
By default, "extra_docs" is set to false, starting document IDs from zero.
This poses a risk as new documents might overwrite existing ones, potentially
causing unintended loss or alteration of data in the collection.
- - `collection_name` (Optional, str) - the name of the collection.
- If key not provided, a default name `autogen-docs` will be used.
+ *[Deprecated]* use `new_docs` when use `vector_db` instead of `client`.
+ - `new_docs` (Optional, bool) - when True, only adds new documents to the collection;
+ when False, updates existing documents and adds new ones. Default is True.
+ Document id is used to determine if a document is new or existing. By default, the
+ id is the hash value of the content.
- `model` (Optional, str) - the model to use for the retrieve chat.
If key not provided, a default model `gpt-4` will be used.
- `chunk_token_size` (Optional, int) - the chunk token size for the retrieve chat.
@@ -143,6 +173,7 @@ def __init__(
models can be found at `https://www.sbert.net/docs/pretrained_models.html`.
The default model is a fast model. If you want to use a high performance model,
`all-mpnet-base-v2` is recommended.
+ *[Deprecated]* no need when use `vector_db` instead of `client`.
- `embedding_function` (Optional, Callable) - the embedding function for creating the
vector db. Default is None, SentenceTransformer with the given `embedding_model`
will be used. If you want to use OpenAI, Cohere, HuggingFace or other embedding
@@ -156,10 +187,14 @@ def __init__(
`Update Context` will be triggered.
- `update_context` (Optional, bool) - if False, will not apply `Update Context` for
interactive retrieval. Default is True.
- - `get_or_create` (Optional, bool) - if True, will create/return a collection for the
- retrieve chat. This is the same as that used in chromadb.
- Default is False. Will raise ValueError if the collection already exists and
- get_or_create is False. Will be set to True if docs_path is None.
+ - `collection_name` (Optional, str) - the name of the collection.
+ If key not provided, a default name `autogen-docs` will be used.
+ - `get_or_create` (Optional, bool) - Whether to get the collection if it exists. Default is False.
+ - `overwrite` (Optional, bool) - Whether to overwrite the collection if it exists. Default is False.
+ Case 1. if the collection does not exist, create the collection.
+ Case 2. the collection exists, if overwrite is True, it will overwrite the collection.
+ Case 3. the collection exists and overwrite is False, if get_or_create is True, it will get the collection,
+ otherwise it raise a ValueError.
- `custom_token_count_function` (Optional, Callable) - a custom function to count the
number of tokens in a string.
The function should take (text:str, model:str) as input and return the
@@ -176,6 +211,8 @@ def __init__(
included files and urls will be chunked regardless of their types.
- `recursive` (Optional, bool) - whether to search documents recursively in the
docs_path. Default is True.
+ - `distance_threshold` (Optional, float) - the threshold for the distance score, only
+ distance smaller than it will be returned. Will be ignored if < 0. Default is -1.
`**kwargs` (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__).
@@ -183,6 +220,7 @@ def __init__(
Example of overriding retrieve_docs - If you have set up a customized vector db, and it's
not compatible with chromadb, you can easily plug in it with below code.
+ *[Deprecated]* use `vector_db` instead. You can extend VectorDB and pass it to the agent.
```python
class MyRetrieveUserProxyAgent(RetrieveUserProxyAgent):
def query_vector_db(
@@ -215,9 +253,14 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
self._retrieve_config = {} if retrieve_config is None else retrieve_config
self._task = self._retrieve_config.get("task", "default")
- self._client = self._retrieve_config.get("client", chromadb.Client())
+ self._vector_db = self._retrieve_config.get("vector_db", "chroma")
+ self._db_config = self._retrieve_config.get("db_config", {})
+ self._client = self._retrieve_config.get("client", None)
+ if self._client is None:
+ self._client = chromadb.Client()
self._docs_path = self._retrieve_config.get("docs_path", None)
self._extra_docs = self._retrieve_config.get("extra_docs", False)
+ self._new_docs = self._retrieve_config.get("new_docs", True)
self._collection_name = self._retrieve_config.get("collection_name", "autogen-docs")
if "docs_path" not in self._retrieve_config:
logger.warning(
@@ -236,6 +279,7 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
self.customized_answer_prefix = self._retrieve_config.get("customized_answer_prefix", "").upper()
self.update_context = self._retrieve_config.get("update_context", True)
self._get_or_create = self._retrieve_config.get("get_or_create", False) if self._docs_path is not None else True
+ self._overwrite = self._retrieve_config.get("overwrite", False)
self.custom_token_count_function = self._retrieve_config.get("custom_token_count_function", count_token)
self.custom_text_split_function = self._retrieve_config.get("custom_text_split_function", None)
self._custom_text_types = self._retrieve_config.get("custom_text_types", TEXT_FORMATS)
@@ -244,17 +288,102 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
self._collection = True if self._docs_path is None else False # whether the collection is created
self._ipython = get_ipython()
self._doc_idx = -1 # the index of the current used doc
- self._results = {} # the results of the current query
+ self._results = [] # the results of the current query
self._intermediate_answers = set() # the intermediate answers
self._doc_contents = [] # the contents of the current used doc
self._doc_ids = [] # the ids of the current used doc
self._current_docs_in_context = [] # the ids of the current context sources
self._search_string = "" # the search string used in the current query
+ self._distance_threshold = self._retrieve_config.get("distance_threshold", -1)
# update the termination message function
self._is_termination_msg = (
self._is_termination_msg_retrievechat if is_termination_msg is None else is_termination_msg
)
+ if isinstance(self._vector_db, str):
+ if not isinstance(self._db_config, dict):
+ raise ValueError("`db_config` should be a dictionary.")
+ if "embedding_function" in self._retrieve_config:
+ self._db_config["embedding_function"] = self._embedding_function
+ self._vector_db = VectorDBFactory.create_vector_db(db_type=self._vector_db, **self._db_config)
self.register_reply(Agent, RetrieveUserProxyAgent._generate_retrieve_user_reply, position=2)
+ self.register_hook(
+ hookable_method="process_message_before_send",
+ hook=self._check_update_context_before_send,
+ )
+
+ def _init_db(self):
+ if not self._vector_db:
+ return
+
+ IS_TO_CHUNK = False # whether to chunk the raw files
+ if self._new_docs:
+ IS_TO_CHUNK = True
+ if not self._docs_path:
+ try:
+ self._vector_db.get_collection(self._collection_name)
+ logger.warning(f"`docs_path` is not provided. Use the existing collection `{self._collection_name}`.")
+ self._overwrite = False
+ self._get_or_create = True
+ IS_TO_CHUNK = False
+ except ValueError:
+ raise ValueError(
+ "`docs_path` is not provided. "
+ f"The collection `{self._collection_name}` doesn't exist either. "
+ "Please provide `docs_path` or create the collection first."
+ )
+ elif self._get_or_create and not self._overwrite:
+ try:
+ self._vector_db.get_collection(self._collection_name)
+ logger.info(f"Use the existing collection `{self._collection_name}`.", color="green")
+ except ValueError:
+ IS_TO_CHUNK = True
+ else:
+ IS_TO_CHUNK = True
+
+ self._vector_db.active_collection = self._vector_db.create_collection(
+ self._collection_name, overwrite=self._overwrite, get_or_create=self._get_or_create
+ )
+
+ docs = None
+ if IS_TO_CHUNK:
+ if self.custom_text_split_function is not None:
+ chunks, sources = split_files_to_chunks(
+ get_files_from_dir(self._docs_path, self._custom_text_types, self._recursive),
+ custom_text_split_function=self.custom_text_split_function,
+ )
+ else:
+ chunks, sources = split_files_to_chunks(
+ get_files_from_dir(self._docs_path, self._custom_text_types, self._recursive),
+ self._chunk_token_size,
+ self._chunk_mode,
+ self._must_break_at_empty_line,
+ )
+ logger.info(f"Found {len(chunks)} chunks.")
+
+ if self._new_docs:
+ all_docs_ids = set(
+ [
+ doc["id"]
+ for doc in self._vector_db.get_docs_by_ids(ids=None, collection_name=self._collection_name)
+ ]
+ )
+ else:
+ all_docs_ids = set()
+
+ chunk_ids = (
+ [hashlib.blake2b(chunk.encode("utf-8")).hexdigest()[:HASH_LENGTH] for chunk in chunks]
+ if not self._vector_db.type == "qdrant"
+ else [str(uuid.UUID(hex=hashlib.md5(chunk.encode("utf-8")).hexdigest())) for chunk in chunks]
+ )
+ chunk_ids_set = set(chunk_ids)
+ chunk_ids_set_idx = [chunk_ids.index(hash_value) for hash_value in chunk_ids_set]
+ docs = [
+ Document(id=chunk_ids[idx], content=chunks[idx], metadata=sources[idx])
+ for idx in chunk_ids_set_idx
+ if chunk_ids[idx] not in all_docs_ids
+ ]
+
+ self._vector_db.insert_docs(docs=docs, collection_name=self._collection_name, upsert=True)
def _is_termination_msg_retrievechat(self, message):
"""Check if a message is a termination message.
@@ -275,6 +404,34 @@ def _is_termination_msg_retrievechat(self, message):
update_context_case1, update_context_case2 = self._check_update_context(message)
return not (contain_code or update_context_case1 or update_context_case2)
+ def _check_update_context_before_send(self, sender, message, recipient, silent):
+ if not isinstance(message, (str, dict)):
+ return message
+ elif isinstance(message, dict):
+ msg_text = message.get("content", message)
+ else:
+ msg_text = message
+
+ if "UPDATE CONTEXT" == msg_text.strip().upper():
+ doc_contents = self._get_context(self._results)
+
+ # Always use self.problem as the query text to retrieve docs, but each time we replace the context with the
+ # next similar docs in the retrieved doc results.
+ if not doc_contents:
+ for _tmp_retrieve_count in range(1, 5):
+ self._reset(intermediate=True)
+ self.retrieve_docs(
+ self.problem, self.n_results * (2 * _tmp_retrieve_count + 1), self._search_string
+ )
+ doc_contents = self._get_context(self._results)
+ if doc_contents or self.n_results * (2 * _tmp_retrieve_count + 1) >= len(self._results[0]):
+ break
+ msg_text = self._generate_message(doc_contents, task=self._task)
+
+ if isinstance(message, dict):
+ message["content"] = msg_text
+ return message
+
@staticmethod
def get_max_tokens(model="gpt-3.5-turbo"):
if "32k" in model:
@@ -288,41 +445,42 @@ def get_max_tokens(model="gpt-3.5-turbo"):
def _reset(self, intermediate=False):
self._doc_idx = -1 # the index of the current used doc
- self._results = {} # the results of the current query
+ self._results = [] # the results of the current query
if not intermediate:
self._intermediate_answers = set() # the intermediate answers
self._doc_contents = [] # the contents of the current used doc
self._doc_ids = [] # the ids of the current used doc
- def _get_context(self, results: Dict[str, Union[List[str], List[List[str]]]]):
+ def _get_context(self, results: QueryResults):
doc_contents = ""
self._current_docs_in_context = []
current_tokens = 0
_doc_idx = self._doc_idx
_tmp_retrieve_count = 0
- for idx, doc in enumerate(results["documents"][0]):
+ for idx, doc in enumerate(results[0]):
+ doc = doc[0]
if idx <= _doc_idx:
continue
- if results["ids"][0][idx] in self._doc_ids:
+ if doc["id"] in self._doc_ids:
continue
- _doc_tokens = self.custom_token_count_function(doc, self._model)
+ _doc_tokens = self.custom_token_count_function(doc["content"], self._model)
if _doc_tokens > self._context_max_tokens:
- func_print = f"Skip doc_id {results['ids'][0][idx]} as it is too long to fit in the context."
+ func_print = f"Skip doc_id {doc['id']} as it is too long to fit in the context."
print(colored(func_print, "green"), flush=True)
self._doc_idx = idx
continue
if current_tokens + _doc_tokens > self._context_max_tokens:
break
- func_print = f"Adding doc_id {results['ids'][0][idx]} to context."
+ func_print = f"Adding content of doc {doc['id']} to context."
print(colored(func_print, "green"), flush=True)
current_tokens += _doc_tokens
- doc_contents += doc + "\n"
- _metadatas = results.get("metadatas")
- if isinstance(_metadatas, list) and isinstance(_metadatas[0][idx], dict):
- self._current_docs_in_context.append(results["metadatas"][0][idx].get("source", ""))
+ doc_contents += doc["content"] + "\n"
+ _metadata = doc.get("metadata")
+ if isinstance(_metadata, dict):
+ self._current_docs_in_context.append(_metadata.get("source", ""))
self._doc_idx = idx
- self._doc_ids.append(results["ids"][0][idx])
- self._doc_contents.append(doc)
+ self._doc_ids.append(doc["id"])
+ self._doc_contents.append(doc["content"])
_tmp_retrieve_count += 1
if _tmp_retrieve_count >= self.n_results:
break
@@ -351,7 +509,7 @@ def _check_update_context(self, message):
message = message.get("content", "")
elif not isinstance(message, str):
message = ""
- update_context_case1 = "UPDATE CONTEXT" in message[-20:].upper() or "UPDATE CONTEXT" in message[:20].upper()
+ update_context_case1 = "UPDATE CONTEXT" in message.upper() and UPDATE_CONTEXT_IN_PROMPT not in message
update_context_case2 = self.customized_answer_prefix and self.customized_answer_prefix not in message.upper()
return update_context_case1, update_context_case2
@@ -393,7 +551,7 @@ def _generate_retrieve_user_reply(
self.problem, self.n_results * (2 * _tmp_retrieve_count + 1), self._search_string
)
doc_contents = self._get_context(self._results)
- if doc_contents:
+ if doc_contents or self.n_results * (2 * _tmp_retrieve_count + 1) >= len(self._results[0]):
break
elif update_context_case2:
# Use the current intermediate info as the query text to retrieve docs, and each time we append the top similar
@@ -405,7 +563,7 @@ def _generate_retrieve_user_reply(
)
self._get_context(self._results)
doc_contents = "\n".join(self._doc_contents) # + "\n" + "\n".join(self._intermediate_answers)
- if doc_contents:
+ if doc_contents or self.n_results * (2 * _tmp_retrieve_count + 1) >= len(self._results[0]):
break
self.clear_history()
@@ -416,21 +574,40 @@ def _generate_retrieve_user_reply(
def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ""):
"""Retrieve docs based on the given problem and assign the results to the class property `_results`.
- In case you want to customize the retrieval process, such as using a different vector db whose APIs are not
- compatible with chromadb or filter results with metadata, you can override this function. Just keep the current
- parameters and add your own parameters with default values, and keep the results in below type.
-
- Type of the results: Dict[str, List[List[Any]]], should have keys "ids" and "documents", "ids" for the ids of
- the retrieved docs and "documents" for the contents of the retrieved docs. Any other keys are optional. Refer
- to `chromadb.api.types.QueryResult` as an example.
- ids: List[string]
- documents: List[List[string]]
+ The retrieved docs should be type of `QueryResults` which is a list of tuples containing the document and
+ the distance.
Args:
problem (str): the problem to be solved.
n_results (int): the number of results to be retrieved. Default is 20.
search_string (str): only docs that contain an exact match of this string will be retrieved. Default is "".
+ Not used if the vector_db doesn't support it.
+
+ Returns:
+ None.
"""
+ if isinstance(self._vector_db, VectorDB):
+ if not self._collection or not self._get_or_create:
+ print("Trying to create collection.")
+ self._init_db()
+ self._collection = True
+ self._get_or_create = True
+
+ kwargs = {}
+ if hasattr(self._vector_db, "type") and self._vector_db.type == "chroma":
+ kwargs["where_document"] = {"$contains": search_string} if search_string else None
+ results = self._vector_db.retrieve_docs(
+ queries=[problem],
+ n_results=n_results,
+ collection_name=self._collection_name,
+ distance_threshold=self._distance_threshold,
+ **kwargs,
+ )
+ self._search_string = search_string
+ self._results = results
+ print("VectorDB returns doc_ids: ", [[r[0]["id"] for r in rr] for rr in results])
+ return
+
if not self._collection or not self._get_or_create:
print("Trying to create collection.")
self._client = create_vector_db_from_dir(
@@ -460,9 +637,13 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
embedding_model=self._embedding_model,
embedding_function=self._embedding_function,
)
+ results["contents"] = results.pop("documents")
+ results = chroma_results_to_query_results(results, "distances")
+ results = filter_results_by_distance(results, self._distance_threshold)
+
self._search_string = search_string
self._results = results
- print("doc_ids: ", results["ids"])
+ print("doc_ids: ", [[r[0]["id"] for r in rr] for rr in results])
@staticmethod
def message_generator(sender, recipient, context):
diff --git a/autogen/agentchat/contrib/society_of_mind_agent.py b/autogen/agentchat/contrib/society_of_mind_agent.py
index 97cf6aee1a5..e76768187c9 100644
--- a/autogen/agentchat/contrib/society_of_mind_agent.py
+++ b/autogen/agentchat/contrib/society_of_mind_agent.py
@@ -1,8 +1,6 @@
# ruff: noqa: E722
import copy
-import json
import traceback
-from dataclasses import dataclass
from typing import Callable, Dict, List, Literal, Optional, Tuple, Union
from autogen import Agent, ConversableAgent, GroupChat, GroupChatManager, OpenAIWrapper
@@ -36,11 +34,12 @@ def __init__(
response_preparer: Optional[Union[str, Callable]] = None,
is_termination_msg: Optional[Callable[[Dict], bool]] = None,
max_consecutive_auto_reply: Optional[int] = None,
- human_input_mode: Optional[str] = "TERMINATE",
+ human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "TERMINATE",
function_map: Optional[Dict[str, Callable]] = None,
code_execution_config: Union[Dict, Literal[False]] = False,
llm_config: Optional[Union[Dict, Literal[False]]] = False,
default_auto_reply: Optional[Union[str, Dict, None]] = "",
+ **kwargs,
):
super().__init__(
name=name,
@@ -52,6 +51,7 @@ def __init__(
code_execution_config=code_execution_config,
llm_config=llm_config,
default_auto_reply=default_auto_reply,
+ **kwargs,
)
self.update_chat_manager(chat_manager)
diff --git a/autogen/agentchat/contrib/text_analyzer_agent.py b/autogen/agentchat/contrib/text_analyzer_agent.py
index e917cca574f..62345156a53 100644
--- a/autogen/agentchat/contrib/text_analyzer_agent.py
+++ b/autogen/agentchat/contrib/text_analyzer_agent.py
@@ -1,6 +1,5 @@
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+from typing import Any, Dict, List, Literal, Optional, Tuple, Union
-from autogen import oai
from autogen.agentchat.agent import Agent
from autogen.agentchat.assistant_agent import ConversableAgent
@@ -17,7 +16,7 @@ def __init__(
self,
name="analyzer",
system_message: Optional[str] = system_message,
- human_input_mode: Optional[str] = "NEVER",
+ human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "NEVER",
llm_config: Optional[Union[Dict, bool]] = None,
**kwargs,
):
diff --git a/autogen/agentchat/contrib/vectordb/base.py b/autogen/agentchat/contrib/vectordb/base.py
index 187d0d6acbb..d7d49d6200c 100644
--- a/autogen/agentchat/contrib/vectordb/base.py
+++ b/autogen/agentchat/contrib/vectordb/base.py
@@ -1,4 +1,16 @@
-from typing import Any, List, Mapping, Optional, Protocol, Sequence, Tuple, TypedDict, Union, runtime_checkable
+from typing import (
+ Any,
+ Callable,
+ List,
+ Mapping,
+ Optional,
+ Protocol,
+ Sequence,
+ Tuple,
+ TypedDict,
+ Union,
+ runtime_checkable,
+)
Metadata = Union[Mapping[str, Any], None]
Vector = Union[Sequence[float], Sequence[int]]
@@ -49,6 +61,9 @@ class VectorDB(Protocol):
active_collection: Any = None
type: str = ""
+ embedding_function: Optional[Callable[[List[str]], List[List[float]]]] = (
+ None # embeddings = embedding_function(sentences)
+ )
def create_collection(self, collection_name: str, overwrite: bool = False, get_or_create: bool = True) -> Any:
"""
@@ -171,7 +186,8 @@ def get_docs_by_ids(
ids: List[ItemID] | A list of document ids. If None, will return all the documents. Default is None.
collection_name: str | The name of the collection. Default is None.
include: List[str] | The fields to include. Default is None.
- If None, will include ["metadatas", "documents"], ids will always be included.
+ If None, will include ["metadatas", "documents"], ids will always be included. This may differ
+ depending on the implementation.
kwargs: dict | Additional keyword arguments.
Returns:
@@ -185,7 +201,7 @@ class VectorDBFactory:
Factory class for creating vector databases.
"""
- PREDEFINED_VECTOR_DB = ["chroma"]
+ PREDEFINED_VECTOR_DB = ["chroma", "pgvector", "mongodb", "qdrant"]
@staticmethod
def create_vector_db(db_type: str, **kwargs) -> VectorDB:
@@ -203,6 +219,18 @@ def create_vector_db(db_type: str, **kwargs) -> VectorDB:
from .chromadb import ChromaVectorDB
return ChromaVectorDB(**kwargs)
+ if db_type.lower() in ["pgvector", "pgvectordb"]:
+ from .pgvectordb import PGVectorDB
+
+ return PGVectorDB(**kwargs)
+ if db_type.lower() in ["mdb", "mongodb", "atlas"]:
+ from .mongodb import MongoDBAtlasVectorDB
+
+ return MongoDBAtlasVectorDB(**kwargs)
+ if db_type.lower() in ["qdrant", "qdrantdb"]:
+ from .qdrant import QdrantVectorDB
+
+ return QdrantVectorDB(**kwargs)
else:
raise ValueError(
f"Unsupported vector database type: {db_type}. Valid types are {VectorDBFactory.PREDEFINED_VECTOR_DB}."
diff --git a/autogen/agentchat/contrib/vectordb/chromadb.py b/autogen/agentchat/contrib/vectordb/chromadb.py
index 6e571d58abc..1ed8708409d 100644
--- a/autogen/agentchat/contrib/vectordb/chromadb.py
+++ b/autogen/agentchat/contrib/vectordb/chromadb.py
@@ -24,7 +24,7 @@ class ChromaVectorDB(VectorDB):
"""
def __init__(
- self, *, client=None, path: str = None, embedding_function: Callable = None, metadata: dict = None, **kwargs
+ self, *, client=None, path: str = "tmp/db", embedding_function: Callable = None, metadata: dict = None, **kwargs
) -> None:
"""
Initialize the vector database.
@@ -32,7 +32,7 @@ def __init__(
Args:
client: chromadb.Client | The client object of the vector database. Default is None.
If provided, it will use the client object directly and ignore other arguments.
- path: str | The path to the vector database. Default is None.
+ path: str | The path to the vector database. Default is `tmp/db`. The default was `None` for version <=0.2.24.
embedding_function: Callable | The embedding function used to generate the vector representation
of the documents. Default is None, SentenceTransformerEmbeddingFunction("all-MiniLM-L6-v2") will be used.
metadata: dict | The metadata of the vector database. Default is None. If None, it will use this
@@ -83,7 +83,7 @@ def create_collection(
if self.active_collection and self.active_collection.name == collection_name:
collection = self.active_collection
else:
- collection = self.client.get_collection(collection_name)
+ collection = self.client.get_collection(collection_name, embedding_function=self.embedding_function)
except ValueError:
collection = None
if collection is None:
@@ -126,7 +126,9 @@ def get_collection(self, collection_name: str = None) -> Collection:
)
else:
if not (self.active_collection and self.active_collection.name == collection_name):
- self.active_collection = self.client.get_collection(collection_name)
+ self.active_collection = self.client.get_collection(
+ collection_name, embedding_function=self.embedding_function
+ )
return self.active_collection
def delete_collection(self, collection_name: str) -> None:
diff --git a/autogen/agentchat/contrib/vectordb/mongodb.py b/autogen/agentchat/contrib/vectordb/mongodb.py
new file mode 100644
index 00000000000..2e0580fe826
--- /dev/null
+++ b/autogen/agentchat/contrib/vectordb/mongodb.py
@@ -0,0 +1,553 @@
+from copy import deepcopy
+from time import monotonic, sleep
+from typing import Any, Callable, Dict, Iterable, List, Literal, Mapping, Set, Tuple, Union
+
+import numpy as np
+from pymongo import MongoClient, UpdateOne, errors
+from pymongo.collection import Collection
+from pymongo.driver_info import DriverInfo
+from pymongo.operations import SearchIndexModel
+from sentence_transformers import SentenceTransformer
+
+from .base import Document, ItemID, QueryResults, VectorDB
+from .utils import get_logger
+
+logger = get_logger(__name__)
+
+DEFAULT_INSERT_BATCH_SIZE = 100_000
+_SAMPLE_SENTENCE = ["The weather is lovely today in paradise."]
+_DELAY = 0.5
+
+
+def with_id_rename(docs: Iterable) -> List[Dict[str, Any]]:
+ """Utility changes _id field from Collection into id for Document."""
+ return [{**{k: v for k, v in d.items() if k != "_id"}, "id": d["_id"]} for d in docs]
+
+
+class MongoDBAtlasVectorDB(VectorDB):
+ """
+ A Collection object for MongoDB.
+ """
+
+ def __init__(
+ self,
+ connection_string: str = "",
+ database_name: str = "vector_db",
+ embedding_function: Callable = SentenceTransformer("all-MiniLM-L6-v2").encode,
+ collection_name: str = None,
+ index_name: str = "vector_index",
+ overwrite: bool = False,
+ wait_until_index_ready: float = None,
+ wait_until_document_ready: float = None,
+ ):
+ """
+ Initialize the vector database.
+
+ Args:
+ connection_string: str | The MongoDB connection string to connect to. Default is ''.
+ database_name: str | The name of the database. Default is 'vector_db'.
+ embedding_function: Callable | The embedding function used to generate the vector representation.
+ collection_name: str | The name of the collection to create for this vector database
+ Defaults to None
+ index_name: str | Index name for the vector database, defaults to 'vector_index'
+ overwrite: bool = False
+ wait_until_index_ready: float | None | Blocking call to wait until the
+ database indexes are ready. None, the default, means no wait.
+ wait_until_document_ready: float | None | Blocking call to wait until the
+ database indexes are ready. None, the default, means no wait.
+ """
+ self.embedding_function = embedding_function
+ self.index_name = index_name
+ self._wait_until_index_ready = wait_until_index_ready
+ self._wait_until_document_ready = wait_until_document_ready
+
+ # This will get the model dimension size by computing the embeddings dimensions
+ self.dimensions = self._get_embedding_size()
+
+ try:
+ self.client = MongoClient(connection_string, driver=DriverInfo(name="autogen"))
+ self.client.admin.command("ping")
+ logger.debug("Successfully created MongoClient")
+ except errors.ServerSelectionTimeoutError as err:
+ raise ConnectionError("Could not connect to MongoDB server") from err
+
+ self.db = self.client[database_name]
+ logger.debug(f"Atlas Database name: {self.db.name}")
+ if collection_name:
+ self.active_collection = self.create_collection(collection_name, overwrite)
+ else:
+ self.active_collection = None
+
+ def _is_index_ready(self, collection: Collection, index_name: str):
+ """Check for the index name in the list of available search indexes to see if the
+ specified index is of status READY
+
+ Args:
+ collection (Collection): MongoDB Collection to for the search indexes
+ index_name (str): Vector Search Index name
+
+ Returns:
+ bool : True if the index is present and READY false otherwise
+ """
+ for index in collection.list_search_indexes(index_name):
+ if index["type"] == "vectorSearch" and index["status"] == "READY":
+ return True
+ return False
+
+ def _wait_for_index(self, collection: Collection, index_name: str, action: str = "create"):
+ """Waits for the index action to be completed. Otherwise throws a TimeoutError.
+
+ Timeout set on instantiation.
+ action: "create" or "delete"
+ """
+ assert action in ["create", "delete"], f"{action=} must be create or delete."
+ start = monotonic()
+ while monotonic() - start < self._wait_until_index_ready:
+ if action == "create" and self._is_index_ready(collection, index_name):
+ return
+ elif action == "delete" and len(list(collection.list_search_indexes())) == 0:
+ return
+ sleep(_DELAY)
+
+ raise TimeoutError(f"Index {self.index_name} is not ready!")
+
+ def _wait_for_document(self, collection: Collection, index_name: str, doc: Document):
+ start = monotonic()
+ while monotonic() - start < self._wait_until_document_ready:
+ query_result = _vector_search(
+ embedding_vector=np.array(self.embedding_function(doc["content"])).tolist(),
+ n_results=1,
+ collection=collection,
+ index_name=index_name,
+ )
+ if query_result and query_result[0][0]["_id"] == doc["id"]:
+ return
+ sleep(_DELAY)
+
+ raise TimeoutError(f"Document {self.index_name} is not ready!")
+
+ def _get_embedding_size(self):
+ return len(self.embedding_function(_SAMPLE_SENTENCE)[0])
+
+ def list_collections(self):
+ """
+ List the collections in the vector database.
+
+ Returns:
+ List[str] | The list of collections.
+ """
+ return self.db.list_collection_names()
+
+ def create_collection(
+ self,
+ collection_name: str,
+ overwrite: bool = False,
+ get_or_create: bool = True,
+ ) -> Collection:
+ """
+ Create a collection in the vector database and create a vector search index in the collection.
+
+ Args:
+ collection_name: str | The name of the collection.
+ overwrite: bool | Whether to overwrite the collection if it exists. Default is False.
+ get_or_create: bool | Whether to get or create the collection. Default is True
+ """
+ if overwrite:
+ self.delete_collection(collection_name)
+
+ if collection_name not in self.db.list_collection_names():
+ # Create a new collection
+ coll = self.db.create_collection(collection_name)
+ self.create_index_if_not_exists(index_name=self.index_name, collection=coll)
+ return coll
+
+ if get_or_create:
+ # The collection already exists, return it.
+ coll = self.db[collection_name]
+ self.create_index_if_not_exists(index_name=self.index_name, collection=coll)
+ return coll
+ else:
+ # get_or_create is False and the collection already exists, raise an error.
+ raise ValueError(f"Collection {collection_name} already exists.")
+
+ def create_index_if_not_exists(self, index_name: str = "vector_index", collection: Collection = None) -> None:
+ """
+ Creates a vector search index on the specified collection in MongoDB.
+
+ Args:
+ MONGODB_INDEX (str, optional): The name of the vector search index to create. Defaults to "vector_search_index".
+ collection (Collection, optional): The MongoDB collection to create the index on. Defaults to None.
+ """
+ if not self._is_index_ready(collection, index_name):
+ self.create_vector_search_index(collection, index_name)
+
+ def get_collection(self, collection_name: str = None) -> Collection:
+ """
+ Get the collection from the vector database.
+
+ Args:
+ collection_name: str | The name of the collection. Default is None. If None, return the
+ current active collection.
+
+ Returns:
+ Collection | The collection object.
+ """
+ if collection_name is None:
+ if self.active_collection is None:
+ raise ValueError("No collection is specified.")
+ else:
+ logger.debug(
+ f"No collection is specified. Using current active collection {self.active_collection.name}."
+ )
+ else:
+ self.active_collection = self.db[collection_name]
+
+ return self.active_collection
+
+ def delete_collection(self, collection_name: str) -> None:
+ """
+ Delete the collection from the vector database.
+
+ Args:
+ collection_name: str | The name of the collection.
+ """
+ for index in self.db[collection_name].list_search_indexes():
+ self.db[collection_name].drop_search_index(index["name"])
+ if self._wait_until_index_ready:
+ self._wait_for_index(self.db[collection_name], index["name"], "delete")
+ return self.db[collection_name].drop()
+
+ def create_vector_search_index(
+ self,
+ collection: Collection,
+ index_name: Union[str, None] = "vector_index",
+ similarity: Literal["euclidean", "cosine", "dotProduct"] = "cosine",
+ ) -> None:
+ """Create a vector search index in the collection.
+
+ Args:
+ collection: An existing Collection in the Atlas Database.
+ index_name: Vector Search Index name.
+ similarity: Algorithm used for measuring vector similarity.
+ kwargs: Additional keyword arguments.
+
+ Returns:
+ None
+ """
+ search_index_model = SearchIndexModel(
+ definition={
+ "fields": [
+ {
+ "type": "vector",
+ "numDimensions": self.dimensions,
+ "path": "embedding",
+ "similarity": similarity,
+ },
+ ]
+ },
+ name=index_name,
+ type="vectorSearch",
+ )
+ # Create the search index
+ try:
+ collection.create_search_index(model=search_index_model)
+ if self._wait_until_index_ready:
+ self._wait_for_index(collection, index_name, "create")
+ logger.debug(f"Search index {index_name} created successfully.")
+ except Exception as e:
+ logger.error(
+ f"Error creating search index: {e}. \n"
+ f"Your client must be connected to an Atlas cluster. "
+ f"You may have to manually create a Collection and Search Index "
+ f"if you are on a free/shared cluster."
+ )
+ raise e
+
+ def insert_docs(
+ self,
+ docs: List[Document],
+ collection_name: str = None,
+ upsert: bool = False,
+ batch_size=DEFAULT_INSERT_BATCH_SIZE,
+ **kwargs,
+ ) -> None:
+ """Insert Documents and Vector Embeddings into the collection of the vector database.
+
+ For large numbers of Documents, insertion is performed in batches.
+
+ Args:
+ docs: List[Document] | A list of documents. Each document is a TypedDict `Document`.
+ collection_name: str | The name of the collection. Default is None.
+ upsert: bool | Whether to update the document if it exists. Default is False.
+ batch_size: Number of documents to be inserted in each batch
+ """
+ if not docs:
+ logger.info("No documents to insert.")
+ return
+
+ collection = self.get_collection(collection_name)
+ if upsert:
+ self.update_docs(docs, collection.name, upsert=True)
+ else:
+ # Sanity checking the first document
+ if docs[0].get("content") is None:
+ raise ValueError("The document content is required.")
+ if docs[0].get("id") is None:
+ raise ValueError("The document id is required.")
+
+ input_ids = set()
+ result_ids = set()
+ id_batch = []
+ text_batch = []
+ metadata_batch = []
+ size = 0
+ i = 0
+ for doc in docs:
+ id = doc["id"]
+ text = doc["content"]
+ metadata = doc.get("metadata", {})
+ id_batch.append(id)
+ text_batch.append(text)
+ metadata_batch.append(metadata)
+ id_size = 1 if isinstance(id, int) else len(id)
+ size += len(text) + len(metadata) + id_size
+ if (i + 1) % batch_size == 0 or size >= 47_000_000:
+ result_ids.update(self._insert_batch(collection, text_batch, metadata_batch, id_batch))
+ input_ids.update(id_batch)
+ id_batch = []
+ text_batch = []
+ metadata_batch = []
+ size = 0
+ i += 1
+ if text_batch:
+ result_ids.update(self._insert_batch(collection, text_batch, metadata_batch, id_batch)) # type: ignore
+ input_ids.update(id_batch)
+
+ if result_ids != input_ids:
+ logger.warning(
+ "Possible data corruption. "
+ "input_ids not in result_ids: {in_diff}.\n"
+ "result_ids not in input_ids: {out_diff}".format(
+ in_diff=input_ids.difference(result_ids), out_diff=result_ids.difference(input_ids)
+ )
+ )
+ if self._wait_until_document_ready and docs:
+ self._wait_for_document(collection, self.index_name, docs[-1])
+
+ def _insert_batch(
+ self, collection: Collection, texts: List[str], metadatas: List[Mapping[str, Any]], ids: List[ItemID]
+ ) -> Set[ItemID]:
+ """Compute embeddings for and insert a batch of Documents into the Collection.
+
+ For performance reasons, we chose to call self.embedding_function just once,
+ with the hopefully small tradeoff of having recreating Document dicts.
+
+ Args:
+ collection: MongoDB Collection
+ texts: List of the main contents of each document
+ metadatas: List of metadata mappings
+ ids: List of ids. Note that these are stored as _id in Collection.
+
+ Returns:
+ List of ids inserted.
+ """
+ n_texts = len(texts)
+ if n_texts == 0:
+ return []
+ # Embed and create the documents
+ embeddings = self.embedding_function(texts).tolist()
+ assert (
+ len(embeddings) == n_texts
+ ), f"The number of embeddings produced by self.embedding_function ({len(embeddings)} does not match the number of texts provided to it ({n_texts})."
+ to_insert = [
+ {"_id": i, "content": t, "metadata": m, "embedding": e}
+ for i, t, m, e in zip(ids, texts, metadatas, embeddings)
+ ]
+ # insert the documents in MongoDB Atlas
+ insert_result = collection.insert_many(to_insert) # type: ignore
+ return insert_result.inserted_ids # TODO Remove this. Replace by log like update_docs
+
+ def update_docs(self, docs: List[Document], collection_name: str = None, **kwargs: Any) -> None:
+ """Update documents, including their embeddings, in the Collection.
+
+ Optionally allow upsert as kwarg.
+
+ Uses deepcopy to avoid changing docs.
+
+ Args:
+ docs: List[Document] | A list of documents.
+ collection_name: str | The name of the collection. Default is None.
+ kwargs: Any | Use upsert=True` to insert documents whose ids are not present in collection.
+ """
+
+ n_docs = len(docs)
+ logger.info(f"Preparing to embed and update {n_docs=}")
+ # Compute the embeddings
+ embeddings: list[list[float]] = self.embedding_function([doc["content"] for doc in docs]).tolist()
+ # Prepare the updates
+ all_updates = []
+ for i in range(n_docs):
+ doc = deepcopy(docs[i])
+ doc["embedding"] = embeddings[i]
+ doc["_id"] = doc.pop("id")
+
+ all_updates.append(UpdateOne({"_id": doc["_id"]}, {"$set": doc}, upsert=kwargs.get("upsert", False)))
+ # Perform update in bulk
+ collection = self.get_collection(collection_name)
+ result = collection.bulk_write(all_updates)
+
+ if self._wait_until_document_ready and docs:
+ self._wait_for_document(collection, self.index_name, docs[-1])
+
+ # Log a result summary
+ logger.info(
+ "Matched: %s, Modified: %s, Upserted: %s",
+ result.matched_count,
+ result.modified_count,
+ result.upserted_count,
+ )
+
+ def delete_docs(self, ids: List[ItemID], collection_name: str = None, **kwargs):
+ """
+ Delete documents from the collection of the vector database.
+
+ Args:
+ ids: List[ItemID] | A list of document ids. Each id is a typed `ItemID`.
+ collection_name: str | The name of the collection. Default is None.
+ """
+ collection = self.get_collection(collection_name)
+ return collection.delete_many({"_id": {"$in": ids}})
+
+ def get_docs_by_ids(
+ self, ids: List[ItemID] = None, collection_name: str = None, include: List[str] = None, **kwargs
+ ) -> List[Document]:
+ """
+ Retrieve documents from the collection of the vector database based on the ids.
+
+ Args:
+ ids: List[ItemID] | A list of document ids. If None, will return all the documents. Default is None.
+ collection_name: str | The name of the collection. Default is None.
+ include: List[str] | The fields to include.
+ If None, will include ["metadata", "content"], ids will always be included.
+ Basically, use include to choose whether to include embedding and metadata
+ kwargs: dict | Additional keyword arguments.
+
+ Returns:
+ List[Document] | The results.
+ """
+ if include is None:
+ include_fields = {"_id": 1, "content": 1, "metadata": 1}
+ else:
+ include_fields = {k: 1 for k in set(include).union({"_id"})}
+ collection = self.get_collection(collection_name)
+ if ids is not None:
+ docs = collection.find({"_id": {"$in": ids}}, include_fields)
+ # Return with _id field from Collection into id for Document
+ return with_id_rename(docs)
+ else:
+ docs = collection.find({}, include_fields)
+ # Return with _id field from Collection into id for Document
+ return with_id_rename(docs)
+
+ def retrieve_docs(
+ self,
+ queries: List[str],
+ collection_name: str = None,
+ n_results: int = 10,
+ distance_threshold: float = -1,
+ **kwargs,
+ ) -> QueryResults:
+ """
+ Retrieve documents from the collection of the vector database based on the queries.
+
+ Args:
+ queries: List[str] | A list of queries. Each query is a string.
+ collection_name: str | The name of the collection. Default is None.
+ n_results: int | The number of relevant documents to return. Default is 10.
+ distance_threshold: float | The threshold for the distance score, only distance smaller than it will be
+ returned. Don't filter with it if < 0. Default is -1.
+ kwargs: Dict | Additional keyword arguments. Ones of importance follow:
+ oversampling_factor: int | This times n_results is 'ef' in the HNSW algorithm.
+ It determines the number of nearest neighbor candidates to consider during the search phase.
+ A higher value leads to more accuracy, but is slower. Default is 10
+
+ Returns:
+ QueryResults | For each query string, a list of nearest documents and their scores.
+ """
+ collection = self.get_collection(collection_name)
+ # Trivial case of an empty collection
+ if collection.count_documents({}) == 0:
+ return []
+
+ logger.debug(f"Using index: {self.index_name}")
+ results = []
+ for query_text in queries:
+ # Compute embedding vector from semantic query
+ logger.debug(f"Query: {query_text}")
+ query_vector = np.array(self.embedding_function([query_text])).tolist()[0]
+ # Find documents with similar vectors using the specified index
+ query_result = _vector_search(
+ query_vector,
+ n_results,
+ collection,
+ self.index_name,
+ distance_threshold,
+ **kwargs,
+ oversampling_factor=kwargs.get("oversampling_factor", 10),
+ )
+ # Change each _id key to id. with_id_rename, but with (doc, score) tuples
+ results.append(
+ [({**{k: v for k, v in d[0].items() if k != "_id"}, "id": d[0]["_id"]}, d[1]) for d in query_result]
+ )
+ return results
+
+
+def _vector_search(
+ embedding_vector: List[float],
+ n_results: int,
+ collection: Collection,
+ index_name: str,
+ distance_threshold: float = -1.0,
+ oversampling_factor=10,
+ include_embedding=False,
+) -> List[Tuple[Dict, float]]:
+ """Core $vectorSearch Aggregation pipeline.
+
+ Args:
+ embedding_vector: Embedding vector of semantic query
+ n_results: Number of documents to return. Defaults to 4.
+ collection: MongoDB Collection with vector index
+ index_name: Name of the vector index
+ distance_threshold: Only distance measures smaller than this will be returned.
+ Don't filter with it if 1 < x < 0. Default is -1.
+ oversampling_factor: int | This times n_results is 'ef' in the HNSW algorithm.
+ It determines the number of nearest neighbor candidates to consider during the search phase.
+ A higher value leads to more accuracy, but is slower. Default = 10
+
+ Returns:
+ List of tuples of length n_results from Collection.
+ Each tuple contains a document dict and a score.
+ """
+
+ pipeline = [
+ {
+ "$vectorSearch": {
+ "index": index_name,
+ "limit": n_results,
+ "numCandidates": n_results * oversampling_factor,
+ "queryVector": embedding_vector,
+ "path": "embedding",
+ }
+ },
+ {"$set": {"score": {"$meta": "vectorSearchScore"}}},
+ ]
+ if distance_threshold >= 0.0:
+ similarity_threshold = 1.0 - distance_threshold
+ pipeline.append({"$match": {"score": {"$gte": similarity_threshold}}})
+
+ if not include_embedding:
+ pipeline.append({"$project": {"embedding": 0}})
+
+ logger.debug("pipeline: %s", pipeline)
+ agg = collection.aggregate(pipeline)
+ return [(doc, doc.pop("score")) for doc in agg]
diff --git a/autogen/agentchat/contrib/vectordb/pgvectordb.py b/autogen/agentchat/contrib/vectordb/pgvectordb.py
new file mode 100644
index 00000000000..ac86802b672
--- /dev/null
+++ b/autogen/agentchat/contrib/vectordb/pgvectordb.py
@@ -0,0 +1,952 @@
+import os
+import re
+import urllib.parse
+from typing import Callable, List, Optional, Union
+
+import numpy as np
+from sentence_transformers import SentenceTransformer
+
+from .base import Document, ItemID, QueryResults, VectorDB
+from .utils import get_logger
+
+try:
+ import pgvector
+ from pgvector.psycopg import register_vector
+except ImportError:
+ raise ImportError("Please install pgvector: `pip install pgvector`")
+
+try:
+ import psycopg
+except ImportError:
+ raise ImportError("Please install pgvector: `pip install psycopg`")
+
+PGVECTOR_MAX_BATCH_SIZE = os.environ.get("PGVECTOR_MAX_BATCH_SIZE", 40000)
+logger = get_logger(__name__)
+
+
+class Collection:
+ """
+ A Collection object for PGVector.
+
+ Attributes:
+ client: The PGVector client.
+ collection_name (str): The name of the collection. Default is "documents".
+ embedding_function (Callable): The embedding function used to generate the vector representation.
+ Default is None. SentenceTransformer("all-MiniLM-L6-v2").encode will be used when None.
+ Models can be chosen from:
+ https://huggingface.co/models?library=sentence-transformers
+ metadata (Optional[dict]): The metadata of the collection.
+ get_or_create (Optional): The flag indicating whether to get or create the collection.
+ """
+
+ def __init__(
+ self,
+ client=None,
+ collection_name: str = "autogen-docs",
+ embedding_function: Callable = None,
+ metadata=None,
+ get_or_create=None,
+ ):
+ """
+ Initialize the Collection object.
+
+ Args:
+ client: The PostgreSQL client.
+ collection_name: The name of the collection. Default is "documents".
+ embedding_function: The embedding function used to generate the vector representation.
+ metadata: The metadata of the collection.
+ get_or_create: The flag indicating whether to get or create the collection.
+ Returns:
+ None
+ """
+ self.client = client
+ self.name = self.set_collection_name(collection_name)
+ self.require_embeddings_or_documents = False
+ self.ids = []
+ if embedding_function:
+ self.embedding_function = embedding_function
+ else:
+ self.embedding_function = SentenceTransformer("all-MiniLM-L6-v2").encode
+ self.metadata = metadata if metadata else {"hnsw:space": "ip", "hnsw:construction_ef": 32, "hnsw:M": 16}
+ self.documents = ""
+ self.get_or_create = get_or_create
+ # This will get the model dimension size by computing the embeddings dimensions
+ sentences = [
+ "The weather is lovely today in paradise.",
+ ]
+ embeddings = self.embedding_function(sentences)
+ self.dimension = len(embeddings[0])
+
+ def set_collection_name(self, collection_name) -> str:
+ name = re.sub("-", "_", collection_name)
+ self.name = name
+ return self.name
+
+ def add(self, ids: List[ItemID], documents: List, embeddings: List = None, metadatas: List = None) -> None:
+ """
+ Add documents to the collection.
+
+ Args:
+ ids (List[ItemID]): A list of document IDs.
+ embeddings (List): A list of document embeddings. Optional
+ metadatas (List): A list of document metadatas. Optional
+ documents (List): A list of documents.
+
+ Returns:
+ None
+ """
+ cursor = self.client.cursor()
+ sql_values = []
+ if embeddings is not None and metadatas is not None:
+ for doc_id, embedding, metadata, document in zip(ids, embeddings, metadatas, documents):
+ metadata = re.sub("'", '"', str(metadata))
+ sql_values.append((doc_id, embedding, metadata, document))
+ sql_string = (
+ f"INSERT INTO {self.name} (id, embedding, metadatas, documents)\n" f"VALUES (%s, %s, %s, %s);\n"
+ )
+ elif embeddings is not None:
+ for doc_id, embedding, document in zip(ids, embeddings, documents):
+ sql_values.append((doc_id, embedding, document))
+ sql_string = f"INSERT INTO {self.name} (id, embedding, documents) " f"VALUES (%s, %s, %s);\n"
+ elif metadatas is not None:
+ for doc_id, metadata, document in zip(ids, metadatas, documents):
+ metadata = re.sub("'", '"', str(metadata))
+ embedding = self.embedding_function(document)
+ sql_values.append((doc_id, metadata, embedding, document))
+ sql_string = (
+ f"INSERT INTO {self.name} (id, metadatas, embedding, documents)\n" f"VALUES (%s, %s, %s, %s);\n"
+ )
+ else:
+ for doc_id, document in zip(ids, documents):
+ embedding = self.embedding_function(document)
+ sql_values.append((doc_id, document, embedding))
+ sql_string = f"INSERT INTO {self.name} (id, documents, embedding)\n" f"VALUES (%s, %s, %s);\n"
+ logger.debug(f"Add SQL String:\n{sql_string}\n{sql_values}")
+ cursor.executemany(sql_string, sql_values)
+ cursor.close()
+
+ def upsert(self, ids: List[ItemID], documents: List, embeddings: List = None, metadatas: List = None) -> None:
+ """
+ Upsert documents into the collection.
+
+ Args:
+ ids (List[ItemID]): A list of document IDs.
+ documents (List): A list of documents.
+ embeddings (List): A list of document embeddings.
+ metadatas (List): A list of document metadatas.
+
+ Returns:
+ None
+ """
+ cursor = self.client.cursor()
+ sql_values = []
+ if embeddings is not None and metadatas is not None:
+ for doc_id, embedding, metadata, document in zip(ids, embeddings, metadatas, documents):
+ metadata = re.sub("'", '"', str(metadata))
+ sql_values.append((doc_id, embedding, metadata, document, embedding, metadata, document))
+ sql_string = (
+ f"INSERT INTO {self.name} (id, embedding, metadatas, documents)\n"
+ f"VALUES (%s, %s, %s, %s)\n"
+ f"ON CONFLICT (id)\n"
+ f"DO UPDATE SET embedding = %s,\n"
+ f"metadatas = %s, documents = %s;\n"
+ )
+ elif embeddings is not None:
+ for doc_id, embedding, document in zip(ids, embeddings, documents):
+ sql_values.append((doc_id, embedding, document, embedding, document))
+ sql_string = (
+ f"INSERT INTO {self.name} (id, embedding, documents) "
+ f"VALUES (%s, %s, %s) ON CONFLICT (id)\n"
+ f"DO UPDATE SET embedding = %s, documents = %s;\n"
+ )
+ elif metadatas is not None:
+ for doc_id, metadata, document in zip(ids, metadatas, documents):
+ metadata = re.sub("'", '"', str(metadata))
+ embedding = self.embedding_function(document)
+ sql_values.append((doc_id, metadata, embedding, document, metadata, document, embedding))
+ sql_string = (
+ f"INSERT INTO {self.name} (id, metadatas, embedding, documents)\n"
+ f"VALUES (%s, %s, %s, %s)\n"
+ f"ON CONFLICT (id)\n"
+ f"DO UPDATE SET metadatas = %s, documents = %s, embedding = %s;\n"
+ )
+ else:
+ for doc_id, document in zip(ids, documents):
+ embedding = self.embedding_function(document)
+ sql_values.append((doc_id, document, embedding, document))
+ sql_string = (
+ f"INSERT INTO {self.name} (id, documents, embedding)\n"
+ f"VALUES (%s, %s, %s)\n"
+ f"ON CONFLICT (id)\n"
+ f"DO UPDATE SET documents = %s;\n"
+ )
+ logger.debug(f"Upsert SQL String:\n{sql_string}\n{sql_values}")
+ cursor.executemany(sql_string, sql_values)
+ cursor.close()
+
+ def count(self) -> int:
+ """
+ Get the total number of documents in the collection.
+
+ Returns:
+ int: The total number of documents.
+ """
+ cursor = self.client.cursor()
+ query = f"SELECT COUNT(*) FROM {self.name}"
+ cursor.execute(query)
+ total = cursor.fetchone()[0]
+ cursor.close()
+ try:
+ total = int(total)
+ except (TypeError, ValueError):
+ total = None
+ return total
+
+ def table_exists(self, table_name: str) -> bool:
+ """
+ Check if a table exists in the PostgreSQL database.
+
+ Args:
+ table_name (str): The name of the table to check.
+
+ Returns:
+ bool: True if the table exists, False otherwise.
+ """
+
+ cursor = self.client.cursor()
+ cursor.execute(
+ """
+ SELECT EXISTS (
+ SELECT 1
+ FROM information_schema.tables
+ WHERE table_name = %s
+ )
+ """,
+ (table_name,),
+ )
+ exists = cursor.fetchone()[0]
+ return exists
+
+ def get(
+ self,
+ ids: Optional[str] = None,
+ include: Optional[str] = None,
+ where: Optional[str] = None,
+ limit: Optional[Union[int, str]] = None,
+ offset: Optional[Union[int, str]] = None,
+ ) -> List[Document]:
+ """
+ Retrieve documents from the collection.
+
+ Args:
+ ids (Optional[List]): A list of document IDs.
+ include (Optional): The fields to include.
+ where (Optional): Additional filtering criteria.
+ limit (Optional): The maximum number of documents to retrieve.
+ offset (Optional): The offset for pagination.
+
+ Returns:
+ List: The retrieved documents.
+ """
+ cursor = self.client.cursor()
+
+ # Initialize variables for query components
+ select_clause = "SELECT id, metadatas, documents, embedding"
+ from_clause = f"FROM {self.name}"
+ where_clause = ""
+ limit_clause = ""
+ offset_clause = ""
+
+ # Handle include clause
+ if include:
+ select_clause = f"SELECT id, {', '.join(include)}, embedding"
+
+ # Handle where clause
+ if ids:
+ where_clause = f"WHERE id IN ({', '.join(['%s' for _ in ids])})"
+ elif where:
+ where_clause = f"WHERE {where}"
+
+ # Handle limit and offset clauses
+ if limit:
+ limit_clause = "LIMIT %s"
+ if offset:
+ offset_clause = "OFFSET %s"
+
+ # Construct the full query
+ query = f"{select_clause} {from_clause} {where_clause} {limit_clause} {offset_clause}"
+ retrieved_documents = []
+ try:
+ # Execute the query with the appropriate values
+ if ids is not None:
+ cursor.execute(query, ids)
+ else:
+ query_params = []
+ if limit:
+ query_params.append(limit)
+ if offset:
+ query_params.append(offset)
+ cursor.execute(query, query_params)
+
+ retrieval = cursor.fetchall()
+ for retrieved_document in retrieval:
+ retrieved_documents.append(
+ Document(
+ id=retrieved_document[0].strip(),
+ metadata=retrieved_document[1],
+ content=retrieved_document[2],
+ embedding=retrieved_document[3],
+ )
+ )
+ except (psycopg.errors.UndefinedTable, psycopg.errors.UndefinedColumn) as e:
+ logger.info(f"Error executing select on non-existent table: {self.name}. Creating it instead. Error: {e}")
+ self.create_collection(collection_name=self.name, dimension=self.dimension)
+ logger.info(f"Created table {self.name}")
+
+ cursor.close()
+ return retrieved_documents
+
+ def update(self, ids: List, embeddings: List, metadatas: List, documents: List) -> None:
+ """
+ Update documents in the collection.
+
+ Args:
+ ids (List): A list of document IDs.
+ embeddings (List): A list of document embeddings.
+ metadatas (List): A list of document metadatas.
+ documents (List): A list of documents.
+
+ Returns:
+ None
+ """
+ cursor = self.client.cursor()
+ sql_values = []
+ for doc_id, embedding, metadata, document in zip(ids, embeddings, metadatas, documents):
+ sql_values.append((doc_id, embedding, metadata, document, doc_id, embedding, metadata, document))
+ sql_string = (
+ f"INSERT INTO {self.name} (id, embedding, metadata, document) "
+ f"VALUES (%s, %s, %s, %s) "
+ f"ON CONFLICT (id) "
+ f"DO UPDATE SET id = %s, embedding = %s, "
+ f"metadata = %s, document = %s;\n"
+ )
+ logger.debug(f"Upsert SQL String:\n{sql_string}\n")
+ cursor.executemany(sql_string, sql_values)
+ cursor.close()
+
+ @staticmethod
+ def euclidean_distance(arr1: List[float], arr2: List[float]) -> float:
+ """
+ Calculate the Euclidean distance between two vectors.
+
+ Parameters:
+ - arr1 (List[float]): The first vector.
+ - arr2 (List[float]): The second vector.
+
+ Returns:
+ - float: The Euclidean distance between arr1 and arr2.
+ """
+ dist = np.linalg.norm(arr1 - arr2)
+ return dist
+
+ @staticmethod
+ def cosine_distance(arr1: List[float], arr2: List[float]) -> float:
+ """
+ Calculate the cosine distance between two vectors.
+
+ Parameters:
+ - arr1 (List[float]): The first vector.
+ - arr2 (List[float]): The second vector.
+
+ Returns:
+ - float: The cosine distance between arr1 and arr2.
+ """
+ dist = np.dot(arr1, arr2) / (np.linalg.norm(arr1) * np.linalg.norm(arr2))
+ return dist
+
+ @staticmethod
+ def inner_product_distance(arr1: List[float], arr2: List[float]) -> float:
+ """
+ Calculate the Euclidean distance between two vectors.
+
+ Parameters:
+ - arr1 (List[float]): The first vector.
+ - arr2 (List[float]): The second vector.
+
+ Returns:
+ - float: The Euclidean distance between arr1 and arr2.
+ """
+ dist = np.linalg.norm(arr1 - arr2)
+ return dist
+
+ def query(
+ self,
+ query_texts: List[str],
+ collection_name: Optional[str] = None,
+ n_results: Optional[int] = 10,
+ distance_type: Optional[str] = "euclidean",
+ distance_threshold: Optional[float] = -1,
+ include_embedding: Optional[bool] = False,
+ ) -> QueryResults:
+ """
+ Query documents in the collection.
+
+ Args:
+ query_texts (List[str]): A list of query texts.
+ collection_name (Optional[str]): The name of the collection.
+ n_results (int): The maximum number of results to return.
+ distance_type (Optional[str]): Distance search type - euclidean or cosine
+ distance_threshold (Optional[float]): Distance threshold to limit searches
+ include_embedding (Optional[bool]): Include embedding values in QueryResults
+ Returns:
+ QueryResults: The query results.
+ """
+ if collection_name:
+ self.name = collection_name
+
+ clause = "ORDER BY"
+ if distance_threshold == -1:
+ distance_threshold = ""
+ clause = "ORDER BY"
+ elif distance_threshold > 0:
+ distance_threshold = f"< {distance_threshold}"
+ clause = "WHERE"
+
+ cursor = self.client.cursor()
+ results = []
+ for query_text in query_texts:
+ vector = self.embedding_function(query_text, convert_to_tensor=False).tolist()
+ if distance_type.lower() == "cosine":
+ index_function = "<=>"
+ elif distance_type.lower() == "euclidean":
+ index_function = "<->"
+ elif distance_type.lower() == "inner-product":
+ index_function = "<#>"
+ else:
+ index_function = "<->"
+ query = (
+ f"SELECT id, documents, embedding, metadatas "
+ f"FROM {self.name} "
+ f"{clause} embedding {index_function} '{str(vector)}' {distance_threshold} "
+ f"LIMIT {n_results}"
+ )
+ cursor.execute(query)
+ result = []
+ for row in cursor.fetchall():
+ fetched_document = Document(id=row[0].strip(), content=row[1], embedding=row[2], metadata=row[3])
+ fetched_document_array = self.convert_string_to_array(array_string=fetched_document.get("embedding"))
+ if distance_type.lower() == "cosine":
+ distance = self.cosine_distance(fetched_document_array, vector)
+ elif distance_type.lower() == "euclidean":
+ distance = self.euclidean_distance(fetched_document_array, vector)
+ elif distance_type.lower() == "inner-product":
+ distance = self.inner_product_distance(fetched_document_array, vector)
+ else:
+ distance = self.euclidean_distance(fetched_document_array, vector)
+ if not include_embedding:
+ fetched_document = Document(id=row[0].strip(), content=row[1], metadata=row[3])
+ result.append((fetched_document, distance))
+ results.append(result)
+ cursor.close()
+ logger.debug(f"Query Results: {results}")
+ return results
+
+ @staticmethod
+ def convert_string_to_array(array_string: str) -> List[float]:
+ """
+ Convert a string representation of an array to a list of floats.
+
+ Parameters:
+ - array_string (str): The string representation of the array.
+
+ Returns:
+ - list: A list of floats parsed from the input string. If the input is
+ not a string, it returns the input itself.
+ """
+ if not isinstance(array_string, str):
+ return array_string
+ array_string = array_string.strip("[]")
+ array = [float(num) for num in array_string.split()]
+ return array
+
+ def modify(self, metadata, collection_name: Optional[str] = None) -> None:
+ """
+ Modify metadata for the collection.
+
+ Args:
+ collection_name: The name of the collection.
+ metadata: The new metadata.
+
+ Returns:
+ None
+ """
+ if collection_name:
+ self.name = collection_name
+ cursor = self.client.cursor()
+ cursor.execute(
+ "UPDATE collections" "SET metadata = '%s'" "WHERE collection_name = '%s';", (metadata, self.name)
+ )
+ cursor.close()
+
+ def delete(self, ids: List[ItemID], collection_name: Optional[str] = None) -> None:
+ """
+ Delete documents from the collection.
+
+ Args:
+ ids (List[ItemID]): A list of document IDs to delete.
+ collection_name (str): The name of the collection to delete.
+
+ Returns:
+ None
+ """
+ if collection_name:
+ self.name = collection_name
+ cursor = self.client.cursor()
+ id_placeholders = ", ".join(["%s" for _ in ids])
+ cursor.execute(f"DELETE FROM {self.name} WHERE id IN ({id_placeholders});", ids)
+ cursor.close()
+
+ def delete_collection(self, collection_name: Optional[str] = None) -> None:
+ """
+ Delete the entire collection.
+
+ Args:
+ collection_name (Optional[str]): The name of the collection to delete.
+
+ Returns:
+ None
+ """
+ if collection_name:
+ self.name = collection_name
+ cursor = self.client.cursor()
+ cursor.execute(f"DROP TABLE IF EXISTS {self.name}")
+ cursor.close()
+
+ def create_collection(
+ self, collection_name: Optional[str] = None, dimension: Optional[Union[str, int]] = None
+ ) -> None:
+ """
+ Create a new collection.
+
+ Args:
+ collection_name (Optional[str]): The name of the new collection.
+ dimension (Optional[Union[str, int]]): The dimension size of the sentence embedding model
+
+ Returns:
+ None
+ """
+ if collection_name:
+ self.name = collection_name
+
+ if dimension:
+ self.dimension = dimension
+ elif self.dimension is None:
+ self.dimension = 384
+
+ cursor = self.client.cursor()
+ cursor.execute(
+ f"CREATE TABLE {self.name} ("
+ f"documents text, id CHAR(8) PRIMARY KEY, metadatas JSONB, embedding vector({self.dimension}));"
+ f"CREATE INDEX "
+ f'ON {self.name} USING hnsw (embedding vector_l2_ops) WITH (m = {self.metadata["hnsw:M"]}, '
+ f'ef_construction = {self.metadata["hnsw:construction_ef"]});'
+ f"CREATE INDEX "
+ f'ON {self.name} USING hnsw (embedding vector_cosine_ops) WITH (m = {self.metadata["hnsw:M"]}, '
+ f'ef_construction = {self.metadata["hnsw:construction_ef"]});'
+ f"CREATE INDEX "
+ f'ON {self.name} USING hnsw (embedding vector_ip_ops) WITH (m = {self.metadata["hnsw:M"]}, '
+ f'ef_construction = {self.metadata["hnsw:construction_ef"]});'
+ )
+ cursor.close()
+
+
+class PGVectorDB(VectorDB):
+ """
+ A vector database that uses PGVector as the backend.
+ """
+
+ def __init__(
+ self,
+ *,
+ conn: Optional[psycopg.Connection] = None,
+ connection_string: Optional[str] = None,
+ host: Optional[str] = None,
+ port: Optional[Union[int, str]] = None,
+ dbname: Optional[str] = None,
+ username: Optional[str] = None,
+ password: Optional[str] = None,
+ connect_timeout: Optional[int] = 10,
+ embedding_function: Callable = None,
+ metadata: Optional[dict] = None,
+ ) -> None:
+ """
+ Initialize the vector database.
+
+ Note: connection_string or host + port + dbname must be specified
+
+ Args:
+ conn: psycopg.Connection | A customer connection object to connect to the database.
+ A connection object may include additional key/values:
+ https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING
+ connection_string: "postgresql://username:password@hostname:port/database" | The PGVector connection string. Default is None.
+ host: str | The host to connect to. Default is None.
+ port: int | The port to connect to. Default is None.
+ dbname: str | The database name to connect to. Default is None.
+ username: str | The database username to use. Default is None.
+ password: str | The database user password to use. Default is None.
+ connect_timeout: int | The timeout to set for the connection. Default is 10.
+ embedding_function: Callable | The embedding function used to generate the vector representation.
+ Default is None. SentenceTransformer("all-MiniLM-L6-v2").encode will be used when None.
+ Models can be chosen from:
+ https://huggingface.co/models?library=sentence-transformers
+ metadata: dict | The metadata of the vector database. Default is None. If None, it will use this
+ setting: {"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 16}. Creates Index on table
+ using hnsw (embedding vector_l2_ops) WITH (m = hnsw:M) ef_construction = "hnsw:construction_ef".
+ For more info: https://github.com/pgvector/pgvector?tab=readme-ov-file#hnsw
+ Returns:
+ None
+ """
+ self.client = self.establish_connection(
+ conn=conn,
+ connection_string=connection_string,
+ host=host,
+ port=port,
+ dbname=dbname,
+ username=username,
+ password=password,
+ connect_timeout=connect_timeout,
+ )
+ if embedding_function:
+ self.embedding_function = embedding_function
+ else:
+ self.embedding_function = SentenceTransformer("all-MiniLM-L6-v2").encode
+ self.metadata = metadata
+ register_vector(self.client)
+ self.active_collection = None
+
+ def establish_connection(
+ self,
+ conn: Optional[psycopg.Connection] = None,
+ connection_string: Optional[str] = None,
+ host: Optional[str] = None,
+ port: Optional[Union[int, str]] = None,
+ dbname: Optional[str] = None,
+ username: Optional[str] = None,
+ password: Optional[str] = None,
+ connect_timeout: Optional[int] = 10,
+ ) -> psycopg.Connection:
+ """
+ Establishes a connection to a PostgreSQL database using psycopg.
+
+ Args:
+ conn: An existing psycopg connection object. If provided, this connection will be used.
+ connection_string: A string containing the connection information. If provided, a new connection will be established using this string.
+ host: The hostname of the PostgreSQL server. Used if connection_string is not provided.
+ port: The port number to connect to at the server host. Used if connection_string is not provided.
+ dbname: The database name. Used if connection_string is not provided.
+ username: The username to connect as. Used if connection_string is not provided.
+ password: The user's password. Used if connection_string is not provided.
+ connect_timeout: Maximum wait for connection, in seconds. The default is 10 seconds.
+
+ Returns:
+ A psycopg.Connection object representing the established connection.
+
+ Raises:
+ PermissionError if no credentials are supplied
+ psycopg.Error: If an error occurs while trying to connect to the database.
+ """
+ try:
+ if conn:
+ self.client = conn
+ elif connection_string:
+ parsed_connection = urllib.parse.urlparse(connection_string)
+ encoded_username = urllib.parse.quote(parsed_connection.username, safe="")
+ encoded_password = urllib.parse.quote(parsed_connection.password, safe="")
+ encoded_password = f":{encoded_password}@"
+ encoded_host = urllib.parse.quote(parsed_connection.hostname, safe="")
+ encoded_port = f":{parsed_connection.port}"
+ encoded_database = urllib.parse.quote(parsed_connection.path[1:], safe="")
+ connection_string_encoded = (
+ f"{parsed_connection.scheme}://{encoded_username}{encoded_password}"
+ f"{encoded_host}{encoded_port}/{encoded_database}"
+ )
+ self.client = psycopg.connect(conninfo=connection_string_encoded, autocommit=True)
+ elif host:
+ connection_string = ""
+ if host:
+ encoded_host = urllib.parse.quote(host, safe="")
+ connection_string += f"host={encoded_host} "
+ if port:
+ connection_string += f"port={port} "
+ if dbname:
+ encoded_database = urllib.parse.quote(dbname, safe="")
+ connection_string += f"dbname={encoded_database} "
+ if username:
+ encoded_username = urllib.parse.quote(username, safe="")
+ connection_string += f"user={encoded_username} "
+ if password:
+ encoded_password = urllib.parse.quote(password, safe="")
+ connection_string += f"password={encoded_password} "
+
+ self.client = psycopg.connect(
+ conninfo=connection_string,
+ connect_timeout=connect_timeout,
+ autocommit=True,
+ )
+ else:
+ logger.error("Credentials were not supplied...")
+ raise PermissionError
+ self.client.execute("CREATE EXTENSION IF NOT EXISTS vector")
+ except psycopg.Error as e:
+ logger.error("Error connecting to the database: ", e)
+ raise e
+ return self.client
+
+ def create_collection(
+ self, collection_name: str, overwrite: bool = False, get_or_create: bool = True
+ ) -> Collection:
+ """
+ Create a collection in the vector database.
+ Case 1. if the collection does not exist, create the collection.
+ Case 2. the collection exists, if overwrite is True, it will overwrite the collection.
+ Case 3. the collection exists and overwrite is False, if get_or_create is True, it will get the collection,
+ otherwise it raise a ValueError.
+
+ Args:
+ collection_name: str | The name of the collection.
+ overwrite: bool | Whether to overwrite the collection if it exists. Default is False.
+ get_or_create: bool | Whether to get the collection if it exists. Default is True.
+
+ Returns:
+ Collection | The collection object.
+ """
+ try:
+ if self.active_collection and self.active_collection.name == collection_name:
+ collection = self.active_collection
+ else:
+ collection = self.get_collection(collection_name)
+ except ValueError:
+ collection = None
+ if collection is None:
+ collection = Collection(
+ client=self.client,
+ collection_name=collection_name,
+ embedding_function=self.embedding_function,
+ get_or_create=get_or_create,
+ metadata=self.metadata,
+ )
+ collection.set_collection_name(collection_name=collection_name)
+ collection.create_collection(collection_name=collection_name)
+ return collection
+ elif overwrite:
+ self.delete_collection(collection_name)
+ collection = Collection(
+ client=self.client,
+ collection_name=collection_name,
+ embedding_function=self.embedding_function,
+ get_or_create=get_or_create,
+ metadata=self.metadata,
+ )
+ collection.set_collection_name(collection_name=collection_name)
+ collection.create_collection(collection_name=collection_name)
+ return collection
+ elif get_or_create:
+ return collection
+ elif not collection.table_exists(table_name=collection_name):
+ collection = Collection(
+ client=self.client,
+ collection_name=collection_name,
+ embedding_function=self.embedding_function,
+ get_or_create=get_or_create,
+ metadata=self.metadata,
+ )
+ collection.set_collection_name(collection_name=collection_name)
+ collection.create_collection(collection_name=collection_name)
+ return collection
+ else:
+ raise ValueError(f"Collection {collection_name} already exists.")
+
+ def get_collection(self, collection_name: str = None) -> Collection:
+ """
+ Get the collection from the vector database.
+
+ Args:
+ collection_name: str | The name of the collection. Default is None. If None, return the
+ current active collection.
+
+ Returns:
+ Collection | The collection object.
+ """
+ if collection_name is None:
+ if self.active_collection is None:
+ raise ValueError("No collection is specified.")
+ else:
+ logger.debug(
+ f"No collection is specified. Using current active collection {self.active_collection.name}."
+ )
+ else:
+ if not (self.active_collection and self.active_collection.name == collection_name):
+ self.active_collection = Collection(
+ client=self.client,
+ collection_name=collection_name,
+ embedding_function=self.embedding_function,
+ )
+ return self.active_collection
+
+ def delete_collection(self, collection_name: str) -> None:
+ """
+ Delete the collection from the vector database.
+
+ Args:
+ collection_name: str | The name of the collection.
+
+ Returns:
+ None
+ """
+ if self.active_collection:
+ self.active_collection.delete_collection(collection_name)
+ else:
+ collection = self.get_collection(collection_name)
+ collection.delete_collection(collection_name)
+ if self.active_collection and self.active_collection.name == collection_name:
+ self.active_collection = None
+
+ def _batch_insert(
+ self, collection: Collection, embeddings=None, ids=None, metadatas=None, documents=None, upsert=False
+ ) -> None:
+ batch_size = int(PGVECTOR_MAX_BATCH_SIZE)
+ default_metadata = {"hnsw:space": "ip", "hnsw:construction_ef": 32, "hnsw:M": 16}
+ default_metadatas = [default_metadata] * min(batch_size, len(documents))
+ for i in range(0, len(documents), min(batch_size, len(documents))):
+ end_idx = i + min(batch_size, len(documents) - i)
+ collection_kwargs = {
+ "documents": documents[i:end_idx],
+ "ids": ids[i:end_idx],
+ "metadatas": metadatas[i:end_idx] if metadatas else default_metadatas,
+ "embeddings": embeddings[i:end_idx] if embeddings else None,
+ }
+ if upsert:
+ collection.upsert(**collection_kwargs)
+ else:
+ collection.add(**collection_kwargs)
+
+ def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False) -> None:
+ """
+ Insert documents into the collection of the vector database.
+
+ Args:
+ docs: List[Document] | A list of documents. Each document is a TypedDict `Document`.
+ collection_name: str | The name of the collection. Default is None.
+ upsert: bool | Whether to update the document if it exists. Default is False.
+ kwargs: Dict | Additional keyword arguments.
+
+ Returns:
+ None
+ """
+ if not docs:
+ return
+ if docs[0].get("content") is None:
+ raise ValueError("The document content is required.")
+ if docs[0].get("id") is None:
+ raise ValueError("The document id is required.")
+ documents = [doc.get("content") for doc in docs]
+ ids = [doc.get("id") for doc in docs]
+
+ collection = self.get_collection(collection_name)
+ if docs[0].get("embedding") is None:
+ logger.debug(
+ "No content embedding is provided. "
+ "Will use the VectorDB's embedding function to generate the content embedding."
+ )
+ embeddings = None
+ else:
+ embeddings = [doc.get("embedding") for doc in docs]
+ if docs[0].get("metadata") is None:
+ metadatas = None
+ else:
+ metadatas = [doc.get("metadata") for doc in docs]
+
+ self._batch_insert(collection, embeddings, ids, metadatas, documents, upsert)
+
+ def update_docs(self, docs: List[Document], collection_name: str = None) -> None:
+ """
+ Update documents in the collection of the vector database.
+
+ Args:
+ docs: List[Document] | A list of documents.
+ collection_name: str | The name of the collection. Default is None.
+
+ Returns:
+ None
+ """
+ self.insert_docs(docs, collection_name, upsert=True)
+
+ def delete_docs(self, ids: List[ItemID], collection_name: str = None) -> None:
+ """
+ Delete documents from the collection of the vector database.
+
+ Args:
+ ids: List[ItemID] | A list of document ids. Each id is a typed `ItemID`.
+ collection_name: str | The name of the collection. Default is None.
+ kwargs: Dict | Additional keyword arguments.
+
+ Returns:
+ None
+ """
+ collection = self.get_collection(collection_name)
+ collection.delete(ids=ids, collection_name=collection_name)
+
+ def retrieve_docs(
+ self,
+ queries: List[str],
+ collection_name: str = None,
+ n_results: int = 10,
+ distance_threshold: float = -1,
+ ) -> QueryResults:
+ """
+ Retrieve documents from the collection of the vector database based on the queries.
+
+ Args:
+ queries: List[str] | A list of queries. Each query is a string.
+ collection_name: str | The name of the collection. Default is None.
+ n_results: int | The number of relevant documents to return. Default is 10.
+ distance_threshold: float | The threshold for the distance score, only distance smaller than it will be
+ returned. Don't filter with it if < 0. Default is -1.
+ kwargs: Dict | Additional keyword arguments.
+
+ Returns:
+ QueryResults | The query results. Each query result is a list of list of tuples containing the document and
+ the distance.
+ """
+ collection = self.get_collection(collection_name)
+ if isinstance(queries, str):
+ queries = [queries]
+ results = collection.query(
+ query_texts=queries,
+ n_results=n_results,
+ distance_threshold=distance_threshold,
+ )
+ logger.debug(f"Retrieve Docs Results:\n{results}")
+ return results
+
+ def get_docs_by_ids(
+ self, ids: List[ItemID] = None, collection_name: str = None, include=None, **kwargs
+ ) -> List[Document]:
+ """
+ Retrieve documents from the collection of the vector database based on the ids.
+
+ Args:
+ ids: List[ItemID] | A list of document ids. If None, will return all the documents. Default is None.
+ collection_name: str | The name of the collection. Default is None.
+ include: List[str] | The fields to include. Default is None.
+ If None, will include ["metadatas", "documents"], ids will always be included.
+ kwargs: dict | Additional keyword arguments.
+
+ Returns:
+ List[Document] | The results.
+ """
+ collection = self.get_collection(collection_name)
+ include = include if include else ["metadatas", "documents"]
+ results = collection.get(ids, include=include, **kwargs)
+ logger.debug(f"Retrieve Documents by ID Results:\n{results}")
+ return results
diff --git a/autogen/agentchat/contrib/vectordb/qdrant.py b/autogen/agentchat/contrib/vectordb/qdrant.py
new file mode 100644
index 00000000000..d9c4ee1d2e5
--- /dev/null
+++ b/autogen/agentchat/contrib/vectordb/qdrant.py
@@ -0,0 +1,328 @@
+import abc
+import logging
+import os
+from typing import Callable, List, Optional, Sequence, Tuple, Union
+
+from .base import Document, ItemID, QueryResults, VectorDB
+from .utils import get_logger
+
+try:
+ from qdrant_client import QdrantClient, models
+except ImportError:
+ raise ImportError("Please install qdrant-client: `pip install qdrant-client`")
+
+logger = get_logger(__name__)
+
+Embeddings = Union[Sequence[float], Sequence[int]]
+
+
+class EmbeddingFunction(abc.ABC):
+ @abc.abstractmethod
+ def __call__(self, inputs: List[str]) -> List[Embeddings]:
+ raise NotImplementedError
+
+
+class FastEmbedEmbeddingFunction(EmbeddingFunction):
+ """Embedding function implementation using FastEmbed - https://qdrant.github.io/fastembed."""
+
+ def __init__(
+ self,
+ model_name: str = "BAAI/bge-small-en-v1.5",
+ batch_size: int = 256,
+ cache_dir: Optional[str] = None,
+ threads: Optional[int] = None,
+ parallel: Optional[int] = None,
+ **kwargs,
+ ):
+ """Initialize fastembed.TextEmbedding.
+
+ Args:
+ model_name (str): The name of the model to use. Defaults to `"BAAI/bge-small-en-v1.5"`.
+ batch_size (int): Batch size for encoding. Higher values will use more memory, but be faster.\
+ Defaults to 256.
+ cache_dir (str, optional): The path to the model cache directory.\
+ Can also be set using the `FASTEMBED_CACHE_PATH` env variable.
+ threads (int, optional): The number of threads single onnxruntime session can use.
+ parallel (int, optional): If `>1`, data-parallel encoding will be used, recommended for large datasets.\
+ If `0`, use all available cores.\
+ If `None`, don't use data-parallel processing, use default onnxruntime threading.\
+ Defaults to None.
+ **kwargs: Additional options to pass to fastembed.TextEmbedding
+ Raises:
+ ValueError: If the model_name is not in the format / e.g. BAAI/bge-small-en-v1.5.
+ """
+ try:
+ from fastembed import TextEmbedding
+ except ImportError as e:
+ raise ValueError(
+ "The 'fastembed' package is not installed. Please install it with `pip install fastembed`",
+ ) from e
+ self._batch_size = batch_size
+ self._parallel = parallel
+ self._model = TextEmbedding(model_name=model_name, cache_dir=cache_dir, threads=threads, **kwargs)
+
+ def __call__(self, inputs: List[str]) -> List[Embeddings]:
+ embeddings = self._model.embed(inputs, batch_size=self._batch_size, parallel=self._parallel)
+
+ return [embedding.tolist() for embedding in embeddings]
+
+
+class QdrantVectorDB(VectorDB):
+ """
+ A vector database implementation that uses Qdrant as the backend.
+ """
+
+ def __init__(
+ self,
+ *,
+ client=None,
+ embedding_function: EmbeddingFunction = None,
+ content_payload_key: str = "_content",
+ metadata_payload_key: str = "_metadata",
+ collection_options: dict = {},
+ **kwargs,
+ ) -> None:
+ """
+ Initialize the vector database.
+
+ Args:
+ client: qdrant_client.QdrantClient | An instance of QdrantClient.
+ embedding_function: Callable | The embedding function used to generate the vector representation
+ of the documents. Defaults to FastEmbedEmbeddingFunction.
+ collection_options: dict | The options for creating the collection.
+ kwargs: dict | Additional keyword arguments.
+ """
+ self.client: QdrantClient = client or QdrantClient(location=":memory:")
+ self.embedding_function = embedding_function or FastEmbedEmbeddingFunction()
+ self.collection_options = collection_options
+ self.content_payload_key = content_payload_key
+ self.metadata_payload_key = metadata_payload_key
+ self.type = "qdrant"
+
+ def create_collection(self, collection_name: str, overwrite: bool = False, get_or_create: bool = True) -> None:
+ """
+ Create a collection in the vector database.
+ Case 1. if the collection does not exist, create the collection.
+ Case 2. the collection exists, if overwrite is True, it will overwrite the collection.
+ Case 3. the collection exists and overwrite is False, if get_or_create is True, it will get the collection,
+ otherwise it raise a ValueError.
+
+ Args:
+ collection_name: str | The name of the collection.
+ overwrite: bool | Whether to overwrite the collection if it exists. Default is False.
+ get_or_create: bool | Whether to get the collection if it exists. Default is True.
+
+ Returns:
+ Any | The collection object.
+ """
+ embeddings_size = len(self.embedding_function(["test"])[0])
+
+ if self.client.collection_exists(collection_name) and overwrite:
+ self.client.delete_collection(collection_name)
+
+ if not self.client.collection_exists(collection_name):
+ self.client.create_collection(
+ collection_name,
+ vectors_config=models.VectorParams(size=embeddings_size, distance=models.Distance.COSINE),
+ **self.collection_options,
+ )
+ elif not get_or_create:
+ raise ValueError(f"Collection {collection_name} already exists.")
+
+ def get_collection(self, collection_name: str = None):
+ """
+ Get the collection from the vector database.
+
+ Args:
+ collection_name: str | The name of the collection.
+
+ Returns:
+ Any | The collection object.
+ """
+ if collection_name is None:
+ raise ValueError("The collection name is required.")
+
+ return self.client.get_collection(collection_name)
+
+ def delete_collection(self, collection_name: str) -> None:
+ """Delete the collection from the vector database.
+
+ Args:
+ collection_name: str | The name of the collection.
+
+ Returns:
+ Any
+ """
+ return self.client.delete_collection(collection_name)
+
+ def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False) -> None:
+ """
+ Insert documents into the collection of the vector database.
+
+ Args:
+ docs: List[Document] | A list of documents. Each document is a TypedDict `Document`.
+ collection_name: str | The name of the collection. Default is None.
+ upsert: bool | Whether to update the document if it exists. Default is False.
+ kwargs: Dict | Additional keyword arguments.
+
+ Returns:
+ None
+ """
+ if not docs:
+ return
+ if any(doc.get("content") is None for doc in docs):
+ raise ValueError("The document content is required.")
+ if any(doc.get("id") is None for doc in docs):
+ raise ValueError("The document id is required.")
+
+ if not upsert and not self._validate_upsert_ids(collection_name, [doc["id"] for doc in docs]):
+ logger.log("Some IDs already exist. Skipping insert", level=logging.WARN)
+
+ self.client.upsert(collection_name, points=self._documents_to_points(docs))
+
+ def update_docs(self, docs: List[Document], collection_name: str = None) -> None:
+ if not docs:
+ return
+ if any(doc.get("id") is None for doc in docs):
+ raise ValueError("The document id is required.")
+ if any(doc.get("content") is None for doc in docs):
+ raise ValueError("The document content is required.")
+ if self._validate_update_ids(collection_name, [doc["id"] for doc in docs]):
+ return self.client.upsert(collection_name, points=self._documents_to_points(docs))
+
+ raise ValueError("Some IDs do not exist. Skipping update")
+
+ def delete_docs(self, ids: List[ItemID], collection_name: str = None, **kwargs) -> None:
+ """
+ Delete documents from the collection of the vector database.
+
+ Args:
+ ids: List[ItemID] | A list of document ids. Each id is a typed `ItemID`.
+ collection_name: str | The name of the collection. Default is None.
+ kwargs: Dict | Additional keyword arguments.
+
+ Returns:
+ None
+ """
+ self.client.delete(collection_name, ids)
+
+ def retrieve_docs(
+ self,
+ queries: List[str],
+ collection_name: str = None,
+ n_results: int = 10,
+ distance_threshold: float = 0,
+ **kwargs,
+ ) -> QueryResults:
+ """
+ Retrieve documents from the collection of the vector database based on the queries.
+
+ Args:
+ queries: List[str] | A list of queries. Each query is a string.
+ collection_name: str | The name of the collection. Default is None.
+ n_results: int | The number of relevant documents to return. Default is 10.
+ distance_threshold: float | The threshold for the distance score, only distance smaller than it will be
+ returned. Don't filter with it if < 0. Default is 0.
+ kwargs: Dict | Additional keyword arguments.
+
+ Returns:
+ QueryResults | The query results. Each query result is a list of list of tuples containing the document and
+ the distance.
+ """
+ embeddings = self.embedding_function(queries)
+ requests = [
+ models.SearchRequest(
+ vector=embedding,
+ limit=n_results,
+ score_threshold=distance_threshold,
+ with_payload=True,
+ with_vector=False,
+ )
+ for embedding in embeddings
+ ]
+
+ batch_results = self.client.search_batch(collection_name, requests)
+ return [self._scored_points_to_documents(results) for results in batch_results]
+
+ def get_docs_by_ids(
+ self, ids: List[ItemID] = None, collection_name: str = None, include=True, **kwargs
+ ) -> List[Document]:
+ """
+ Retrieve documents from the collection of the vector database based on the ids.
+
+ Args:
+ ids: List[ItemID] | A list of document ids. If None, will return all the documents. Default is None.
+ collection_name: str | The name of the collection. Default is None.
+ include: List[str] | The fields to include. Default is True.
+ If None, will include ["metadatas", "documents"], ids will always be included.
+ kwargs: dict | Additional keyword arguments.
+
+ Returns:
+ List[Document] | The results.
+ """
+ if ids is None:
+ results = self.client.scroll(collection_name=collection_name, with_payload=include, with_vectors=True)[0]
+ else:
+ results = self.client.retrieve(collection_name, ids=ids, with_payload=include, with_vectors=True)
+ return [self._point_to_document(result) for result in results]
+
+ def _point_to_document(self, point) -> Document:
+ return {
+ "id": point.id,
+ "content": point.payload.get(self.content_payload_key, ""),
+ "metadata": point.payload.get(self.metadata_payload_key, {}),
+ "embedding": point.vector,
+ }
+
+ def _points_to_documents(self, points) -> List[Document]:
+ return [self._point_to_document(point) for point in points]
+
+ def _scored_point_to_document(self, scored_point: models.ScoredPoint) -> Tuple[Document, float]:
+ return self._point_to_document(scored_point), scored_point.score
+
+ def _documents_to_points(self, documents: List[Document]):
+ contents = [document["content"] for document in documents]
+ embeddings = self.embedding_function(contents)
+ points = [
+ models.PointStruct(
+ id=documents[i]["id"],
+ vector=embeddings[i],
+ payload={
+ self.content_payload_key: documents[i].get("content"),
+ self.metadata_payload_key: documents[i].get("metadata"),
+ },
+ )
+ for i in range(len(documents))
+ ]
+ return points
+
+ def _scored_points_to_documents(self, scored_points: List[models.ScoredPoint]) -> List[Tuple[Document, float]]:
+ return [self._scored_point_to_document(scored_point) for scored_point in scored_points]
+
+ def _validate_update_ids(self, collection_name: str, ids: List[str]) -> bool:
+ """
+ Validates all the IDs exist in the collection
+ """
+ retrieved_ids = [
+ point.id for point in self.client.retrieve(collection_name, ids=ids, with_payload=False, with_vectors=False)
+ ]
+
+ if missing_ids := set(ids) - set(retrieved_ids):
+ logger.log(f"Missing IDs: {missing_ids}. Skipping update", level=logging.WARN)
+ return False
+
+ return True
+
+ def _validate_upsert_ids(self, collection_name: str, ids: List[str]) -> bool:
+ """
+ Validate none of the IDs exist in the collection
+ """
+ retrieved_ids = [
+ point.id for point in self.client.retrieve(collection_name, ids=ids, with_payload=False, with_vectors=False)
+ ]
+
+ if existing_ids := set(ids) & set(retrieved_ids):
+ logger.log(f"Existing IDs: {existing_ids}.", level=logging.WARN)
+ return False
+
+ return True
diff --git a/autogen/agentchat/contrib/vectordb/utils.py b/autogen/agentchat/contrib/vectordb/utils.py
index ae1ef125251..7812f218654 100644
--- a/autogen/agentchat/contrib/vectordb/utils.py
+++ b/autogen/agentchat/contrib/vectordb/utils.py
@@ -25,6 +25,9 @@ def error(self, msg, *args, color="light_red", **kwargs):
def critical(self, msg, *args, color="red", **kwargs):
super().critical(colored(msg, color), *args, **kwargs)
+ def fatal(self, msg, *args, color="red", **kwargs):
+ super().fatal(colored(msg, color), *args, **kwargs)
+
def get_logger(name: str, level: int = logging.INFO) -> ColoredLogger:
logger = ColoredLogger(name, level)
@@ -96,15 +99,20 @@ def chroma_results_to_query_results(data_dict: Dict[str, List[List[Any]]], speci
]
"""
- keys = [key for key in data_dict if key != special_key]
+ keys = [
+ key
+ for key in data_dict
+ if key != special_key and data_dict[key] is not None and isinstance(data_dict[key][0], list)
+ ]
result = []
+ data_special_key = data_dict[special_key]
- for i in range(len(data_dict[special_key])):
+ for i in range(len(data_special_key)):
sub_result = []
- for j, distance in enumerate(data_dict[special_key][i]):
+ for j, distance in enumerate(data_special_key[i]):
sub_dict = {}
for key in keys:
- if data_dict[key] is not None and len(data_dict[key]) > i:
+ if len(data_dict[key]) > i:
sub_dict[key[:-1]] = data_dict[key][i][j] # remove 's' in the end from key
sub_result.append((sub_dict, distance))
result.append(sub_result)
diff --git a/autogen/agentchat/contrib/web_surfer.py b/autogen/agentchat/contrib/web_surfer.py
index 1a54aeebe15..f74915a9b40 100644
--- a/autogen/agentchat/contrib/web_surfer.py
+++ b/autogen/agentchat/contrib/web_surfer.py
@@ -34,13 +34,14 @@ def __init__(
description: Optional[str] = DEFAULT_DESCRIPTION,
is_termination_msg: Optional[Callable[[Dict[str, Any]], bool]] = None,
max_consecutive_auto_reply: Optional[int] = None,
- human_input_mode: Optional[str] = "TERMINATE",
+ human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "TERMINATE",
function_map: Optional[Dict[str, Callable]] = None,
code_execution_config: Union[Dict, Literal[False]] = False,
llm_config: Optional[Union[Dict, Literal[False]]] = None,
summarizer_llm_config: Optional[Union[Dict, Literal[False]]] = None,
default_auto_reply: Optional[Union[str, Dict, None]] = "",
browser_config: Optional[Union[Dict, None]] = None,
+ **kwargs,
):
super().__init__(
name=name,
@@ -53,6 +54,7 @@ def __init__(
code_execution_config=code_execution_config,
llm_config=llm_config,
default_auto_reply=default_auto_reply,
+ **kwargs,
)
self._create_summarizer_client(summarizer_llm_config, llm_config)
@@ -111,7 +113,9 @@ def _create_summarizer_client(self, summarizer_llm_config: Dict[str, Any], llm_c
self.summarizer_llm_config = summarizer_llm_config # type: ignore[assignment]
# Create the summarizer client
- self.summarization_client = None if self.summarizer_llm_config is False else OpenAIWrapper(**self.summarizer_llm_config) # type: ignore[arg-type]
+ self.summarization_client = (
+ None if self.summarizer_llm_config is False else OpenAIWrapper(**self.summarizer_llm_config)
+ ) # type: ignore[arg-type]
def _register_functions(self) -> None:
"""Register the functions for the inner assistant and user proxy."""
@@ -250,7 +254,7 @@ def _answer_from_page(
def _summarize_page(
url: Annotated[
Optional[str], "[Optional] The url of the page to summarize. (Defaults to current page)"
- ] = None
+ ] = None,
) -> str:
return _answer_from_page(url=url, question=None)
diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py
index 4ff1a9d051b..ed550128780 100644
--- a/autogen/agentchat/conversable_agent.py
+++ b/autogen/agentchat/conversable_agent.py
@@ -12,11 +12,11 @@
import sys
from collections import defaultdict
-from functools import partial
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union
from openai import BadRequestError
+from autogen.agentchat.chat import _post_process_carryover_item
from autogen.exception_utils import InvalidCarryOverType, SenderRequired
from .._pydantic import model_dump
@@ -37,7 +37,7 @@
from ..function_utils import get_function_schema, load_basemodels_if_needed, serialize_to_str
from ..io.base import IOStream
from ..oai.client import ModelClient, OpenAIWrapper
-from ..runtime_logging import log_new_agent, logging_enabled
+from ..runtime_logging import log_event, log_function_use, log_new_agent, logging_enabled
from .agent import Agent, LLMAgent
from .chat import ChatResult, a_initiate_chats, initiate_chats
from .utils import consolidate_chat_info, gather_usage_summary
@@ -82,6 +82,8 @@ def __init__(
llm_config: Optional[Union[Dict, Literal[False]]] = None,
default_auto_reply: Union[str, Dict] = "",
description: Optional[str] = None,
+ chat_messages: Optional[Dict[Agent, List[Dict]]] = None,
+ silent: Optional[bool] = None,
):
"""
Args:
@@ -127,6 +129,11 @@ def __init__(
default_auto_reply (str or dict): default auto reply when no code execution or llm-based reply is generated.
description (str): a short description of the agent. This description is used by other agents
(e.g. the GroupChatManager) to decide when to call upon this agent. (Default: system_message)
+ chat_messages (dict or None): the previous chat messages that this agent had in the past with other agents.
+ Can be used to give the agent a memory by providing the chat history. This will allow the agent to
+ resume previous had conversations. Defaults to an empty chat history.
+ silent (bool or None): (Experimental) whether to print the message sent. If None, will use the value of
+ silent in each function.
"""
# we change code_execution_config below and we have to make sure we don't change the input
# in case of UserProxyAgent, without this we could even change the default value {}
@@ -136,7 +143,11 @@ def __init__(
self._name = name
# a dictionary of conversations, default value is list
- self._oai_messages = defaultdict(list)
+ if chat_messages is None:
+ self._oai_messages = defaultdict(list)
+ else:
+ self._oai_messages = chat_messages
+
self._oai_system_message = [{"content": system_message, "role": "system"}]
self._description = description if description is not None else system_message
self._is_termination_msg = (
@@ -144,9 +155,16 @@ def __init__(
if is_termination_msg is not None
else (lambda x: content_str(x.get("content")) == "TERMINATE")
)
+ self.silent = silent
# Take a copy to avoid modifying the given dict
if isinstance(llm_config, dict):
- llm_config = copy.deepcopy(llm_config)
+ try:
+ llm_config = copy.deepcopy(llm_config)
+ except TypeError as e:
+ raise TypeError(
+ "Please implement __deepcopy__ method for each value class in llm_config to support deepcopy."
+ " Refer to the docs for more details: https://microsoft.github.io/autogen/docs/topics/llm_configuration#adding-http-client-in-llm_config-for-proxy"
+ ) from e
self._validate_llm_config(llm_config)
@@ -234,7 +252,7 @@ def __init__(
# Registered hooks are kept in lists, indexed by hookable method, to be called in their order of registration.
# New hookable methods should be added to this list as required to support new agent capabilities.
- self.hook_lists = {
+ self.hook_lists: Dict[str, List[Callable]] = {
"process_last_received_message": [],
"process_all_messages_before_reply": [],
"process_message_before_send": [],
@@ -254,6 +272,10 @@ def _validate_llm_config(self, llm_config):
)
self.client = None if self.llm_config is False else OpenAIWrapper(**self.llm_config)
+ @staticmethod
+ def _is_silent(agent: Agent, silent: Optional[bool] = False) -> bool:
+ return agent.silent if agent.silent is not None else silent
+
@property
def name(self) -> str:
"""Get the name of the agent."""
@@ -360,9 +382,9 @@ def replace_reply_func(self, old_reply_func: Callable, new_reply_func: Callable)
f["reply_func"] = new_reply_func
@staticmethod
- def _summary_from_nested_chats(
+ def _get_chats_to_run(
chat_queue: List[Dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any
- ) -> Tuple[bool, str]:
+ ) -> List[Dict[str, Any]]:
"""A simple chat reply function.
This function initiate one or a sequence of chats between the "recipient" and the agents in the
chat_queue.
@@ -389,22 +411,59 @@ def _summary_from_nested_chats(
if message:
current_c["message"] = message
chat_to_run.append(current_c)
+ return chat_to_run
+
+ @staticmethod
+ def _summary_from_nested_chats(
+ chat_queue: List[Dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any
+ ) -> Tuple[bool, Union[str, None]]:
+ """A simple chat reply function.
+ This function initiate one or a sequence of chats between the "recipient" and the agents in the
+ chat_queue.
+
+ It extracts and returns a summary from the nested chat based on the "summary_method" in each chat in chat_queue.
+
+ Returns:
+ Tuple[bool, str]: A tuple where the first element indicates the completion of the chat, and the second element contains the summary of the last chat if any chats were initiated.
+ """
+ chat_to_run = ConversableAgent._get_chats_to_run(chat_queue, recipient, messages, sender, config)
if not chat_to_run:
return True, None
res = initiate_chats(chat_to_run)
return True, res[-1].summary
+ @staticmethod
+ async def _a_summary_from_nested_chats(
+ chat_queue: List[Dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any
+ ) -> Tuple[bool, Union[str, None]]:
+ """A simple chat reply function.
+ This function initiate one or a sequence of chats between the "recipient" and the agents in the
+ chat_queue.
+
+ It extracts and returns a summary from the nested chat based on the "summary_method" in each chat in chat_queue.
+
+ Returns:
+ Tuple[bool, str]: A tuple where the first element indicates the completion of the chat, and the second element contains the summary of the last chat if any chats were initiated.
+ """
+ chat_to_run = ConversableAgent._get_chats_to_run(chat_queue, recipient, messages, sender, config)
+ if not chat_to_run:
+ return True, None
+ res = await a_initiate_chats(chat_to_run)
+ index_of_last_chat = chat_to_run[-1]["chat_id"]
+ return True, res[index_of_last_chat].summary
+
def register_nested_chats(
self,
chat_queue: List[Dict[str, Any]],
trigger: Union[Type[Agent], str, Agent, Callable[[Agent], bool], List],
reply_func_from_nested_chats: Union[str, Callable] = "summary_from_nested_chats",
position: int = 2,
+ use_async: Union[bool, None] = None,
**kwargs,
) -> None:
"""Register a nested chat reply function.
Args:
- chat_queue (list): a list of chat objects to be initiated.
+ chat_queue (list): a list of chat objects to be initiated. If use_async is used, then all messages in chat_queue must have a chat-id associated with them.
trigger (Agent class, str, Agent instance, callable, or list): refer to `register_reply` for details.
reply_func_from_nested_chats (Callable, str): the reply function for the nested chat.
The function takes a chat_queue for nested chat, recipient agent, a list of messages, a sender agent and a config as input and returns a reply message.
@@ -419,20 +478,45 @@ def reply_func_from_nested_chats(
) -> Tuple[bool, Union[str, Dict, None]]:
```
position (int): Ref to `register_reply` for details. Default to 2. It means we first check the termination and human reply, then check the registered nested chat reply.
+ use_async: Uses a_initiate_chats internally to start nested chats. If the original chat is initiated with a_initiate_chats, you may set this to true so nested chats do not run in sync.
kwargs: Ref to `register_reply` for details.
"""
- if reply_func_from_nested_chats == "summary_from_nested_chats":
- reply_func_from_nested_chats = self._summary_from_nested_chats
- if not callable(reply_func_from_nested_chats):
- raise ValueError("reply_func_from_nested_chats must be a callable")
- reply_func = partial(reply_func_from_nested_chats, chat_queue)
+ if use_async:
+ for chat in chat_queue:
+ if chat.get("chat_id") is None:
+ raise ValueError("chat_id is required for async nested chats")
+
+ if use_async:
+ if reply_func_from_nested_chats == "summary_from_nested_chats":
+ reply_func_from_nested_chats = self._a_summary_from_nested_chats
+ if not callable(reply_func_from_nested_chats) or not inspect.iscoroutinefunction(
+ reply_func_from_nested_chats
+ ):
+ raise ValueError("reply_func_from_nested_chats must be a callable and a coroutine")
+
+ async def wrapped_reply_func(recipient, messages=None, sender=None, config=None):
+ return await reply_func_from_nested_chats(chat_queue, recipient, messages, sender, config)
+
+ else:
+ if reply_func_from_nested_chats == "summary_from_nested_chats":
+ reply_func_from_nested_chats = self._summary_from_nested_chats
+ if not callable(reply_func_from_nested_chats):
+ raise ValueError("reply_func_from_nested_chats must be a callable")
+
+ def wrapped_reply_func(recipient, messages=None, sender=None, config=None):
+ return reply_func_from_nested_chats(chat_queue, recipient, messages, sender, config)
+
+ functools.update_wrapper(wrapped_reply_func, reply_func_from_nested_chats)
+
self.register_reply(
trigger,
- reply_func,
+ wrapped_reply_func,
position,
kwargs.get("config"),
kwargs.get("reset_config"),
- ignore_async_in_sync_chat=kwargs.get("ignore_async_in_sync_chat"),
+ ignore_async_in_sync_chat=(
+ not use_async if use_async is not None else kwargs.get("ignore_async_in_sync_chat")
+ ),
)
@property
@@ -542,7 +626,7 @@ def _assert_valid_name(name):
raise ValueError(f"Invalid name: {name}. Name must be less than 64 characters.")
return name
- def _append_oai_message(self, message: Union[Dict, str], role, conversation_id: Agent) -> bool:
+ def _append_oai_message(self, message: Union[Dict, str], role, conversation_id: Agent, is_sending: bool) -> bool:
"""Append a message to the ChatCompletion conversation.
If the message received is a string, it will be put in the "content" field of the new dictionary.
@@ -554,6 +638,7 @@ def _append_oai_message(self, message: Union[Dict, str], role, conversation_id:
message (dict or str): message to be appended to the ChatCompletion conversation.
role (str): role of the message, can be "assistant" or "function".
conversation_id (Agent): id of the conversation, should be the recipient or sender.
+ is_sending (bool): If the agent (aka self) is sending to the conversation_id agent, otherwise receiving.
Returns:
bool: whether the message is appended to the ChatCompletion conversation.
@@ -573,12 +658,25 @@ def _append_oai_message(self, message: Union[Dict, str], role, conversation_id:
if message.get("role") in ["function", "tool"]:
oai_message["role"] = message.get("role")
+ elif "override_role" in message:
+ # If we have a direction to override the role then set the
+ # role accordingly. Used to customise the role for the
+ # select speaker prompt.
+ oai_message["role"] = message.get("override_role")
else:
oai_message["role"] = role
if oai_message.get("function_call", False) or oai_message.get("tool_calls", False):
oai_message["role"] = "assistant" # only messages with role 'assistant' can have a function call.
+ elif "name" not in oai_message:
+ # If we don't have a name field, append it
+ if is_sending:
+ oai_message["name"] = self.name
+ else:
+ oai_message["name"] = conversation_id.name
+
self._oai_messages[conversation_id].append(oai_message)
+
return True
def _process_message_before_send(
@@ -587,7 +685,9 @@ def _process_message_before_send(
"""Process the message before sending it to the recipient."""
hook_list = self.hook_lists["process_message_before_send"]
for hook in hook_list:
- message = hook(sender=self, message=message, recipient=recipient, silent=silent)
+ message = hook(
+ sender=self, message=message, recipient=recipient, silent=ConversableAgent._is_silent(self, silent)
+ )
return message
def send(
@@ -629,10 +729,10 @@ def send(
Raises:
ValueError: if the message can't be converted into a valid ChatCompletion message.
"""
- message = self._process_message_before_send(message, recipient, silent)
+ message = self._process_message_before_send(message, recipient, ConversableAgent._is_silent(self, silent))
# When the agent composes and sends the message, the role of the message is "assistant"
# unless it's "function".
- valid = self._append_oai_message(message, "assistant", recipient)
+ valid = self._append_oai_message(message, "assistant", recipient, is_sending=True)
if valid:
recipient.receive(message, self, request_reply, silent)
else:
@@ -679,10 +779,10 @@ async def a_send(
Raises:
ValueError: if the message can't be converted into a valid ChatCompletion message.
"""
- message = self._process_message_before_send(message, recipient, silent)
+ message = self._process_message_before_send(message, recipient, ConversableAgent._is_silent(self, silent))
# When the agent composes and sends the message, the role of the message is "assistant"
# unless it's "function".
- valid = self._append_oai_message(message, "assistant", recipient)
+ valid = self._append_oai_message(message, "assistant", recipient, is_sending=True)
if valid:
await recipient.a_receive(message, self, request_reply, silent)
else:
@@ -753,12 +853,16 @@ def _print_received_message(self, message: Union[Dict, str], sender: Agent):
def _process_received_message(self, message: Union[Dict, str], sender: Agent, silent: bool):
# When the agent receives a message, the role of the message is "user". (If 'role' exists and is 'function', it will remain unchanged.)
- valid = self._append_oai_message(message, "user", sender)
+ valid = self._append_oai_message(message, "user", sender, is_sending=False)
+ if logging_enabled():
+ log_event(self, "received_message", message=message, sender=sender.name, valid=valid)
+
if not valid:
raise ValueError(
"Received message can't be converted into a valid ChatCompletion message. Either content or function_call must be provided."
)
- if not silent:
+
+ if not ConversableAgent._is_silent(sender, silent):
self._print_received_message(message, sender)
def receive(
@@ -929,6 +1033,7 @@ def my_summary_method(
One example key is "summary_prompt", and value is a string of text used to prompt a LLM-based agent (the sender or receiver agent) to reflect
on the conversation and extract a summary when summary_method is "reflection_with_llm".
The default summary_prompt is DEFAULT_SUMMARY_PROMPT, i.e., "Summarize takeaway from the conversation. Do not add any introductory phrases. If the intended request is NOT properly addressed, please point it out."
+ Another available key is "summary_role", which is the role of the message sent to the agent in charge of summarizing. Default is "system".
message (str, dict or Callable): the initial message to be sent to the recipient. Needs to be provided. Otherwise, input() will be called to get the initial message.
- If a string or a dict is provided, it will be used as the initial message. `generate_init_message` is called to generate the initial message for the agent based on this string and the context.
If dict, it may contain the following reserved fields (either content or tool_calls need to be provided).
@@ -1138,11 +1243,18 @@ def my_summary_method(
@staticmethod
def _last_msg_as_summary(sender, recipient, summary_args) -> str:
"""Get a chat summary from the last message of the recipient."""
+ summary = ""
try:
- summary = recipient.last_message(sender)["content"].replace("TERMINATE", "")
+ content = recipient.last_message(sender)["content"]
+ if isinstance(content, str):
+ summary = content.replace("TERMINATE", "")
+ elif isinstance(content, list):
+ # Remove the `TERMINATE` word in the content list.
+ summary = "\n".join(
+ x["text"].replace("TERMINATE", "") for x in content if isinstance(x, dict) and "text" in x
+ )
except (IndexError, AttributeError) as e:
warnings.warn(f"Cannot extract summary using last_msg: {e}. Using an empty str as summary.", UserWarning)
- summary = ""
return summary
@staticmethod
@@ -1153,8 +1265,13 @@ def _reflection_with_llm_as_summary(sender, recipient, summary_args):
raise ValueError("The summary_prompt must be a string.")
msg_list = recipient.chat_messages_for_summary(sender)
agent = sender if recipient is None else recipient
+ role = summary_args.get("summary_role", None)
+ if role and not isinstance(role, str):
+ raise ValueError("The summary_role in summary_arg must be a string.")
try:
- summary = sender._reflection_with_llm(prompt, msg_list, llm_agent=agent, cache=summary_args.get("cache"))
+ summary = sender._reflection_with_llm(
+ prompt, msg_list, llm_agent=agent, cache=summary_args.get("cache"), role=role
+ )
except BadRequestError as e:
warnings.warn(
f"Cannot extract summary using reflection_with_llm: {e}. Using an empty str as summary.", UserWarning
@@ -1163,7 +1280,12 @@ def _reflection_with_llm_as_summary(sender, recipient, summary_args):
return summary
def _reflection_with_llm(
- self, prompt, messages, llm_agent: Optional[Agent] = None, cache: Optional[AbstractCache] = None
+ self,
+ prompt,
+ messages,
+ llm_agent: Optional[Agent] = None,
+ cache: Optional[AbstractCache] = None,
+ role: Union[str, None] = None,
) -> str:
"""Get a chat summary using reflection with an llm client based on the conversation history.
@@ -1172,10 +1294,14 @@ def _reflection_with_llm(
messages (list): The messages generated as part of a chat conversation.
llm_agent: the agent with an llm client.
cache (AbstractCache or None): the cache client to be used for this conversation.
+ role (str): the role of the message, usually "system" or "user". Default is "system".
"""
+ if not role:
+ role = "system"
+
system_msg = [
{
- "role": "system",
+ "role": role,
"content": prompt,
}
]
@@ -1190,6 +1316,23 @@ def _reflection_with_llm(
response = self._generate_oai_reply_from_client(llm_client=llm_client, messages=messages, cache=cache)
return response
+ def _check_chat_queue_for_sender(self, chat_queue: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ """
+ Check the chat queue and add the "sender" key if it's missing.
+
+ Args:
+ chat_queue (List[Dict[str, Any]]): A list of dictionaries containing chat information.
+
+ Returns:
+ List[Dict[str, Any]]: A new list of dictionaries with the "sender" key added if it was missing.
+ """
+ chat_queue_with_sender = []
+ for chat_info in chat_queue:
+ if chat_info.get("sender") is None:
+ chat_info["sender"] = self
+ chat_queue_with_sender.append(chat_info)
+ return chat_queue_with_sender
+
def initiate_chats(self, chat_queue: List[Dict[str, Any]]) -> List[ChatResult]:
"""(Experimental) Initiate chats with multiple agents.
@@ -1199,16 +1342,12 @@ def initiate_chats(self, chat_queue: List[Dict[str, Any]]) -> List[ChatResult]:
Returns: a list of ChatResult objects corresponding to the finished chats in the chat_queue.
"""
- _chat_queue = chat_queue.copy()
- for chat_info in _chat_queue:
- chat_info["sender"] = self
+ _chat_queue = self._check_chat_queue_for_sender(chat_queue)
self._finished_chats = initiate_chats(_chat_queue)
return self._finished_chats
async def a_initiate_chats(self, chat_queue: List[Dict[str, Any]]) -> Dict[int, ChatResult]:
- _chat_queue = chat_queue.copy()
- for chat_info in _chat_queue:
- chat_info["sender"] = self
+ _chat_queue = self._check_chat_queue_for_sender(chat_queue)
self._finished_chats = await a_initiate_chats(_chat_queue)
return self._finished_chats
@@ -1314,14 +1453,12 @@ def _generate_oai_reply_from_client(self, llm_client, messages, cache) -> Union[
# TODO: #1143 handle token limit exceeded error
response = llm_client.create(
- context=messages[-1].pop("context", None),
- messages=all_messages,
- cache=cache,
+ context=messages[-1].pop("context", None), messages=all_messages, cache=cache, agent=self
)
extracted_response = llm_client.extract_text_or_completion_object(response)[0]
if extracted_response is None:
- warnings.warn("Extracted_response from {response} is None.", UserWarning)
+ warnings.warn(f"Extracted_response from {response} is None.", UserWarning)
return None
# ensure function and tool calls will be accepted when sent back to the LLM
if not isinstance(extracted_response, str) and hasattr(extracted_response, "model_dump"):
@@ -1681,7 +1818,7 @@ def check_termination_and_human_reply(
sender_name = "the sender" if sender is None else sender.name
if self.human_input_mode == "ALWAYS":
reply = self.get_human_input(
- f"Provide feedback to {sender_name}. Press enter to skip and use auto-reply, or type 'exit' to end the conversation: "
+ f"Replying as {self.name}. Provide feedback to {sender_name}. Press enter to skip and use auto-reply, or type 'exit' to end the conversation: "
)
no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else ""
# if the human input is empty, and the message is a termination message, then we will terminate the conversation
@@ -1794,7 +1931,7 @@ async def a_check_termination_and_human_reply(
sender_name = "the sender" if sender is None else sender.name
if self.human_input_mode == "ALWAYS":
reply = await self.a_get_human_input(
- f"Provide feedback to {sender_name}. Press enter to skip and use auto-reply, or type 'exit' to end the conversation: "
+ f"Replying as {self.name}. Provide feedback to {sender_name}. Press enter to skip and use auto-reply, or type 'exit' to end the conversation: "
)
no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else ""
# if the human input is empty, and the message is a termination message, then we will terminate the conversation
@@ -1929,6 +2066,15 @@ def generate_reply(
continue
if self._match_trigger(reply_func_tuple["trigger"], sender):
final, reply = reply_func(self, messages=messages, sender=sender, config=reply_func_tuple["config"])
+ if logging_enabled():
+ log_event(
+ self,
+ "reply_func_executed",
+ reply_func_module=reply_func.__module__,
+ reply_func_name=reply_func.__name__,
+ final=final,
+ reply=reply,
+ )
if final:
return reply
return self._default_auto_reply
@@ -2134,7 +2280,7 @@ def _format_json_str(jstr):
Ex 2:
"{\n \"location\": \"Boston, MA\"\n}" -> "{"location": "Boston, MA"}"
- 2. this function also handles JSON escape sequences inside quotes,
+ 2. this function also handles JSON escape sequences inside quotes.
Ex 1:
'{"args": "a\na\na\ta"}' -> '{"args": "a\\na\\na\\ta"}'
"""
@@ -2183,7 +2329,7 @@ def execute_function(self, func_call, verbose: bool = False) -> Tuple[bool, Dict
arguments = json.loads(input_string)
except json.JSONDecodeError as e:
arguments = None
- content = f"Error: {e}\n You argument should follow json format."
+ content = f"Error: {e}\n The argument must be in JSON format."
# Try to execute the function
if arguments is not None:
@@ -2240,7 +2386,7 @@ async def a_execute_function(self, func_call):
arguments = json.loads(input_string)
except json.JSONDecodeError as e:
arguments = None
- content = f"Error: {e}\n You argument should follow json format."
+ content = f"Error: {e}\n The argument must be in JSON format."
# Try to execute the function
if arguments is not None:
@@ -2314,7 +2460,7 @@ def _process_carryover(self, content: str, kwargs: dict) -> str:
if isinstance(kwargs["carryover"], str):
content += "\nContext: \n" + kwargs["carryover"]
elif isinstance(kwargs["carryover"], list):
- content += "\nContext: \n" + ("\n").join([t for t in kwargs["carryover"]])
+ content += "\nContext: \n" + ("\n").join([_post_process_carryover_item(t) for t in kwargs["carryover"]])
else:
raise InvalidCarryOverType(
"Carryover should be a string or a list of strings. Not adding carryover to the message."
@@ -2354,6 +2500,8 @@ def register_function(self, function_map: Dict[str, Union[Callable, None]]):
self._assert_valid_name(name)
if func is None and name not in self._function_map.keys():
warnings.warn(f"The function {name} to remove doesn't exist", name)
+ if name in self._function_map:
+ warnings.warn(f"Function '{name}' is being overridden.", UserWarning)
self._function_map.update(function_map)
self._function_map = {k: v for k, v in self._function_map.items() if v is not None}
@@ -2390,6 +2538,9 @@ def update_function_signature(self, func_sig: Union[str, Dict], is_remove: None)
self._assert_valid_name(func_sig["name"])
if "functions" in self.llm_config.keys():
+ if any(func["name"] == func_sig["name"] for func in self.llm_config["functions"]):
+ warnings.warn(f"Function '{func_sig['name']}' is being overridden.", UserWarning)
+
self.llm_config["functions"] = [
func for func in self.llm_config["functions"] if func.get("name") != func_sig["name"]
] + [func_sig]
@@ -2429,7 +2580,9 @@ def update_tool_signature(self, tool_sig: Union[str, Dict], is_remove: None):
f"The tool signature must be of the type dict. Received tool signature type {type(tool_sig)}"
)
self._assert_valid_name(tool_sig["function"]["name"])
- if "tools" in self.llm_config.keys():
+ if "tools" in self.llm_config:
+ if any(tool["function"]["name"] == tool_sig["function"]["name"] for tool in self.llm_config["tools"]):
+ warnings.warn(f"Function '{tool_sig['function']['name']}' is being overridden.", UserWarning)
self.llm_config["tools"] = [
tool
for tool in self.llm_config["tools"]
@@ -2469,13 +2622,16 @@ def _wrap_function(self, func: F) -> F:
@functools.wraps(func)
def _wrapped_func(*args, **kwargs):
retval = func(*args, **kwargs)
-
+ if logging_enabled():
+ log_function_use(self, func, kwargs, retval)
return serialize_to_str(retval)
@load_basemodels_if_needed
@functools.wraps(func)
async def _a_wrapped_func(*args, **kwargs):
retval = await func(*args, **kwargs)
+ if logging_enabled():
+ log_function_use(self, func, kwargs, retval)
return serialize_to_str(retval)
wrapped_func = _a_wrapped_func if inspect.iscoroutinefunction(func) else _wrapped_func
@@ -2665,7 +2821,7 @@ def process_all_messages_before_reply(self, messages: List[Dict]) -> List[Dict]:
processed_messages = hook(processed_messages)
return processed_messages
- def process_last_received_message(self, messages):
+ def process_last_received_message(self, messages: List[Dict]) -> List[Dict]:
"""
Calls any registered capability hooks to use and potentially modify the text of the last message,
as long as the last message is not a function call or exit command.
@@ -2699,6 +2855,7 @@ def process_last_received_message(self, messages):
processed_user_content = user_content
for hook in hook_list:
processed_user_content = hook(processed_user_content)
+
if processed_user_content == user_content:
return messages # No hooks actually modified the user's message.
diff --git a/autogen/agentchat/groupchat.py b/autogen/agentchat/groupchat.py
index f5b6106863a..2ebdf95b7d3 100644
--- a/autogen/agentchat/groupchat.py
+++ b/autogen/agentchat/groupchat.py
@@ -1,18 +1,28 @@
+import copy
+import json
import logging
import random
import re
import sys
from dataclasses import dataclass, field
-from typing import Callable, Dict, List, Literal, Optional, Tuple, Union
+from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from ..code_utils import content_str
from ..exception_utils import AgentNameConflict, NoEligibleSpeaker, UndefinedNextAgent
+from ..formatting_utils import colored
from ..graph_utils import check_graph_validity, invert_disallowed_to_allowed
from ..io.base import IOStream
from ..runtime_logging import log_new_agent, logging_enabled
from .agent import Agent
+from .chat import ChatResult
from .conversable_agent import ConversableAgent
+try:
+ # Non-core module
+ from .contrib.capabilities import transform_messages
+except ImportError:
+ transform_messages = None
+
logger = logging.getLogger(__name__)
@@ -28,13 +38,29 @@ class GroupChat:
When set to True and when a message is a function call suggestion,
the next speaker will be chosen from an agent which contains the corresponding function name
in its `function_map`.
- - select_speaker_message_template: customize the select speaker message (used in "auto" speaker selection), which appears first in the message context and generally includes the agent descriptions and list of agents. The string value will be converted to an f-string, use "{roles}" to output the agent's and their role descriptions and "{agentlist}" for a comma-separated list of agent names in square brackets. The default value is:
+ - select_speaker_message_template: customize the select speaker message (used in "auto" speaker selection), which appears first in the message context and generally includes the agent descriptions and list of agents. If the string contains "{roles}" it will replaced with the agent's and their role descriptions. If the string contains "{agentlist}" it will be replaced with a comma-separated list of agent names in square brackets. The default value is:
"You are in a role play game. The following roles are available:
{roles}.
Read the following conversation.
Then select the next role from {agentlist} to play. Only return the role."
- - select_speaker_prompt_template: customize the select speaker prompt (used in "auto" speaker selection), which appears last in the message context and generally includes the list of agents and guidance for the LLM to select the next agent. The string value will be converted to an f-string, use "{agentlist}" for a comma-separated list of agent names in square brackets. The default value is:
+ - select_speaker_prompt_template: customize the select speaker prompt (used in "auto" speaker selection), which appears last in the message context and generally includes the list of agents and guidance for the LLM to select the next agent. If the string contains "{agentlist}" it will be replaced with a comma-separated list of agent names in square brackets. The default value is:
"Read the above conversation. Then select the next role from {agentlist} to play. Only return the role."
+ To ignore this prompt being used, set this to None. If set to None, ensure your instructions for selecting a speaker are in the select_speaker_message_template string.
+ - select_speaker_auto_multiple_template: customize the follow-up prompt used when selecting a speaker fails with a response that contains multiple agent names. This prompt guides the LLM to return just one agent name. Applies only to "auto" speaker selection method. If the string contains "{agentlist}" it will be replaced with a comma-separated list of agent names in square brackets. The default value is:
+ "You provided more than one name in your text, please return just the name of the next speaker. To determine the speaker use these prioritised rules:
+ 1. If the context refers to themselves as a speaker e.g. "As the..." , choose that speaker's name
+ 2. If it refers to the "next" speaker name, choose that name
+ 3. Otherwise, choose the first provided speaker's name in the context
+ The names are case-sensitive and should not be abbreviated or changed.
+ Respond with ONLY the name of the speaker and DO NOT provide a reason."
+ - select_speaker_auto_none_template: customize the follow-up prompt used when selecting a speaker fails with a response that contains no agent names. This prompt guides the LLM to return an agent name and provides a list of agent names. Applies only to "auto" speaker selection method. If the string contains "{agentlist}" it will be replaced with a comma-separated list of agent names in square brackets. The default value is:
+ "You didn't choose a speaker. As a reminder, to determine the speaker use these prioritised rules:
+ 1. If the context refers to themselves as a speaker e.g. "As the..." , choose that speaker's name
+ 2. If it refers to the "next" speaker name, choose that name
+ 3. Otherwise, choose the first provided speaker's name in the context
+ The names are case-sensitive and should not be abbreviated or changed.
+ The only names that are accepted are {agentlist}.
+ Respond with ONLY the name of the speaker and DO NOT provide a reason."
- speaker_selection_method: the method for selecting the next speaker. Default is "auto".
Could be any of the following (case insensitive), will raise ValueError if not recognized:
- "auto": the next speaker is selected automatically by LLM.
@@ -51,6 +77,17 @@ def custom_speaker_selection_func(
last_speaker: Agent, groupchat: GroupChat
) -> Union[Agent, str, None]:
```
+ - max_retries_for_selecting_speaker: the maximum number of times the speaker selection requery process will run.
+ If, during speaker selection, multiple agent names or no agent names are returned by the LLM as the next agent, it will be queried again up to the maximum number
+ of times until a single agent is returned or it exhausts the maximum attempts.
+ Applies only to "auto" speaker selection method.
+ Default is 2.
+ - select_speaker_transform_messages: (optional) the message transformations to apply to the nested select speaker agent-to-agent chat messages.
+ Takes a TransformMessages object, defaults to None and is only utilised when the speaker selection method is "auto".
+ - select_speaker_auto_verbose: whether to output the select speaker responses and selections
+ If set to True, the outputs from the two agents in the nested select speaker chat will be output, along with
+ whether the responses were successful, or not, in selecting an agent
+ Applies only to "auto" speaker selection method.
- allow_repeat_speaker: whether to allow the same speaker to speak consecutively.
Default is True, in which case all speakers are allowed to speak consecutively.
If `allow_repeat_speaker` is a list of Agents, then only those listed agents are allowed to repeat.
@@ -73,14 +110,15 @@ def custom_speaker_selection_func(
agents: List[Agent]
messages: List[Dict]
- max_round: Optional[int] = 10
- admin_name: Optional[str] = "Admin"
- func_call_filter: Optional[bool] = True
+ max_round: int = 10
+ admin_name: str = "Admin"
+ func_call_filter: bool = True
speaker_selection_method: Union[Literal["auto", "manual", "random", "round_robin"], Callable] = "auto"
+ max_retries_for_selecting_speaker: int = 2
allow_repeat_speaker: Optional[Union[bool, List[Agent]]] = None
allowed_or_disallowed_speaker_transitions: Optional[Dict] = None
speaker_transitions_type: Literal["allowed", "disallowed", None] = None
- enable_clear_history: Optional[bool] = False
+ enable_clear_history: bool = False
send_introductions: bool = False
select_speaker_message_template: str = """You are in a role play game. The following roles are available:
{roles}.
@@ -89,6 +127,21 @@ def custom_speaker_selection_func(
select_speaker_prompt_template: str = (
"Read the above conversation. Then select the next role from {agentlist} to play. Only return the role."
)
+ select_speaker_auto_multiple_template: str = """You provided more than one name in your text, please return just the name of the next speaker. To determine the speaker use these prioritised rules:
+ 1. If the context refers to themselves as a speaker e.g. "As the..." , choose that speaker's name
+ 2. If it refers to the "next" speaker name, choose that name
+ 3. Otherwise, choose the first provided speaker's name in the context
+ The names are case-sensitive and should not be abbreviated or changed.
+ Respond with ONLY the name of the speaker and DO NOT provide a reason."""
+ select_speaker_auto_none_template: str = """You didn't choose a speaker. As a reminder, to determine the speaker use these prioritised rules:
+ 1. If the context refers to themselves as a speaker e.g. "As the..." , choose that speaker's name
+ 2. If it refers to the "next" speaker name, choose that name
+ 3. Otherwise, choose the first provided speaker's name in the context
+ The names are case-sensitive and should not be abbreviated or changed.
+ The only names that are accepted are {agentlist}.
+ Respond with ONLY the name of the speaker and DO NOT provide a reason."""
+ select_speaker_transform_messages: Optional[Any] = None
+ select_speaker_auto_verbose: Optional[bool] = False
role_for_select_speaker_messages: Optional[str] = "system"
_VALID_SPEAKER_SELECTION_METHODS = ["auto", "manual", "random", "round_robin"]
@@ -178,16 +231,51 @@ def __post_init__(self):
agents=self.agents,
)
- # Check select_speaker_message_template and select_speaker_prompt_template have values
+ # Check select speaker messages, prompts, roles, and retries have values
if self.select_speaker_message_template is None or len(self.select_speaker_message_template) == 0:
raise ValueError("select_speaker_message_template cannot be empty or None.")
- if self.select_speaker_prompt_template is None or len(self.select_speaker_prompt_template) == 0:
- raise ValueError("select_speaker_prompt_template cannot be empty or None.")
+ if self.select_speaker_prompt_template is not None and len(self.select_speaker_prompt_template) == 0:
+ self.select_speaker_prompt_template = None
if self.role_for_select_speaker_messages is None or len(self.role_for_select_speaker_messages) == 0:
raise ValueError("role_for_select_speaker_messages cannot be empty or None.")
+ if self.select_speaker_auto_multiple_template is None or len(self.select_speaker_auto_multiple_template) == 0:
+ raise ValueError("select_speaker_auto_multiple_template cannot be empty or None.")
+
+ if self.select_speaker_auto_none_template is None or len(self.select_speaker_auto_none_template) == 0:
+ raise ValueError("select_speaker_auto_none_template cannot be empty or None.")
+
+ if self.max_retries_for_selecting_speaker is None or len(self.role_for_select_speaker_messages) == 0:
+ raise ValueError("role_for_select_speaker_messages cannot be empty or None.")
+
+ # Validate max select speakers retries
+ if self.max_retries_for_selecting_speaker is None or not isinstance(
+ self.max_retries_for_selecting_speaker, int
+ ):
+ raise ValueError("max_retries_for_selecting_speaker cannot be None or non-int")
+ elif self.max_retries_for_selecting_speaker < 0:
+ raise ValueError("max_retries_for_selecting_speaker must be greater than or equal to zero")
+
+ # Load message transforms here (load once for the Group Chat so we don't have to re-initiate it and it maintains the cache across subsequent select speaker calls)
+ self._speaker_selection_transforms = None
+ if self.select_speaker_transform_messages is not None:
+ if transform_messages is not None:
+ if isinstance(self.select_speaker_transform_messages, transform_messages.TransformMessages):
+ self._speaker_selection_transforms = self.select_speaker_transform_messages
+ else:
+ raise ValueError("select_speaker_transform_messages must be None or MessageTransforms.")
+ else:
+ logger.warning(
+ "TransformMessages could not be loaded, the 'select_speaker_transform_messages' transform"
+ "will not apply."
+ )
+
+ # Validate select_speaker_auto_verbose
+ if self.select_speaker_auto_verbose is None or not isinstance(self.select_speaker_auto_verbose, bool):
+ raise ValueError("select_speaker_auto_verbose cannot be None or non-bool")
+
@property
def agent_names(self) -> List[str]:
"""Return the names of the agents in the group chat."""
@@ -266,7 +354,13 @@ def select_speaker_msg(self, agents: Optional[List[Agent]] = None) -> str:
return return_msg
def select_speaker_prompt(self, agents: Optional[List[Agent]] = None) -> str:
- """Return the floating system prompt selecting the next speaker. This is always the *last* message in the context."""
+ """Return the floating system prompt selecting the next speaker.
+ This is always the *last* message in the context.
+ Will return None if the select_speaker_prompt_template is None."""
+
+ if self.select_speaker_prompt_template is None:
+ return None
+
if agents is None:
agents = self.agents
@@ -450,33 +544,34 @@ def _prepare_and_select_agents(
select_speaker_messages[-1] = dict(select_speaker_messages[-1], function_call=None)
if select_speaker_messages[-1].get("tool_calls", False):
select_speaker_messages[-1] = dict(select_speaker_messages[-1], tool_calls=None)
- select_speaker_messages = select_speaker_messages + [
- {
- "role": self.role_for_select_speaker_messages,
- "content": self.select_speaker_prompt(graph_eligible_agents),
- }
- ]
return selected_agent, graph_eligible_agents, select_speaker_messages
def select_speaker(self, last_speaker: Agent, selector: ConversableAgent) -> Agent:
- """Select the next speaker."""
+ """Select the next speaker (with requery)."""
+
+ # Prepare the list of available agents and select an agent if selection method allows (non-auto)
selected_agent, agents, messages = self._prepare_and_select_agents(last_speaker)
if selected_agent:
return selected_agent
- # auto speaker selection
- selector.update_system_message(self.select_speaker_msg(agents))
- final, name = selector.generate_oai_reply(messages)
- return self._finalize_speaker(last_speaker, final, name, agents)
+ elif self.speaker_selection_method == "manual":
+ # An agent has not been selected while in manual mode, so move to the next agent
+ return self.next_agent(last_speaker)
+
+ # auto speaker selection with 2-agent chat
+ return self._auto_select_speaker(last_speaker, selector, messages, agents)
async def a_select_speaker(self, last_speaker: Agent, selector: ConversableAgent) -> Agent:
- """Select the next speaker."""
+ """Select the next speaker (with requery), asynchronously."""
+
selected_agent, agents, messages = self._prepare_and_select_agents(last_speaker)
if selected_agent:
return selected_agent
- # auto speaker selection
- selector.update_system_message(self.select_speaker_msg(agents))
- final, name = await selector.a_generate_oai_reply(messages)
- return self._finalize_speaker(last_speaker, final, name, agents)
+ elif self.speaker_selection_method == "manual":
+ # An agent has not been selected while in manual mode, so move to the next agent
+ return self.next_agent(last_speaker)
+
+ # auto speaker selection with 2-agent chat
+ return await self.a_auto_select_speaker(last_speaker, selector, messages, agents)
def _finalize_speaker(self, last_speaker: Agent, final: bool, name: str, agents: Optional[List[Agent]]) -> Agent:
if not final:
@@ -496,6 +591,324 @@ def _finalize_speaker(self, last_speaker: Agent, final: bool, name: str, agents:
agent = self.agent_by_name(name)
return agent if agent else self.next_agent(last_speaker, agents)
+ def _auto_select_speaker(
+ self,
+ last_speaker: Agent,
+ selector: ConversableAgent,
+ messages: Optional[List[Dict]],
+ agents: Optional[List[Agent]],
+ ) -> Agent:
+ """Selects next speaker for the "auto" speaker selection method. Utilises its own two-agent chat to determine the next speaker and supports requerying.
+
+ Speaker selection for "auto" speaker selection method:
+ 1. Create a two-agent chat with a speaker selector agent and a speaker validator agent, like a nested chat
+ 2. Inject the group messages into the new chat
+ 3. Run the two-agent chat, evaluating the result of response from the speaker selector agent:
+ - If a single agent is provided then we return it and finish. If not, we add an additional message to this nested chat in an attempt to guide the LLM to a single agent response
+ 4. Chat continues until a single agent is nominated or there are no more attempts left
+ 5. If we run out of turns and no single agent can be determined, the next speaker in the list of agents is returned
+
+ Args:
+ last_speaker Agent: The previous speaker in the group chat
+ selector ConversableAgent:
+ messages Optional[List[Dict]]: Current chat messages
+ agents Optional[List[Agent]]: Valid list of agents for speaker selection
+
+ Returns:
+ Dict: a counter for mentioned agents.
+ """
+
+ # If no agents are passed in, assign all the group chat's agents
+ if agents is None:
+ agents = self.agents
+
+ # The maximum number of speaker selection attempts (including requeries)
+ # is the initial speaker selection attempt plus the maximum number of retries.
+ # We track these and use them in the validation function as we can't
+ # access the max_turns from within validate_speaker_name.
+ max_attempts = 1 + self.max_retries_for_selecting_speaker
+ attempts_left = max_attempts
+ attempt = 0
+
+ # Registered reply function for checking_agent, checks the result of the response for agent names
+ def validate_speaker_name(recipient, messages, sender, config) -> Tuple[bool, Union[str, Dict, None]]:
+ # The number of retries left, starting at max_retries_for_selecting_speaker
+ nonlocal attempts_left
+ nonlocal attempt
+
+ attempt = attempt + 1
+ attempts_left = attempts_left - 1
+
+ return self._validate_speaker_name(recipient, messages, sender, config, attempts_left, attempt, agents)
+
+ # Two-agent chat for speaker selection
+
+ # Agent for checking the response from the speaker_select_agent
+ checking_agent = ConversableAgent("checking_agent", default_auto_reply=max_attempts)
+
+ # Register the speaker validation function with the checking agent
+ checking_agent.register_reply(
+ [ConversableAgent, None],
+ reply_func=validate_speaker_name, # Validate each response
+ remove_other_reply_funcs=True,
+ )
+
+ # NOTE: Do we have a speaker prompt (select_speaker_prompt_template is not None)? If we don't, we need to feed in the last message to start the nested chat
+
+ # Agent for selecting a single agent name from the response
+ speaker_selection_agent = ConversableAgent(
+ "speaker_selection_agent",
+ system_message=self.select_speaker_msg(agents),
+ chat_messages=(
+ {checking_agent: messages}
+ if self.select_speaker_prompt_template is not None
+ else {checking_agent: messages[:-1]}
+ ),
+ llm_config=selector.llm_config,
+ human_input_mode="NEVER", # Suppresses some extra terminal outputs, outputs will be handled by select_speaker_auto_verbose
+ )
+
+ # Create the starting message
+ if self.select_speaker_prompt_template is not None:
+ start_message = {
+ "content": self.select_speaker_prompt(agents),
+ "name": "checking_agent",
+ "override_role": self.role_for_select_speaker_messages,
+ }
+ else:
+ start_message = messages[-1]
+
+ # Add the message transforms, if any, to the speaker selection agent
+ if self._speaker_selection_transforms is not None:
+ self._speaker_selection_transforms.add_to_agent(speaker_selection_agent)
+
+ # Run the speaker selection chat
+ result = checking_agent.initiate_chat(
+ speaker_selection_agent,
+ cache=None, # don't use caching for the speaker selection chat
+ message=start_message,
+ max_turns=2
+ * max(1, max_attempts), # Limiting the chat to the number of attempts, including the initial one
+ clear_history=False,
+ silent=not self.select_speaker_auto_verbose, # Base silence on the verbose attribute
+ )
+
+ return self._process_speaker_selection_result(result, last_speaker, agents)
+
+ async def a_auto_select_speaker(
+ self,
+ last_speaker: Agent,
+ selector: ConversableAgent,
+ messages: Optional[List[Dict]],
+ agents: Optional[List[Agent]],
+ ) -> Agent:
+ """(Asynchronous) Selects next speaker for the "auto" speaker selection method. Utilises its own two-agent chat to determine the next speaker and supports requerying.
+
+ Speaker selection for "auto" speaker selection method:
+ 1. Create a two-agent chat with a speaker selector agent and a speaker validator agent, like a nested chat
+ 2. Inject the group messages into the new chat
+ 3. Run the two-agent chat, evaluating the result of response from the speaker selector agent:
+ - If a single agent is provided then we return it and finish. If not, we add an additional message to this nested chat in an attempt to guide the LLM to a single agent response
+ 4. Chat continues until a single agent is nominated or there are no more attempts left
+ 5. If we run out of turns and no single agent can be determined, the next speaker in the list of agents is returned
+
+ Args:
+ last_speaker Agent: The previous speaker in the group chat
+ selector ConversableAgent:
+ messages Optional[List[Dict]]: Current chat messages
+ agents Optional[List[Agent]]: Valid list of agents for speaker selection
+
+ Returns:
+ Dict: a counter for mentioned agents.
+ """
+
+ # If no agents are passed in, assign all the group chat's agents
+ if agents is None:
+ agents = self.agents
+
+ # The maximum number of speaker selection attempts (including requeries)
+ # We track these and use them in the validation function as we can't
+ # access the max_turns from within validate_speaker_name
+ max_attempts = 1 + self.max_retries_for_selecting_speaker
+ attempts_left = max_attempts
+ attempt = 0
+
+ # Registered reply function for checking_agent, checks the result of the response for agent names
+ def validate_speaker_name(recipient, messages, sender, config) -> Tuple[bool, Union[str, Dict, None]]:
+ # The number of retries left, starting at max_retries_for_selecting_speaker
+ nonlocal attempts_left
+ nonlocal attempt
+
+ attempt = attempt + 1
+ attempts_left = attempts_left - 1
+
+ return self._validate_speaker_name(recipient, messages, sender, config, attempts_left, attempt, agents)
+
+ # Two-agent chat for speaker selection
+
+ # Agent for checking the response from the speaker_select_agent
+ checking_agent = ConversableAgent("checking_agent", default_auto_reply=max_attempts)
+
+ # Register the speaker validation function with the checking agent
+ checking_agent.register_reply(
+ [ConversableAgent, None],
+ reply_func=validate_speaker_name, # Validate each response
+ remove_other_reply_funcs=True,
+ )
+
+ # NOTE: Do we have a speaker prompt (select_speaker_prompt_template is not None)? If we don't, we need to feed in the last message to start the nested chat
+
+ # Agent for selecting a single agent name from the response
+ speaker_selection_agent = ConversableAgent(
+ "speaker_selection_agent",
+ system_message=self.select_speaker_msg(agents),
+ chat_messages={checking_agent: messages},
+ llm_config=selector.llm_config,
+ human_input_mode="NEVER", # Suppresses some extra terminal outputs, outputs will be handled by select_speaker_auto_verbose
+ )
+
+ # Create the starting message
+ if self.select_speaker_prompt_template is not None:
+ start_message = {
+ "content": self.select_speaker_prompt(agents),
+ "override_role": self.role_for_select_speaker_messages,
+ }
+ else:
+ start_message = messages[-1]
+
+ # Add the message transforms, if any, to the speaker selection agent
+ if self._speaker_selection_transforms is not None:
+ self._speaker_selection_transforms.add_to_agent(speaker_selection_agent)
+
+ # Run the speaker selection chat
+ result = await checking_agent.a_initiate_chat(
+ speaker_selection_agent,
+ cache=None, # don't use caching for the speaker selection chat
+ message=start_message,
+ max_turns=2
+ * max(1, max_attempts), # Limiting the chat to the number of attempts, including the initial one
+ clear_history=False,
+ silent=not self.select_speaker_auto_verbose, # Base silence on the verbose attribute
+ )
+
+ return self._process_speaker_selection_result(result, last_speaker, agents)
+
+ def _validate_speaker_name(
+ self, recipient, messages, sender, config, attempts_left, attempt, agents
+ ) -> Tuple[bool, Union[str, Dict, None]]:
+ """Validates the speaker response for each round in the internal 2-agent
+ chat within the auto select speaker method.
+
+ Used by auto_select_speaker and a_auto_select_speaker.
+ """
+
+ # Output the query and requery results
+ if self.select_speaker_auto_verbose:
+ iostream = IOStream.get_default()
+
+ # Validate the speaker name selected
+ select_name = messages[-1]["content"].strip()
+
+ mentions = self._mentioned_agents(select_name, agents)
+
+ if len(mentions) == 1:
+ # Success on retry, we have just one name mentioned
+ selected_agent_name = next(iter(mentions))
+
+ # Add the selected agent to the response so we can return it
+ messages.append({"role": "user", "content": f"[AGENT SELECTED]{selected_agent_name}"})
+
+ if self.select_speaker_auto_verbose:
+ iostream.print(
+ colored(
+ f">>>>>>>> Select speaker attempt {attempt} of {attempt + attempts_left} successfully selected: {selected_agent_name}",
+ "green",
+ ),
+ flush=True,
+ )
+
+ elif len(mentions) > 1:
+ # More than one name on requery so add additional reminder prompt for next retry
+
+ if self.select_speaker_auto_verbose:
+ iostream.print(
+ colored(
+ f">>>>>>>> Select speaker attempt {attempt} of {attempt + attempts_left} failed as it included multiple agent names.",
+ "red",
+ ),
+ flush=True,
+ )
+
+ if attempts_left:
+ # Message to return to the chat for the next attempt
+ agentlist = f"{[agent.name for agent in agents]}"
+
+ return True, {
+ "content": self.select_speaker_auto_multiple_template.format(agentlist=agentlist),
+ "name": "checking_agent",
+ "override_role": self.role_for_select_speaker_messages,
+ }
+ else:
+ # Final failure, no attempts left
+ messages.append(
+ {
+ "role": "user",
+ "content": f"[AGENT SELECTION FAILED]Select speaker attempt #{attempt} of {attempt + attempts_left} failed as it returned multiple names.",
+ }
+ )
+
+ else:
+ # No names at all on requery so add additional reminder prompt for next retry
+
+ if self.select_speaker_auto_verbose:
+ iostream.print(
+ colored(
+ f">>>>>>>> Select speaker attempt #{attempt} failed as it did not include any agent names.",
+ "red",
+ ),
+ flush=True,
+ )
+
+ if attempts_left:
+ # Message to return to the chat for the next attempt
+ agentlist = f"{[agent.name for agent in agents]}"
+
+ return True, {
+ "content": self.select_speaker_auto_none_template.format(agentlist=agentlist),
+ "name": "checking_agent",
+ "override_role": self.role_for_select_speaker_messages,
+ }
+ else:
+ # Final failure, no attempts left
+ messages.append(
+ {
+ "role": "user",
+ "content": f"[AGENT SELECTION FAILED]Select speaker attempt #{attempt} of {attempt + attempts_left} failed as it did not include any agent names.",
+ }
+ )
+
+ return True, None
+
+ def _process_speaker_selection_result(self, result, last_speaker: ConversableAgent, agents: Optional[List[Agent]]):
+ """Checks the result of the auto_select_speaker function, returning the
+ agent to speak.
+
+ Used by auto_select_speaker and a_auto_select_speaker."""
+ if len(result.chat_history) > 0:
+ # Use the final message, which will have the selected agent or reason for failure
+ final_message = result.chat_history[-1]["content"]
+
+ if "[AGENT SELECTED]" in final_message:
+ # Have successfully selected an agent, return it
+ return self.agent_by_name(final_message.replace("[AGENT SELECTED]", ""))
+
+ else: # "[AGENT SELECTION FAILED]"
+ # Failed to select an agent, so we'll select the next agent in the list
+ next_agent = self.next_agent(last_speaker, agents)
+
+ # No agent, return the failed reason
+ return next_agent
+
def _participant_roles(self, agents: List[Agent] = None) -> str:
# Default to all agents registered
if agents is None:
@@ -560,8 +973,9 @@ def __init__(
name: Optional[str] = "chat_manager",
# unlimited consecutive auto reply by default
max_consecutive_auto_reply: Optional[int] = sys.maxsize,
- human_input_mode: Optional[str] = "NEVER",
+ human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "NEVER",
system_message: Optional[Union[str, List]] = "Group chat manager.",
+ silent: bool = False,
**kwargs,
):
if (
@@ -585,6 +999,9 @@ def __init__(
# Store groupchat
self._groupchat = groupchat
+ self._last_speaker = None
+ self._silent = silent
+
# Order of register_reply is important.
# Allow sync chat if initiated using initiate_chat
self.register_reply(Agent, GroupChatManager.run_chat, config=groupchat, reset_config=GroupChat.reset)
@@ -624,6 +1041,53 @@ def _prepare_chat(
if (recipient != agent or prepare_recipient) and isinstance(agent, ConversableAgent):
agent._prepare_chat(self, clear_history, False, reply_at_receive)
+ @property
+ def last_speaker(self) -> Agent:
+ """Return the agent who sent the last message to group chat manager.
+
+ In a group chat, an agent will always send a message to the group chat manager, and the group chat manager will
+ send the message to all other agents in the group chat. So, when an agent receives a message, it will always be
+ from the group chat manager. With this property, the agent receiving the message can know who actually sent the
+ message.
+
+ Example:
+ ```python
+ from autogen import ConversableAgent
+ from autogen import GroupChat, GroupChatManager
+
+
+ def print_messages(recipient, messages, sender, config):
+ # Print the message immediately
+ print(
+ f"Sender: {sender.name} | Recipient: {recipient.name} | Message: {messages[-1].get('content')}"
+ )
+ print(f"Real Sender: {sender.last_speaker.name}")
+ assert sender.last_speaker.name in messages[-1].get("content")
+ return False, None # Required to ensure the agent communication flow continues
+
+
+ agent_a = ConversableAgent("agent A", default_auto_reply="I'm agent A.")
+ agent_b = ConversableAgent("agent B", default_auto_reply="I'm agent B.")
+ agent_c = ConversableAgent("agent C", default_auto_reply="I'm agent C.")
+ for agent in [agent_a, agent_b, agent_c]:
+ agent.register_reply(
+ [ConversableAgent, None], reply_func=print_messages, config=None
+ )
+ group_chat = GroupChat(
+ [agent_a, agent_b, agent_c],
+ messages=[],
+ max_round=6,
+ speaker_selection_method="random",
+ allow_repeat_speaker=True,
+ )
+ chat_manager = GroupChatManager(group_chat)
+ groupchat_result = agent_a.initiate_chat(
+ chat_manager, message="Hi, there, I'm agent A."
+ )
+ ```
+ """
+ return self._last_speaker
+
def run_chat(
self,
messages: Optional[List[Dict]] = None,
@@ -637,6 +1101,7 @@ def run_chat(
speaker = sender
groupchat = config
send_introductions = getattr(groupchat, "send_introductions", False)
+ silent = getattr(self, "_silent", False)
if send_introductions:
# Broadcast the intro
@@ -651,6 +1116,7 @@ def run_chat(
a.previous_cache = a.client_cache
a.client_cache = self.client_cache
for i in range(groupchat.max_round):
+ self._last_speaker = speaker
groupchat.append(message, speaker)
# broadcast the message to all agents except the speaker
for agent in groupchat.agents:
@@ -662,6 +1128,9 @@ def run_chat(
try:
# select the next speaker
speaker = groupchat.select_speaker(speaker, self)
+ if not silent:
+ iostream = IOStream.get_default()
+ iostream.print(colored(f"\nNext speaker: {speaker.name}\n", "green"), flush=True)
# let the speaker speak
reply = speaker.generate_reply(sender=self)
except KeyboardInterrupt:
@@ -691,7 +1160,7 @@ def run_chat(
reply["content"] = self.clear_agents_history(reply, groupchat)
# The speaker sends the message without requesting a reply
- speaker.send(reply, self, request_reply=False)
+ speaker.send(reply, self, request_reply=False, silent=silent)
message = self.last_message(speaker)
if self.client_cache is not None:
for a in groupchat.agents:
@@ -712,6 +1181,7 @@ async def a_run_chat(
speaker = sender
groupchat = config
send_introductions = getattr(groupchat, "send_introductions", False)
+ silent = getattr(self, "_silent", False)
if send_introductions:
# Broadcast the intro
@@ -756,7 +1226,7 @@ async def a_run_chat(
if reply is None:
break
# The speaker sends the message without requesting a reply
- await speaker.a_send(reply, self, request_reply=False)
+ await speaker.a_send(reply, self, request_reply=False, silent=silent)
message = self.last_message(speaker)
if self.client_cache is not None:
for a in groupchat.agents:
@@ -764,6 +1234,303 @@ async def a_run_chat(
a.previous_cache = None
return True, None
+ def resume(
+ self,
+ messages: Union[List[Dict], str],
+ remove_termination_string: Union[str, Callable[[str], str]] = None,
+ silent: Optional[bool] = False,
+ ) -> Tuple[ConversableAgent, Dict]:
+ """Resumes a group chat using the previous messages as a starting point. Requires the agents, group chat, and group chat manager to be established
+ as per the original group chat.
+
+ Args:
+ - messages Union[List[Dict], str]: The content of the previous chat's messages, either as a Json string or a list of message dictionaries.
+ - remove_termination_string (str or function): Remove the termination string from the last message to prevent immediate termination
+ If a string is provided, this string will be removed from last message.
+ If a function is provided, the last message will be passed to this function.
+ - silent (bool or None): (Experimental) whether to print the messages for this conversation. Default is False.
+
+ Returns:
+ - Tuple[ConversableAgent, Dict]: A tuple containing the last agent who spoke and their message
+ """
+
+ # Convert messages from string to messages list, if needed
+ if isinstance(messages, str):
+ messages = self.messages_from_string(messages)
+ elif isinstance(messages, list) and all(isinstance(item, dict) for item in messages):
+ messages = copy.deepcopy(messages)
+ else:
+ raise Exception("Messages is not of type str or List[Dict]")
+
+ # Clean up the objects, ensuring there are no messages in the agents and group chat
+
+ # Clear agent message history
+ for agent in self._groupchat.agents:
+ if isinstance(agent, ConversableAgent):
+ agent.clear_history()
+
+ # Clear Manager message history
+ self.clear_history()
+
+ # Clear GroupChat messages
+ self._groupchat.reset()
+
+ # Validation of message and agents
+
+ try:
+ self._valid_resume_messages(messages)
+ except:
+ raise
+
+ # Load the messages into the group chat
+ for i, message in enumerate(messages):
+ if "name" in message:
+ message_speaker_agent = self._groupchat.agent_by_name(message["name"])
+ else:
+ # If there's no name, assign the group chat manager (this is an indication the ChatResult messages was used instead of groupchat.messages as state)
+ message_speaker_agent = self
+ message["name"] = self.name
+
+ # If it wasn't an agent speaking, it may be the manager
+ if not message_speaker_agent and message["name"] == self.name:
+ message_speaker_agent = self
+
+ # Add previous messages to each agent (except the last message, as we'll kick off the conversation with it)
+ if i != len(messages) - 1:
+ for agent in self._groupchat.agents:
+ self.send(message, self._groupchat.agent_by_name(agent.name), request_reply=False, silent=True)
+
+ # Add previous message to the new groupchat, if it's an admin message the name may not match so add the message directly
+ if message_speaker_agent:
+ self._groupchat.append(message, message_speaker_agent)
+ else:
+ self._groupchat.messages.append(message)
+
+ # Last speaker agent
+ last_speaker_name = message["name"]
+
+ # Last message to check for termination (we could avoid this by ignoring termination check for resume in the future)
+ last_message = message
+
+ # Get last speaker as an agent
+ previous_last_agent = self._groupchat.agent_by_name(name=last_speaker_name)
+
+ # If we didn't match a last speaker agent, we check that it's the group chat's admin name and assign the manager, if so
+ if not previous_last_agent and (
+ last_speaker_name == self._groupchat.admin_name or last_speaker_name == self.name
+ ):
+ previous_last_agent = self
+
+ # Termination removal and check
+ self._process_resume_termination(remove_termination_string, messages)
+
+ if not silent:
+ iostream = IOStream.get_default()
+ iostream.print(
+ f"Prepared group chat with {len(messages)} messages, the last speaker is",
+ colored(last_speaker_name, "yellow"),
+ flush=True,
+ )
+
+ # Update group chat settings for resuming
+ self._groupchat.send_introductions = False
+
+ return previous_last_agent, last_message
+
+ async def a_resume(
+ self,
+ messages: Union[List[Dict], str],
+ remove_termination_string: Union[str, Callable[[str], str]] = None,
+ silent: Optional[bool] = False,
+ ) -> Tuple[ConversableAgent, Dict]:
+ """Resumes a group chat using the previous messages as a starting point, asynchronously. Requires the agents, group chat, and group chat manager to be established
+ as per the original group chat.
+
+ Args:
+ - messages Union[List[Dict], str]: The content of the previous chat's messages, either as a Json string or a list of message dictionaries.
+ - remove_termination_string (str or function): Remove the termination string from the last message to prevent immediate termination
+ If a string is provided, this string will be removed from last message.
+ If a function is provided, the last message will be passed to this function, and the function returns the string after processing.
+ - silent (bool or None): (Experimental) whether to print the messages for this conversation. Default is False.
+
+ Returns:
+ - Tuple[ConversableAgent, Dict]: A tuple containing the last agent who spoke and their message
+ """
+
+ # Convert messages from string to messages list, if needed
+ if isinstance(messages, str):
+ messages = self.messages_from_string(messages)
+ elif isinstance(messages, list) and all(isinstance(item, dict) for item in messages):
+ messages = copy.deepcopy(messages)
+ else:
+ raise Exception("Messages is not of type str or List[Dict]")
+
+ # Clean up the objects, ensuring there are no messages in the agents and group chat
+
+ # Clear agent message history
+ for agent in self._groupchat.agents:
+ if isinstance(agent, ConversableAgent):
+ agent.clear_history()
+
+ # Clear Manager message history
+ self.clear_history()
+
+ # Clear GroupChat messages
+ self._groupchat.reset()
+
+ # Validation of message and agents
+
+ try:
+ self._valid_resume_messages(messages)
+ except:
+ raise
+
+ # Load the messages into the group chat
+ for i, message in enumerate(messages):
+ if "name" in message:
+ message_speaker_agent = self._groupchat.agent_by_name(message["name"])
+ else:
+ # If there's no name, assign the group chat manager (this is an indication the ChatResult messages was used instead of groupchat.messages as state)
+ message_speaker_agent = self
+ message["name"] = self.name
+
+ # If it wasn't an agent speaking, it may be the manager
+ if not message_speaker_agent and message["name"] == self.name:
+ message_speaker_agent = self
+
+ # Add previous messages to each agent (except their own messages and the last message, as we'll kick off the conversation with it)
+ if i != len(messages) - 1:
+ for agent in self._groupchat.agents:
+ if agent.name != message["name"]:
+ await self.a_send(
+ message, self._groupchat.agent_by_name(agent.name), request_reply=False, silent=True
+ )
+
+ # Add previous message to the new groupchat, if it's an admin message the name may not match so add the message directly
+ if message_speaker_agent:
+ self._groupchat.append(message, message_speaker_agent)
+ else:
+ self._groupchat.messages.append(message)
+
+ # Last speaker agent
+ last_speaker_name = message["name"]
+
+ # Last message to check for termination (we could avoid this by ignoring termination check for resume in the future)
+ last_message = message
+
+ # Get last speaker as an agent
+ previous_last_agent = self._groupchat.agent_by_name(name=last_speaker_name)
+
+ # If we didn't match a last speaker agent, we check that it's the group chat's admin name and assign the manager, if so
+ if not previous_last_agent and (
+ last_speaker_name == self._groupchat.admin_name or last_speaker_name == self.name
+ ):
+ previous_last_agent = self
+
+ # Termination removal and check
+ self._process_resume_termination(remove_termination_string, messages)
+
+ if not silent:
+ iostream = IOStream.get_default()
+ iostream.print(
+ f"Prepared group chat with {len(messages)} messages, the last speaker is",
+ colored(last_speaker_name, "yellow"),
+ flush=True,
+ )
+
+ # Update group chat settings for resuming
+ self._groupchat.send_introductions = False
+
+ return previous_last_agent, last_message
+
+ def _valid_resume_messages(self, messages: List[Dict]):
+ """Validates the messages used for resuming
+
+ args:
+ messages (List[Dict]): list of messages to resume with
+
+ returns:
+ - bool: Whether they are valid for resuming
+ """
+ # Must have messages to start with, otherwise they should run run_chat
+ if not messages:
+ raise Exception(
+ "Cannot resume group chat as no messages were provided. Use GroupChatManager.run_chat or ConversableAgent.initiate_chat to start a new chat."
+ )
+
+ # Check that all agents in the chat messages exist in the group chat
+ for message in messages:
+ if message.get("name"):
+ if (
+ not self._groupchat.agent_by_name(message["name"])
+ and not message["name"] == self._groupchat.admin_name # ignore group chat's name
+ and not message["name"] == self.name # ignore group chat manager's name
+ ):
+ raise Exception(f"Agent name in message doesn't exist as agent in group chat: {message['name']}")
+
+ def _process_resume_termination(
+ self, remove_termination_string: Union[str, Callable[[str], str]], messages: List[Dict]
+ ):
+ """Removes termination string, if required, and checks if termination may occur.
+
+ args:
+ remove_termination_string (str or function): Remove the termination string from the last message to prevent immediate termination
+ If a string is provided, this string will be removed from last message.
+ If a function is provided, the last message will be passed to this function, and the function returns the string after processing.
+
+ returns:
+ None
+ """
+
+ last_message = messages[-1]
+
+ # Replace any given termination string in the last message
+ if isinstance(remove_termination_string, str):
+
+ def _remove_termination_string(content: str) -> str:
+ return content.replace(remove_termination_string, "")
+
+ else:
+ _remove_termination_string = remove_termination_string
+
+ if _remove_termination_string:
+ if messages[-1].get("content"):
+ messages[-1]["content"] = _remove_termination_string(messages[-1]["content"])
+
+ # Check if the last message meets termination (if it has one)
+ if self._is_termination_msg:
+ if self._is_termination_msg(last_message):
+ logger.warning("WARNING: Last message meets termination criteria and this may terminate the chat.")
+
+ def messages_from_string(self, message_string: str) -> List[Dict]:
+ """Reads the saved state of messages in Json format for resume and returns as a messages list
+
+ args:
+ - message_string: Json string, the saved state
+
+ returns:
+ - List[Dict]: List of messages
+ """
+ try:
+ state = json.loads(message_string)
+ except json.JSONDecodeError:
+ raise Exception("Messages string is not a valid JSON string")
+
+ return state
+
+ def messages_to_string(self, messages: List[Dict]) -> str:
+ """Converts the provided messages into a Json string that can be used for resuming the chat.
+ The state is made up of a list of messages
+
+ args:
+ - messages (List[Dict]): set of messages to convert to a string
+
+ returns:
+ - str: Json representation of the messages which can be persisted for resuming later
+ """
+
+ return json.dumps(messages)
+
def _raise_exception_on_async_reply_functions(self) -> None:
"""Raise an exception if any async reply functions are registered.
diff --git a/autogen/agentchat/user_proxy_agent.py b/autogen/agentchat/user_proxy_agent.py
index a80296a8355..d50e4d8b89c 100644
--- a/autogen/agentchat/user_proxy_agent.py
+++ b/autogen/agentchat/user_proxy_agent.py
@@ -35,6 +35,7 @@ def __init__(
llm_config: Optional[Union[Dict, Literal[False]]] = False,
system_message: Optional[Union[str, List]] = "",
description: Optional[str] = None,
+ **kwargs,
):
"""
Args:
@@ -79,6 +80,8 @@ def __init__(
Only used when llm_config is not False. Use it to reprogram the agent.
description (str): a short description of the agent. This description is used by other agents
(e.g. the GroupChatManager) to decide when to call upon this agent. (Default: system_message)
+ **kwargs (dict): Please refer to other kwargs in
+ [ConversableAgent](conversable_agent#__init__).
"""
super().__init__(
name=name,
@@ -93,6 +96,7 @@ def __init__(
description=(
description if description is not None else self.DEFAULT_USER_PROXY_AGENT_DESCRIPTIONS[human_input_mode]
),
+ **kwargs,
)
if logging_enabled():
diff --git a/autogen/agentchat/utils.py b/autogen/agentchat/utils.py
index eef3741605d..b32c2f5f0a0 100644
--- a/autogen/agentchat/utils.py
+++ b/autogen/agentchat/utils.py
@@ -1,5 +1,5 @@
import re
-from typing import Any, Callable, Dict, List, Tuple, Union
+from typing import Any, Callable, Dict, List, Union
from .agent import Agent
@@ -26,33 +26,46 @@ def consolidate_chat_info(chat_info, uniform_sender=None) -> None:
), "llm client must be set in either the recipient or sender when summary_method is reflection_with_llm."
-def gather_usage_summary(agents: List[Agent]) -> Tuple[Dict[str, any], Dict[str, any]]:
+def gather_usage_summary(agents: List[Agent]) -> Dict[Dict[str, Dict], Dict[str, Dict]]:
r"""Gather usage summary from all agents.
Args:
agents: (list): List of agents.
Returns:
- tuple: (total_usage_summary, actual_usage_summary)
+ dictionary: A dictionary containing two keys:
+ - "usage_including_cached_inference": Cost information on the total usage, including the tokens in cached inference.
+ - "usage_excluding_cached_inference": Cost information on the usage of tokens, excluding the tokens in cache. No larger than "usage_including_cached_inference".
Example:
```python
- total_usage_summary = {
- "total_cost": 0.0006090000000000001,
- "gpt-35-turbo": {
- "cost": 0.0006090000000000001,
- "prompt_tokens": 242,
- "completion_tokens": 123,
- "total_tokens": 365
+ {
+ "usage_including_cached_inference" : {
+ "total_cost": 0.0006090000000000001,
+ "gpt-35-turbo": {
+ "cost": 0.0006090000000000001,
+ "prompt_tokens": 242,
+ "completion_tokens": 123,
+ "total_tokens": 365
+ },
+ },
+
+ "usage_excluding_cached_inference" : {
+ "total_cost": 0.0006090000000000001,
+ "gpt-35-turbo": {
+ "cost": 0.0006090000000000001,
+ "prompt_tokens": 242,
+ "completion_tokens": 123,
+ "total_tokens": 365
+ },
}
}
```
Note:
- `actual_usage_summary` follows the same format.
- If none of the agents incurred any cost (not having a client), then the total_usage_summary and actual_usage_summary will be `{'total_cost': 0}`.
+ If none of the agents incurred any cost (not having a client), then the usage_including_cached_inference and usage_excluding_cached_inference will be `{'total_cost': 0}`.
"""
def aggregate_summary(usage_summary: Dict[str, Any], agent_summary: Dict[str, Any]) -> None:
@@ -69,15 +82,18 @@ def aggregate_summary(usage_summary: Dict[str, Any], agent_summary: Dict[str, An
usage_summary[model]["completion_tokens"] += data.get("completion_tokens", 0)
usage_summary[model]["total_tokens"] += data.get("total_tokens", 0)
- total_usage_summary = {"total_cost": 0}
- actual_usage_summary = {"total_cost": 0}
+ usage_including_cached_inference = {"total_cost": 0}
+ usage_excluding_cached_inference = {"total_cost": 0}
for agent in agents:
if getattr(agent, "client", None):
- aggregate_summary(total_usage_summary, agent.client.total_usage_summary)
- aggregate_summary(actual_usage_summary, agent.client.actual_usage_summary)
+ aggregate_summary(usage_including_cached_inference, agent.client.total_usage_summary)
+ aggregate_summary(usage_excluding_cached_inference, agent.client.actual_usage_summary)
- return total_usage_summary, actual_usage_summary
+ return {
+ "usage_including_cached_inference": usage_including_cached_inference,
+ "usage_excluding_cached_inference": usage_excluding_cached_inference,
+ }
def parse_tags_from_content(tag: str, content: Union[str, List[Dict[str, Any]]]) -> List[Dict[str, Dict[str, str]]]:
diff --git a/autogen/browser_utils.py b/autogen/browser_utils.py
index c6ccbba38e1..99e51fcd4ca 100644
--- a/autogen/browser_utils.py
+++ b/autogen/browser_utils.py
@@ -36,6 +36,7 @@ def __init__(
start_page: Optional[str] = None,
viewport_size: Optional[int] = 1024 * 8,
downloads_folder: Optional[Union[str, None]] = None,
+ bing_base_url: str = "https://api.bing.microsoft.com/v7.0/search",
bing_api_key: Optional[Union[str, None]] = None,
request_kwargs: Optional[Union[Dict[str, Any], None]] = None,
):
@@ -47,6 +48,7 @@ def __init__(
self.viewport_current_page = 0
self.viewport_pages: List[Tuple[int, int]] = list()
self.set_address(self.start_page)
+ self.bing_base_url = bing_base_url
self.bing_api_key = bing_api_key
self.request_kwargs = request_kwargs
@@ -145,7 +147,7 @@ def _bing_api_call(self, query: str) -> Dict[str, Dict[str, List[Dict[str, Union
request_kwargs["stream"] = False
# Make the request
- response = requests.get("https://api.bing.microsoft.com/v7.0/search", **request_kwargs)
+ response = requests.get(self.bing_base_url, **request_kwargs)
response.raise_for_status()
results = response.json()
diff --git a/autogen/cache/cache.py b/autogen/cache/cache.py
index 0770079f295..6a15d993ff6 100644
--- a/autogen/cache/cache.py
+++ b/autogen/cache/cache.py
@@ -2,7 +2,7 @@
import sys
from types import TracebackType
-from typing import Any, Dict, Optional, Type, Union
+from typing import Any, Dict, Optional, Type, TypedDict, Union
from .abstract_cache_base import AbstractCache
from .cache_factory import CacheFactory
@@ -26,7 +26,12 @@ class Cache(AbstractCache):
cache: The cache instance created based on the provided configuration.
"""
- ALLOWED_CONFIG_KEYS = ["cache_seed", "redis_url", "cache_path_root"]
+ ALLOWED_CONFIG_KEYS = [
+ "cache_seed",
+ "redis_url",
+ "cache_path_root",
+ "cosmos_db_config",
+ ]
@staticmethod
def redis(cache_seed: Union[str, int] = 42, redis_url: str = "redis://localhost:6379/0") -> "Cache":
@@ -56,6 +61,32 @@ def disk(cache_seed: Union[str, int] = 42, cache_path_root: str = ".cache") -> "
"""
return Cache({"cache_seed": cache_seed, "cache_path_root": cache_path_root})
+ @staticmethod
+ def cosmos_db(
+ connection_string: Optional[str] = None,
+ container_id: Optional[str] = None,
+ cache_seed: Union[str, int] = 42,
+ client: Optional[any] = None,
+ ) -> "Cache":
+ """
+ Create a Cosmos DB cache instance with 'autogen_cache' as database ID.
+
+ Args:
+ connection_string (str, optional): Connection string to the Cosmos DB account.
+ container_id (str, optional): The container ID for the Cosmos DB account.
+ cache_seed (Union[str, int], optional): A seed for the cache.
+ client: Optional[CosmosClient]: Pass an existing Cosmos DB client.
+ Returns:
+ Cache: A Cache instance configured for Cosmos DB.
+ """
+ cosmos_db_config = {
+ "connection_string": connection_string,
+ "database_id": "autogen_cache",
+ "container_id": container_id,
+ "client": client,
+ }
+ return Cache({"cache_seed": str(cache_seed), "cosmos_db_config": cosmos_db_config})
+
def __init__(self, config: Dict[str, Any]):
"""
Initialize the Cache with the given configuration.
@@ -69,15 +100,19 @@ def __init__(self, config: Dict[str, Any]):
ValueError: If an invalid configuration key is provided.
"""
self.config = config
+ # Ensure that the seed is always treated as a string before being passed to any cache factory or stored.
+ self.config["cache_seed"] = str(self.config.get("cache_seed", 42))
+
# validate config
for key in self.config.keys():
if key not in self.ALLOWED_CONFIG_KEYS:
raise ValueError(f"Invalid config key: {key}")
# create cache instance
self.cache = CacheFactory.cache_factory(
- self.config.get("cache_seed", "42"),
- self.config.get("redis_url", None),
- self.config.get("cache_path_root", None),
+ seed=self.config["cache_seed"],
+ redis_url=self.config.get("redis_url"),
+ cache_path_root=self.config.get("cache_path_root"),
+ cosmosdb_config=self.config.get("cosmos_db_config"),
)
def __enter__(self) -> "Cache":
diff --git a/autogen/cache/cache_factory.py b/autogen/cache/cache_factory.py
index 8fc4713f06e..7c9d71884cb 100644
--- a/autogen/cache/cache_factory.py
+++ b/autogen/cache/cache_factory.py
@@ -1,5 +1,6 @@
import logging
-from typing import Optional, Union
+import os
+from typing import Any, Dict, Optional, Union
from .abstract_cache_base import AbstractCache
from .disk_cache import DiskCache
@@ -8,25 +9,28 @@
class CacheFactory:
@staticmethod
def cache_factory(
- seed: Union[str, int], redis_url: Optional[str] = None, cache_path_root: str = ".cache"
+ seed: Union[str, int],
+ redis_url: Optional[str] = None,
+ cache_path_root: str = ".cache",
+ cosmosdb_config: Optional[Dict[str, Any]] = None,
) -> AbstractCache:
"""
Factory function for creating cache instances.
- Based on the provided redis_url, this function decides whether to create a RedisCache
- or DiskCache instance. If RedisCache is available and redis_url is provided,
- a RedisCache instance is created. Otherwise, a DiskCache instance is used.
+ This function decides whether to create a RedisCache, DiskCache, or CosmosDBCache instance
+ based on the provided parameters. If RedisCache is available and a redis_url is provided,
+ a RedisCache instance is created. If connection_string, database_id, and container_id
+ are provided, a CosmosDBCache is created. Otherwise, a DiskCache instance is used.
Args:
- seed (Union[str, int]): A string or int used as a seed or namespace for the cache.
- This could be useful for creating distinct cache instances
- or for namespacing keys in the cache.
- redis_url (str or None): The URL for the Redis server. If this is None
- or if RedisCache is not available, a DiskCache instance is created.
+ seed (Union[str, int]): Used as a seed or namespace for the cache.
+ redis_url (Optional[str]): URL for the Redis server.
+ cache_path_root (str): Root path for the disk cache.
+ cosmosdb_config (Optional[Dict[str, str]]): Dictionary containing 'connection_string',
+ 'database_id', and 'container_id' for Cosmos DB cache.
Returns:
- An instance of either RedisCache or DiskCache, depending on the availability of RedisCache
- and the provided redis_url.
+ An instance of RedisCache, DiskCache, or CosmosDBCache.
Examples:
@@ -40,14 +44,36 @@ def cache_factory(
```python
disk_cache = cache_factory("myseed", None)
```
+
+ Creating a Cosmos DB cache:
+ ```python
+ cosmos_cache = cache_factory("myseed", cosmosdb_config={
+ "connection_string": "your_connection_string",
+ "database_id": "your_database_id",
+ "container_id": "your_container_id"}
+ )
+ ```
+
"""
- if redis_url is not None:
+ if redis_url:
try:
from .redis_cache import RedisCache
return RedisCache(seed, redis_url)
except ImportError:
- logging.warning("RedisCache is not available. Creating a DiskCache instance instead.")
- return DiskCache(f"./{cache_path_root}/{seed}")
- else:
- return DiskCache(f"./{cache_path_root}/{seed}")
+ logging.warning(
+ "RedisCache is not available. Checking other cache options. The last fallback is DiskCache."
+ )
+
+ if cosmosdb_config:
+ try:
+ from .cosmos_db_cache import CosmosDBCache
+
+ return CosmosDBCache.create_cache(seed, cosmosdb_config)
+
+ except ImportError:
+ logging.warning("CosmosDBCache is not available. Fallback to DiskCache.")
+
+ # Default to DiskCache if neither Redis nor Cosmos DB configurations are provided
+ path = os.path.join(cache_path_root, str(seed))
+ return DiskCache(os.path.join(".", path))
diff --git a/autogen/cache/cosmos_db_cache.py b/autogen/cache/cosmos_db_cache.py
new file mode 100644
index 00000000000..b85be923c2f
--- /dev/null
+++ b/autogen/cache/cosmos_db_cache.py
@@ -0,0 +1,144 @@
+# Install Azure Cosmos DB SDK if not already
+
+import pickle
+from typing import Any, Optional, TypedDict, Union
+
+from azure.cosmos import CosmosClient, PartitionKey, exceptions
+from azure.cosmos.exceptions import CosmosResourceNotFoundError
+
+from autogen.cache.abstract_cache_base import AbstractCache
+
+
+class CosmosDBConfig(TypedDict, total=False):
+ connection_string: str
+ database_id: str
+ container_id: str
+ cache_seed: Optional[Union[str, int]]
+ client: Optional[CosmosClient]
+
+
+class CosmosDBCache(AbstractCache):
+ """
+ Synchronous implementation of AbstractCache using Azure Cosmos DB NoSQL API.
+
+ This class provides a concrete implementation of the AbstractCache
+ interface using Azure Cosmos DB for caching data, with synchronous operations.
+
+ Attributes:
+ seed (Union[str, int]): A seed or namespace used as a partition key.
+ client (CosmosClient): The Cosmos DB client used for caching.
+ container: The container instance used for caching.
+ """
+
+ def __init__(self, seed: Union[str, int], cosmosdb_config: CosmosDBConfig):
+ """
+ Initialize the CosmosDBCache instance.
+
+ Args:
+ seed (Union[str, int]): A seed or namespace for the cache, used as a partition key.
+ connection_string (str): The connection string for the Cosmos DB account.
+ container_id (str): The container ID to be used for caching.
+ client (Optional[CosmosClient]): An existing CosmosClient instance to be used for caching.
+ """
+ self.seed = str(seed)
+ self.client = cosmosdb_config.get("client") or CosmosClient.from_connection_string(
+ cosmosdb_config["connection_string"]
+ )
+ database_id = cosmosdb_config.get("database_id", "autogen_cache")
+ self.database = self.client.get_database_client(database_id)
+ container_id = cosmosdb_config.get("container_id")
+ self.container = self.database.create_container_if_not_exists(
+ id=container_id, partition_key=PartitionKey(path="/partitionKey")
+ )
+
+ @classmethod
+ def create_cache(cls, seed: Union[str, int], cosmosdb_config: CosmosDBConfig):
+ """
+ Factory method to create a CosmosDBCache instance based on the provided configuration.
+ This method decides whether to use an existing CosmosClient or create a new one.
+ """
+ if "client" in cosmosdb_config and isinstance(cosmosdb_config["client"], CosmosClient):
+ return cls.from_existing_client(seed, **cosmosdb_config)
+ else:
+ return cls.from_config(seed, cosmosdb_config)
+
+ @classmethod
+ def from_config(cls, seed: Union[str, int], cosmosdb_config: CosmosDBConfig):
+ return cls(str(seed), cosmosdb_config)
+
+ @classmethod
+ def from_connection_string(cls, seed: Union[str, int], connection_string: str, database_id: str, container_id: str):
+ config = {"connection_string": connection_string, "database_id": database_id, "container_id": container_id}
+ return cls(str(seed), config)
+
+ @classmethod
+ def from_existing_client(cls, seed: Union[str, int], client: CosmosClient, database_id: str, container_id: str):
+ config = {"client": client, "database_id": database_id, "container_id": container_id}
+ return cls(str(seed), config)
+
+ def get(self, key: str, default: Optional[Any] = None) -> Optional[Any]:
+ """
+ Retrieve an item from the Cosmos DB cache.
+
+ Args:
+ key (str): The key identifying the item in the cache.
+ default (optional): The default value to return if the key is not found.
+
+ Returns:
+ The deserialized value associated with the key if found, else the default value.
+ """
+ try:
+ response = self.container.read_item(item=key, partition_key=str(self.seed))
+ return pickle.loads(response["data"])
+ except CosmosResourceNotFoundError:
+ return default
+ except Exception as e:
+ # Log the exception or rethrow after logging if needed
+ # Consider logging or handling the error appropriately here
+ raise e
+
+ def set(self, key: str, value: Any) -> None:
+ """
+ Set an item in the Cosmos DB cache.
+
+ Args:
+ key (str): The key under which the item is to be stored.
+ value: The value to be stored in the cache.
+
+ Notes:
+ The value is serialized using pickle before being stored.
+ """
+ try:
+ serialized_value = pickle.dumps(value)
+ item = {"id": key, "partitionKey": str(self.seed), "data": serialized_value}
+ self.container.upsert_item(item)
+ except Exception as e:
+ # Log or handle exception
+ raise e
+
+ def close(self) -> None:
+ """
+ Close the Cosmos DB client.
+
+ Perform any necessary cleanup, such as closing network connections.
+ """
+ # CosmosClient doesn"t require explicit close in the current SDK
+ # If you created the client inside this class, you should close it if necessary
+ pass
+
+ def __enter__(self):
+ """
+ Context management entry.
+
+ Returns:
+ self: The instance itself.
+ """
+ return self
+
+ def __exit__(self, exc_type: Optional[type], exc_value: Optional[Exception], traceback: Optional[Any]) -> None:
+ """
+ Context management exit.
+
+ Perform cleanup actions such as closing the Cosmos DB client.
+ """
+ self.close()
diff --git a/autogen/code_utils.py b/autogen/code_utils.py
index aa75756e04a..98ed6067066 100644
--- a/autogen/code_utils.py
+++ b/autogen/code_utils.py
@@ -6,8 +6,10 @@
import subprocess
import sys
import time
+import venv
from concurrent.futures import ThreadPoolExecutor, TimeoutError
from hashlib import md5
+from types import SimpleNamespace
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import docker
@@ -41,7 +43,7 @@
def content_str(content: Union[str, List[Union[UserMessageTextContentPart, UserMessageImageContentPart]], None]) -> str:
- """Converts the `content` field of an OpenAI merssage into a string format.
+ """Converts the `content` field of an OpenAI message into a string format.
This function processes content that may be a string, a list of mixed text and image URLs, or None,
and converts it into a string. Text is directly appended to the result string, while image URLs are
@@ -251,6 +253,8 @@ def _cmd(lang: str) -> str:
return lang
if lang in ["shell"]:
return "sh"
+ if lang == "javascript":
+ return "node"
if lang in ["ps1", "pwsh", "powershell"]:
powershell_command = get_powershell_command()
return powershell_command
@@ -281,7 +285,7 @@ def in_docker_container() -> bool:
return os.path.exists("/.dockerenv")
-def decide_use_docker(use_docker) -> bool:
+def decide_use_docker(use_docker: Optional[bool]) -> Optional[bool]:
if use_docker is None:
env_var_use_docker = os.environ.get("AUTOGEN_USE_DOCKER", "True")
@@ -717,3 +721,19 @@ def implement(
# cost += metrics["gen_cost"]
# if metrics["succeed_assertions"] or i == len(configs) - 1:
# return responses[metrics["index_selected"]], cost, i
+
+
+def create_virtual_env(dir_path: str, **env_args) -> SimpleNamespace:
+ """Creates a python virtual environment and returns the context.
+
+ Args:
+ dir_path (str): Directory path where the env will be created.
+ **env_args: Any extra args to pass to the `EnvBuilder`
+
+ Returns:
+ SimpleNamespace: the virtual env context object."""
+ if not env_args:
+ env_args = {"with_pip": True}
+ env_builder = venv.EnvBuilder(**env_args)
+ env_builder.create(dir_path)
+ return env_builder.ensure_directories(dir_path)
diff --git a/autogen/coding/base.py b/autogen/coding/base.py
index ccbfe6b9293..7c9e19d73f3 100644
--- a/autogen/coding/base.py
+++ b/autogen/coding/base.py
@@ -4,7 +4,6 @@
from pydantic import BaseModel, Field
-from ..agentchat.agent import LLMAgent
from ..types import UserMessageImageContentPart, UserMessageTextContentPart
__all__ = ("CodeBlock", "CodeResult", "CodeExtractor", "CodeExecutor", "CodeExecutionConfig")
diff --git a/autogen/coding/docker_commandline_code_executor.py b/autogen/coding/docker_commandline_code_executor.py
index 143b241c2cf..6d8f4e309c8 100644
--- a/autogen/coding/docker_commandline_code_executor.py
+++ b/autogen/coding/docker_commandline_code_executor.py
@@ -8,7 +8,7 @@
from pathlib import Path
from time import sleep
from types import TracebackType
-from typing import Any, List, Optional, Type, Union
+from typing import Any, ClassVar, Dict, List, Optional, Type, Union
import docker
from docker.errors import ImageNotFound
@@ -39,14 +39,30 @@ def _wait_for_ready(container: Any, timeout: int = 60, stop_time: float = 0.1) -
class DockerCommandLineCodeExecutor(CodeExecutor):
+ DEFAULT_EXECUTION_POLICY: ClassVar[Dict[str, bool]] = {
+ "bash": True,
+ "shell": True,
+ "sh": True,
+ "pwsh": True,
+ "powershell": True,
+ "ps1": True,
+ "python": True,
+ "javascript": False,
+ "html": False,
+ "css": False,
+ }
+ LANGUAGE_ALIASES: ClassVar[Dict[str, str]] = {"py": "python", "js": "javascript"}
+
def __init__(
self,
image: str = "python:3-slim",
container_name: Optional[str] = None,
timeout: int = 60,
work_dir: Union[Path, str] = Path("."),
+ bind_dir: Optional[Union[Path, str]] = None,
auto_remove: bool = True,
stop_container: bool = True,
+ execution_policies: Optional[Dict[str, bool]] = None,
):
"""(Experimental) A code executor class that executes code through
a command line environment in a Docker container.
@@ -67,6 +83,9 @@ def __init__(
timeout (int, optional): The timeout for code execution. Defaults to 60.
work_dir (Union[Path, str], optional): The working directory for the code
execution. Defaults to Path(".").
+ bind_dir (Union[Path, str], optional): The directory that will be bound
+ to the code executor container. Useful for cases where you want to spawn
+ the container from within a container. Defaults to work_dir.
auto_remove (bool, optional): If true, will automatically remove the Docker
container when it is stopped. Defaults to True.
stop_container (bool, optional): If true, will automatically stop the
@@ -76,17 +95,19 @@ def __init__(
Raises:
ValueError: On argument error, or if the container fails to start.
"""
-
if timeout < 1:
raise ValueError("Timeout must be greater than or equal to 1.")
if isinstance(work_dir, str):
work_dir = Path(work_dir)
-
work_dir.mkdir(exist_ok=True)
- client = docker.from_env()
+ if bind_dir is None:
+ bind_dir = work_dir
+ elif isinstance(bind_dir, str):
+ bind_dir = Path(bind_dir)
+ client = docker.from_env()
# Check if the image exists
try:
client.images.get(image)
@@ -105,7 +126,7 @@ def __init__(
entrypoint="/bin/sh",
tty=True,
auto_remove=auto_remove,
- volumes={str(work_dir.resolve()): {"bind": "/workspace", "mode": "rw"}},
+ volumes={str(bind_dir.resolve()): {"bind": "/workspace", "mode": "rw"}},
working_dir="/workspace",
)
self._container.start()
@@ -118,7 +139,6 @@ def cleanup() -> None:
container.stop()
except docker.errors.NotFound:
pass
-
atexit.unregister(cleanup)
if stop_container:
@@ -132,6 +152,10 @@ def cleanup() -> None:
self._timeout = timeout
self._work_dir: Path = work_dir
+ self._bind_dir: Path = bind_dir
+ self.execution_policies = self.DEFAULT_EXECUTION_POLICY.copy()
+ if execution_policies is not None:
+ self.execution_policies.update(execution_policies)
@property
def timeout(self) -> int:
@@ -143,6 +167,11 @@ def work_dir(self) -> Path:
"""(Experimental) The working directory for the code execution."""
return self._work_dir
+ @property
+ def bind_dir(self) -> Path:
+ """(Experimental) The binding directory for the code execution container."""
+ return self._bind_dir
+
@property
def code_extractor(self) -> CodeExtractor:
"""(Experimental) Export a code extractor that can be used by an agent."""
@@ -164,35 +193,42 @@ def execute_code_blocks(self, code_blocks: List[CodeBlock]) -> CommandLineCodeRe
files = []
last_exit_code = 0
for code_block in code_blocks:
- lang = code_block.language
+ lang = self.LANGUAGE_ALIASES.get(code_block.language.lower(), code_block.language.lower())
+ if lang not in self.DEFAULT_EXECUTION_POLICY:
+ outputs.append(f"Unsupported language {lang}\n")
+ last_exit_code = 1
+ break
+
+ execute_code = self.execution_policies.get(lang, False)
code = silence_pip(code_block.code, lang)
+ # Check if there is a filename comment
try:
- # Check if there is a filename comment
- filename = _get_file_name_from_content(code, Path("/workspace"))
+ filename = _get_file_name_from_content(code, self._work_dir)
except ValueError:
- return CommandLineCodeResult(exit_code=1, output="Filename is not in the workspace")
+ outputs.append("Filename is not in the workspace")
+ last_exit_code = 1
+ break
- if filename is None:
- # create a file with an automatically generated name
- code_hash = md5(code.encode()).hexdigest()
- filename = f"tmp_code_{code_hash}.{'py' if lang.startswith('python') else lang}"
+ if not filename:
+ filename = f"tmp_code_{md5(code.encode()).hexdigest()}.{lang}"
code_path = self._work_dir / filename
with code_path.open("w", encoding="utf-8") as fout:
fout.write(code)
+ files.append(code_path)
- command = ["timeout", str(self._timeout), _cmd(lang), filename]
+ if not execute_code:
+ outputs.append(f"Code saved to {str(code_path)}\n")
+ continue
+ command = ["timeout", str(self._timeout), _cmd(lang), filename]
result = self._container.exec_run(command)
exit_code = result.exit_code
output = result.output.decode("utf-8")
if exit_code == 124:
- output += "\n"
- output += TIMEOUT_MSG
-
+ output += "\n" + TIMEOUT_MSG
outputs.append(output)
- files.append(code_path)
last_exit_code = exit_code
if exit_code != 0:
diff --git a/autogen/coding/func_with_reqs.py b/autogen/coding/func_with_reqs.py
index 6f199573822..f255f1df017 100644
--- a/autogen/coding/func_with_reqs.py
+++ b/autogen/coding/func_with_reqs.py
@@ -6,7 +6,7 @@
from dataclasses import dataclass, field
from importlib.abc import SourceLoader
from textwrap import dedent, indent
-from typing import Any, Callable, Generic, List, TypeVar, Union
+from typing import Any, Callable, Generic, List, Set, TypeVar, Union
from typing_extensions import ParamSpec
@@ -159,12 +159,12 @@ def _build_python_functions_file(
funcs: List[Union[FunctionWithRequirements[Any, P], Callable[..., Any], FunctionWithRequirementsStr]]
) -> str:
# First collect all global imports
- global_imports = set()
+ global_imports: Set[str] = set()
for func in funcs:
if isinstance(func, (FunctionWithRequirements, FunctionWithRequirementsStr)):
- global_imports.update(func.global_imports)
+ global_imports.update(map(_import_to_str, func.global_imports))
- content = "\n".join(map(_import_to_str, global_imports)) + "\n\n"
+ content = "\n".join(global_imports) + "\n\n"
for func in funcs:
content += _to_code(func) + "\n\n"
diff --git a/autogen/coding/jupyter/base.py b/autogen/coding/jupyter/base.py
index d896b6ac3cc..0e7acaf1e87 100644
--- a/autogen/coding/jupyter/base.py
+++ b/autogen/coding/jupyter/base.py
@@ -10,9 +10,9 @@ class JupyterConnectionInfo:
"""`str` - Host of the Jupyter gateway server"""
use_https: bool
"""`bool` - Whether to use HTTPS"""
- port: int
- """`int` - Port of the Jupyter gateway server"""
- token: Optional[str]
+ port: Optional[int] = None
+ """`Optional[int]` - Port of the Jupyter gateway server. If None, the default port is used"""
+ token: Optional[str] = None
"""`Optional[str]` - Token for authentication. If None, no token is used"""
diff --git a/autogen/coding/jupyter/jupyter_client.py b/autogen/coding/jupyter/jupyter_client.py
index 44aafd8f5b0..b3de374fce9 100644
--- a/autogen/coding/jupyter/jupyter_client.py
+++ b/autogen/coding/jupyter/jupyter_client.py
@@ -41,10 +41,12 @@ def _get_headers(self) -> Dict[str, str]:
def _get_api_base_url(self) -> str:
protocol = "https" if self._connection_info.use_https else "http"
- return f"{protocol}://{self._connection_info.host}:{self._connection_info.port}"
+ port = f":{self._connection_info.port}" if self._connection_info.port else ""
+ return f"{protocol}://{self._connection_info.host}{port}"
def _get_ws_base_url(self) -> str:
- return f"ws://{self._connection_info.host}:{self._connection_info.port}"
+ port = f":{self._connection_info.port}" if self._connection_info.port else ""
+ return f"ws://{self._connection_info.host}{port}"
def list_kernel_specs(self) -> Dict[str, Dict[str, str]]:
response = self._session.get(f"{self._get_api_base_url()}/api/kernelspecs", headers=self._get_headers())
diff --git a/autogen/coding/local_commandline_code_executor.py b/autogen/coding/local_commandline_code_executor.py
index 68ef76b7e7f..620b359a4ae 100644
--- a/autogen/coding/local_commandline_code_executor.py
+++ b/autogen/coding/local_commandline_code_executor.py
@@ -1,4 +1,5 @@
import logging
+import os
import re
import subprocess
import sys
@@ -6,7 +7,8 @@
from hashlib import md5
from pathlib import Path
from string import Template
-from typing import Any, Callable, ClassVar, List, TypeVar, Union, cast
+from types import SimpleNamespace
+from typing import Any, Callable, ClassVar, Dict, List, Optional, Union
from typing_extensions import ParamSpec
@@ -28,7 +30,31 @@
class LocalCommandLineCodeExecutor(CodeExecutor):
- SUPPORTED_LANGUAGES: ClassVar[List[str]] = ["bash", "shell", "sh", "pwsh", "powershell", "ps1", "python"]
+ SUPPORTED_LANGUAGES: ClassVar[List[str]] = [
+ "bash",
+ "shell",
+ "sh",
+ "pwsh",
+ "powershell",
+ "ps1",
+ "python",
+ "javascript",
+ "html",
+ "css",
+ ]
+ DEFAULT_EXECUTION_POLICY: ClassVar[Dict[str, bool]] = {
+ "bash": True,
+ "shell": True,
+ "sh": True,
+ "pwsh": True,
+ "powershell": True,
+ "ps1": True,
+ "python": True,
+ "javascript": False,
+ "html": False,
+ "css": False,
+ }
+
FUNCTION_PROMPT_TEMPLATE: ClassVar[
str
] = """You have access to the following user defined functions. They can be accessed from the module called `$module_name` by their function names.
@@ -40,32 +66,45 @@ class LocalCommandLineCodeExecutor(CodeExecutor):
def __init__(
self,
timeout: int = 60,
+ virtual_env_context: Optional[SimpleNamespace] = None,
work_dir: Union[Path, str] = Path("."),
functions: List[Union[FunctionWithRequirements[Any, A], Callable[..., Any], FunctionWithRequirementsStr]] = [],
functions_module: str = "functions",
+ execution_policies: Optional[Dict[str, bool]] = None,
):
- """(Experimental) A code executor class that executes code through a local command line
+ """(Experimental) A code executor class that executes or saves LLM generated code a local command line
environment.
- **This will execute LLM generated code on the local machine.**
+ **This will execute or save LLM generated code on the local machine.**
+
+ Each code block is saved as a file in the working directory. Depending on the execution policy,
+ the code may be executed in a separate process.
+ The code blocks are executed or save in the order they are received.
+ Command line code is sanitized against a list of dangerous commands to prevent self-destructive commands from being executed,
+ which could potentially affect the user's environment. Supported languages include Python, shell scripts (bash, shell, sh),
+ PowerShell (pwsh, powershell, ps1), HTML, CSS, and JavaScript.
+ Execution policies determine whether each language's code blocks are executed or saved only.
+
+ ## Execution with a Python virtual environment
+ A python virtual env can be used to execute code and install dependencies. This has the added benefit of not polluting the
+ base environment with unwanted modules.
+ ```python
+ from autogen.code_utils import create_virtual_env
+ from autogen.coding import LocalCommandLineCodeExecutor
- Each code block is saved as a file and executed in a separate process in
- the working directory, and a unique file is generated and saved in the
- working directory for each code block.
- The code blocks are executed in the order they are received.
- Command line code is sanitized using regular expression match against a list of dangerous commands in order to prevent self-destructive
- commands from being executed which may potentially affect the users environment.
- Currently the only supported languages is Python and shell scripts.
- For Python code, use the language "python" for the code block.
- For shell scripts, use the language "bash", "shell", or "sh" for the code
- block.
+ venv_dir = ".venv"
+ venv_context = create_virtual_env(venv_dir)
+
+ executor = LocalCommandLineCodeExecutor(virtual_env_context=venv_context)
+ ```
Args:
- timeout (int): The timeout for code execution. Default is 60.
- work_dir (str): The working directory for the code execution. If None,
- a default working directory will be used. The default working
- directory is the current directory ".".
- functions (List[Union[FunctionWithRequirements[Any, A], Callable[..., Any]]]): A list of functions that are available to the code executor. Default is an empty list.
+ timeout (int): The timeout for code execution, default is 60 seconds.
+ virtual_env_context (Optional[SimpleNamespace]): The virtual environment context to use.
+ work_dir (Union[Path, str]): The working directory for code execution, defaults to the current directory.
+ functions (List[Union[FunctionWithRequirements[Any, A], Callable[..., Any], FunctionWithRequirementsStr]]): A list of callable functions available to the executor.
+ functions_module (str): The module name under which functions are accessible.
+ execution_policies (Optional[Dict[str, bool]]): A dictionary mapping languages to execution policies (True for execution, False for saving only). Defaults to class-wide DEFAULT_EXECUTION_POLICY.
"""
if timeout < 1:
@@ -83,6 +122,7 @@ def __init__(
self._timeout = timeout
self._work_dir: Path = work_dir
+ self._virtual_env_context: Optional[SimpleNamespace] = virtual_env_context
self._functions = functions
# Setup could take some time so we intentionally wait for the first code block to do it.
@@ -91,6 +131,10 @@ def __init__(
else:
self._setup_functions_complete = True
+ self.execution_policies = self.DEFAULT_EXECUTION_POLICY.copy()
+ if execution_policies is not None:
+ self.execution_policies.update(execution_policies)
+
def format_functions_for_prompt(self, prompt_template: str = FUNCTION_PROMPT_TEMPLATE) -> str:
"""(Experimental) Format the functions for a prompt.
@@ -104,7 +148,6 @@ def format_functions_for_prompt(self, prompt_template: str = FUNCTION_PROMPT_TEM
Returns:
str: The formatted prompt.
"""
-
template = Template(prompt_template)
return template.substitute(
module_name=self._functions_module,
@@ -171,26 +214,23 @@ def _setup_functions(self) -> None:
required_packages = list(set(flattened_packages))
if len(required_packages) > 0:
logging.info("Ensuring packages are installed in executor.")
-
- cmd = [sys.executable, "-m", "pip", "install"]
- cmd.extend(required_packages)
-
+ if self._virtual_env_context:
+ py_executable = self._virtual_env_context.env_exe
+ else:
+ py_executable = sys.executable
+ cmd = [py_executable, "-m", "pip", "install"] + required_packages
try:
result = subprocess.run(
cmd, cwd=self._work_dir, capture_output=True, text=True, timeout=float(self._timeout)
)
except subprocess.TimeoutExpired as e:
raise ValueError("Pip install timed out") from e
-
if result.returncode != 0:
raise ValueError(f"Pip install failed. {result.stdout}, {result.stderr}")
-
# Attempt to load the function file to check for syntax errors, imports etc.
exec_result = self._execute_code_dont_check_setup([CodeBlock(code=func_file_content, language="python")])
-
if exec_result.exit_code != 0:
raise ValueError(f"Functions failed to load: {exec_result.output}")
-
self._setup_functions_complete = True
def execute_code_blocks(self, code_blocks: List[CodeBlock]) -> CommandLineCodeResult:
@@ -201,10 +241,8 @@ def execute_code_blocks(self, code_blocks: List[CodeBlock]) -> CommandLineCodeRe
Returns:
CommandLineCodeResult: The result of the code execution."""
-
if not self._setup_functions_complete:
self._setup_functions()
-
return self._execute_code_dont_check_setup(code_blocks)
def _execute_code_dont_check_setup(self, code_blocks: List[CodeBlock]) -> CommandLineCodeResult:
@@ -229,6 +267,7 @@ def _execute_code_dont_check_setup(self, code_blocks: List[CodeBlock]) -> Comman
logs_all += "\n" + f"unknown language {lang}"
break
+ execute_code = self.execution_policies.get(lang, False)
try:
# Check if there is a filename comment
filename = _get_file_name_from_content(code, self._work_dir)
@@ -239,18 +278,32 @@ def _execute_code_dont_check_setup(self, code_blocks: List[CodeBlock]) -> Comman
# create a file with an automatically generated name
code_hash = md5(code.encode()).hexdigest()
filename = f"tmp_code_{code_hash}.{'py' if lang.startswith('python') else lang}"
-
written_file = (self._work_dir / filename).resolve()
with written_file.open("w", encoding="utf-8") as f:
f.write(code)
file_names.append(written_file)
- program = sys.executable if lang.startswith("python") else _cmd(lang)
+ if not execute_code:
+ # Just return a message that the file is saved.
+ logs_all += f"Code saved to {str(written_file)}\n"
+ exitcode = 0
+ continue
+
+ program = _cmd(lang)
cmd = [program, str(written_file.absolute())]
+ env = os.environ.copy()
+
+ if self._virtual_env_context:
+ virtual_env_abs_path = os.path.abspath(self._virtual_env_context.bin_path)
+ path_with_virtualenv = rf"{virtual_env_abs_path}{os.pathsep}{env['PATH']}"
+ env["PATH"] = path_with_virtualenv
+ if WIN32:
+ activation_script = os.path.join(virtual_env_abs_path, "activate.bat")
+ cmd = [activation_script, "&&", *cmd]
try:
result = subprocess.run(
- cmd, cwd=self._work_dir, capture_output=True, text=True, timeout=float(self._timeout)
+ cmd, cwd=self._work_dir, capture_output=True, text=True, timeout=float(self._timeout), env=env
)
except subprocess.TimeoutExpired:
logs_all += "\n" + TIMEOUT_MSG
diff --git a/autogen/coding/utils.py b/autogen/coding/utils.py
index 0a7c5a7785d..d692bfe35b9 100644
--- a/autogen/coding/utils.py
+++ b/autogen/coding/utils.py
@@ -3,23 +3,31 @@
from pathlib import Path
from typing import Optional
+filename_patterns = [
+ re.compile(r"^", re.DOTALL),
+ re.compile(r"^/\* (filename:)?(.+?) \*/", re.DOTALL),
+ re.compile(r"^// (filename:)?(.+?)$", re.DOTALL),
+ re.compile(r"^# (filename:)?(.+?)$", re.DOTALL),
+]
+
# Raises ValueError if the file is not in the workspace
def _get_file_name_from_content(code: str, workspace_path: Path) -> Optional[str]:
- first_line = code.split("\n")[0]
+ first_line = code.split("\n")[0].strip()
# TODO - support other languages
- if first_line.startswith("# filename:"):
- filename = first_line.split(":")[1].strip()
-
- # Handle relative paths in the filename
- path = Path(filename)
- if not path.is_absolute():
- path = workspace_path / path
- path = path.resolve()
- # Throws an error if the file is not in the workspace
- relative = path.relative_to(workspace_path.resolve())
- return str(relative)
+ for pattern in filename_patterns:
+ matches = pattern.match(first_line)
+ if matches is not None:
+ filename = matches.group(2).strip()
+ # Handle relative paths in the filename
+ path = Path(filename)
+ if not path.is_absolute():
+ path = workspace_path / path
+ path = path.resolve()
+ # Throws an error if the file is not in the workspace
+ relative = path.relative_to(workspace_path.resolve())
+ return str(relative)
return None
diff --git a/autogen/function_utils.py b/autogen/function_utils.py
index dd225fd4719..6b9b6f5b129 100644
--- a/autogen/function_utils.py
+++ b/autogen/function_utils.py
@@ -353,4 +353,4 @@ def serialize_to_str(x: Any) -> str:
elif isinstance(x, BaseModel):
return model_dump_json(x)
else:
- return json.dumps(x)
+ return json.dumps(x, ensure_ascii=False)
diff --git a/autogen/graph_utils.py b/autogen/graph_utils.py
index 88c218fde5e..d36b47a12ed 100644
--- a/autogen/graph_utils.py
+++ b/autogen/graph_utils.py
@@ -1,5 +1,5 @@
import logging
-from typing import Dict, List
+from typing import Dict, List, Optional
from autogen.agentchat import Agent
@@ -110,7 +110,9 @@ def invert_disallowed_to_allowed(disallowed_speaker_transitions_dict: dict, agen
return allowed_speaker_transitions_dict
-def visualize_speaker_transitions_dict(speaker_transitions_dict: dict, agents: List[Agent]):
+def visualize_speaker_transitions_dict(
+ speaker_transitions_dict: dict, agents: List[Agent], export_path: Optional[str] = None
+):
"""
Visualize the speaker_transitions_dict using networkx.
"""
@@ -133,4 +135,8 @@ def visualize_speaker_transitions_dict(speaker_transitions_dict: dict, agents: L
# Visualize
nx.draw(G, with_labels=True, font_weight="bold")
- plt.show()
+
+ if export_path is not None:
+ plt.savefig(export_path)
+ else:
+ plt.show()
diff --git a/autogen/logger/__init__.py b/autogen/logger/__init__.py
index 6561cab4360..c30711940c9 100644
--- a/autogen/logger/__init__.py
+++ b/autogen/logger/__init__.py
@@ -1,4 +1,5 @@
+from .file_logger import FileLogger
from .logger_factory import LoggerFactory
from .sqlite_logger import SqliteLogger
-__all__ = ("LoggerFactory", "SqliteLogger")
+__all__ = ("LoggerFactory", "SqliteLogger", "FileLogger")
diff --git a/autogen/logger/base_logger.py b/autogen/logger/base_logger.py
index 24e19c475c5..c5c236fa4ae 100644
--- a/autogen/logger/base_logger.py
+++ b/autogen/logger/base_logger.py
@@ -3,14 +3,15 @@
import sqlite3
import uuid
from abc import ABC, abstractmethod
-from typing import TYPE_CHECKING, Any, Dict, List, Union
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, TypeVar, Union
from openai import AzureOpenAI, OpenAI
from openai.types.chat import ChatCompletion
if TYPE_CHECKING:
- from autogen import ConversableAgent, OpenAIWrapper
+ from autogen import Agent, ConversableAgent, OpenAIWrapper
+F = TypeVar("F", bound=Callable[..., Any])
ConfigItem = Dict[str, Union[str, List[str]]]
LLMConfig = Dict[str, Union[None, float, int, ConfigItem, List[ConfigItem]]]
@@ -32,6 +33,7 @@ def log_chat_completion(
invocation_id: uuid.UUID,
client_id: int,
wrapper_id: int,
+ source: Union[str, Agent],
request: Dict[str, Union[float, str, List[Dict[str, str]]]],
response: Union[str, ChatCompletion],
is_cached: int,
@@ -49,9 +51,10 @@ def log_chat_completion(
invocation_id (uuid): A unique identifier for the invocation to the OpenAIWrapper.create method call
client_id (int): A unique identifier for the underlying OpenAI client instance
wrapper_id (int): A unique identifier for the OpenAIWrapper instance
- request (dict): A dictionary representing the the request or call to the OpenAI client endpoint
+ source (str or Agent): The source/creator of the event as a string name or an Agent instance
+ request (dict): A dictionary representing the request or call to the OpenAI client endpoint
response (str or ChatCompletion): The response from OpenAI
- is_chached (int): 1 if the response was a cache hit, 0 otherwise
+ is_cached (int): 1 if the response was a cache hit, 0 otherwise
cost(float): The cost for OpenAI response
start_time (str): A string representing the moment the request was initiated
"""
@@ -68,6 +71,18 @@ def log_new_agent(self, agent: ConversableAgent, init_args: Dict[str, Any]) -> N
"""
...
+ @abstractmethod
+ def log_event(self, source: Union[str, Agent], name: str, **kwargs: Dict[str, Any]) -> None:
+ """
+ Log an event for an agent.
+
+ Args:
+ source (str or Agent): The source/creator of the event as a string name or an Agent instance
+ name (str): The name of the event
+ kwargs (dict): The event information to log
+ """
+ ...
+
@abstractmethod
def log_new_wrapper(self, wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig, List[LLMConfig]]]) -> None:
"""
@@ -92,6 +107,18 @@ def log_new_client(
"""
...
+ @abstractmethod
+ def log_function_use(self, source: Union[str, Agent], function: F, args: Dict[str, Any], returns: Any) -> None:
+ """
+ Log the use of a registered function (could be a tool)
+
+ Args:
+ source (str or Agent): The source/creator of the event as a string name or an Agent instance
+ function (F): The function information
+ args (dict): The function args to log
+ returns (any): The return
+ """
+
@abstractmethod
def stop(self) -> None:
"""
diff --git a/autogen/logger/file_logger.py b/autogen/logger/file_logger.py
new file mode 100644
index 00000000000..37bbbd25a52
--- /dev/null
+++ b/autogen/logger/file_logger.py
@@ -0,0 +1,277 @@
+from __future__ import annotations
+
+import json
+import logging
+import os
+import threading
+import uuid
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
+
+from openai import AzureOpenAI, OpenAI
+from openai.types.chat import ChatCompletion
+
+from autogen.logger.base_logger import BaseLogger
+from autogen.logger.logger_utils import get_current_ts, to_dict
+
+from .base_logger import LLMConfig
+
+if TYPE_CHECKING:
+ from autogen import Agent, ConversableAgent, OpenAIWrapper
+ from autogen.oai.anthropic import AnthropicClient
+ from autogen.oai.bedrock import BedrockClient
+ from autogen.oai.cohere import CohereClient
+ from autogen.oai.gemini import GeminiClient
+ from autogen.oai.groq import GroqClient
+ from autogen.oai.mistral import MistralAIClient
+ from autogen.oai.together import TogetherClient
+
+logger = logging.getLogger(__name__)
+
+F = TypeVar("F", bound=Callable[..., Any])
+
+__all__ = ("FileLogger",)
+
+
+def safe_serialize(obj: Any) -> str:
+ def default(o: Any) -> str:
+ if hasattr(o, "to_json"):
+ return str(o.to_json())
+ else:
+ return f"<>"
+
+ return json.dumps(obj, default=default)
+
+
+class FileLogger(BaseLogger):
+ def __init__(self, config: Dict[str, Any]):
+ self.config = config
+ self.session_id = str(uuid.uuid4())
+
+ curr_dir = os.getcwd()
+ self.log_dir = os.path.join(curr_dir, "autogen_logs")
+ os.makedirs(self.log_dir, exist_ok=True)
+
+ self.log_file = os.path.join(self.log_dir, self.config.get("filename", "runtime.log"))
+ try:
+ with open(self.log_file, "a"):
+ pass
+ except Exception as e:
+ logger.error(f"[file_logger] Failed to create logging file: {e}")
+
+ self.logger = logging.getLogger(__name__)
+ self.logger.setLevel(logging.INFO)
+ file_handler = logging.FileHandler(self.log_file)
+ self.logger.addHandler(file_handler)
+
+ def start(self) -> str:
+ """Start the logger and return the session_id."""
+ try:
+ self.logger.info(f"Started new session with Session ID: {self.session_id}")
+ except Exception as e:
+ logger.error(f"[file_logger] Failed to create logging file: {e}")
+ finally:
+ return self.session_id
+
+ def log_chat_completion(
+ self,
+ invocation_id: uuid.UUID,
+ client_id: int,
+ wrapper_id: int,
+ source: Union[str, Agent],
+ request: Dict[str, Union[float, str, List[Dict[str, str]]]],
+ response: Union[str, ChatCompletion],
+ is_cached: int,
+ cost: float,
+ start_time: str,
+ ) -> None:
+ """
+ Log a chat completion.
+ """
+ thread_id = threading.get_ident()
+ source_name = None
+ if isinstance(source, str):
+ source_name = source
+ else:
+ source_name = source.name
+ try:
+ log_data = json.dumps(
+ {
+ "invocation_id": str(invocation_id),
+ "client_id": client_id,
+ "wrapper_id": wrapper_id,
+ "request": to_dict(request),
+ "response": str(response),
+ "is_cached": is_cached,
+ "cost": cost,
+ "start_time": start_time,
+ "end_time": get_current_ts(),
+ "thread_id": thread_id,
+ "source_name": source_name,
+ }
+ )
+
+ self.logger.info(log_data)
+ except Exception as e:
+ self.logger.error(f"[file_logger] Failed to log chat completion: {e}")
+
+ def log_new_agent(self, agent: ConversableAgent, init_args: Dict[str, Any] = {}) -> None:
+ """
+ Log a new agent instance.
+ """
+ thread_id = threading.get_ident()
+
+ try:
+ log_data = json.dumps(
+ {
+ "id": id(agent),
+ "agent_name": agent.name if hasattr(agent, "name") and agent.name is not None else "",
+ "wrapper_id": to_dict(
+ agent.client.wrapper_id if hasattr(agent, "client") and agent.client is not None else ""
+ ),
+ "session_id": self.session_id,
+ "current_time": get_current_ts(),
+ "agent_type": type(agent).__name__,
+ "args": to_dict(init_args),
+ "thread_id": thread_id,
+ }
+ )
+ self.logger.info(log_data)
+ except Exception as e:
+ self.logger.error(f"[file_logger] Failed to log new agent: {e}")
+
+ def log_event(self, source: Union[str, Agent], name: str, **kwargs: Dict[str, Any]) -> None:
+ """
+ Log an event from an agent or a string source.
+ """
+ from autogen import Agent
+
+ # This takes an object o as input and returns a string. If the object o cannot be serialized, instead of raising an error,
+ # it returns a string indicating that the object is non-serializable, along with its type's qualified name obtained using __qualname__.
+ json_args = json.dumps(kwargs, default=lambda o: f"<>")
+ thread_id = threading.get_ident()
+
+ if isinstance(source, Agent):
+ try:
+ log_data = json.dumps(
+ {
+ "source_id": id(source),
+ "source_name": str(source.name) if hasattr(source, "name") else source,
+ "event_name": name,
+ "agent_module": source.__module__,
+ "agent_class": source.__class__.__name__,
+ "json_state": json_args,
+ "timestamp": get_current_ts(),
+ "thread_id": thread_id,
+ }
+ )
+ self.logger.info(log_data)
+ except Exception as e:
+ self.logger.error(f"[file_logger] Failed to log event {e}")
+ else:
+ try:
+ log_data = json.dumps(
+ {
+ "source_id": id(source),
+ "source_name": str(source.name) if hasattr(source, "name") else source,
+ "event_name": name,
+ "json_state": json_args,
+ "timestamp": get_current_ts(),
+ "thread_id": thread_id,
+ }
+ )
+ self.logger.info(log_data)
+ except Exception as e:
+ self.logger.error(f"[file_logger] Failed to log event {e}")
+
+ def log_new_wrapper(
+ self, wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig, List[LLMConfig]]] = {}
+ ) -> None:
+ """
+ Log a new wrapper instance.
+ """
+ thread_id = threading.get_ident()
+
+ try:
+ log_data = json.dumps(
+ {
+ "wrapper_id": id(wrapper),
+ "session_id": self.session_id,
+ "json_state": json.dumps(init_args),
+ "timestamp": get_current_ts(),
+ "thread_id": thread_id,
+ }
+ )
+ self.logger.info(log_data)
+ except Exception as e:
+ self.logger.error(f"[file_logger] Failed to log event {e}")
+
+ def log_new_client(
+ self,
+ client: (
+ AzureOpenAI
+ | OpenAI
+ | GeminiClient
+ | AnthropicClient
+ | MistralAIClient
+ | TogetherClient
+ | GroqClient
+ | CohereClient
+ | BedrockClient
+ ),
+ wrapper: OpenAIWrapper,
+ init_args: Dict[str, Any],
+ ) -> None:
+ """
+ Log a new client instance.
+ """
+ thread_id = threading.get_ident()
+
+ try:
+ log_data = json.dumps(
+ {
+ "client_id": id(client),
+ "wrapper_id": id(wrapper),
+ "session_id": self.session_id,
+ "class": type(client).__name__,
+ "json_state": json.dumps(init_args),
+ "timestamp": get_current_ts(),
+ "thread_id": thread_id,
+ }
+ )
+ self.logger.info(log_data)
+ except Exception as e:
+ self.logger.error(f"[file_logger] Failed to log event {e}")
+
+ def log_function_use(self, source: Union[str, Agent], function: F, args: Dict[str, Any], returns: Any) -> None:
+ """
+ Log a registered function(can be a tool) use from an agent or a string source.
+ """
+ thread_id = threading.get_ident()
+
+ try:
+ log_data = json.dumps(
+ {
+ "source_id": id(source),
+ "source_name": str(source.name) if hasattr(source, "name") else source,
+ "agent_module": source.__module__,
+ "agent_class": source.__class__.__name__,
+ "timestamp": get_current_ts(),
+ "thread_id": thread_id,
+ "input_args": safe_serialize(args),
+ "returns": safe_serialize(returns),
+ }
+ )
+ self.logger.info(log_data)
+ except Exception as e:
+ self.logger.error(f"[file_logger] Failed to log event {e}")
+
+ def get_connection(self) -> None:
+ """Method is intentionally left blank because there is no specific connection needed for the FileLogger."""
+ pass
+
+ def stop(self) -> None:
+ """Close the file handler and remove it from the logger."""
+ for handler in self.logger.handlers:
+ if isinstance(handler, logging.FileHandler):
+ handler.close()
+ self.logger.removeHandler(handler)
diff --git a/autogen/logger/logger_factory.py b/autogen/logger/logger_factory.py
index 8073c0c07d3..ed9567977bb 100644
--- a/autogen/logger/logger_factory.py
+++ b/autogen/logger/logger_factory.py
@@ -1,6 +1,7 @@
-from typing import Any, Dict, Optional
+from typing import Any, Dict, Literal, Optional
from autogen.logger.base_logger import BaseLogger
+from autogen.logger.file_logger import FileLogger
from autogen.logger.sqlite_logger import SqliteLogger
__all__ = ("LoggerFactory",)
@@ -8,11 +9,15 @@
class LoggerFactory:
@staticmethod
- def get_logger(logger_type: str = "sqlite", config: Optional[Dict[str, Any]] = None) -> BaseLogger:
+ def get_logger(
+ logger_type: Literal["sqlite", "file"] = "sqlite", config: Optional[Dict[str, Any]] = None
+ ) -> BaseLogger:
if config is None:
config = {}
if logger_type == "sqlite":
return SqliteLogger(config)
+ elif logger_type == "file":
+ return FileLogger(config)
else:
raise ValueError(f"[logger_factory] Unknown logger type: {logger_type}")
diff --git a/autogen/logger/sqlite_logger.py b/autogen/logger/sqlite_logger.py
index 62f758c51eb..f76d039ce9d 100644
--- a/autogen/logger/sqlite_logger.py
+++ b/autogen/logger/sqlite_logger.py
@@ -6,7 +6,7 @@
import sqlite3
import threading
import uuid
-from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, TypeVar, Union
from openai import AzureOpenAI, OpenAI
from openai.types.chat import ChatCompletion
@@ -17,13 +17,32 @@
from .base_logger import LLMConfig
if TYPE_CHECKING:
- from autogen import ConversableAgent, OpenAIWrapper
+ from autogen import Agent, ConversableAgent, OpenAIWrapper
+ from autogen.oai.anthropic import AnthropicClient
+ from autogen.oai.bedrock import BedrockClient
+ from autogen.oai.cohere import CohereClient
+ from autogen.oai.gemini import GeminiClient
+ from autogen.oai.groq import GroqClient
+ from autogen.oai.mistral import MistralAIClient
+ from autogen.oai.together import TogetherClient
logger = logging.getLogger(__name__)
lock = threading.Lock()
__all__ = ("SqliteLogger",)
+F = TypeVar("F", bound=Callable[..., Any])
+
+
+def safe_serialize(obj: Any) -> str:
+ def default(o: Any) -> str:
+ if hasattr(o, "to_json"):
+ return str(o.to_json())
+ else:
+ return f"<>"
+
+ return json.dumps(obj, default=default)
+
class SqliteLogger(BaseLogger):
schema_version = 1
@@ -48,6 +67,7 @@ def start(self) -> str:
client_id INTEGER,
wrapper_id INTEGER,
session_id TEXT,
+ source_name TEXT,
request TEXT,
response TEXT,
is_cached INEGER,
@@ -103,6 +123,32 @@ class TEXT, -- type or class name of cli
"""
self._run_query(query=query)
+ query = """
+ CREATE TABLE IF NOT EXISTS events (
+ event_name TEXT,
+ source_id INTEGER,
+ source_name TEXT,
+ agent_module TEXT DEFAULT NULL,
+ agent_class_name TEXT DEFAULT NULL,
+ id INTEGER PRIMARY KEY,
+ json_state TEXT,
+ timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
+ );
+ """
+ self._run_query(query=query)
+
+ query = """
+ CREATE TABLE IF NOT EXISTS function_calls (
+ source_id INTEGER,
+ source_name TEXT,
+ function_name TEXT,
+ args TEXT DEFAULT NULL,
+ returns TEXT DEFAULT NULL,
+ timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
+ );
+ """
+ self._run_query(query=query)
+
current_verion = self._get_current_db_version()
if current_verion is None:
self._run_query(
@@ -177,6 +223,7 @@ def log_chat_completion(
invocation_id: uuid.UUID,
client_id: int,
wrapper_id: int,
+ source: Union[str, Agent],
request: Dict[str, Union[float, str, List[Dict[str, str]]]],
response: Union[str, ChatCompletion],
is_cached: int,
@@ -193,10 +240,16 @@ def log_chat_completion(
else:
response_messages = json.dumps(to_dict(response), indent=4)
+ source_name = None
+ if isinstance(source, str):
+ source_name = source
+ else:
+ source_name = source.name
+
query = """
INSERT INTO chat_completions (
- invocation_id, client_id, wrapper_id, session_id, request, response, is_cached, cost, start_time, end_time
- ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
+ invocation_id, client_id, wrapper_id, session_id, request, response, is_cached, cost, start_time, end_time, source_name
+ ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"""
args = (
invocation_id,
@@ -209,6 +262,7 @@ def log_chat_completion(
cost,
start_time,
end_time,
+ source_name,
)
self._run_query(query=query, args=args)
@@ -221,7 +275,16 @@ def log_new_agent(self, agent: ConversableAgent, init_args: Dict[str, Any]) -> N
args = to_dict(
init_args,
- exclude=("self", "__class__", "api_key", "organization", "base_url", "azure_endpoint"),
+ exclude=(
+ "self",
+ "__class__",
+ "api_key",
+ "organization",
+ "base_url",
+ "azure_endpoint",
+ "azure_ad_token",
+ "azure_ad_token_provider",
+ ),
no_recursive=(Agent,),
)
@@ -246,12 +309,57 @@ class = excluded.class,
)
self._run_query(query=query, args=args)
+ def log_event(self, source: Union[str, Agent], name: str, **kwargs: Dict[str, Any]) -> None:
+ from autogen import Agent
+
+ if self.con is None:
+ return
+
+ json_args = json.dumps(kwargs, default=lambda o: f"<>")
+
+ if isinstance(source, Agent):
+ query = """
+ INSERT INTO events (source_id, source_name, event_name, agent_module, agent_class_name, json_state, timestamp) VALUES (?, ?, ?, ?, ?, ?, ?)
+ """
+ args = (
+ id(source),
+ source.name if hasattr(source, "name") else source,
+ name,
+ source.__module__,
+ source.__class__.__name__,
+ json_args,
+ get_current_ts(),
+ )
+ self._run_query(query=query, args=args)
+ else:
+ query = """
+ INSERT INTO events (source_id, source_name, event_name, json_state, timestamp) VALUES (?, ?, ?, ?, ?)
+ """
+ args_str_based = (
+ id(source),
+ source.name if hasattr(source, "name") else source,
+ name,
+ json_args,
+ get_current_ts(),
+ )
+ self._run_query(query=query, args=args_str_based)
+
def log_new_wrapper(self, wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig, List[LLMConfig]]]) -> None:
if self.con is None:
return
args = to_dict(
- init_args, exclude=("self", "__class__", "api_key", "organization", "base_url", "azure_endpoint")
+ init_args,
+ exclude=(
+ "self",
+ "__class__",
+ "api_key",
+ "organization",
+ "base_url",
+ "azure_endpoint",
+ "azure_ad_token",
+ "azure_ad_token_provider",
+ ),
)
query = """
@@ -266,14 +374,55 @@ def log_new_wrapper(self, wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLM
)
self._run_query(query=query, args=args)
+ def log_function_use(self, source: Union[str, Agent], function: F, args: Dict[str, Any], returns: Any) -> None:
+
+ if self.con is None:
+ return
+
+ query = """
+ INSERT INTO function_calls (source_id, source_name, function_name, args, returns, timestamp) VALUES (?, ?, ?, ?, ?, ?)
+ """
+ query_args: Tuple[Any, ...] = (
+ id(source),
+ source.name if hasattr(source, "name") else source,
+ function.__name__,
+ safe_serialize(args),
+ safe_serialize(returns),
+ get_current_ts(),
+ )
+ self._run_query(query=query, args=query_args)
+
def log_new_client(
- self, client: Union[AzureOpenAI, OpenAI], wrapper: OpenAIWrapper, init_args: Dict[str, Any]
+ self,
+ client: Union[
+ AzureOpenAI,
+ OpenAI,
+ GeminiClient,
+ AnthropicClient,
+ MistralAIClient,
+ TogetherClient,
+ GroqClient,
+ CohereClient,
+ BedrockClient,
+ ],
+ wrapper: OpenAIWrapper,
+ init_args: Dict[str, Any],
) -> None:
if self.con is None:
return
args = to_dict(
- init_args, exclude=("self", "__class__", "api_key", "organization", "base_url", "azure_endpoint")
+ init_args,
+ exclude=(
+ "self",
+ "__class__",
+ "api_key",
+ "organization",
+ "base_url",
+ "azure_endpoint",
+ "azure_ad_token",
+ "azure_ad_token_provider",
+ ),
)
query = """
diff --git a/autogen/oai/anthropic.py b/autogen/oai/anthropic.py
new file mode 100644
index 00000000000..8ed6f909e6b
--- /dev/null
+++ b/autogen/oai/anthropic.py
@@ -0,0 +1,422 @@
+"""
+Create an OpenAI-compatible client for the Anthropic API.
+
+Example usage:
+Install the `anthropic` package by running `pip install --upgrade anthropic`.
+- https://docs.anthropic.com/en/docs/quickstart-guide
+
+import autogen
+
+config_list = [
+ {
+ "model": "claude-3-sonnet-20240229",
+ "api_key": os.getenv("ANTHROPIC_API_KEY"),
+ "api_type": "anthropic",
+ }
+]
+
+assistant = autogen.AssistantAgent("assistant", llm_config={"config_list": config_list})
+
+Example usage for Anthropic Bedrock:
+
+Install the `anthropic` package by running `pip install --upgrade anthropic`.
+- https://docs.anthropic.com/en/docs/quickstart-guide
+
+import autogen
+
+config_list = [
+ {
+ "model": "anthropic.claude-3-5-sonnet-20240620-v1:0",
+ "aws_access_key":,
+ "aws_secret_key":,
+ "aws_session_token":,
+ "aws_region":"us-east-1",
+ "api_type": "anthropic",
+ }
+]
+
+assistant = autogen.AssistantAgent("assistant", llm_config={"config_list": config_list})
+
+"""
+
+from __future__ import annotations
+
+import copy
+import inspect
+import json
+import os
+import time
+import warnings
+from typing import Any, Dict, List, Tuple, Union
+
+from anthropic import Anthropic, AnthropicBedrock
+from anthropic import __version__ as anthropic_version
+from anthropic.types import Completion, Message, TextBlock, ToolUseBlock
+from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
+from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
+from openai.types.completion_usage import CompletionUsage
+from typing_extensions import Annotated
+
+from autogen.oai.client_utils import validate_parameter
+
+TOOL_ENABLED = anthropic_version >= "0.23.1"
+if TOOL_ENABLED:
+ from anthropic.types.tool_use_block_param import (
+ ToolUseBlockParam,
+ )
+
+
+ANTHROPIC_PRICING_1k = {
+ "claude-3-5-sonnet-20240620": (0.003, 0.015),
+ "claude-3-sonnet-20240229": (0.003, 0.015),
+ "claude-3-opus-20240229": (0.015, 0.075),
+ "claude-3-haiku-20240307": (0.00025, 0.00125),
+ "claude-2.1": (0.008, 0.024),
+ "claude-2.0": (0.008, 0.024),
+ "claude-instant-1.2": (0.008, 0.024),
+}
+
+
+class AnthropicClient:
+ def __init__(self, **kwargs: Any):
+ """
+ Initialize the Anthropic API client.
+ Args:
+ api_key (str): The API key for the Anthropic API or set the `ANTHROPIC_API_KEY` environment variable.
+ """
+ self._api_key = kwargs.get("api_key", None)
+ self._aws_access_key = kwargs.get("aws_access_key", None)
+ self._aws_secret_key = kwargs.get("aws_secret_key", None)
+ self._aws_session_token = kwargs.get("aws_session_token", None)
+ self._aws_region = kwargs.get("aws_region", None)
+
+ if not self._api_key:
+ self._api_key = os.getenv("ANTHROPIC_API_KEY")
+
+ if not self._aws_access_key:
+ self._aws_access_key = os.getenv("AWS_ACCESS_KEY")
+
+ if not self._aws_secret_key:
+ self._aws_secret_key = os.getenv("AWS_SECRET_KEY")
+
+ if not self._aws_region:
+ self._aws_region = os.getenv("AWS_REGION")
+
+ if self._api_key is None and (
+ self._aws_access_key is None or self._aws_secret_key is None or self._aws_region is None
+ ):
+ raise ValueError("API key or AWS credentials are required to use the Anthropic API.")
+
+ if self._api_key is not None:
+ self._client = Anthropic(api_key=self._api_key)
+ else:
+ self._client = AnthropicBedrock(
+ aws_access_key=self._aws_access_key,
+ aws_secret_key=self._aws_secret_key,
+ aws_session_token=self._aws_session_token,
+ aws_region=self._aws_region,
+ )
+
+ self._last_tooluse_status = {}
+
+ def load_config(self, params: Dict[str, Any]):
+ """Load the configuration for the Anthropic API client."""
+ anthropic_params = {}
+
+ anthropic_params["model"] = params.get("model", None)
+ assert anthropic_params["model"], "Please provide a `model` in the config_list to use the Anthropic API."
+
+ anthropic_params["temperature"] = validate_parameter(
+ params, "temperature", (float, int), False, 1.0, (0.0, 1.0), None
+ )
+ anthropic_params["max_tokens"] = validate_parameter(params, "max_tokens", int, False, 4096, (1, None), None)
+ anthropic_params["top_k"] = validate_parameter(params, "top_k", int, True, None, (1, None), None)
+ anthropic_params["top_p"] = validate_parameter(params, "top_p", (float, int), True, None, (0.0, 1.0), None)
+ anthropic_params["stop_sequences"] = validate_parameter(params, "stop_sequences", list, True, None, None, None)
+ anthropic_params["stream"] = validate_parameter(params, "stream", bool, False, False, None, None)
+
+ if anthropic_params["stream"]:
+ warnings.warn(
+ "Streaming is not currently supported, streaming will be disabled.",
+ UserWarning,
+ )
+ anthropic_params["stream"] = False
+
+ return anthropic_params
+
+ def cost(self, response) -> float:
+ """Calculate the cost of the completion using the Anthropic pricing."""
+ return response.cost
+
+ @property
+ def api_key(self):
+ return self._api_key
+
+ @property
+ def aws_access_key(self):
+ return self._aws_access_key
+
+ @property
+ def aws_secret_key(self):
+ return self._aws_secret_key
+
+ @property
+ def aws_session_token(self):
+ return self._aws_session_token
+
+ @property
+ def aws_region(self):
+ return self._aws_region
+
+ def create(self, params: Dict[str, Any]) -> Completion:
+ if "tools" in params:
+ converted_functions = self.convert_tools_to_functions(params["tools"])
+ params["functions"] = params.get("functions", []) + converted_functions
+
+ # Convert AutoGen messages to Anthropic messages
+ anthropic_messages = oai_messages_to_anthropic_messages(params)
+ anthropic_params = self.load_config(params)
+
+ # TODO: support stream
+ params = params.copy()
+ if "functions" in params:
+ tools_configs = params.pop("functions")
+ tools_configs = [self.openai_func_to_anthropic(tool) for tool in tools_configs]
+ params["tools"] = tools_configs
+
+ # Anthropic doesn't accept None values, so we need to use keyword argument unpacking instead of setting parameters.
+ # Copy params we need into anthropic_params
+ # Remove any that don't have values
+ anthropic_params["messages"] = anthropic_messages
+ if "system" in params:
+ anthropic_params["system"] = params["system"]
+ if "tools" in params:
+ anthropic_params["tools"] = params["tools"]
+ if anthropic_params["top_k"] is None:
+ del anthropic_params["top_k"]
+ if anthropic_params["top_p"] is None:
+ del anthropic_params["top_p"]
+ if anthropic_params["stop_sequences"] is None:
+ del anthropic_params["stop_sequences"]
+
+ response = self._client.messages.create(**anthropic_params)
+
+ # Calculate and save the cost onto the response
+ prompt_tokens = response.usage.input_tokens
+ completion_tokens = response.usage.output_tokens
+
+ message_text = ""
+ if response is not None:
+ # If we have tool use as the response, populate completed tool calls for our return OAI response
+ if response.stop_reason == "tool_use":
+ anthropic_finish = "tool_calls"
+ tool_calls = []
+ for content in response.content:
+ if type(content) == ToolUseBlock:
+ tool_calls.append(
+ ChatCompletionMessageToolCall(
+ id=content.id,
+ function={"name": content.name, "arguments": json.dumps(content.input)},
+ type="function",
+ )
+ )
+ else:
+ anthropic_finish = "stop"
+ tool_calls = None
+
+ # Retrieve any text content from the response
+ for content in response.content:
+ if type(content) == TextBlock:
+ message_text = content.text
+ break
+
+ # Convert output back to AutoGen response format
+ message = ChatCompletionMessage(
+ role="assistant",
+ content=message_text,
+ function_call=None,
+ tool_calls=tool_calls,
+ )
+ choices = [Choice(finish_reason=anthropic_finish, index=0, message=message)]
+
+ response_oai = ChatCompletion(
+ id=response.id,
+ model=anthropic_params["model"],
+ created=int(time.time()),
+ object="chat.completion",
+ choices=choices,
+ usage=CompletionUsage(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=prompt_tokens + completion_tokens,
+ ),
+ cost=_calculate_cost(prompt_tokens, completion_tokens, anthropic_params["model"]),
+ )
+
+ return response_oai
+
+ def message_retrieval(self, response) -> List:
+ """
+ Retrieve and return a list of strings or a list of Choice.Message from the response.
+
+ NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
+ since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
+ """
+ return [choice.message for choice in response.choices]
+
+ @staticmethod
+ def openai_func_to_anthropic(openai_func: dict) -> dict:
+ res = openai_func.copy()
+ res["input_schema"] = res.pop("parameters")
+ return res
+
+ @staticmethod
+ def get_usage(response: ChatCompletion) -> Dict:
+ """Get the usage of tokens and their cost information."""
+ return {
+ "prompt_tokens": response.usage.prompt_tokens if response.usage is not None else 0,
+ "completion_tokens": response.usage.completion_tokens if response.usage is not None else 0,
+ "total_tokens": response.usage.total_tokens if response.usage is not None else 0,
+ "cost": response.cost if hasattr(response, "cost") else 0.0,
+ "model": response.model,
+ }
+
+ @staticmethod
+ def convert_tools_to_functions(tools: List) -> List:
+ functions = []
+ for tool in tools:
+ if tool.get("type") == "function" and "function" in tool:
+ functions.append(tool["function"])
+
+ return functions
+
+
+def oai_messages_to_anthropic_messages(params: Dict[str, Any]) -> list[dict[str, Any]]:
+ """Convert messages from OAI format to Anthropic format.
+ We correct for any specific role orders and types, etc.
+ """
+
+ # Track whether we have tools passed in. If not, tool use / result messages should be converted to text messages.
+ # Anthropic requires a tools parameter with the tools listed, if there are other messages with tool use or tool results.
+ # This can occur when we don't need tool calling, such as for group chat speaker selection.
+ has_tools = "tools" in params
+
+ # Convert messages to Anthropic compliant format
+ processed_messages = []
+
+ # Used to interweave user messages to ensure user/assistant alternating
+ user_continue_message = {"content": "Please continue.", "role": "user"}
+ assistant_continue_message = {"content": "Please continue.", "role": "assistant"}
+
+ tool_use_messages = 0
+ tool_result_messages = 0
+ last_tool_use_index = -1
+ last_tool_result_index = -1
+ for message in params["messages"]:
+ if message["role"] == "system":
+ params["system"] = message["content"]
+ else:
+ # New messages will be added here, manage role alternations
+ expected_role = "user" if len(processed_messages) % 2 == 0 else "assistant"
+
+ if "tool_calls" in message:
+ # Map the tool call options to Anthropic's ToolUseBlock
+ tool_uses = []
+ tool_names = []
+ for tool_call in message["tool_calls"]:
+ tool_uses.append(
+ ToolUseBlock(
+ type="tool_use",
+ id=tool_call["id"],
+ name=tool_call["function"]["name"],
+ input=json.loads(tool_call["function"]["arguments"]),
+ )
+ )
+ if has_tools:
+ tool_use_messages += 1
+ tool_names.append(tool_call["function"]["name"])
+
+ if expected_role == "user":
+ # Insert an extra user message as we will append an assistant message
+ processed_messages.append(user_continue_message)
+
+ if has_tools:
+ processed_messages.append({"role": "assistant", "content": tool_uses})
+ last_tool_use_index = len(processed_messages) - 1
+ else:
+ # Not using tools, so put in a plain text message
+ processed_messages.append(
+ {
+ "role": "assistant",
+ "content": f"Some internal function(s) that could be used: [{', '.join(tool_names)}]",
+ }
+ )
+ elif "tool_call_id" in message:
+ if has_tools:
+ # Map the tool usage call to tool_result for Anthropic
+ tool_result = {
+ "type": "tool_result",
+ "tool_use_id": message["tool_call_id"],
+ "content": message["content"],
+ }
+
+ # If the previous message also had a tool_result, add it to that
+ # Otherwise append a new message
+ if last_tool_result_index == len(processed_messages) - 1:
+ processed_messages[-1]["content"].append(tool_result)
+ else:
+ if expected_role == "assistant":
+ # Insert an extra assistant message as we will append a user message
+ processed_messages.append(assistant_continue_message)
+
+ processed_messages.append({"role": "user", "content": [tool_result]})
+ last_tool_result_index = len(processed_messages) - 1
+
+ tool_result_messages += 1
+ else:
+ # Not using tools, so put in a plain text message
+ processed_messages.append(
+ {"role": "user", "content": f"Running the function returned: {message['content']}"}
+ )
+ elif message["content"] == "":
+ # Ignoring empty messages
+ pass
+ else:
+ if expected_role != message["role"]:
+ # Inserting the alternating continue message
+ processed_messages.append(
+ user_continue_message if expected_role == "user" else assistant_continue_message
+ )
+
+ processed_messages.append(message)
+
+ # We'll replace the last tool_use if there's no tool_result (occurs if we finish the conversation before running the function)
+ if has_tools and tool_use_messages != tool_result_messages:
+ processed_messages[last_tool_use_index] = assistant_continue_message
+
+ # name is not a valid field on messages
+ for message in processed_messages:
+ if "name" in message:
+ message.pop("name", None)
+
+ # Note: When using reflection_with_llm we may end up with an "assistant" message as the last message and that may cause a blank response
+ # So, if the last role is not user, add a 'user' continue message at the end
+ if processed_messages[-1]["role"] != "user":
+ processed_messages.append(user_continue_message)
+
+ return processed_messages
+
+
+def _calculate_cost(input_tokens: int, output_tokens: int, model: str) -> float:
+ """Calculate the cost of the completion using the Anthropic pricing."""
+ total = 0.0
+
+ if model in ANTHROPIC_PRICING_1k:
+ input_cost_per_1k, output_cost_per_1k = ANTHROPIC_PRICING_1k[model]
+ input_cost = (input_tokens / 1000) * input_cost_per_1k
+ output_cost = (output_tokens / 1000) * output_cost_per_1k
+ total = input_cost + output_cost
+ else:
+ warnings.warn(f"Cost calculation not available for model {model}", UserWarning)
+
+ return total
diff --git a/autogen/oai/bedrock.py b/autogen/oai/bedrock.py
new file mode 100644
index 00000000000..7894781e3ee
--- /dev/null
+++ b/autogen/oai/bedrock.py
@@ -0,0 +1,606 @@
+"""
+Create a compatible client for the Amazon Bedrock Converse API.
+
+Example usage:
+Install the `boto3` package by running `pip install --upgrade boto3`.
+- https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html
+
+import autogen
+
+config_list = [
+ {
+ "api_type": "bedrock",
+ "model": "meta.llama3-1-8b-instruct-v1:0",
+ "aws_region": "us-west-2",
+ "aws_access_key": "",
+ "aws_secret_key": "",
+ "price" : [0.003, 0.015]
+ }
+]
+
+assistant = autogen.AssistantAgent("assistant", llm_config={"config_list": config_list})
+
+"""
+
+from __future__ import annotations
+
+import base64
+import json
+import os
+import re
+import time
+import warnings
+from typing import Any, Dict, List, Literal, Tuple
+
+import boto3
+import requests
+from botocore.config import Config
+from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
+from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
+from openai.types.completion_usage import CompletionUsage
+
+from autogen.oai.client_utils import validate_parameter
+
+
+class BedrockClient:
+ """Client for Amazon's Bedrock Converse API."""
+
+ _retries = 5
+
+ def __init__(self, **kwargs: Any):
+ """
+ Initialises BedrockClient for Amazon's Bedrock Converse API
+ """
+ self._aws_access_key = kwargs.get("aws_access_key", None)
+ self._aws_secret_key = kwargs.get("aws_secret_key", None)
+ self._aws_session_token = kwargs.get("aws_session_token", None)
+ self._aws_region = kwargs.get("aws_region", None)
+ self._aws_profile_name = kwargs.get("aws_profile_name", None)
+
+ if not self._aws_access_key:
+ self._aws_access_key = os.getenv("AWS_ACCESS_KEY")
+
+ if not self._aws_secret_key:
+ self._aws_secret_key = os.getenv("AWS_SECRET_KEY")
+
+ if not self._aws_session_token:
+ self._aws_session_token = os.getenv("AWS_SESSION_TOKEN")
+
+ if not self._aws_region:
+ self._aws_region = os.getenv("AWS_REGION")
+
+ if self._aws_region is None:
+ raise ValueError("Region is required to use the Amazon Bedrock API.")
+
+ # Initialize Bedrock client, session, and runtime
+ bedrock_config = Config(
+ region_name=self._aws_region,
+ signature_version="v4",
+ retries={"max_attempts": self._retries, "mode": "standard"},
+ )
+
+ session = boto3.Session(
+ aws_access_key_id=self._aws_access_key,
+ aws_secret_access_key=self._aws_secret_key,
+ aws_session_token=self._aws_session_token,
+ profile_name=self._aws_profile_name,
+ )
+
+ self.bedrock_runtime = session.client(service_name="bedrock-runtime", config=bedrock_config)
+
+ def message_retrieval(self, response):
+ """Retrieve the messages from the response."""
+ return [choice.message for choice in response.choices]
+
+ def parse_custom_params(self, params: Dict[str, Any]):
+ """
+ Parses custom parameters for logic in this client class
+ """
+
+ # Should we separate system messages into its own request parameter, default is True
+ # This is required because not all models support a system prompt (e.g. Mistral Instruct).
+ self._supports_system_prompts = params.get("supports_system_prompts", True)
+
+ def parse_params(self, params: Dict[str, Any]) -> tuple[Dict[str, Any], Dict[str, Any]]:
+ """
+ Loads the valid parameters required to invoke Bedrock Converse
+ Returns a tuple of (base_params, additional_params)
+ """
+
+ base_params = {}
+ additional_params = {}
+
+ # Amazon Bedrock base model IDs are here:
+ # https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
+ self._model_id = params.get("model", None)
+ assert self._model_id, "Please provide the 'model` in the config_list to use Amazon Bedrock"
+
+ # Parameters vary based on the model used.
+ # As we won't cater for all models and parameters, it's the developer's
+ # responsibility to implement the parameters and they will only be
+ # included if the developer has it in the config.
+ #
+ # Important:
+ # No defaults will be used (as they can vary per model)
+ # No ranges will be used (as they can vary)
+ # We will cover all the main parameters but there may be others
+ # that need to be added later
+ #
+ # Here are some pages that show the parameters available for different models
+ # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html
+ # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-text-completion.html
+ # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command-r-plus.html
+ # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html
+ # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral-chat-completion.html
+
+ # Here are the possible "base" parameters and their suitable types
+ base_parameters = [["temperature", (float, int)], ["topP", (float, int)], ["maxTokens", (int)]]
+
+ for param_name, suitable_types in base_parameters:
+ if param_name in params:
+ base_params[param_name] = validate_parameter(
+ params, param_name, suitable_types, False, None, None, None
+ )
+
+ # Here are the possible "model-specific" parameters and their suitable types, known as additional parameters
+ additional_parameters = [
+ ["top_p", (float, int)],
+ ["top_k", (int)],
+ ["k", (int)],
+ ["seed", (int)],
+ ]
+
+ for param_name, suitable_types in additional_parameters:
+ if param_name in params:
+ additional_params[param_name] = validate_parameter(
+ params, param_name, suitable_types, False, None, None, None
+ )
+
+ # Streaming
+ if "stream" in params:
+ self._streaming = params["stream"]
+ else:
+ self._streaming = False
+
+ # For this release we will not support streaming as many models do not support streaming with tool use
+ if self._streaming:
+ warnings.warn(
+ "Streaming is not currently supported, streaming will be disabled.",
+ UserWarning,
+ )
+ self._streaming = False
+
+ return base_params, additional_params
+
+ def create(self, params):
+ """Run Amazon Bedrock inference and return AutoGen response"""
+
+ # Set custom client class settings
+ self.parse_custom_params(params)
+
+ # Parse the inference parameters
+ base_params, additional_params = self.parse_params(params)
+
+ has_tools = "tools" in params
+ messages = oai_messages_to_bedrock_messages(params["messages"], has_tools, self._supports_system_prompts)
+
+ if self._supports_system_prompts:
+ system_messages = extract_system_messages(params["messages"])
+
+ tool_config = format_tools(params["tools"] if has_tools else [])
+
+ request_args = {"messages": messages, "modelId": self._model_id}
+
+ # Base and additional args
+ if len(base_params) > 0:
+ request_args["inferenceConfig"] = base_params
+
+ if len(additional_params) > 0:
+ request_args["additionalModelRequestFields"] = additional_params
+
+ if self._supports_system_prompts:
+ request_args["system"] = system_messages
+
+ if len(tool_config["tools"]) > 0:
+ request_args["toolConfig"] = tool_config
+
+ try:
+ response = self.bedrock_runtime.converse(
+ **request_args,
+ )
+ except Exception as e:
+ raise RuntimeError(f"Failed to get response from Bedrock: {e}")
+
+ if response is None:
+ raise RuntimeError(f"Failed to get response from Bedrock after retrying {self._retries} times.")
+
+ finish_reason = convert_stop_reason_to_finish_reason(response["stopReason"])
+ response_message = response["output"]["message"]
+
+ if finish_reason == "tool_calls":
+ tool_calls = format_tool_calls(response_message["content"])
+ # text = ""
+ else:
+ tool_calls = None
+
+ text = ""
+ for content in response_message["content"]:
+ if "text" in content:
+ text = content["text"]
+ # NOTE: other types of output may be dealt with here
+
+ message = ChatCompletionMessage(role="assistant", content=text, tool_calls=tool_calls)
+
+ response_usage = response["usage"]
+ usage = CompletionUsage(
+ prompt_tokens=response_usage["inputTokens"],
+ completion_tokens=response_usage["outputTokens"],
+ total_tokens=response_usage["totalTokens"],
+ )
+
+ return ChatCompletion(
+ id=response["ResponseMetadata"]["RequestId"],
+ choices=[Choice(finish_reason=finish_reason, index=0, message=message)],
+ created=int(time.time()),
+ model=self._model_id,
+ object="chat.completion",
+ usage=usage,
+ )
+
+ def cost(self, response: ChatCompletion) -> float:
+ """Calculate the cost of the response."""
+ return calculate_cost(response.usage.prompt_tokens, response.usage.completion_tokens, response.model)
+
+ @staticmethod
+ def get_usage(response) -> Dict:
+ """Get the usage of tokens and their cost information."""
+ return {
+ "prompt_tokens": response.usage.prompt_tokens,
+ "completion_tokens": response.usage.completion_tokens,
+ "total_tokens": response.usage.total_tokens,
+ "cost": response.cost,
+ "model": response.model,
+ }
+
+
+def extract_system_messages(messages: List[dict]) -> List:
+ """Extract the system messages from the list of messages.
+
+ Args:
+ messages (list[dict]): List of messages.
+
+ Returns:
+ List[SystemMessage]: List of System messages.
+ """
+
+ """
+ system_messages = [message.get("content")[0]["text"] for message in messages if message.get("role") == "system"]
+ return system_messages # ''.join(system_messages)
+ """
+
+ for message in messages:
+ if message.get("role") == "system":
+ if isinstance(message["content"], str):
+ return [{"text": message.get("content")}]
+ else:
+ return [{"text": message.get("content")[0]["text"]}]
+ return []
+
+
+def oai_messages_to_bedrock_messages(
+ messages: List[Dict[str, Any]], has_tools: bool, supports_system_prompts: bool
+) -> List[Dict]:
+ """
+ Convert messages from OAI format to Bedrock format.
+ We correct for any specific role orders and types, etc.
+ AWS Bedrock requires messages to alternate between user and assistant roles. This function ensures that the messages
+ are in the correct order and format for Bedrock by inserting "Please continue" messages as needed.
+ This is the same method as the one in the Autogen Anthropic client
+ """
+
+ # Track whether we have tools passed in. If not, tool use / result messages should be converted to text messages.
+ # Bedrock requires a tools parameter with the tools listed, if there are other messages with tool use or tool results.
+ # This can occur when we don't need tool calling, such as for group chat speaker selection
+
+ # Convert messages to Bedrock compliant format
+
+ # Take out system messages if the model supports it, otherwise leave them in.
+ if supports_system_prompts:
+ messages = [x for x in messages if not x["role"] == "system"]
+ else:
+ # Replace role="system" with role="user"
+ for msg in messages:
+ if msg["role"] == "system":
+ msg["role"] = "user"
+
+ processed_messages = []
+
+ # Used to interweave user messages to ensure user/assistant alternating
+ user_continue_message = {"content": [{"text": "Please continue."}], "role": "user"}
+ assistant_continue_message = {
+ "content": [{"text": "Please continue."}],
+ "role": "assistant",
+ }
+
+ tool_use_messages = 0
+ tool_result_messages = 0
+ last_tool_use_index = -1
+ last_tool_result_index = -1
+ # user_role_index = 0 if supports_system_prompts else 1 # If system prompts are supported, messages start with user, otherwise they'll be the second message
+ for message in messages:
+ # New messages will be added here, manage role alternations
+ expected_role = "user" if len(processed_messages) % 2 == 0 else "assistant"
+
+ if "tool_calls" in message:
+ # Map the tool call options to Bedrock's format
+ tool_uses = []
+ tool_names = []
+ for tool_call in message["tool_calls"]:
+ tool_uses.append(
+ {
+ "toolUse": {
+ "toolUseId": tool_call["id"],
+ "name": tool_call["function"]["name"],
+ "input": json.loads(tool_call["function"]["arguments"]),
+ }
+ }
+ )
+ if has_tools:
+ tool_use_messages += 1
+ tool_names.append(tool_call["function"]["name"])
+
+ if expected_role == "user":
+ # Insert an extra user message as we will append an assistant message
+ processed_messages.append(user_continue_message)
+
+ if has_tools:
+ processed_messages.append({"role": "assistant", "content": tool_uses})
+ last_tool_use_index = len(processed_messages) - 1
+ else:
+ # Not using tools, so put in a plain text message
+ processed_messages.append(
+ {
+ "role": "assistant",
+ "content": [
+ {"text": f"Some internal function(s) that could be used: [{', '.join(tool_names)}]"}
+ ],
+ }
+ )
+ elif "tool_call_id" in message:
+ if has_tools:
+ # Map the tool usage call to tool_result for Bedrock
+ tool_result = {
+ "toolResult": {
+ "toolUseId": message["tool_call_id"],
+ "content": [{"text": message["content"]}],
+ }
+ }
+
+ # If the previous message also had a tool_result, add it to that
+ # Otherwise append a new message
+ if last_tool_result_index == len(processed_messages) - 1:
+ processed_messages[-1]["content"].append(tool_result)
+ else:
+ if expected_role == "assistant":
+ # Insert an extra assistant message as we will append a user message
+ processed_messages.append(assistant_continue_message)
+
+ processed_messages.append({"role": "user", "content": [tool_result]})
+ last_tool_result_index = len(processed_messages) - 1
+
+ tool_result_messages += 1
+ else:
+ # Not using tools, so put in a plain text message
+ processed_messages.append(
+ {
+ "role": "user",
+ "content": [{"text": f"Running the function returned: {message['content']}"}],
+ }
+ )
+ elif message["content"] == "":
+ # Ignoring empty messages
+ pass
+ else:
+ if expected_role != message["role"] and not (len(processed_messages) == 0 and message["role"] == "system"):
+ # Inserting the alternating continue message (ignore if it's the first message and a system message)
+ processed_messages.append(
+ user_continue_message if expected_role == "user" else assistant_continue_message
+ )
+
+ processed_messages.append(
+ {
+ "role": message["role"],
+ "content": parse_content_parts(message=message),
+ }
+ )
+
+ # We'll replace the last tool_use if there's no tool_result (occurs if we finish the conversation before running the function)
+ if has_tools and tool_use_messages != tool_result_messages:
+ processed_messages[last_tool_use_index] = assistant_continue_message
+
+ # name is not a valid field on messages
+ for message in processed_messages:
+ if "name" in message:
+ message.pop("name", None)
+
+ # Note: When using reflection_with_llm we may end up with an "assistant" message as the last message and that may cause a blank response
+ # So, if the last role is not user, add a 'user' continue message at the end
+ if processed_messages[-1]["role"] != "user":
+ processed_messages.append(user_continue_message)
+
+ return processed_messages
+
+
+def parse_content_parts(
+ message: Dict[str, Any],
+) -> List[dict]:
+ content: str | List[Dict[str, Any]] = message.get("content")
+ if isinstance(content, str):
+ return [
+ {
+ "text": content,
+ }
+ ]
+ content_parts = []
+ for part in content:
+ # part_content: Dict = part.get("content")
+ if "text" in part: # part_content:
+ content_parts.append(
+ {
+ "text": part.get("text"),
+ }
+ )
+ elif "image_url" in part: # part_content:
+ image_data, content_type = parse_image(part.get("image_url").get("url"))
+ content_parts.append(
+ {
+ "image": {
+ "format": content_type[6:], # image/
+ "source": {"bytes": image_data},
+ },
+ }
+ )
+ else:
+ # Ignore..
+ continue
+ return content_parts
+
+
+def parse_image(image_url: str) -> Tuple[bytes, str]:
+ """Try to get the raw data from an image url.
+
+ Ref: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ImageSource.html
+ returns a tuple of (Image Data, Content Type)
+ """
+ pattern = r"^data:(image/[a-z]*);base64,\s*"
+ content_type = re.search(pattern, image_url)
+ # if already base64 encoded.
+ # Only supports 'image/jpeg', 'image/png', 'image/gif' or 'image/webp'
+ if content_type:
+ image_data = re.sub(pattern, "", image_url)
+ return base64.b64decode(image_data), content_type.group(1)
+
+ # Send a request to the image URL
+ response = requests.get(image_url)
+ # Check if the request was successful
+ if response.status_code == 200:
+
+ content_type = response.headers.get("Content-Type")
+ if not content_type.startswith("image"):
+ content_type = "image/jpeg"
+ # Get the image content
+ image_content = response.content
+ return image_content, content_type
+ else:
+ raise RuntimeError("Unable to access the image url")
+
+
+def format_tools(tools: List[Dict[str, Any]]) -> Dict[Literal["tools"], List[Dict[str, Any]]]:
+ converted_schema = {"tools": []}
+
+ for tool in tools:
+ if tool["type"] == "function":
+ function = tool["function"]
+ converted_tool = {
+ "toolSpec": {
+ "name": function["name"],
+ "description": function["description"],
+ "inputSchema": {"json": {"type": "object", "properties": {}, "required": []}},
+ }
+ }
+
+ for prop_name, prop_details in function["parameters"]["properties"].items():
+ converted_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop_name] = {
+ "type": prop_details["type"],
+ "description": prop_details.get("description", ""),
+ }
+ if "enum" in prop_details:
+ converted_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop_name]["enum"] = prop_details[
+ "enum"
+ ]
+ if "default" in prop_details:
+ converted_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop_name]["default"] = (
+ prop_details["default"]
+ )
+
+ if "required" in function["parameters"]:
+ converted_tool["toolSpec"]["inputSchema"]["json"]["required"] = function["parameters"]["required"]
+
+ converted_schema["tools"].append(converted_tool)
+
+ return converted_schema
+
+
+def format_tool_calls(content):
+ """Converts Converse API response tool calls to AutoGen format"""
+ tool_calls = []
+ for tool_request in content:
+ if "toolUse" in tool_request:
+ tool = tool_request["toolUse"]
+
+ tool_calls.append(
+ ChatCompletionMessageToolCall(
+ id=tool["toolUseId"],
+ function={
+ "name": tool["name"],
+ "arguments": json.dumps(tool["input"]),
+ },
+ type="function",
+ )
+ )
+ return tool_calls
+
+
+def convert_stop_reason_to_finish_reason(
+ stop_reason: str,
+) -> Literal["stop", "length", "tool_calls", "content_filter"]:
+ """
+ Converts Bedrock finish reasons to our finish reasons, according to OpenAI:
+
+ - stop: if the model hit a natural stop point or a provided stop sequence,
+ - length: if the maximum number of tokens specified in the request was reached,
+ - content_filter: if content was omitted due to a flag from our content filters,
+ - tool_calls: if the model called a tool
+ """
+ if stop_reason:
+ finish_reason_mapping = {
+ "tool_use": "tool_calls",
+ "finished": "stop",
+ "end_turn": "stop",
+ "max_tokens": "length",
+ "stop_sequence": "stop",
+ "complete": "stop",
+ "content_filtered": "content_filter",
+ }
+ return finish_reason_mapping.get(stop_reason.lower(), stop_reason.lower())
+
+ warnings.warn(f"Unsupported stop reason: {stop_reason}", UserWarning)
+ return None
+
+
+# NOTE: As this will be quite dynamic, it's expected that the developer will use the "price" parameter in their config
+# These may be removed.
+PRICES_PER_K_TOKENS = {
+ "meta.llama3-8b-instruct-v1:0": (0.0003, 0.0006),
+ "meta.llama3-70b-instruct-v1:0": (0.00265, 0.0035),
+ "mistral.mistral-7b-instruct-v0:2": (0.00015, 0.0002),
+ "mistral.mixtral-8x7b-instruct-v0:1": (0.00045, 0.0007),
+ "mistral.mistral-large-2402-v1:0": (0.004, 0.012),
+ "mistral.mistral-small-2402-v1:0": (0.001, 0.003),
+}
+
+
+def calculate_cost(input_tokens: int, output_tokens: int, model_id: str) -> float:
+ """Calculate the cost of the completion using the Bedrock pricing."""
+
+ if model_id in PRICES_PER_K_TOKENS:
+ input_cost_per_k, output_cost_per_k = PRICES_PER_K_TOKENS[model_id]
+ input_cost = (input_tokens / 1000) * input_cost_per_k
+ output_cost = (output_tokens / 1000) * output_cost_per_k
+ return input_cost + output_cost
+ else:
+ warnings.warn(
+ f'Cannot get the costs for {model_id}. The cost will be 0. In your config_list, add field {{"price" : [prompt_price_per_1k, completion_token_price_per_1k]}} for customized pricing.',
+ UserWarning,
+ )
+ return 0
diff --git a/autogen/oai/client.py b/autogen/oai/client.py
index de35e5c5273..3ae37257b21 100644
--- a/autogen/oai/client.py
+++ b/autogen/oai/client.py
@@ -42,6 +42,55 @@
TOOL_ENABLED = True
ERROR = None
+try:
+ from autogen.oai.gemini import GeminiClient
+
+ gemini_import_exception: Optional[ImportError] = None
+except ImportError as e:
+ gemini_import_exception = e
+
+try:
+ from autogen.oai.anthropic import AnthropicClient
+
+ anthropic_import_exception: Optional[ImportError] = None
+except ImportError as e:
+ anthropic_import_exception = e
+
+try:
+ from autogen.oai.mistral import MistralAIClient
+
+ mistral_import_exception: Optional[ImportError] = None
+except ImportError as e:
+ mistral_import_exception = e
+
+try:
+ from autogen.oai.together import TogetherClient
+
+ together_import_exception: Optional[ImportError] = None
+except ImportError as e:
+ together_import_exception = e
+
+try:
+ from autogen.oai.groq import GroqClient
+
+ groq_import_exception: Optional[ImportError] = None
+except ImportError as e:
+ groq_import_exception = e
+
+try:
+ from autogen.oai.cohere import CohereClient
+
+ cohere_import_exception: Optional[ImportError] = None
+except ImportError as e:
+ cohere_import_exception = e
+
+try:
+ from autogen.oai.bedrock import BedrockClient
+
+ bedrock_import_exception: Optional[ImportError] = None
+except ImportError as e:
+ bedrock_import_exception = e
+
logger = logging.getLogger(__name__)
if not logger.handlers:
# Add the console handler.
@@ -283,8 +332,10 @@ def cost(self, response: Union[ChatCompletion, Completion]) -> float:
"""Calculate the cost of the response."""
model = response.model
if model not in OAI_PRICE1K:
- # TODO: add logging to warn that the model is not found
- logger.debug(f"Model {model} is not found. The cost will be 0.", exc_info=True)
+ # log warning that the model is not found
+ logger.warning(
+ f'Model {model} is not found. The cost will be 0. In your config_list, add field {{"price" : [prompt_price_per_1k, completion_token_price_per_1k]}} for customized pricing.'
+ )
return 0
n_input_tokens = response.usage.prompt_tokens if response.usage is not None else 0 # type: ignore [union-attr]
@@ -312,6 +363,7 @@ class OpenAIWrapper:
"""A wrapper class for openai client."""
extra_kwargs = {
+ "agent",
"cache",
"cache_seed",
"filter_func",
@@ -320,6 +372,7 @@ class OpenAIWrapper:
"api_version",
"api_type",
"tags",
+ "price",
}
openai_kwargs = set(inspect.getfullargspec(OpenAI.__init__).kwonlyargs)
@@ -341,7 +394,7 @@ def __init__(self, *, config_list: Optional[List[Dict[str, Any]]] = None, **base
"api_key": os.environ.get("AZURE_OPENAI_API_KEY"),
"api_type": "azure",
"base_url": os.environ.get("AZURE_OPENAI_API_BASE"),
- "api_version": "2024-02-15-preview",
+ "api_version": "2024-02-01",
},
{
"model": "gpt-3.5-turbo",
@@ -400,12 +453,31 @@ def _configure_azure_openai(self, config: Dict[str, Any], openai_config: Dict[st
openai_config["azure_deployment"] = openai_config["azure_deployment"].replace(".", "")
openai_config["azure_endpoint"] = openai_config.get("azure_endpoint", openai_config.pop("base_url", None))
+ # Create a default Azure token provider if requested
+ if openai_config.get("azure_ad_token_provider") == "DEFAULT":
+ import azure.identity
+
+ openai_config["azure_ad_token_provider"] = azure.identity.get_bearer_token_provider(
+ azure.identity.DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
+ )
+
+ def _configure_openai_config_for_bedrock(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> None:
+ """Update openai_config with AWS credentials from config."""
+ required_keys = ["aws_access_key", "aws_secret_key", "aws_region"]
+ optional_keys = ["aws_session_token", "aws_profile_name"]
+ for key in required_keys:
+ if key in config:
+ openai_config[key] = config[key]
+ for key in optional_keys:
+ if key in config:
+ openai_config[key] = config[key]
+
def _register_default_client(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> None:
"""Create a client with the given config to override openai_config,
after removing extra kwargs.
For Azure models/deployment names there's a convenience modification of model removing dots in
- the it's value (Azure deploment names can't have dots). I.e. if you have Azure deployment name
+ the it's value (Azure deployment names can't have dots). I.e. if you have Azure deployment name
"gpt-35-turbo" and define model "gpt-3.5-turbo" in the config the function will remove the dot
from the name and create a client that connects to "gpt-35-turbo" Azure deployment.
"""
@@ -425,6 +497,44 @@ def _register_default_client(self, config: Dict[str, Any], openai_config: Dict[s
self._configure_azure_openai(config, openai_config)
client = AzureOpenAI(**openai_config)
self._clients.append(OpenAIClient(client))
+ elif api_type is not None and api_type.startswith("google"):
+ if gemini_import_exception:
+ raise ImportError("Please install `google-generativeai` to use Google OpenAI API.")
+ client = GeminiClient(**openai_config)
+ self._clients.append(client)
+ elif api_type is not None and api_type.startswith("anthropic"):
+ if "api_key" not in config:
+ self._configure_openai_config_for_bedrock(config, openai_config)
+ if anthropic_import_exception:
+ raise ImportError("Please install `anthropic` to use Anthropic API.")
+ client = AnthropicClient(**openai_config)
+ self._clients.append(client)
+ elif api_type is not None and api_type.startswith("mistral"):
+ if mistral_import_exception:
+ raise ImportError("Please install `mistralai` to use the Mistral.AI API.")
+ client = MistralAIClient(**openai_config)
+ self._clients.append(client)
+ elif api_type is not None and api_type.startswith("together"):
+ if together_import_exception:
+ raise ImportError("Please install `together` to use the Together.AI API.")
+ client = TogetherClient(**openai_config)
+ self._clients.append(client)
+ elif api_type is not None and api_type.startswith("groq"):
+ if groq_import_exception:
+ raise ImportError("Please install `groq` to use the Groq API.")
+ client = GroqClient(**openai_config)
+ self._clients.append(client)
+ elif api_type is not None and api_type.startswith("cohere"):
+ if cohere_import_exception:
+ raise ImportError("Please install `cohere` to use the Cohere API.")
+ client = CohereClient(**openai_config)
+ self._clients.append(client)
+ elif api_type is not None and api_type.startswith("bedrock"):
+ self._configure_openai_config_for_bedrock(config, openai_config)
+ if bedrock_import_exception:
+ raise ImportError("Please install `boto3` to use the Amazon Bedrock API.")
+ client = BedrockClient(**openai_config)
+ self._clients.append(client)
else:
client = OpenAI(**openai_config)
self._clients.append(OpenAIClient(client))
@@ -522,6 +632,7 @@ def create(self, **config: Any) -> ModelClient.ModelClientResponseProtocol:
Note that the cache argument overrides the legacy cache_seed argument: if this argument is provided,
then the cache_seed argument is ignored. If this argument is not provided or None,
then the cache_seed argument is used.
+ - agent (AbstractAgent | None): The object responsible for creating a completion if an agent.
- (Legacy) cache_seed (int | None) for using the DiskCache. Default to 41.
An integer cache_seed is useful when implementing "controlled randomness" for the completion.
None for no caching.
@@ -537,7 +648,7 @@ def yes_or_no_filter(context, response):
```
- allow_format_str_template (bool | None): Whether to allow format string template in the config. Default to false.
- - api_version (str | None): The api version. Default to None. E.g., "2024-02-15-preview".
+ - api_version (str | None): The api version. Default to None. E.g., "2024-02-01".
Raises:
- RuntimeError: If all declared custom model clients are not registered
- APIError: If any model client create call raises an APIError
@@ -569,6 +680,15 @@ def yes_or_no_filter(context, response):
cache = extra_kwargs.get("cache")
filter_func = extra_kwargs.get("filter_func")
context = extra_kwargs.get("context")
+ agent = extra_kwargs.get("agent")
+ price = extra_kwargs.get("price", None)
+ if isinstance(price, list):
+ price = tuple(price)
+ elif isinstance(price, float) or isinstance(price, int):
+ logger.warning(
+ "Input price is a float/int. Using the same price for prompt and completion tokens. Use a list/tuple if prompt and completion token prices are different."
+ )
+ price = (price, price)
total_usage = None
actual_usage = None
@@ -606,6 +726,7 @@ def yes_or_no_filter(context, response):
invocation_id=invocation_id,
client_id=id(client),
wrapper_id=id(self),
+ agent=agent,
request=params,
response=response,
is_cached=1,
@@ -638,6 +759,7 @@ def yes_or_no_filter(context, response):
invocation_id=invocation_id,
client_id=id(client),
wrapper_id=id(self),
+ agent=agent,
request=params,
response=f"error_code:{error_code}, config {i} failed",
is_cached=0,
@@ -653,7 +775,10 @@ def yes_or_no_filter(context, response):
raise
else:
# add cost calculation before caching no matter filter is passed or not
- response.cost = client.cost(response)
+ if price is not None:
+ response.cost = self._cost_with_customized_price(response, price)
+ else:
+ response.cost = client.cost(response)
actual_usage = client.get_usage(response)
total_usage = actual_usage.copy() if actual_usage is not None else total_usage
self._update_usage(actual_usage=actual_usage, total_usage=total_usage)
@@ -668,6 +793,7 @@ def yes_or_no_filter(context, response):
invocation_id=invocation_id,
client_id=id(client),
wrapper_id=id(self),
+ agent=agent,
request=params,
response=response,
is_cached=0,
@@ -686,6 +812,17 @@ def yes_or_no_filter(context, response):
continue # filter is not passed; try the next config
raise RuntimeError("Should not reach here.")
+ @staticmethod
+ def _cost_with_customized_price(
+ response: ModelClient.ModelClientResponseProtocol, price_1k: Tuple[float, float]
+ ) -> None:
+ """If a customized cost is passed, overwrite the cost in the response."""
+ n_input_tokens = response.usage.prompt_tokens if response.usage is not None else 0 # type: ignore [union-attr]
+ n_output_tokens = response.usage.completion_tokens if response.usage is not None else 0 # type: ignore [union-attr]
+ if n_output_tokens is None:
+ n_output_tokens = 0
+ return (n_input_tokens * price_1k[0] + n_output_tokens * price_1k[1]) / 1000
+
@staticmethod
def _update_dict_from_chunk(chunk: BaseModel, d: Dict[str, Any], field: str) -> int:
"""Update the dict from the chunk.
diff --git a/autogen/oai/client_utils.py b/autogen/oai/client_utils.py
new file mode 100644
index 00000000000..55730485b40
--- /dev/null
+++ b/autogen/oai/client_utils.py
@@ -0,0 +1,154 @@
+"""Utilities for client classes"""
+
+import warnings
+from typing import Any, Dict, List, Optional, Tuple
+
+
+def validate_parameter(
+ params: Dict[str, Any],
+ param_name: str,
+ allowed_types: Tuple,
+ allow_None: bool,
+ default_value: Any,
+ numerical_bound: Tuple,
+ allowed_values: list,
+) -> Any:
+ """
+ Validates a given config parameter, checking its type, values, and setting defaults
+ Parameters:
+ params (Dict[str, Any]): Dictionary containing parameters to validate.
+ param_name (str): The name of the parameter to validate.
+ allowed_types (Tuple): Tuple of acceptable types for the parameter.
+ allow_None (bool): Whether the parameter can be `None`.
+ default_value (Any): The default value to use if the parameter is invalid or missing.
+ numerical_bound (Optional[Tuple[Optional[float], Optional[float]]]):
+ A tuple specifying the lower and upper bounds for numerical parameters.
+ Each bound can be `None` if not applicable.
+ allowed_values (Optional[List[Any]]): A list of acceptable values for the parameter.
+ Can be `None` if no specific values are required.
+
+ Returns:
+ Any: The validated parameter value or the default value if validation fails.
+
+ Raises:
+ TypeError: If `allowed_values` is provided but is not a list.
+
+ Example Usage:
+ ```python
+ # Validating a numerical parameter within specific bounds
+ params = {"temperature": 0.5, "safety_model": "Meta-Llama/Llama-Guard-7b"}
+ temperature = validate_parameter(params, "temperature", (int, float), True, 0.7, (0, 1), None)
+ # Result: 0.5
+
+ # Validating a parameter that can be one of a list of allowed values
+ model = validate_parameter(
+ params, "safety_model", str, True, None, None, ["Meta-Llama/Llama-Guard-7b", "Meta-Llama/Llama-Guard-13b"]
+ )
+ # If "safety_model" is missing or invalid in params, defaults to "default"
+ ```
+ """
+
+ if allowed_values is not None and not isinstance(allowed_values, list):
+ raise TypeError(f"allowed_values should be a list or None, got {type(allowed_values).__name__}")
+
+ param_value = params.get(param_name, default_value)
+ warning = ""
+
+ if param_value is None and allow_None:
+ pass
+ elif param_value is None:
+ if not allow_None:
+ warning = "cannot be None"
+ elif not isinstance(param_value, allowed_types):
+ # Check types and list possible types if invalid
+ if isinstance(allowed_types, tuple):
+ formatted_types = "(" + ", ".join(f"{t.__name__}" for t in allowed_types) + ")"
+ else:
+ formatted_types = f"{allowed_types.__name__}"
+ warning = f"must be of type {formatted_types}{' or None' if allow_None else ''}"
+ elif numerical_bound:
+ # Check the value fits in possible bounds
+ lower_bound, upper_bound = numerical_bound
+ if (lower_bound is not None and param_value < lower_bound) or (
+ upper_bound is not None and param_value > upper_bound
+ ):
+ warning = "has numerical bounds"
+ if lower_bound is not None:
+ warning += f", >= {str(lower_bound)}"
+ if upper_bound is not None:
+ if lower_bound is not None:
+ warning += " and"
+ warning += f" <= {str(upper_bound)}"
+ if allow_None:
+ warning += ", or can be None"
+
+ elif allowed_values:
+ # Check if the value matches any allowed values
+ if not (allow_None and param_value is None):
+ if param_value not in allowed_values:
+ warning = f"must be one of these values [{allowed_values}]{', or can be None' if allow_None else ''}"
+
+ # If we failed any checks, warn and set to default value
+ if warning:
+ warnings.warn(
+ f"Config error - {param_name} {warning}, defaulting to {default_value}.",
+ UserWarning,
+ )
+ param_value = default_value
+
+ return param_value
+
+
+def should_hide_tools(messages: List[Dict[str, Any]], tools: List[Dict[str, Any]], hide_tools_param: str) -> bool:
+ """
+ Determines if tools should be hidden. This function is used to hide tools when they have been run, minimising the chance of the LLM choosing them when they shouldn't.
+ Parameters:
+ messages (List[Dict[str, Any]]): List of messages
+ tools (List[Dict[str, Any]]): List of tools
+ hide_tools_param (str): "hide_tools" parameter value. Can be "if_all_run" (hide tools if all tools have been run), "if_any_run" (hide tools if any of the tools have been run), "never" (never hide tools). Default is "never".
+
+ Returns:
+ bool: Indicates whether the tools should be excluded from the response create request
+
+ Example Usage:
+ ```python
+ # Validating a numerical parameter within specific bounds
+ messages = params.get("messages", [])
+ tools = params.get("tools", None)
+ hide_tools = should_hide_tools(messages, tools, params["hide_tools"])
+ """
+
+ if hide_tools_param == "never" or tools is None or len(tools) == 0:
+ return False
+ elif hide_tools_param == "if_any_run":
+ # Return True if any tool_call_id exists, indicating a tool call has been executed. False otherwise.
+ return any(["tool_call_id" in dictionary for dictionary in messages])
+ elif hide_tools_param == "if_all_run":
+ # Return True if all tools have been executed at least once. False otherwise.
+
+ # Get the list of tool names
+ check_tool_names = [item["function"]["name"] for item in tools]
+
+ # Prepare a list of tool call ids and related function names
+ tool_call_ids = {}
+
+ # Loop through the messages and check if the tools have been run, removing them as we go
+ for message in messages:
+ if "tool_calls" in message:
+ # Register the tool ids and the function names (there could be multiple tool calls)
+ for tool_call in message["tool_calls"]:
+ tool_call_ids[tool_call["id"]] = tool_call["function"]["name"]
+ elif "tool_call_id" in message:
+ # Tool called, get the name of the function based on the id
+ tool_name_called = tool_call_ids[message["tool_call_id"]]
+
+ # If we had not yet called the tool, check and remove it to indicate we have
+ if tool_name_called in check_tool_names:
+ check_tool_names.remove(tool_name_called)
+
+ # Return True if all tools have been called at least once (accounted for)
+ return len(check_tool_names) == 0
+ else:
+ raise TypeError(
+ f"hide_tools_param is not a valid value ['if_all_run','if_any_run','never'], got '{hide_tools_param}'"
+ )
diff --git a/autogen/oai/cohere.py b/autogen/oai/cohere.py
new file mode 100644
index 00000000000..3d38d86425f
--- /dev/null
+++ b/autogen/oai/cohere.py
@@ -0,0 +1,479 @@
+"""Create an OpenAI-compatible client using Cohere's API.
+
+Example:
+ llm_config={
+ "config_list": [{
+ "api_type": "cohere",
+ "model": "command-r-plus",
+ "api_key": os.environ.get("COHERE_API_KEY")
+ "client_name": "autogen-cohere", # Optional parameter
+ }
+ ]}
+
+ agent = autogen.AssistantAgent("my_agent", llm_config=llm_config)
+
+Install Cohere's python library using: pip install --upgrade cohere
+
+Resources:
+- https://docs.cohere.com/reference/chat
+"""
+
+from __future__ import annotations
+
+import json
+import logging
+import os
+import random
+import sys
+import time
+import warnings
+from typing import Any, Dict, List
+
+from cohere import Client as Cohere
+from cohere.types import ToolParameterDefinitionsValue, ToolResult
+from flaml.automl.logger import logger_formatter
+from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
+from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
+from openai.types.completion_usage import CompletionUsage
+
+from autogen.oai.client_utils import validate_parameter
+
+logger = logging.getLogger(__name__)
+if not logger.handlers:
+ # Add the console handler.
+ _ch = logging.StreamHandler(stream=sys.stdout)
+ _ch.setFormatter(logger_formatter)
+ logger.addHandler(_ch)
+
+
+COHERE_PRICING_1K = {
+ "command-r-plus": (0.003, 0.015),
+ "command-r": (0.0005, 0.0015),
+ "command-nightly": (0.00025, 0.00125),
+ "command": (0.015, 0.075),
+ "command-light": (0.008, 0.024),
+ "command-light-nightly": (0.008, 0.024),
+}
+
+
+class CohereClient:
+ """Client for Cohere's API."""
+
+ def __init__(self, **kwargs):
+ """Requires api_key or environment variable to be set
+
+ Args:
+ api_key (str): The API key for using Cohere (or environment variable COHERE_API_KEY needs to be set)
+ """
+ # Ensure we have the api_key upon instantiation
+ self.api_key = kwargs.get("api_key", None)
+ if not self.api_key:
+ self.api_key = os.getenv("COHERE_API_KEY")
+
+ assert (
+ self.api_key
+ ), "Please include the api_key in your config list entry for Cohere or set the COHERE_API_KEY env variable."
+
+ def message_retrieval(self, response) -> List:
+ """
+ Retrieve and return a list of strings or a list of Choice.Message from the response.
+
+ NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
+ since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
+ """
+ return [choice.message for choice in response.choices]
+
+ def cost(self, response) -> float:
+ return response.cost
+
+ @staticmethod
+ def get_usage(response) -> Dict:
+ """Return usage summary of the response using RESPONSE_USAGE_KEYS."""
+ # ... # pragma: no cover
+ return {
+ "prompt_tokens": response.usage.prompt_tokens,
+ "completion_tokens": response.usage.completion_tokens,
+ "total_tokens": response.usage.total_tokens,
+ "cost": response.cost,
+ "model": response.model,
+ }
+
+ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
+ """Loads the parameters for Cohere API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults"""
+ cohere_params = {}
+
+ # Check that we have what we need to use Cohere's API
+ # We won't enforce the available models as they are likely to change
+ cohere_params["model"] = params.get("model", None)
+ assert cohere_params[
+ "model"
+ ], "Please specify the 'model' in your config list entry to nominate the Cohere model to use."
+
+ # Validate allowed Cohere parameters
+ # https://docs.cohere.com/reference/chat
+ cohere_params["temperature"] = validate_parameter(
+ params, "temperature", (int, float), False, 0.3, (0, None), None
+ )
+ cohere_params["max_tokens"] = validate_parameter(params, "max_tokens", int, True, None, (0, None), None)
+ cohere_params["k"] = validate_parameter(params, "k", int, False, 0, (0, 500), None)
+ cohere_params["p"] = validate_parameter(params, "p", (int, float), False, 0.75, (0.01, 0.99), None)
+ cohere_params["seed"] = validate_parameter(params, "seed", int, True, None, None, None)
+ cohere_params["frequency_penalty"] = validate_parameter(
+ params, "frequency_penalty", (int, float), True, 0, (0, 1), None
+ )
+ cohere_params["presence_penalty"] = validate_parameter(
+ params, "presence_penalty", (int, float), True, 0, (0, 1), None
+ )
+
+ # Cohere parameters we are ignoring:
+ # preamble - we will put the system prompt in here.
+ # parallel_tool_calls (defaults to True), perfect as is.
+ # conversation_id - allows resuming a previous conversation, we don't support this.
+ logging.info("Conversation ID: %s", params.get("conversation_id", "None"))
+ # connectors - allows web search or other custom connectors, not implementing for now but could be useful in the future.
+ logging.info("Connectors: %s", params.get("connectors", "None"))
+ # search_queries_only - to control whether only search queries are used, we're not using connectors so ignoring.
+ # documents - a list of documents that can be used to support the chat. Perhaps useful in the future for RAG.
+ # citation_quality - used for RAG flows and dependent on other parameters we're ignoring.
+ # max_input_tokens - limits input tokens, not needed.
+ logging.info("Max Input Tokens: %s", params.get("max_input_tokens", "None"))
+ # stop_sequences - used to stop generation, not needed.
+ logging.info("Stop Sequences: %s", params.get("stop_sequences", "None"))
+
+ return cohere_params
+
+ def create(self, params: Dict) -> ChatCompletion:
+
+ messages = params.get("messages", [])
+ client_name = params.get("client_name") or "autogen-cohere"
+ # Parse parameters to the Cohere API's parameters
+ cohere_params = self.parse_params(params)
+
+ # Convert AutoGen messages to Cohere messages
+ cohere_messages, preamble, final_message = oai_messages_to_cohere_messages(messages, params, cohere_params)
+
+ cohere_params["chat_history"] = cohere_messages
+ cohere_params["message"] = final_message
+ cohere_params["preamble"] = preamble
+
+ # We use chat model by default
+ client = Cohere(api_key=self.api_key, client_name=client_name)
+
+ # Token counts will be returned
+ prompt_tokens = 0
+ completion_tokens = 0
+ total_tokens = 0
+
+ # Stream if in parameters
+ streaming = True if "stream" in params and params["stream"] else False
+ cohere_finish = ""
+
+ max_retries = 5
+ for attempt in range(max_retries):
+ ans = None
+ try:
+ if streaming:
+ response = client.chat_stream(**cohere_params)
+ else:
+ response = client.chat(**cohere_params)
+ except CohereRateLimitError as e:
+ raise RuntimeError(f"Cohere exception occurred: {e}")
+ else:
+
+ if streaming:
+ # Streaming...
+ ans = ""
+ for event in response:
+ if event.event_type == "text-generation":
+ ans = ans + event.text
+ elif event.event_type == "tool-calls-generation":
+ # When streaming, tool calls are compiled at the end into a single event_type
+ ans = event.text
+ cohere_finish = "tool_calls"
+ tool_calls = []
+ for tool_call in event.tool_calls:
+ tool_calls.append(
+ ChatCompletionMessageToolCall(
+ id=str(random.randint(0, 100000)),
+ function={
+ "name": tool_call.name,
+ "arguments": (
+ "" if tool_call.parameters is None else json.dumps(tool_call.parameters)
+ ),
+ },
+ type="function",
+ )
+ )
+
+ # Not using billed_units, but that may be better for cost purposes
+ prompt_tokens = event.response.meta.tokens.input_tokens
+ completion_tokens = event.response.meta.tokens.output_tokens
+ total_tokens = prompt_tokens + completion_tokens
+
+ response_id = event.response.response_id
+ else:
+ # Non-streaming finished
+ ans: str = response.text
+
+ # Not using billed_units, but that may be better for cost purposes
+ prompt_tokens = response.meta.tokens.input_tokens
+ completion_tokens = response.meta.tokens.output_tokens
+ total_tokens = prompt_tokens + completion_tokens
+
+ response_id = response.response_id
+ break
+
+ if response is not None:
+
+ response_content = ans
+
+ if streaming:
+ # Streaming response
+ if cohere_finish == "":
+ cohere_finish = "stop"
+ tool_calls = None
+ else:
+ # Non-streaming response
+ # If we have tool calls as the response, populate completed tool calls for our return OAI response
+ if response.tool_calls is not None:
+ cohere_finish = "tool_calls"
+ tool_calls = []
+ for tool_call in response.tool_calls:
+
+ # if parameters are null, clear them out (Cohere can return a string "null" if no parameter values)
+
+ tool_calls.append(
+ ChatCompletionMessageToolCall(
+ id=str(random.randint(0, 100000)),
+ function={
+ "name": tool_call.name,
+ "arguments": (
+ "" if tool_call.parameters is None else json.dumps(tool_call.parameters)
+ ),
+ },
+ type="function",
+ )
+ )
+ else:
+ cohere_finish = "stop"
+ tool_calls = None
+ else:
+ raise RuntimeError(f"Failed to get response from Cohere after retrying {attempt + 1} times.")
+
+ # 3. convert output
+ message = ChatCompletionMessage(
+ role="assistant",
+ content=response_content,
+ function_call=None,
+ tool_calls=tool_calls,
+ )
+ choices = [Choice(finish_reason=cohere_finish, index=0, message=message)]
+
+ response_oai = ChatCompletion(
+ id=response_id,
+ model=cohere_params["model"],
+ created=int(time.time()),
+ object="chat.completion",
+ choices=choices,
+ usage=CompletionUsage(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=total_tokens,
+ ),
+ cost=calculate_cohere_cost(prompt_tokens, completion_tokens, cohere_params["model"]),
+ )
+
+ return response_oai
+
+
+def extract_to_cohere_tool_results(tool_call_id: str, content_output: str, all_tool_calls) -> List[Dict[str, Any]]:
+ temp_tool_results = []
+
+ for tool_call in all_tool_calls:
+ if tool_call["id"] == tool_call_id:
+
+ call = {
+ "name": tool_call["function"]["name"],
+ "parameters": json.loads(
+ tool_call["function"]["arguments"] if not tool_call["function"]["arguments"] == "" else "{}"
+ ),
+ }
+ output = [{"value": content_output}]
+ temp_tool_results.append(ToolResult(call=call, outputs=output))
+ return temp_tool_results
+
+
+def oai_messages_to_cohere_messages(
+ messages: list[Dict[str, Any]], params: Dict[str, Any], cohere_params: Dict[str, Any]
+) -> tuple[list[dict[str, Any]], str, str]:
+ """Convert messages from OAI format to Cohere's format.
+ We correct for any specific role orders and types.
+
+ Parameters:
+ messages: list[Dict[str, Any]]: AutoGen messages
+ params: Dict[str, Any]: AutoGen parameters dictionary
+ cohere_params: Dict[str, Any]: Cohere parameters dictionary
+
+ Returns:
+ List[Dict[str, Any]]: Chat History messages
+ str: Preamble (system message)
+ str: Message (the final user message)
+ """
+
+ cohere_messages = []
+ preamble = ""
+
+ # Tools
+ if "tools" in params:
+ cohere_tools = []
+ for tool in params["tools"]:
+
+ # build list of properties
+ parameters = {}
+
+ for key, value in tool["function"]["parameters"]["properties"].items():
+ type_str = value["type"]
+ required = True # Defaults to False, we could consider leaving it as default.
+ description = value["description"]
+
+ # If we have an 'enum' key, add that to the description (as not allowed to pass in enum as a field)
+ if "enum" in value:
+ # Access the enum list
+ enum_values = value["enum"]
+ enum_strings = [str(value) for value in enum_values]
+ enum_string = ", ".join(enum_strings)
+ description = description + ". Possible values are " + enum_string + "."
+
+ parameters[key] = ToolParameterDefinitionsValue(
+ description=description, type=type_str, required=required
+ )
+
+ cohere_tool = {
+ "name": tool["function"]["name"],
+ "description": tool["function"]["description"],
+ "parameter_definitions": parameters,
+ }
+
+ cohere_tools.append(cohere_tool)
+
+ if len(cohere_tools) > 0:
+ cohere_params["tools"] = cohere_tools
+
+ tool_calls = []
+ tool_results = []
+
+ # Rules for cohere messages:
+ # no 'name' field
+ # 'system' messages go into the preamble parameter
+ # user role = 'USER'
+ # assistant role = 'CHATBOT'
+ # 'content' field renamed to 'message'
+ # tools go into tools parameter
+ # tool_results go into tool_results parameter
+ messages_length = len(messages)
+ for index, message in enumerate(messages):
+
+ if "role" in message and message["role"] == "system":
+ # System message
+ if preamble == "":
+ preamble = message["content"]
+ else:
+ preamble = preamble + "\n" + message["content"]
+ elif "tool_calls" in message:
+ # Suggested tool calls, build up the list before we put it into the tool_results
+ for tool_call in message["tool_calls"]:
+ tool_calls.append(tool_call)
+
+ # We also add the suggested tool call as a message
+ new_message = {
+ "role": "CHATBOT",
+ "message": message["content"],
+ "tool_calls": [
+ {
+ "name": tool_call_.get("function", {}).get("name"),
+ "parameters": json.loads(tool_call_.get("function", {}).get("arguments") or "null"),
+ }
+ for tool_call_ in message["tool_calls"]
+ ],
+ }
+
+ cohere_messages.append(new_message)
+ elif "role" in message and message["role"] == "tool":
+ if not (tool_call_id := message.get("tool_call_id")):
+ continue
+
+ # Convert the tool call to a result
+ content_output = message["content"]
+ tool_results_chat_turn = extract_to_cohere_tool_results(tool_call_id, content_output, tool_calls)
+ if (index == messages_length - 1) or (messages[index + 1].get("role", "").lower() in ("user", "tool")):
+ # If the tool call is the last message or the next message is a user/tool message, this is a recent tool call.
+ # So, we pass it into tool_results.
+ tool_results.extend(tool_results_chat_turn)
+ continue
+
+ else:
+ # If its not the current tool call, we pass it as a tool message in the chat history.
+ new_message = {"role": "TOOL", "tool_results": tool_results_chat_turn}
+ cohere_messages.append(new_message)
+
+ elif "content" in message and isinstance(message["content"], str):
+ # Standard text message
+ new_message = {
+ "role": "USER" if message["role"] == "user" else "CHATBOT",
+ "message": message["content"],
+ }
+
+ cohere_messages.append(new_message)
+
+ # Append any Tool Results
+ if len(tool_results) != 0:
+ cohere_params["tool_results"] = tool_results
+
+ # Enable multi-step tool use: https://docs.cohere.com/docs/multi-step-tool-use
+ cohere_params["force_single_step"] = False
+
+ # If we're adding tool_results, like we are, the last message can't be a USER message
+ # So, we add a CHATBOT 'continue' message, if so.
+ # Changed key from "content" to "message" (jaygdesai/autogen_Jay)
+ if cohere_messages[-1]["role"].lower() == "user":
+ cohere_messages.append({"role": "CHATBOT", "message": "Please continue."})
+
+ # We return a blank message when we have tool results
+ # TODO: Check what happens if tool_results aren't the latest message
+ return cohere_messages, preamble, ""
+
+ else:
+
+ # We need to get the last message to assign to the message field for Cohere,
+ # if the last message is a user message, use that, otherwise put in 'continue'.
+ if cohere_messages[-1]["role"] == "USER":
+ return cohere_messages[0:-1], preamble, cohere_messages[-1]["message"]
+ else:
+ return cohere_messages, preamble, "Please continue."
+
+
+def calculate_cohere_cost(input_tokens: int, output_tokens: int, model: str) -> float:
+ """Calculate the cost of the completion using the Cohere pricing."""
+ total = 0.0
+
+ if model in COHERE_PRICING_1K:
+ input_cost_per_k, output_cost_per_k = COHERE_PRICING_1K[model]
+ input_cost = (input_tokens / 1000) * input_cost_per_k
+ output_cost = (output_tokens / 1000) * output_cost_per_k
+ total = input_cost + output_cost
+ else:
+ warnings.warn(f"Cost calculation not available for {model} model", UserWarning)
+
+ return total
+
+
+class CohereError(Exception):
+ """Base class for other Cohere exceptions"""
+
+ pass
+
+
+class CohereRateLimitError(CohereError):
+ """Raised when rate limit is exceeded"""
+
+ pass
diff --git a/autogen/oai/completion.py b/autogen/oai/completion.py
index e3b01ee4dd8..5a62cde33df 100644
--- a/autogen/oai/completion.py
+++ b/autogen/oai/completion.py
@@ -741,7 +741,7 @@ def create(
"api_key": os.environ.get("AZURE_OPENAI_API_KEY"),
"api_type": "azure",
"base_url": os.environ.get("AZURE_OPENAI_API_BASE"),
- "api_version": "2024-02-15-preview",
+ "api_version": "2024-02-01",
},
{
"model": "gpt-3.5-turbo",
diff --git a/autogen/oai/gemini.py b/autogen/oai/gemini.py
new file mode 100644
index 00000000000..33790c9851c
--- /dev/null
+++ b/autogen/oai/gemini.py
@@ -0,0 +1,485 @@
+"""Create a OpenAI-compatible client for Gemini features.
+
+
+Example:
+ llm_config={
+ "config_list": [{
+ "api_type": "google",
+ "model": "gemini-pro",
+ "api_key": os.environ.get("GOOGLE_GEMINI_API_KEY"),
+ "safety_settings": [
+ {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH"},
+ {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_ONLY_HIGH"},
+ {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_ONLY_HIGH"},
+ {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_ONLY_HIGH"}
+ ],
+ "top_p":0.5,
+ "max_tokens": 2048,
+ "temperature": 1.0,
+ "top_k": 5
+ }
+ ]}
+
+ agent = autogen.AssistantAgent("my_agent", llm_config=llm_config)
+
+Resources:
+- https://ai.google.dev/docs
+- https://cloud.google.com/vertex-ai/docs/generative-ai/migrate-from-azure
+- https://blog.google/technology/ai/google-gemini-pro-imagen-duet-ai-update/
+- https://ai.google.dev/api/python/google/generativeai/ChatSession
+"""
+
+from __future__ import annotations
+
+import base64
+import logging
+import os
+import random
+import re
+import time
+import warnings
+from io import BytesIO
+from typing import Any, Dict, List, Mapping, Union
+
+import google.generativeai as genai
+import requests
+import vertexai
+from google.ai.generativelanguage import Content, Part
+from google.api_core.exceptions import InternalServerError
+from google.auth.credentials import Credentials
+from openai.types.chat import ChatCompletion
+from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
+from openai.types.completion_usage import CompletionUsage
+from PIL import Image
+from vertexai.generative_models import Content as VertexAIContent
+from vertexai.generative_models import GenerativeModel
+from vertexai.generative_models import HarmBlockThreshold as VertexAIHarmBlockThreshold
+from vertexai.generative_models import HarmCategory as VertexAIHarmCategory
+from vertexai.generative_models import Part as VertexAIPart
+from vertexai.generative_models import SafetySetting as VertexAISafetySetting
+
+logger = logging.getLogger(__name__)
+
+
+class GeminiClient:
+ """Client for Google's Gemini API.
+
+ Please visit this [page](https://github.com/microsoft/autogen/issues/2387) for the roadmap of Gemini integration
+ of AutoGen.
+ """
+
+ # Mapping, where Key is a term used by Autogen, and Value is a term used by Gemini
+ PARAMS_MAPPING = {
+ "max_tokens": "max_output_tokens",
+ # "n": "candidate_count", # Gemini supports only `n=1`
+ "stop_sequences": "stop_sequences",
+ "temperature": "temperature",
+ "top_p": "top_p",
+ "top_k": "top_k",
+ "max_output_tokens": "max_output_tokens",
+ }
+
+ def _initialize_vertexai(self, **params):
+ if "google_application_credentials" in params:
+ # Path to JSON Keyfile
+ os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = params["google_application_credentials"]
+ vertexai_init_args = {}
+ if "project_id" in params:
+ vertexai_init_args["project"] = params["project_id"]
+ if "location" in params:
+ vertexai_init_args["location"] = params["location"]
+ if "credentials" in params:
+ assert isinstance(
+ params["credentials"], Credentials
+ ), "Object type google.auth.credentials.Credentials is expected!"
+ vertexai_init_args["credentials"] = params["credentials"]
+ if vertexai_init_args:
+ vertexai.init(**vertexai_init_args)
+
+ def __init__(self, **kwargs):
+ """Uses either either api_key for authentication from the LLM config
+ (specifying the GOOGLE_GEMINI_API_KEY environment variable also works),
+ or follows the Google authentication mechanism for VertexAI in Google Cloud if no api_key is specified,
+ where project_id and location can also be passed as parameters. Previously created credentials object can be provided,
+ or a Service account key file can also be used. If neither a service account key file, nor the api_key are passed,
+ then the default credentials will be used, which could be a personal account if the user is already authenticated in,
+ like in Google Cloud Shell.
+
+ Args:
+ api_key (str): The API key for using Gemini.
+ credentials (google.auth.credentials.Credentials): credentials to be used for authentication with vertexai.
+ google_application_credentials (str): Path to the JSON service account key file of the service account.
+ Alternatively, the GOOGLE_APPLICATION_CREDENTIALS environment variable
+ can also be set instead of using this argument.
+ project_id (str): Google Cloud project id, which is only valid in case no API key is specified.
+ location (str): Compute region to be used, like 'us-west1'.
+ This parameter is only valid in case no API key is specified.
+ """
+ self.api_key = kwargs.get("api_key", None)
+ if not self.api_key:
+ self.api_key = os.getenv("GOOGLE_GEMINI_API_KEY")
+ if self.api_key is None:
+ self.use_vertexai = True
+ self._initialize_vertexai(**kwargs)
+ else:
+ self.use_vertexai = False
+ else:
+ self.use_vertexai = False
+ if not self.use_vertexai:
+ assert ("project_id" not in kwargs) and (
+ "location" not in kwargs
+ ), "Google Cloud project and compute location cannot be set when using an API Key!"
+
+ def message_retrieval(self, response) -> List:
+ """
+ Retrieve and return a list of strings or a list of Choice.Message from the response.
+
+ NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
+ since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
+ """
+ return [choice.message for choice in response.choices]
+
+ def cost(self, response) -> float:
+ return response.cost
+
+ @staticmethod
+ def get_usage(response) -> Dict:
+ """Return usage summary of the response using RESPONSE_USAGE_KEYS."""
+ # ... # pragma: no cover
+ return {
+ "prompt_tokens": response.usage.prompt_tokens,
+ "completion_tokens": response.usage.completion_tokens,
+ "total_tokens": response.usage.total_tokens,
+ "cost": response.cost,
+ "model": response.model,
+ }
+
+ def create(self, params: Dict) -> ChatCompletion:
+ if self.use_vertexai:
+ self._initialize_vertexai(**params)
+ else:
+ assert ("project_id" not in params) and (
+ "location" not in params
+ ), "Google Cloud project and compute location cannot be set when using an API Key!"
+ model_name = params.get("model", "gemini-pro")
+ if not model_name:
+ raise ValueError(
+ "Please provide a model name for the Gemini Client. "
+ "You can configure it in the OAI Config List file. "
+ "See this [LLM configuration tutorial](https://microsoft.github.io/autogen/docs/topics/llm_configuration/) for more details."
+ )
+
+ params.get("api_type", "google") # not used
+ messages = params.get("messages", [])
+ stream = params.get("stream", False)
+ n_response = params.get("n", 1)
+ system_instruction = params.get("system_instruction", None)
+ response_validation = params.get("response_validation", True)
+
+ generation_config = {
+ gemini_term: params[autogen_term]
+ for autogen_term, gemini_term in self.PARAMS_MAPPING.items()
+ if autogen_term in params
+ }
+ if self.use_vertexai:
+ safety_settings = GeminiClient._to_vertexai_safety_settings(params.get("safety_settings", {}))
+ else:
+ safety_settings = params.get("safety_settings", {})
+
+ if stream:
+ warnings.warn(
+ "Streaming is not supported for Gemini yet, and it will have no effect. Please set stream=False.",
+ UserWarning,
+ )
+
+ if n_response > 1:
+ warnings.warn("Gemini only supports `n=1` for now. We only generate one response.", UserWarning)
+
+ if "vision" not in model_name:
+ # A. create and call the chat model.
+ gemini_messages = self._oai_messages_to_gemini_messages(messages)
+ if self.use_vertexai:
+ model = GenerativeModel(
+ model_name,
+ generation_config=generation_config,
+ safety_settings=safety_settings,
+ system_instruction=system_instruction,
+ )
+ chat = model.start_chat(history=gemini_messages[:-1], response_validation=response_validation)
+ else:
+ # we use chat model by default
+ model = genai.GenerativeModel(
+ model_name,
+ generation_config=generation_config,
+ safety_settings=safety_settings,
+ system_instruction=system_instruction,
+ )
+ genai.configure(api_key=self.api_key)
+ chat = model.start_chat(history=gemini_messages[:-1])
+ max_retries = 5
+ for attempt in range(max_retries):
+ ans = None
+ try:
+ response = chat.send_message(
+ gemini_messages[-1].parts, stream=stream, safety_settings=safety_settings
+ )
+ except InternalServerError:
+ delay = 5 * (2**attempt)
+ warnings.warn(
+ f"InternalServerError `500` occurs when calling Gemini's chat model. Retry in {delay} seconds...",
+ UserWarning,
+ )
+ time.sleep(delay)
+ except Exception as e:
+ raise RuntimeError(f"Google GenAI exception occurred while calling Gemini API: {e}")
+ else:
+ # `ans = response.text` is unstable. Use the following code instead.
+ ans: str = chat.history[-1].parts[0].text
+ break
+
+ if ans is None:
+ raise RuntimeError(f"Fail to get response from Google AI after retrying {attempt + 1} times.")
+
+ prompt_tokens = model.count_tokens(chat.history[:-1]).total_tokens
+ completion_tokens = model.count_tokens(ans).total_tokens
+ elif model_name == "gemini-pro-vision":
+ # B. handle the vision model
+ if self.use_vertexai:
+ model = GenerativeModel(
+ model_name,
+ generation_config=generation_config,
+ safety_settings=safety_settings,
+ system_instruction=system_instruction,
+ )
+ else:
+ model = genai.GenerativeModel(
+ model_name,
+ generation_config=generation_config,
+ safety_settings=safety_settings,
+ system_instruction=system_instruction,
+ )
+ genai.configure(api_key=self.api_key)
+ # Gemini's vision model does not support chat history yet
+ # chat = model.start_chat(history=gemini_messages[:-1])
+ # response = chat.send_message(gemini_messages[-1].parts)
+ user_message = self._oai_content_to_gemini_content(messages[-1]["content"])
+ if len(messages) > 2:
+ warnings.warn(
+ "Warning: Gemini's vision model does not support chat history yet.",
+ "We only use the last message as the prompt.",
+ UserWarning,
+ )
+
+ response = model.generate_content(user_message, stream=stream)
+ # ans = response.text
+ if self.use_vertexai:
+ ans: str = response.candidates[0].content.parts[0].text
+ else:
+ ans: str = response._result.candidates[0].content.parts[0].text
+
+ prompt_tokens = model.count_tokens(user_message).total_tokens
+ completion_tokens = model.count_tokens(ans).total_tokens
+
+ # 3. convert output
+ message = ChatCompletionMessage(role="assistant", content=ans, function_call=None, tool_calls=None)
+ choices = [Choice(finish_reason="stop", index=0, message=message)]
+
+ response_oai = ChatCompletion(
+ id=str(random.randint(0, 1000)),
+ model=model_name,
+ created=int(time.time()),
+ object="chat.completion",
+ choices=choices,
+ usage=CompletionUsage(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=prompt_tokens + completion_tokens,
+ ),
+ cost=calculate_gemini_cost(prompt_tokens, completion_tokens, model_name),
+ )
+
+ return response_oai
+
+ def _oai_content_to_gemini_content(self, content: Union[str, List]) -> List:
+ """Convert content from OAI format to Gemini format"""
+ rst = []
+ if isinstance(content, str):
+ if content == "":
+ content = "empty" # Empty content is not allowed.
+ if self.use_vertexai:
+ rst.append(VertexAIPart.from_text(content))
+ else:
+ rst.append(Part(text=content))
+ return rst
+
+ assert isinstance(content, list)
+
+ for msg in content:
+ if isinstance(msg, dict):
+ assert "type" in msg, f"Missing 'type' field in message: {msg}"
+ if msg["type"] == "text":
+ if self.use_vertexai:
+ rst.append(VertexAIPart.from_text(text=msg["text"]))
+ else:
+ rst.append(Part(text=msg["text"]))
+ elif msg["type"] == "image_url":
+ if self.use_vertexai:
+ img_url = msg["image_url"]["url"]
+ re.match(r"data:image/(?:png|jpeg);base64,", img_url)
+ img = get_image_data(img_url, use_b64=False)
+ # image/png works with jpeg as well
+ img_part = VertexAIPart.from_data(img, mime_type="image/png")
+ rst.append(img_part)
+ else:
+ b64_img = get_image_data(msg["image_url"]["url"])
+ img = _to_pil(b64_img)
+ rst.append(img)
+ else:
+ raise ValueError(f"Unsupported message type: {msg['type']}")
+ else:
+ raise ValueError(f"Unsupported message type: {type(msg)}")
+ return rst
+
+ def _concat_parts(self, parts: List[Part]) -> List:
+ """Concatenate parts with the same type.
+ If two adjacent parts both have the "text" attribute, then it will be joined into one part.
+ """
+ if not parts:
+ return []
+
+ concatenated_parts = []
+ previous_part = parts[0]
+
+ for current_part in parts[1:]:
+ if previous_part.text != "":
+ if self.use_vertexai:
+ previous_part = VertexAIPart.from_text(previous_part.text + current_part.text)
+ else:
+ previous_part.text += current_part.text
+ else:
+ concatenated_parts.append(previous_part)
+ previous_part = current_part
+
+ if previous_part.text == "":
+ if self.use_vertexai:
+ previous_part = VertexAIPart.from_text("empty")
+ else:
+ previous_part.text = "empty" # Empty content is not allowed.
+ concatenated_parts.append(previous_part)
+
+ return concatenated_parts
+
+ def _oai_messages_to_gemini_messages(self, messages: list[Dict[str, Any]]) -> list[dict[str, Any]]:
+ """Convert messages from OAI format to Gemini format.
+ Make sure the "user" role and "model" role are interleaved.
+ Also, make sure the last item is from the "user" role.
+ """
+ prev_role = None
+ rst = []
+ curr_parts = []
+ for i, message in enumerate(messages):
+ parts = self._oai_content_to_gemini_content(message["content"])
+ role = "user" if message["role"] in ["user", "system"] else "model"
+ if (prev_role is None) or (role == prev_role):
+ curr_parts += parts
+ elif role != prev_role:
+ if self.use_vertexai:
+ rst.append(VertexAIContent(parts=curr_parts, role=prev_role))
+ else:
+ rst.append(Content(parts=curr_parts, role=prev_role))
+ curr_parts = parts
+ prev_role = role
+
+ # handle the last message
+ if self.use_vertexai:
+ rst.append(VertexAIContent(parts=curr_parts, role=role))
+ else:
+ rst.append(Content(parts=curr_parts, role=role))
+
+ # The Gemini is restrict on order of roles, such that
+ # 1. The messages should be interleaved between user and model.
+ # 2. The last message must be from the user role.
+ # We add a dummy message "continue" if the last role is not the user.
+ if rst[-1].role != "user":
+ if self.use_vertexai:
+ rst.append(VertexAIContent(parts=self._oai_content_to_gemini_content("continue"), role="user"))
+ else:
+ rst.append(Content(parts=self._oai_content_to_gemini_content("continue"), role="user"))
+
+ return rst
+
+ @staticmethod
+ def _to_vertexai_safety_settings(safety_settings):
+ """Convert safety settings to VertexAI format if needed,
+ like when specifying them in the OAI_CONFIG_LIST
+ """
+ if isinstance(safety_settings, list) and all(
+ [
+ isinstance(safety_setting, dict) and not isinstance(safety_setting, VertexAISafetySetting)
+ for safety_setting in safety_settings
+ ]
+ ):
+ vertexai_safety_settings = []
+ for safety_setting in safety_settings:
+ if safety_setting["category"] not in VertexAIHarmCategory.__members__:
+ invalid_category = safety_setting["category"]
+ logger.error(f"Safety setting category {invalid_category} is invalid")
+ elif safety_setting["threshold"] not in VertexAIHarmBlockThreshold.__members__:
+ invalid_threshold = safety_setting["threshold"]
+ logger.error(f"Safety threshold {invalid_threshold} is invalid")
+ else:
+ vertexai_safety_setting = VertexAISafetySetting(
+ category=safety_setting["category"],
+ threshold=safety_setting["threshold"],
+ )
+ vertexai_safety_settings.append(vertexai_safety_setting)
+ return vertexai_safety_settings
+ else:
+ return safety_settings
+
+
+def _to_pil(data: str) -> Image.Image:
+ """
+ Converts a base64 encoded image data string to a PIL Image object.
+
+ This function first decodes the base64 encoded string to bytes, then creates a BytesIO object from the bytes,
+ and finally creates and returns a PIL Image object from the BytesIO object.
+
+ Parameters:
+ data (str): The base64 encoded image data string.
+
+ Returns:
+ Image.Image: The PIL Image object created from the input data.
+ """
+ return Image.open(BytesIO(base64.b64decode(data)))
+
+
+def get_image_data(image_file: str, use_b64=True) -> bytes:
+ if image_file.startswith("http://") or image_file.startswith("https://"):
+ response = requests.get(image_file)
+ content = response.content
+ elif re.match(r"data:image/(?:png|jpeg);base64,", image_file):
+ return re.sub(r"data:image/(?:png|jpeg);base64,", "", image_file)
+ else:
+ image = Image.open(image_file).convert("RGB")
+ buffered = BytesIO()
+ image.save(buffered, format="PNG")
+ content = buffered.getvalue()
+
+ if use_b64:
+ return base64.b64encode(content).decode("utf-8")
+ else:
+ return content
+
+
+def calculate_gemini_cost(input_tokens: int, output_tokens: int, model_name: str) -> float:
+ if "1.5" in model_name or "gemini-experimental" in model_name:
+ # "gemini-1.5-pro-preview-0409"
+ # Cost is $7 per million input tokens and $21 per million output tokens
+ return 7.0 * input_tokens / 1e6 + 21.0 * output_tokens / 1e6
+
+ if "gemini-pro" not in model_name and "gemini-1.0-pro" not in model_name:
+ warnings.warn(f"Cost calculation is not implemented for model {model_name}. Using Gemini-1.0-Pro.", UserWarning)
+
+ # Cost is $0.5 per million input tokens and $1.5 per million output tokens
+ return 0.5 * input_tokens / 1e6 + 1.5 * output_tokens / 1e6
diff --git a/autogen/oai/groq.py b/autogen/oai/groq.py
new file mode 100644
index 00000000000..d2abe5116a2
--- /dev/null
+++ b/autogen/oai/groq.py
@@ -0,0 +1,282 @@
+"""Create an OpenAI-compatible client using Groq's API.
+
+Example:
+ llm_config={
+ "config_list": [{
+ "api_type": "groq",
+ "model": "mixtral-8x7b-32768",
+ "api_key": os.environ.get("GROQ_API_KEY")
+ }
+ ]}
+
+ agent = autogen.AssistantAgent("my_agent", llm_config=llm_config)
+
+Install Groq's python library using: pip install --upgrade groq
+
+Resources:
+- https://console.groq.com/docs/quickstart
+"""
+
+from __future__ import annotations
+
+import copy
+import os
+import time
+import warnings
+from typing import Any, Dict, List
+
+from groq import Groq, Stream
+from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
+from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
+from openai.types.completion_usage import CompletionUsage
+
+from autogen.oai.client_utils import should_hide_tools, validate_parameter
+
+# Cost per thousand tokens - Input / Output (NOTE: Convert $/Million to $/K)
+GROQ_PRICING_1K = {
+ "llama3-70b-8192": (0.00059, 0.00079),
+ "mixtral-8x7b-32768": (0.00024, 0.00024),
+ "llama3-8b-8192": (0.00005, 0.00008),
+ "gemma-7b-it": (0.00007, 0.00007),
+}
+
+
+class GroqClient:
+ """Client for Groq's API."""
+
+ def __init__(self, **kwargs):
+ """Requires api_key or environment variable to be set
+
+ Args:
+ api_key (str): The API key for using Groq (or environment variable GROQ_API_KEY needs to be set)
+ """
+ # Ensure we have the api_key upon instantiation
+ self.api_key = kwargs.get("api_key", None)
+ if not self.api_key:
+ self.api_key = os.getenv("GROQ_API_KEY")
+
+ assert (
+ self.api_key
+ ), "Please include the api_key in your config list entry for Groq or set the GROQ_API_KEY env variable."
+
+ def message_retrieval(self, response) -> List:
+ """
+ Retrieve and return a list of strings or a list of Choice.Message from the response.
+
+ NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
+ since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
+ """
+ return [choice.message for choice in response.choices]
+
+ def cost(self, response) -> float:
+ return response.cost
+
+ @staticmethod
+ def get_usage(response) -> Dict:
+ """Return usage summary of the response using RESPONSE_USAGE_KEYS."""
+ # ... # pragma: no cover
+ return {
+ "prompt_tokens": response.usage.prompt_tokens,
+ "completion_tokens": response.usage.completion_tokens,
+ "total_tokens": response.usage.total_tokens,
+ "cost": response.cost,
+ "model": response.model,
+ }
+
+ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
+ """Loads the parameters for Groq API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults"""
+ groq_params = {}
+
+ # Check that we have what we need to use Groq's API
+ # We won't enforce the available models as they are likely to change
+ groq_params["model"] = params.get("model", None)
+ assert groq_params[
+ "model"
+ ], "Please specify the 'model' in your config list entry to nominate the Groq model to use."
+
+ # Validate allowed Groq parameters
+ # https://console.groq.com/docs/api-reference#chat
+ groq_params["frequency_penalty"] = validate_parameter(
+ params, "frequency_penalty", (int, float), True, None, (-2, 2), None
+ )
+ groq_params["max_tokens"] = validate_parameter(params, "max_tokens", int, True, None, (0, None), None)
+ groq_params["presence_penalty"] = validate_parameter(
+ params, "presence_penalty", (int, float), True, None, (-2, 2), None
+ )
+ groq_params["seed"] = validate_parameter(params, "seed", int, True, None, None, None)
+ groq_params["stream"] = validate_parameter(params, "stream", bool, True, False, None, None)
+ groq_params["temperature"] = validate_parameter(params, "temperature", (int, float), True, 1, (0, 2), None)
+ groq_params["top_p"] = validate_parameter(params, "top_p", (int, float), True, None, None, None)
+
+ # Groq parameters not supported by their models yet, ignoring
+ # logit_bias, logprobs, top_logprobs
+
+ # Groq parameters we are ignoring:
+ # n (must be 1), response_format (to enforce JSON but needs prompting as well), user,
+ # parallel_tool_calls (defaults to True), stop
+ # function_call (deprecated), functions (deprecated)
+ # tool_choice (none if no tools, auto if there are tools)
+
+ return groq_params
+
+ def create(self, params: Dict) -> ChatCompletion:
+
+ messages = params.get("messages", [])
+
+ # Convert AutoGen messages to Groq messages
+ groq_messages = oai_messages_to_groq_messages(messages)
+
+ # Parse parameters to the Groq API's parameters
+ groq_params = self.parse_params(params)
+
+ # Add tools to the call if we have them and aren't hiding them
+ if "tools" in params:
+ hide_tools = validate_parameter(
+ params, "hide_tools", str, False, "never", None, ["if_all_run", "if_any_run", "never"]
+ )
+ if not should_hide_tools(groq_messages, params["tools"], hide_tools):
+ groq_params["tools"] = params["tools"]
+
+ groq_params["messages"] = groq_messages
+
+ # We use chat model by default, and set max_retries to 5 (in line with typical retries loop)
+ client = Groq(api_key=self.api_key, max_retries=5)
+
+ # Token counts will be returned
+ prompt_tokens = 0
+ completion_tokens = 0
+ total_tokens = 0
+
+ # Streaming tool call recommendations
+ streaming_tool_calls = []
+
+ ans = None
+ try:
+ response = client.chat.completions.create(**groq_params)
+ except Exception as e:
+ raise RuntimeError(f"Groq exception occurred: {e}")
+ else:
+
+ if groq_params["stream"]:
+ # Read in the chunks as they stream, taking in tool_calls which may be across
+ # multiple chunks if more than one suggested
+ ans = ""
+ for chunk in response:
+ ans = ans + (chunk.choices[0].delta.content or "")
+
+ if chunk.choices[0].delta.tool_calls:
+ # We have a tool call recommendation
+ for tool_call in chunk.choices[0].delta.tool_calls:
+ streaming_tool_calls.append(
+ ChatCompletionMessageToolCall(
+ id=tool_call.id,
+ function={
+ "name": tool_call.function.name,
+ "arguments": tool_call.function.arguments,
+ },
+ type="function",
+ )
+ )
+
+ if chunk.choices[0].finish_reason:
+ prompt_tokens = chunk.x_groq.usage.prompt_tokens
+ completion_tokens = chunk.x_groq.usage.completion_tokens
+ total_tokens = chunk.x_groq.usage.total_tokens
+ else:
+ # Non-streaming finished
+ ans: str = response.choices[0].message.content
+
+ prompt_tokens = response.usage.prompt_tokens
+ completion_tokens = response.usage.completion_tokens
+ total_tokens = response.usage.total_tokens
+
+ if response is not None:
+
+ if isinstance(response, Stream):
+ # Streaming response
+ if chunk.choices[0].finish_reason == "tool_calls":
+ groq_finish = "tool_calls"
+ tool_calls = streaming_tool_calls
+ else:
+ groq_finish = "stop"
+ tool_calls = None
+
+ response_content = ans
+ response_id = chunk.id
+ else:
+ # Non-streaming response
+ # If we have tool calls as the response, populate completed tool calls for our return OAI response
+ if response.choices[0].finish_reason == "tool_calls":
+ groq_finish = "tool_calls"
+ tool_calls = []
+ for tool_call in response.choices[0].message.tool_calls:
+ tool_calls.append(
+ ChatCompletionMessageToolCall(
+ id=tool_call.id,
+ function={"name": tool_call.function.name, "arguments": tool_call.function.arguments},
+ type="function",
+ )
+ )
+ else:
+ groq_finish = "stop"
+ tool_calls = None
+
+ response_content = response.choices[0].message.content
+ response_id = response.id
+ else:
+ raise RuntimeError("Failed to get response from Groq after retrying 5 times.")
+
+ # 3. convert output
+ message = ChatCompletionMessage(
+ role="assistant",
+ content=response_content,
+ function_call=None,
+ tool_calls=tool_calls,
+ )
+ choices = [Choice(finish_reason=groq_finish, index=0, message=message)]
+
+ response_oai = ChatCompletion(
+ id=response_id,
+ model=groq_params["model"],
+ created=int(time.time()),
+ object="chat.completion",
+ choices=choices,
+ usage=CompletionUsage(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=total_tokens,
+ ),
+ cost=calculate_groq_cost(prompt_tokens, completion_tokens, groq_params["model"]),
+ )
+
+ return response_oai
+
+
+def oai_messages_to_groq_messages(messages: list[Dict[str, Any]]) -> list[dict[str, Any]]:
+ """Convert messages from OAI format to Groq's format.
+ We correct for any specific role orders and types.
+ """
+
+ groq_messages = copy.deepcopy(messages)
+
+ # Remove the name field
+ for message in groq_messages:
+ if "name" in message:
+ message.pop("name", None)
+
+ return groq_messages
+
+
+def calculate_groq_cost(input_tokens: int, output_tokens: int, model: str) -> float:
+ """Calculate the cost of the completion using the Groq pricing."""
+ total = 0.0
+
+ if model in GROQ_PRICING_1K:
+ input_cost_per_k, output_cost_per_k = GROQ_PRICING_1K[model]
+ input_cost = (input_tokens / 1000) * input_cost_per_k
+ output_cost = (output_tokens / 1000) * output_cost_per_k
+ total = input_cost + output_cost
+ else:
+ warnings.warn(f"Cost calculation not available for model {model}", UserWarning)
+
+ return total
diff --git a/autogen/oai/mistral.py b/autogen/oai/mistral.py
new file mode 100644
index 00000000000..10d0f926ffb
--- /dev/null
+++ b/autogen/oai/mistral.py
@@ -0,0 +1,273 @@
+"""Create an OpenAI-compatible client using Mistral.AI's API.
+
+Example:
+ llm_config={
+ "config_list": [{
+ "api_type": "mistral",
+ "model": "open-mixtral-8x22b",
+ "api_key": os.environ.get("MISTRAL_API_KEY")
+ }
+ ]}
+
+ agent = autogen.AssistantAgent("my_agent", llm_config=llm_config)
+
+Install Mistral.AI python library using: pip install --upgrade mistralai
+
+Resources:
+- https://docs.mistral.ai/getting-started/quickstart/
+
+NOTE: Requires mistralai package version >= 1.0.1
+"""
+
+import inspect
+import json
+import os
+import time
+import warnings
+from typing import Any, Dict, List, Union
+
+# Mistral libraries
+# pip install mistralai
+from mistralai import (
+ AssistantMessage,
+ Function,
+ FunctionCall,
+ Mistral,
+ SystemMessage,
+ ToolCall,
+ ToolMessage,
+ UserMessage,
+)
+from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
+from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
+from openai.types.completion_usage import CompletionUsage
+
+from autogen.oai.client_utils import should_hide_tools, validate_parameter
+
+
+class MistralAIClient:
+ """Client for Mistral.AI's API."""
+
+ def __init__(self, **kwargs):
+ """Requires api_key or environment variable to be set
+
+ Args:
+ api_key (str): The API key for using Mistral.AI (or environment variable MISTRAL_API_KEY needs to be set)
+ """
+
+ # Ensure we have the api_key upon instantiation
+ self.api_key = kwargs.get("api_key", None)
+ if not self.api_key:
+ self.api_key = os.getenv("MISTRAL_API_KEY", None)
+
+ assert (
+ self.api_key
+ ), "Please specify the 'api_key' in your config list entry for Mistral or set the MISTRAL_API_KEY env variable."
+
+ self._client = Mistral(api_key=self.api_key)
+
+ def message_retrieval(self, response: ChatCompletion) -> Union[List[str], List[ChatCompletionMessage]]:
+ """Retrieve the messages from the response."""
+
+ return [choice.message for choice in response.choices]
+
+ def cost(self, response) -> float:
+ return response.cost
+
+ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
+ """Loads the parameters for Mistral.AI API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults"""
+ mistral_params = {}
+
+ # 1. Validate models
+ mistral_params["model"] = params.get("model", None)
+ assert mistral_params[
+ "model"
+ ], "Please specify the 'model' in your config list entry to nominate the Mistral.ai model to use."
+
+ # 2. Validate allowed Mistral.AI parameters
+ mistral_params["temperature"] = validate_parameter(params, "temperature", (int, float), True, 0.7, None, None)
+ mistral_params["top_p"] = validate_parameter(params, "top_p", (int, float), True, None, None, None)
+ mistral_params["max_tokens"] = validate_parameter(params, "max_tokens", int, True, None, (0, None), None)
+ mistral_params["safe_prompt"] = validate_parameter(
+ params, "safe_prompt", bool, False, False, None, [True, False]
+ )
+ mistral_params["random_seed"] = validate_parameter(params, "random_seed", int, True, None, False, None)
+
+ # TODO
+ if params.get("stream", False):
+ warnings.warn(
+ "Streaming is not currently supported, streaming will be disabled.",
+ UserWarning,
+ )
+
+ # 3. Convert messages to Mistral format
+ mistral_messages = []
+ tool_call_ids = {} # tool call ids to function name mapping
+ for message in params["messages"]:
+ if message["role"] == "assistant" and "tool_calls" in message and message["tool_calls"] is not None:
+ # Convert OAI ToolCall to Mistral ToolCall
+ mistral_messages_tools = []
+ for toolcall in message["tool_calls"]:
+ mistral_messages_tools.append(
+ ToolCall(
+ id=toolcall["id"],
+ function=FunctionCall(
+ name=toolcall["function"]["name"],
+ arguments=json.loads(toolcall["function"]["arguments"]),
+ ),
+ )
+ )
+
+ mistral_messages.append(AssistantMessage(content="", tool_calls=mistral_messages_tools))
+
+ # Map tool call id to the function name
+ for tool_call in message["tool_calls"]:
+ tool_call_ids[tool_call["id"]] = tool_call["function"]["name"]
+
+ elif message["role"] == "system":
+ if len(mistral_messages) > 0 and mistral_messages[-1].role == "assistant":
+ # System messages can't appear after an Assistant message, so use a UserMessage
+ mistral_messages.append(UserMessage(content=message["content"]))
+ else:
+ mistral_messages.append(SystemMessage(content=message["content"]))
+ elif message["role"] == "assistant":
+ mistral_messages.append(AssistantMessage(content=message["content"]))
+ elif message["role"] == "user":
+ mistral_messages.append(UserMessage(content=message["content"]))
+
+ elif message["role"] == "tool":
+ # Indicates the result of a tool call, the name is the function name called
+ mistral_messages.append(
+ ToolMessage(
+ name=tool_call_ids[message["tool_call_id"]],
+ content=message["content"],
+ tool_call_id=message["tool_call_id"],
+ )
+ )
+ else:
+ warnings.warn(f"Unknown message role {message['role']}", UserWarning)
+
+ # 4. Last message needs to be user or tool, if not, add a "please continue" message
+ if not isinstance(mistral_messages[-1], UserMessage) and not isinstance(mistral_messages[-1], ToolMessage):
+ mistral_messages.append(UserMessage(content="Please continue."))
+
+ mistral_params["messages"] = mistral_messages
+
+ # 5. Add tools to the call if we have them and aren't hiding them
+ if "tools" in params:
+ hide_tools = validate_parameter(
+ params, "hide_tools", str, False, "never", None, ["if_all_run", "if_any_run", "never"]
+ )
+ if not should_hide_tools(params["messages"], params["tools"], hide_tools):
+ mistral_params["tools"] = tool_def_to_mistral(params["tools"])
+
+ return mistral_params
+
+ def create(self, params: Dict[str, Any]) -> ChatCompletion:
+ # 1. Parse parameters to Mistral.AI API's parameters
+ mistral_params = self.parse_params(params)
+
+ # 2. Call Mistral.AI API
+ mistral_response = self._client.chat.complete(**mistral_params)
+ # TODO: Handle streaming
+
+ # 3. Convert Mistral response to OAI compatible format
+ if mistral_response.choices[0].finish_reason == "tool_calls":
+ mistral_finish = "tool_calls"
+ tool_calls = []
+ for tool_call in mistral_response.choices[0].message.tool_calls:
+ tool_calls.append(
+ ChatCompletionMessageToolCall(
+ id=tool_call.id,
+ function={"name": tool_call.function.name, "arguments": tool_call.function.arguments},
+ type="function",
+ )
+ )
+ else:
+ mistral_finish = "stop"
+ tool_calls = None
+
+ message = ChatCompletionMessage(
+ role="assistant",
+ content=mistral_response.choices[0].message.content,
+ function_call=None,
+ tool_calls=tool_calls,
+ )
+ choices = [Choice(finish_reason=mistral_finish, index=0, message=message)]
+
+ response_oai = ChatCompletion(
+ id=mistral_response.id,
+ model=mistral_response.model,
+ created=int(time.time()),
+ object="chat.completion",
+ choices=choices,
+ usage=CompletionUsage(
+ prompt_tokens=mistral_response.usage.prompt_tokens,
+ completion_tokens=mistral_response.usage.completion_tokens,
+ total_tokens=mistral_response.usage.prompt_tokens + mistral_response.usage.completion_tokens,
+ ),
+ cost=calculate_mistral_cost(
+ mistral_response.usage.prompt_tokens, mistral_response.usage.completion_tokens, mistral_response.model
+ ),
+ )
+
+ return response_oai
+
+ @staticmethod
+ def get_usage(response: ChatCompletion) -> Dict:
+ return {
+ "prompt_tokens": response.usage.prompt_tokens if response.usage is not None else 0,
+ "completion_tokens": response.usage.completion_tokens if response.usage is not None else 0,
+ "total_tokens": (
+ response.usage.prompt_tokens + response.usage.completion_tokens if response.usage is not None else 0
+ ),
+ "cost": response.cost if hasattr(response, "cost") else 0,
+ "model": response.model,
+ }
+
+
+def tool_def_to_mistral(tool_definitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ """Converts AutoGen tool definition to a mistral tool format"""
+
+ mistral_tools = []
+
+ for autogen_tool in tool_definitions:
+ mistral_tool = {
+ "type": "function",
+ "function": Function(
+ name=autogen_tool["function"]["name"],
+ description=autogen_tool["function"]["description"],
+ parameters=autogen_tool["function"]["parameters"],
+ ),
+ }
+
+ mistral_tools.append(mistral_tool)
+
+ return mistral_tools
+
+
+def calculate_mistral_cost(input_tokens: int, output_tokens: int, model_name: str) -> float:
+ """Calculate the cost of the mistral response."""
+
+ # Prices per 1 thousand tokens
+ # https://mistral.ai/technology/
+ model_cost_map = {
+ "open-mistral-7b": {"input": 0.00025, "output": 0.00025},
+ "open-mixtral-8x7b": {"input": 0.0007, "output": 0.0007},
+ "open-mixtral-8x22b": {"input": 0.002, "output": 0.006},
+ "mistral-small-latest": {"input": 0.001, "output": 0.003},
+ "mistral-medium-latest": {"input": 0.00275, "output": 0.0081},
+ "mistral-large-latest": {"input": 0.0003, "output": 0.0003},
+ "mistral-large-2407": {"input": 0.0003, "output": 0.0003},
+ "open-mistral-nemo-2407": {"input": 0.0003, "output": 0.0003},
+ "codestral-2405": {"input": 0.001, "output": 0.003},
+ }
+
+ # Ensure we have the model they are using and return the total cost
+ if model_name in model_cost_map:
+ costs = model_cost_map[model_name]
+
+ return (input_tokens * costs["input"] / 1000) + (output_tokens * costs["output"] / 1000)
+ else:
+ warnings.warn(f"Cost calculation is not implemented for model {model_name}, will return $0.", UserWarning)
+ return 0
diff --git a/autogen/oai/openai_utils.py b/autogen/oai/openai_utils.py
index 80be557eadd..41b94324118 100644
--- a/autogen/oai/openai_utils.py
+++ b/autogen/oai/openai_utils.py
@@ -1,26 +1,42 @@
+import importlib.metadata
import json
import logging
import os
import re
import tempfile
+import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Union
from dotenv import find_dotenv, load_dotenv
from openai import OpenAI
from openai.types.beta.assistant import Assistant
-
-NON_CACHE_KEY = ["api_key", "base_url", "api_type", "api_version"]
-DEFAULT_AZURE_API_VERSION = "2024-02-15-preview"
+from packaging.version import parse
+
+NON_CACHE_KEY = [
+ "api_key",
+ "base_url",
+ "api_type",
+ "api_version",
+ "azure_ad_token",
+ "azure_ad_token_provider",
+ "credentials",
+]
+DEFAULT_AZURE_API_VERSION = "2024-02-01"
OAI_PRICE1K = {
- # https://openai.com/pricing
+ # https://openai.com/api/pricing/
+ # gpt-4o
+ "gpt-4o": (0.005, 0.015),
+ "gpt-4o-2024-05-13": (0.005, 0.015),
+ "gpt-4o-2024-08-06": (0.0025, 0.01),
# gpt-4-turbo
- "gpt-4-0125-preview": (0.01, 0.03),
- "gpt-4-1106-preview": (0.01, 0.03),
- "gpt-4-1106-vision-preview": (0.01, 0.03), # TODO: support vision pricing of images
+ "gpt-4-turbo-2024-04-09": (0.01, 0.03),
# gpt-4
"gpt-4": (0.03, 0.06),
"gpt-4-32k": (0.06, 0.12),
+ # gpt-4o-mini
+ "gpt-4o-mini": (0.000150, 0.000600),
+ "gpt-4o-mini-2024-07-18": (0.000150, 0.000600),
# gpt-3.5 turbo
"gpt-3.5-turbo": (0.0005, 0.0015), # default is 0125
"gpt-3.5-turbo-0125": (0.0005, 0.0015), # 16k
@@ -29,6 +45,9 @@
"davinci-002": 0.002,
"babbage-002": 0.0004,
# old model
+ "gpt-4-0125-preview": (0.01, 0.03),
+ "gpt-4-1106-preview": (0.01, 0.03),
+ "gpt-4-1106-vision-preview": (0.01, 0.03), # TODO: support vision pricing of images
"gpt-3.5-turbo-1106": (0.001, 0.002),
"gpt-3.5-turbo-0613": (0.0015, 0.002),
# "gpt-3.5-turbo-16k": (0.003, 0.004),
@@ -89,7 +108,7 @@ def is_valid_api_key(api_key: str) -> bool:
Returns:
bool: A boolean that indicates if input is valid OpenAI API key.
"""
- api_key_re = re.compile(r"^sk-[A-Za-z0-9]{32,}$")
+ api_key_re = re.compile(r"^sk-([A-Za-z0-9]+(-+[A-Za-z0-9]+)*-)?[A-Za-z0-9]{32,}$")
return bool(re.fullmatch(api_key_re, api_key))
@@ -120,7 +139,7 @@ def get_config_list(
# Optionally, define the API type and version if they are common for all keys
api_type = 'azure'
- api_version = '2024-02-15-preview'
+ api_version = '2024-02-01'
# Call the get_config_list function to get a list of configuration dictionaries
config_list = get_config_list(api_keys, base_urls, api_type, api_version)
@@ -372,11 +391,10 @@ def config_list_gpt4_gpt35(
def filter_config(
config_list: List[Dict[str, Any]],
filter_dict: Optional[Dict[str, Union[List[Union[str, None]], Set[Union[str, None]]]]],
+ exclude: bool = False,
) -> List[Dict[str, Any]]:
- """
- This function filters `config_list` by checking each configuration dictionary against the
- criteria specified in `filter_dict`. A configuration dictionary is retained if for every
- key in `filter_dict`, see example below.
+ """This function filters `config_list` by checking each configuration dictionary against the criteria specified in
+ `filter_dict`. A configuration dictionary is retained if for every key in `filter_dict`, see example below.
Args:
config_list (list of dict): A list of configuration dictionaries to be filtered.
@@ -387,71 +405,68 @@ def filter_config(
when it is found in the list of acceptable values. If the configuration's
field's value is a list, then a match occurs if there is a non-empty
intersection with the acceptable values.
-
-
+ exclude (bool): If False (the default value), configs that match the filter will be included in the returned
+ list. If True, configs that match the filter will be excluded in the returned list.
Returns:
list of dict: A list of configuration dictionaries that meet all the criteria specified
in `filter_dict`.
Example:
- ```python
- # Example configuration list with various models and API types
- configs = [
- {'model': 'gpt-3.5-turbo'},
- {'model': 'gpt-4'},
- {'model': 'gpt-3.5-turbo', 'api_type': 'azure'},
- {'model': 'gpt-3.5-turbo', 'tags': ['gpt35_turbo', 'gpt-35-turbo']},
- ]
-
- # Define filter criteria to select configurations for the 'gpt-3.5-turbo' model
- # that are also using the 'azure' API type
- filter_criteria = {
- 'model': ['gpt-3.5-turbo'], # Only accept configurations for 'gpt-3.5-turbo'
- 'api_type': ['azure'] # Only accept configurations for 'azure' API type
- }
-
- # Apply the filter to the configuration list
- filtered_configs = filter_config(configs, filter_criteria)
-
- # The resulting `filtered_configs` will be:
- # [{'model': 'gpt-3.5-turbo', 'api_type': 'azure', ...}]
-
-
- # Define a filter to select a given tag
- filter_criteria = {
- 'tags': ['gpt35_turbo'],
- }
-
- # Apply the filter to the configuration list
- filtered_configs = filter_config(configs, filter_criteria)
-
- # The resulting `filtered_configs` will be:
- # [{'model': 'gpt-3.5-turbo', 'tags': ['gpt35_turbo', 'gpt-35-turbo']}]
- ```
-
+ ```python
+ # Example configuration list with various models and API types
+ configs = [
+ {'model': 'gpt-3.5-turbo'},
+ {'model': 'gpt-4'},
+ {'model': 'gpt-3.5-turbo', 'api_type': 'azure'},
+ {'model': 'gpt-3.5-turbo', 'tags': ['gpt35_turbo', 'gpt-35-turbo']},
+ ]
+ # Define filter criteria to select configurations for the 'gpt-3.5-turbo' model
+ # that are also using the 'azure' API type
+ filter_criteria = {
+ 'model': ['gpt-3.5-turbo'], # Only accept configurations for 'gpt-3.5-turbo'
+ 'api_type': ['azure'] # Only accept configurations for 'azure' API type
+ }
+ # Apply the filter to the configuration list
+ filtered_configs = filter_config(configs, filter_criteria)
+ # The resulting `filtered_configs` will be:
+ # [{'model': 'gpt-3.5-turbo', 'api_type': 'azure', ...}]
+ # Define a filter to select a given tag
+ filter_criteria = {
+ 'tags': ['gpt35_turbo'],
+ }
+ # Apply the filter to the configuration list
+ filtered_configs = filter_config(configs, filter_criteria)
+ # The resulting `filtered_configs` will be:
+ # [{'model': 'gpt-3.5-turbo', 'tags': ['gpt35_turbo', 'gpt-35-turbo']}]
+ ```
Note:
- If `filter_dict` is empty or None, no filtering is applied and `config_list` is returned as is.
- If a configuration dictionary in `config_list` does not contain a key specified in `filter_dict`,
it is considered a non-match and is excluded from the result.
- If the list of acceptable values for a key in `filter_dict` includes None, then configuration
dictionaries that do not have that key will also be considered a match.
- """
- def _satisfies(config_value: Any, acceptable_values: Any) -> bool:
- if isinstance(config_value, list):
- return bool(set(config_value) & set(acceptable_values)) # Non-empty intersection
- else:
- return config_value in acceptable_values
+ """
if filter_dict:
- config_list = [
- config
- for config in config_list
- if all(_satisfies(config.get(key), value) for key, value in filter_dict.items())
+ return [
+ item
+ for item in config_list
+ if all(_satisfies_criteria(item.get(key), values) != exclude for key, values in filter_dict.items())
]
return config_list
+def _satisfies_criteria(value: Any, criteria_values: Any) -> bool:
+ if value is None:
+ return False
+
+ if isinstance(value, list):
+ return bool(set(value) & set(criteria_values)) # Non-empty intersection
+ else:
+ return value in criteria_values
+
+
def config_list_from_json(
env_or_file: str,
file_location: Optional[str] = "",
@@ -674,3 +689,114 @@ def retrieve_assistants_by_name(client: OpenAI, name: str) -> List[Assistant]:
if assistant.name == name:
candidate_assistants.append(assistant)
return candidate_assistants
+
+
+def detect_gpt_assistant_api_version() -> str:
+ """Detect the openai assistant API version"""
+ oai_version = importlib.metadata.version("openai")
+ if parse(oai_version) < parse("1.21"):
+ return "v1"
+ else:
+ return "v2"
+
+
+def create_gpt_vector_store(client: OpenAI, name: str, fild_ids: List[str]) -> Any:
+ """Create a openai vector store for gpt assistant"""
+
+ try:
+ vector_store = client.beta.vector_stores.create(name=name)
+ except Exception as e:
+ raise AttributeError(f"Failed to create vector store, please install the latest OpenAI python package: {e}")
+
+ # poll the status of the file batch for completion.
+ batch = client.beta.vector_stores.file_batches.create_and_poll(vector_store_id=vector_store.id, file_ids=fild_ids)
+
+ if batch.status == "in_progress":
+ time.sleep(1)
+ logging.debug(f"file batch status: {batch.file_counts}")
+ batch = client.beta.vector_stores.file_batches.poll(vector_store_id=vector_store.id, batch_id=batch.id)
+
+ if batch.status == "completed":
+ return vector_store
+
+ raise ValueError(f"Failed to upload files to vector store {vector_store.id}:{batch.status}")
+
+
+def create_gpt_assistant(
+ client: OpenAI, name: str, instructions: str, model: str, assistant_config: Dict[str, Any]
+) -> Assistant:
+ """Create a openai gpt assistant"""
+
+ assistant_create_kwargs = {}
+ gpt_assistant_api_version = detect_gpt_assistant_api_version()
+ tools = assistant_config.get("tools", [])
+
+ if gpt_assistant_api_version == "v2":
+ tool_resources = assistant_config.get("tool_resources", {})
+ file_ids = assistant_config.get("file_ids")
+ if tool_resources.get("file_search") is not None and file_ids is not None:
+ raise ValueError(
+ "Cannot specify both `tool_resources['file_search']` tool and `file_ids` in the assistant config."
+ )
+
+ # Designed for backwards compatibility for the V1 API
+ # Instead of V1 AssistantFile, files are attached to Assistants using the tool_resources object.
+ for tool in tools:
+ if tool["type"] == "retrieval":
+ tool["type"] = "file_search"
+ if file_ids is not None:
+ # create a vector store for the file search tool
+ vs = create_gpt_vector_store(client, f"{name}-vectorestore", file_ids)
+ tool_resources["file_search"] = {
+ "vector_store_ids": [vs.id],
+ }
+ elif tool["type"] == "code_interpreter" and file_ids is not None:
+ tool_resources["code_interpreter"] = {
+ "file_ids": file_ids,
+ }
+
+ assistant_create_kwargs["tools"] = tools
+ if len(tool_resources) > 0:
+ assistant_create_kwargs["tool_resources"] = tool_resources
+ else:
+ # not support forwards compatibility
+ if "tool_resources" in assistant_config:
+ raise ValueError("`tool_resources` argument are not supported in the openai assistant V1 API.")
+ if any(tool["type"] == "file_search" for tool in tools):
+ raise ValueError(
+ "`file_search` tool are not supported in the openai assistant V1 API, please use `retrieval`."
+ )
+ assistant_create_kwargs["tools"] = tools
+ assistant_create_kwargs["file_ids"] = assistant_config.get("file_ids", [])
+
+ logging.info(f"Creating assistant with config: {assistant_create_kwargs}")
+ return client.beta.assistants.create(name=name, instructions=instructions, model=model, **assistant_create_kwargs)
+
+
+def update_gpt_assistant(client: OpenAI, assistant_id: str, assistant_config: Dict[str, Any]) -> Assistant:
+ """Update openai gpt assistant"""
+
+ gpt_assistant_api_version = detect_gpt_assistant_api_version()
+ assistant_update_kwargs = {}
+
+ if assistant_config.get("tools") is not None:
+ assistant_update_kwargs["tools"] = assistant_config["tools"]
+
+ if assistant_config.get("instructions") is not None:
+ assistant_update_kwargs["instructions"] = assistant_config["instructions"]
+
+ if gpt_assistant_api_version == "v2":
+ if assistant_config.get("tool_resources") is not None:
+ assistant_update_kwargs["tool_resources"] = assistant_config["tool_resources"]
+ else:
+ if assistant_config.get("file_ids") is not None:
+ assistant_update_kwargs["file_ids"] = assistant_config["file_ids"]
+
+ return client.beta.assistants.update(assistant_id=assistant_id, **assistant_update_kwargs)
+
+
+def _satisfies(config_value: Any, acceptable_values: Any) -> bool:
+ if isinstance(config_value, list):
+ return bool(set(config_value) & set(acceptable_values)) # Non-empty intersection
+ else:
+ return config_value in acceptable_values
diff --git a/autogen/oai/together.py b/autogen/oai/together.py
new file mode 100644
index 00000000000..bbbe851ba77
--- /dev/null
+++ b/autogen/oai/together.py
@@ -0,0 +1,351 @@
+"""Create an OpenAI-compatible client using Together.AI's API.
+
+Example:
+ llm_config={
+ "config_list": [{
+ "api_type": "together",
+ "model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
+ "api_key": os.environ.get("TOGETHER_API_KEY")
+ }
+ ]}
+
+ agent = autogen.AssistantAgent("my_agent", llm_config=llm_config)
+
+Install Together.AI python library using: pip install --upgrade together
+
+Resources:
+- https://docs.together.ai/docs/inference-python
+"""
+
+from __future__ import annotations
+
+import base64
+import copy
+import os
+import random
+import re
+import time
+import warnings
+from io import BytesIO
+from typing import Any, Dict, List, Mapping, Tuple, Union
+
+import requests
+from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
+from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
+from openai.types.completion_usage import CompletionUsage
+from PIL import Image
+from together import Together, error
+
+from autogen.oai.client_utils import should_hide_tools, validate_parameter
+
+
+class TogetherClient:
+ """Client for Together.AI's API."""
+
+ def __init__(self, **kwargs):
+ """Requires api_key or environment variable to be set
+
+ Args:
+ api_key (str): The API key for using Together.AI (or environment variable TOGETHER_API_KEY needs to be set)
+ """
+ # Ensure we have the api_key upon instantiation
+ self.api_key = kwargs.get("api_key", None)
+ if not self.api_key:
+ self.api_key = os.getenv("TOGETHER_API_KEY")
+
+ assert (
+ self.api_key
+ ), "Please include the api_key in your config list entry for Together.AI or set the TOGETHER_API_KEY env variable."
+
+ def message_retrieval(self, response) -> List:
+ """
+ Retrieve and return a list of strings or a list of Choice.Message from the response.
+
+ NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
+ since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
+ """
+ return [choice.message for choice in response.choices]
+
+ def cost(self, response) -> float:
+ return response.cost
+
+ @staticmethod
+ def get_usage(response) -> Dict:
+ """Return usage summary of the response using RESPONSE_USAGE_KEYS."""
+ # ... # pragma: no cover
+ return {
+ "prompt_tokens": response.usage.prompt_tokens,
+ "completion_tokens": response.usage.completion_tokens,
+ "total_tokens": response.usage.total_tokens,
+ "cost": response.cost,
+ "model": response.model,
+ }
+
+ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
+ """Loads the parameters for Together.AI API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults"""
+ together_params = {}
+
+ # Check that we have what we need to use Together.AI's API
+ together_params["model"] = params.get("model", None)
+ assert together_params[
+ "model"
+ ], "Please specify the 'model' in your config list entry to nominate the Together.AI model to use."
+
+ # Validate allowed Together.AI parameters
+ # https://github.com/togethercomputer/together-python/blob/94ffb30daf0ac3e078be986af7228f85f79bde99/src/together/resources/completions.py#L44
+ together_params["max_tokens"] = validate_parameter(params, "max_tokens", int, True, 512, (0, None), None)
+ together_params["stream"] = validate_parameter(params, "stream", bool, False, False, None, None)
+ together_params["temperature"] = validate_parameter(params, "temperature", (int, float), True, None, None, None)
+ together_params["top_p"] = validate_parameter(params, "top_p", (int, float), True, None, None, None)
+ together_params["top_k"] = validate_parameter(params, "top_k", int, True, None, None, None)
+ together_params["repetition_penalty"] = validate_parameter(
+ params, "repetition_penalty", float, True, None, None, None
+ )
+ together_params["presence_penalty"] = validate_parameter(
+ params, "presence_penalty", (int, float), True, None, (-2, 2), None
+ )
+ together_params["frequency_penalty"] = validate_parameter(
+ params, "frequency_penalty", (int, float), True, None, (-2, 2), None
+ )
+ together_params["min_p"] = validate_parameter(params, "min_p", (int, float), True, None, (0, 1), None)
+ together_params["safety_model"] = validate_parameter(
+ params, "safety_model", str, True, None, None, None
+ ) # We won't enforce the available models as they are likely to change
+
+ # Check if they want to stream and use tools, which isn't currently supported (TODO)
+ if together_params["stream"] and "tools" in params:
+ warnings.warn(
+ "Streaming is not supported when using tools, streaming will be disabled.",
+ UserWarning,
+ )
+
+ together_params["stream"] = False
+
+ return together_params
+
+ def create(self, params: Dict) -> ChatCompletion:
+
+ messages = params.get("messages", [])
+
+ # Convert AutoGen messages to Together.AI messages
+ together_messages = oai_messages_to_together_messages(messages)
+
+ # Parse parameters to Together.AI API's parameters
+ together_params = self.parse_params(params)
+
+ # Add tools to the call if we have them and aren't hiding them
+ if "tools" in params:
+ hide_tools = validate_parameter(
+ params, "hide_tools", str, False, "never", None, ["if_all_run", "if_any_run", "never"]
+ )
+ if not should_hide_tools(together_messages, params["tools"], hide_tools):
+ together_params["tools"] = params["tools"]
+
+ together_params["messages"] = together_messages
+
+ # We use chat model by default
+ client = Together(api_key=self.api_key)
+
+ # Token counts will be returned
+ prompt_tokens = 0
+ completion_tokens = 0
+ total_tokens = 0
+
+ max_retries = 5
+ for attempt in range(max_retries):
+ ans = None
+ try:
+ response = client.chat.completions.create(**together_params)
+ except Exception as e:
+ raise RuntimeError(f"Together.AI exception occurred: {e}")
+ else:
+
+ if together_params["stream"]:
+ # Read in the chunks as they stream
+ ans = ""
+ for chunk in response:
+ ans = ans + (chunk.choices[0].delta.content or "")
+
+ prompt_tokens = chunk.usage.prompt_tokens
+ completion_tokens = chunk.usage.completion_tokens
+ total_tokens = chunk.usage.total_tokens
+ else:
+ ans: str = response.choices[0].message.content
+
+ prompt_tokens = response.usage.prompt_tokens
+ completion_tokens = response.usage.completion_tokens
+ total_tokens = response.usage.total_tokens
+ break
+
+ if response is not None:
+ # If we have tool calls as the response, populate completed tool calls for our return OAI response
+ if response.choices[0].finish_reason == "tool_calls":
+ together_finish = "tool_calls"
+ tool_calls = []
+ for tool_call in response.choices[0].message.tool_calls:
+ tool_calls.append(
+ ChatCompletionMessageToolCall(
+ id=tool_call.id,
+ function={"name": tool_call.function.name, "arguments": tool_call.function.arguments},
+ type="function",
+ )
+ )
+ else:
+ together_finish = "stop"
+ tool_calls = None
+
+ else:
+ raise RuntimeError(f"Failed to get response from Together.AI after retrying {attempt + 1} times.")
+
+ # 3. convert output
+ message = ChatCompletionMessage(
+ role="assistant",
+ content=response.choices[0].message.content,
+ function_call=None,
+ tool_calls=tool_calls,
+ )
+ choices = [Choice(finish_reason=together_finish, index=0, message=message)]
+
+ response_oai = ChatCompletion(
+ id=response.id,
+ model=together_params["model"],
+ created=int(time.time()),
+ object="chat.completion",
+ choices=choices,
+ usage=CompletionUsage(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=total_tokens,
+ ),
+ cost=calculate_together_cost(prompt_tokens, completion_tokens, together_params["model"]),
+ )
+
+ return response_oai
+
+
+def oai_messages_to_together_messages(messages: list[Dict[str, Any]]) -> list[dict[str, Any]]:
+ """Convert messages from OAI format to Together.AI format.
+ We correct for any specific role orders and types.
+ """
+
+ together_messages = copy.deepcopy(messages)
+
+ # If we have a message with role='tool', which occurs when a function is executed, change it to 'user'
+ for msg in together_messages:
+ if "role" in msg and msg["role"] == "tool":
+ msg["role"] = "user"
+
+ return together_messages
+
+
+# MODELS AND COSTS
+chat_lang_code_model_sizes = {
+ "zero-one-ai/Yi-34B-Chat": 34,
+ "allenai/OLMo-7B-Instruct": 7,
+ "allenai/OLMo-7B-Twin-2T": 7,
+ "allenai/OLMo-7B": 7,
+ "Austism/chronos-hermes-13b": 13,
+ "deepseek-ai/deepseek-coder-33b-instruct": 33,
+ "deepseek-ai/deepseek-llm-67b-chat": 67,
+ "garage-bAInd/Platypus2-70B-instruct": 70,
+ "google/gemma-2b-it": 2,
+ "google/gemma-7b-it": 7,
+ "Gryphe/MythoMax-L2-13b": 13,
+ "lmsys/vicuna-13b-v1.5": 13,
+ "lmsys/vicuna-7b-v1.5": 7,
+ "codellama/CodeLlama-13b-Instruct-hf": 13,
+ "codellama/CodeLlama-34b-Instruct-hf": 34,
+ "codellama/CodeLlama-70b-Instruct-hf": 70,
+ "codellama/CodeLlama-7b-Instruct-hf": 7,
+ "meta-llama/Llama-2-70b-chat-hf": 70,
+ "meta-llama/Llama-2-13b-chat-hf": 13,
+ "meta-llama/Llama-2-7b-chat-hf": 7,
+ "meta-llama/Llama-3-8b-chat-hf": 8,
+ "meta-llama/Llama-3-70b-chat-hf": 70,
+ "mistralai/Mistral-7B-Instruct-v0.1": 7,
+ "mistralai/Mistral-7B-Instruct-v0.2": 7,
+ "mistralai/Mistral-7B-Instruct-v0.3": 7,
+ "NousResearch/Nous-Capybara-7B-V1p9": 7,
+ "NousResearch/Nous-Hermes-llama-2-7b": 7,
+ "NousResearch/Nous-Hermes-Llama2-13b": 13,
+ "NousResearch/Nous-Hermes-2-Yi-34B": 34,
+ "openchat/openchat-3.5-1210": 7,
+ "Open-Orca/Mistral-7B-OpenOrca": 7,
+ "Qwen/Qwen1.5-0.5B-Chat": 0.5,
+ "Qwen/Qwen1.5-1.8B-Chat": 1.8,
+ "Qwen/Qwen1.5-4B-Chat": 4,
+ "Qwen/Qwen1.5-7B-Chat": 7,
+ "Qwen/Qwen1.5-14B-Chat": 14,
+ "Qwen/Qwen1.5-32B-Chat": 32,
+ "Qwen/Qwen1.5-72B-Chat": 72,
+ "Qwen/Qwen1.5-110B-Chat": 110,
+ "Qwen/Qwen2-72B-Instruct": 72,
+ "snorkelai/Snorkel-Mistral-PairRM-DPO": 7,
+ "togethercomputer/alpaca-7b": 7,
+ "teknium/OpenHermes-2-Mistral-7B": 7,
+ "teknium/OpenHermes-2p5-Mistral-7B": 7,
+ "togethercomputer/Llama-2-7B-32K-Instruct": 7,
+ "togethercomputer/RedPajama-INCITE-Chat-3B-v1": 3,
+ "togethercomputer/RedPajama-INCITE-7B-Chat": 7,
+ "togethercomputer/StripedHyena-Nous-7B": 7,
+ "Undi95/ReMM-SLERP-L2-13B": 13,
+ "Undi95/Toppy-M-7B": 7,
+ "WizardLM/WizardLM-13B-V1.2": 13,
+ "upstage/SOLAR-10.7B-Instruct-v1.0": 11,
+}
+
+# Cost per million tokens based on up to X Billion parameters, e.g. up 4B is $0.1/million
+chat_lang_code_model_costs = {4: 0.1, 8: 0.2, 21: 0.3, 41: 0.8, 80: 0.9, 110: 1.8}
+
+mixture_model_sizes = {
+ "cognitivecomputations/dolphin-2.5-mixtral-8x7b": 56,
+ "databricks/dbrx-instruct": 132,
+ "mistralai/Mixtral-8x7B-Instruct-v0.1": 47,
+ "mistralai/Mixtral-8x22B-Instruct-v0.1": 141,
+ "NousResearch/Nous-Hermes-2-Mistral-7B-DPO": 7,
+ "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO": 47,
+ "NousResearch/Nous-Hermes-2-Mixtral-8x7B-SFT": 47,
+ "Snowflake/snowflake-arctic-instruct": 480,
+}
+
+# Cost per million tokens based on up to X Billion parameters, e.g. up 56B is $0.6/million
+mixture_costs = {56: 0.6, 176: 1.2, 480: 2.4}
+
+
+def calculate_together_cost(input_tokens: int, output_tokens: int, model_name: str) -> float:
+ """Cost calculation for inference"""
+
+ if model_name in chat_lang_code_model_sizes or model_name in mixture_model_sizes:
+ cost_per_mil = 0
+
+ # Chat, Language, Code models
+ if model_name in chat_lang_code_model_sizes:
+ size_in_b = chat_lang_code_model_sizes[model_name]
+
+ for top_size in chat_lang_code_model_costs.keys():
+ if size_in_b <= top_size:
+ cost_per_mil = chat_lang_code_model_costs[top_size]
+ break
+
+ else:
+ # Mixture-of-experts
+ size_in_b = mixture_model_sizes[model_name]
+
+ for top_size in mixture_costs.keys():
+ if size_in_b <= top_size:
+ cost_per_mil = mixture_costs[top_size]
+ break
+
+ if cost_per_mil == 0:
+ warnings.warn("Model size doesn't align with cost structure.", UserWarning)
+
+ return cost_per_mil * ((input_tokens + output_tokens) / 1e6)
+
+ else:
+ # Model is not in our list of models, can't determine the cost
+ warnings.warn(
+ "The model isn't catered for costing, to apply costs you can use the 'price' key on your config_list.",
+ UserWarning,
+ )
+
+ return 0
diff --git a/autogen/retrieve_utils.py b/autogen/retrieve_utils.py
index e83f8a80f36..9393903ec86 100644
--- a/autogen/retrieve_utils.py
+++ b/autogen/retrieve_utils.py
@@ -1,4 +1,5 @@
import glob
+import hashlib
import os
import re
from typing import Callable, List, Tuple, Union
@@ -156,7 +157,7 @@ def split_files_to_chunks(
chunk_mode: str = "multi_lines",
must_break_at_empty_line: bool = True,
custom_text_split_function: Callable = None,
-):
+) -> Tuple[List[str], List[dict]]:
"""Split a list of files into chunks of max_tokens."""
chunks = []
@@ -275,15 +276,22 @@ def parse_html_to_markdown(html: str, url: str = None) -> str:
return webpage_text
+def _generate_file_name_from_url(url: str, max_length=255) -> str:
+ url_bytes = url.encode("utf-8")
+ hash = hashlib.blake2b(url_bytes).hexdigest()
+ parsed_url = urlparse(url)
+ file_name = os.path.basename(url)
+ file_name = f"{parsed_url.netloc}_{file_name}_{hash[:min(8, max_length-len(parsed_url.netloc)-len(file_name)-1)]}"
+ return file_name
+
+
def get_file_from_url(url: str, save_path: str = None) -> Tuple[str, str]:
"""Download a file from a URL."""
if save_path is None:
save_path = "tmp/chromadb"
os.makedirs(save_path, exist_ok=True)
if os.path.isdir(save_path):
- filename = os.path.basename(url)
- if filename == "": # "www.example.com/"
- filename = url.split("/")[-2]
+ filename = _generate_file_name_from_url(url)
save_path = os.path.join(save_path, filename)
else:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
@@ -327,7 +335,7 @@ def create_vector_db_from_dir(
dir_path: Union[str, List[str]],
max_tokens: int = 4000,
client: API = None,
- db_path: str = "/tmp/chromadb.db",
+ db_path: str = "tmp/chromadb.db",
collection_name: str = "all-my-documents",
get_or_create: bool = False,
chunk_mode: str = "multi_lines",
@@ -347,7 +355,7 @@ def create_vector_db_from_dir(
dir_path (Union[str, List[str]]): the path to the directory, file, url or a list of them.
max_tokens (Optional, int): the maximum number of tokens per chunk. Default is 4000.
client (Optional, API): the chromadb client. Default is None.
- db_path (Optional, str): the path to the chromadb. Default is "/tmp/chromadb.db".
+ db_path (Optional, str): the path to the chromadb. Default is "tmp/chromadb.db". The default was `/tmp/chromadb.db` for version <=0.2.24.
collection_name (Optional, str): the name of the collection. Default is "all-my-documents".
get_or_create (Optional, bool): Whether to get or create the collection. Default is False. If True, the collection
will be returned if it already exists. Will raise ValueError if the collection already exists and get_or_create is False.
@@ -420,7 +428,7 @@ def query_vector_db(
query_texts: List[str],
n_results: int = 10,
client: API = None,
- db_path: str = "/tmp/chromadb.db",
+ db_path: str = "tmp/chromadb.db",
collection_name: str = "all-my-documents",
search_string: str = "",
embedding_model: str = "all-MiniLM-L6-v2",
@@ -433,7 +441,7 @@ def query_vector_db(
query_texts (List[str]): the list of strings which will be used to query the vector db.
n_results (Optional, int): the number of results to return. Default is 10.
client (Optional, API): the chromadb compatible client. Default is None, a chromadb client will be used.
- db_path (Optional, str): the path to the vector db. Default is "/tmp/chromadb.db".
+ db_path (Optional, str): the path to the vector db. Default is "tmp/chromadb.db". The default was `/tmp/chromadb.db` for version <=0.2.24.
collection_name (Optional, str): the name of the collection. Default is "all-my-documents".
search_string (Optional, str): the search string. Only docs that contain an exact match of this string will be retrieved. Default is "".
embedding_model (Optional, str): the embedding model to use. Default is "all-MiniLM-L6-v2". Will be ignored if
diff --git a/autogen/runtime_logging.py b/autogen/runtime_logging.py
index 8c704b4383f..0fd7cc2fc8b 100644
--- a/autogen/runtime_logging.py
+++ b/autogen/runtime_logging.py
@@ -3,28 +3,53 @@
import logging
import sqlite3
import uuid
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, TypeVar, Union
from openai import AzureOpenAI, OpenAI
from openai.types.chat import ChatCompletion
-from autogen.logger.base_logger import LLMConfig
+from autogen.logger.base_logger import BaseLogger, LLMConfig
from autogen.logger.logger_factory import LoggerFactory
if TYPE_CHECKING:
- from autogen import ConversableAgent, OpenAIWrapper
+ from autogen import Agent, ConversableAgent, OpenAIWrapper
+ from autogen.oai.anthropic import AnthropicClient
+ from autogen.oai.bedrock import BedrockClient
+ from autogen.oai.cohere import CohereClient
+ from autogen.oai.gemini import GeminiClient
+ from autogen.oai.groq import GroqClient
+ from autogen.oai.mistral import MistralAIClient
+ from autogen.oai.together import TogetherClient
logger = logging.getLogger(__name__)
autogen_logger = None
is_logging = False
-
-def start(logger_type: str = "sqlite", config: Optional[Dict[str, Any]] = None) -> str:
+F = TypeVar("F", bound=Callable[..., Any])
+
+
+def start(
+ logger: Optional[BaseLogger] = None,
+ logger_type: Literal["sqlite", "file"] = "sqlite",
+ config: Optional[Dict[str, Any]] = None,
+) -> str:
+ """
+ Start logging for the runtime.
+ Args:
+ logger (BaseLogger): A logger instance
+ logger_type (str): The type of logger to use (default: sqlite)
+ config (dict): Configuration for the logger
+ Returns:
+ session_id (str(uuid.uuid4)): a unique id for the logging session
+ """
global autogen_logger
global is_logging
- autogen_logger = LoggerFactory.get_logger(logger_type=logger_type, config=config)
+ if logger:
+ autogen_logger = logger
+ else:
+ autogen_logger = LoggerFactory.get_logger(logger_type=logger_type, config=config)
try:
session_id = autogen_logger.start()
@@ -39,6 +64,7 @@ def log_chat_completion(
invocation_id: uuid.UUID,
client_id: int,
wrapper_id: int,
+ agent: Union[str, Agent],
request: Dict[str, Union[float, str, List[Dict[str, str]]]],
response: Union[str, ChatCompletion],
is_cached: int,
@@ -50,7 +76,7 @@ def log_chat_completion(
return
autogen_logger.log_chat_completion(
- invocation_id, client_id, wrapper_id, request, response, is_cached, cost, start_time
+ invocation_id, client_id, wrapper_id, agent, request, response, is_cached, cost, start_time
)
@@ -62,6 +88,22 @@ def log_new_agent(agent: ConversableAgent, init_args: Dict[str, Any]) -> None:
autogen_logger.log_new_agent(agent, init_args)
+def log_event(source: Union[str, Agent], name: str, **kwargs: Dict[str, Any]) -> None:
+ if autogen_logger is None:
+ logger.error("[runtime logging] log_event: autogen logger is None")
+ return
+
+ autogen_logger.log_event(source, name, **kwargs)
+
+
+def log_function_use(agent: Union[str, Agent], function: F, args: Dict[str, Any], returns: any):
+ if autogen_logger is None:
+ logger.error("[runtime logging] log_function_use: autogen logger is None")
+ return
+
+ autogen_logger.log_function_use(agent, function, args, returns)
+
+
def log_new_wrapper(wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig, List[LLMConfig]]]) -> None:
if autogen_logger is None:
logger.error("[runtime logging] log_new_wrapper: autogen logger is None")
@@ -70,7 +112,21 @@ def log_new_wrapper(wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig
autogen_logger.log_new_wrapper(wrapper, init_args)
-def log_new_client(client: Union[AzureOpenAI, OpenAI], wrapper: OpenAIWrapper, init_args: Dict[str, Any]) -> None:
+def log_new_client(
+ client: Union[
+ AzureOpenAI,
+ OpenAI,
+ GeminiClient,
+ AnthropicClient,
+ MistralAIClient,
+ TogetherClient,
+ GroqClient,
+ CohereClient,
+ BedrockClient,
+ ],
+ wrapper: OpenAIWrapper,
+ init_args: Dict[str, Any],
+) -> None:
if autogen_logger is None:
logger.error("[runtime logging] log_new_client: autogen logger is None")
return
diff --git a/autogen/token_count_utils.py b/autogen/token_count_utils.py
index 9bda6c50fb2..8552a8f1653 100644
--- a/autogen/token_count_utils.py
+++ b/autogen/token_count_utils.py
@@ -14,7 +14,8 @@ def get_max_token_limit(model: str = "gpt-3.5-turbo-0613") -> int:
model = re.sub(r"^gpt4", "gpt-4", model)
max_token_limit = {
- "gpt-3.5-turbo": 4096,
+ "gpt-3.5-turbo": 16385,
+ "gpt-3.5-turbo-0125": 16385,
"gpt-3.5-turbo-0301": 4096,
"gpt-3.5-turbo-0613": 4096,
"gpt-3.5-turbo-instruct": 4096,
@@ -22,6 +23,8 @@ def get_max_token_limit(model: str = "gpt-3.5-turbo-0613") -> int:
"gpt-3.5-turbo-16k-0613": 16385,
"gpt-3.5-turbo-1106": 16385,
"gpt-4": 8192,
+ "gpt-4-turbo": 128000,
+ "gpt-4-turbo-2024-04-09": 128000,
"gpt-4-32k": 32768,
"gpt-4-32k-0314": 32768, # deprecate in Sep
"gpt-4-0314": 8192, # deprecate in Sep
@@ -31,6 +34,11 @@ def get_max_token_limit(model: str = "gpt-3.5-turbo-0613") -> int:
"gpt-4-0125-preview": 128000,
"gpt-4-turbo-preview": 128000,
"gpt-4-vision-preview": 128000,
+ "gpt-4o": 128000,
+ "gpt-4o-2024-05-13": 128000,
+ "gpt-4o-2024-08-06": 128000,
+ "gpt-4o-mini": 128000,
+ "gpt-4o-mini-2024-07-18": 128000,
}
return max_token_limit[model]
@@ -66,7 +74,7 @@ def count_token(input: Union[str, List, Dict], model: str = "gpt-3.5-turbo-0613"
elif isinstance(input, list) or isinstance(input, dict):
return _num_token_from_messages(input, model=model)
else:
- raise ValueError("input must be str, list or dict")
+ raise ValueError(f"input must be str, list or dict, but we got {type(input)}")
def _num_token_from_text(text: str, model: str = "gpt-3.5-turbo-0613"):
@@ -90,7 +98,7 @@ def _num_token_from_messages(messages: Union[List, Dict], model="gpt-3.5-turbo-0
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
- print("Warning: model not found. Using cl100k_base encoding.")
+ logger.warning(f"Model {model} not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
if model in {
"gpt-3.5-turbo-0613",
@@ -111,6 +119,15 @@ def _num_token_from_messages(messages: Union[List, Dict], model="gpt-3.5-turbo-0
elif "gpt-4" in model:
logger.info("gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
return _num_token_from_messages(messages, model="gpt-4-0613")
+ elif "gemini" in model:
+ logger.info("Gemini is not supported in tiktoken. Returning num tokens assuming gpt-4-0613.")
+ return _num_token_from_messages(messages, model="gpt-4-0613")
+ elif "claude" in model:
+ logger.info("Claude is not supported in tiktoken. Returning num tokens assuming gpt-4-0613.")
+ return _num_token_from_messages(messages, model="gpt-4-0613")
+ elif "mistral-" in model or "mixtral-" in model:
+ logger.info("Mistral.AI models are not supported in tiktoken. Returning num tokens assuming gpt-4-0613.")
+ return _num_token_from_messages(messages, model="gpt-4-0613")
else:
raise NotImplementedError(
f"""_num_token_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
@@ -152,7 +169,7 @@ def num_tokens_from_functions(functions, model="gpt-3.5-turbo-0613") -> int:
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
- print("Warning: model not found. Using cl100k_base encoding.")
+ logger.warning(f"Model {model} not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
num_tokens = 0
@@ -179,7 +196,7 @@ def num_tokens_from_functions(functions, model="gpt-3.5-turbo-0613") -> int:
function_tokens += 3
function_tokens += len(encoding.encode(o))
else:
- print(f"Warning: not supported field {field}")
+ logger.warning(f"Not supported field {field}")
function_tokens += 11
if len(parameters["properties"]) == 0:
function_tokens -= 2
diff --git a/autogen/types.py b/autogen/types.py
index 77ca70b70b9..461765a6adc 100644
--- a/autogen/types.py
+++ b/autogen/types.py
@@ -1,5 +1,7 @@
from typing import Dict, List, Literal, TypedDict, Union
+MessageContentType = Union[str, List[Union[Dict, str]], None]
+
class UserMessageTextContentPart(TypedDict):
type: Literal["text"]
diff --git a/autogen/version.py b/autogen/version.py
index b243d3db22b..9b1b78b4b3a 100644
--- a/autogen/version.py
+++ b/autogen/version.py
@@ -1 +1 @@
-__version__ = "0.2.23"
+__version__ = "0.2.35"
diff --git a/dotnet/.config/dotnet-tools.json b/dotnet/.config/dotnet-tools.json
index 5b341cff736..6b2517ea2c6 100644
--- a/dotnet/.config/dotnet-tools.json
+++ b/dotnet/.config/dotnet-tools.json
@@ -1,12 +1,18 @@
{
- "version": 1,
- "isRoot": true,
- "tools": {
- "dotnet-repl": {
- "version": "0.1.205",
- "commands": [
- "dotnet-repl"
- ]
- }
+ "version": 1,
+ "isRoot": true,
+ "tools": {
+ "dotnet-repl": {
+ "version": "0.1.205",
+ "commands": [
+ "dotnet-repl"
+ ]
+ },
+ "docfx": {
+ "version": "2.67.5",
+ "commands": [
+ "docfx"
+ ]
}
- }
\ No newline at end of file
+ }
+}
\ No newline at end of file
diff --git a/dotnet/.editorconfig b/dotnet/.editorconfig
new file mode 100644
index 00000000000..5a604ce0096
--- /dev/null
+++ b/dotnet/.editorconfig
@@ -0,0 +1,183 @@
+# EditorConfig is awesome:http://EditorConfig.org
+
+# top-most EditorConfig file
+root = true
+
+# Don't use tabs for indentation.
+[*]
+indent_style = space
+# (Please don't specify an indent_size here; that has too many unintended consequences.)
+
+# Code files
+[*.{cs,csx,vb,vbx}]
+indent_size = 4
+insert_final_newline = true
+charset = utf-8-bom
+
+[*.xaml]
+indent_size = 4
+
+[*.ps1]
+indent_size = 2
+
+# Xml project files
+[*.{csproj,vbproj,vcxproj,vcxproj.filters,proj,projitems,shproj}]
+indent_size = 2
+
+# Xml config files
+[*.{props,targets,ruleset,config,nuspec,resx,vsixmanifest,vsct}]
+indent_size = 2
+
+# JSON files
+[*.json]
+indent_size = 2
+
+[*.groovy]
+indent_size = 2
+
+# Dotnet code style settings:
+[*.{cs,vb}]
+# Sort using and Import directives with System.* appearing first
+dotnet_sort_system_directives_first = true
+dotnet_style_require_accessibility_modifiers = always:warning
+
+# No blank line between System.* and Microsoft.*
+dotnet_separate_import_directive_groups = false
+
+# Suggest more modern language features when available
+dotnet_style_object_initializer = true:suggestion
+dotnet_style_collection_initializer = true:suggestion
+dotnet_style_coalesce_expression = true:error
+dotnet_style_null_propagation = true:error
+dotnet_style_explicit_tuple_names = true:suggestion
+dotnet_style_prefer_inferred_tuple_names = true:suggestion
+dotnet_style_prefer_inferred_anonymous_type_member_names = true:suggestion
+dotnet_style_prefer_is_null_check_over_reference_equality_method = true:suggestion
+dotnet_style_prefer_conditional_expression_over_return = false
+dotnet_style_prefer_conditional_expression_over_assignment = false
+dotnet_style_prefer_auto_properties = false
+
+# Use language keywords instead of framework type names for type references
+dotnet_style_predefined_type_for_locals_parameters_members = true:error
+dotnet_style_predefined_type_for_member_access = true:error
+
+# Prefer read-only on fields
+dotnet_style_readonly_field = false
+
+# CSharp code style settings:
+[*.cs]
+
+# Prefer "var" only when the type is apparent
+csharp_style_var_for_built_in_types = false:suggestion
+csharp_style_var_when_type_is_apparent = true:suggestion
+csharp_style_var_elsewhere = false:suggestion
+
+# Prefer method-like constructs to have a block body
+csharp_style_expression_bodied_methods = false:none
+csharp_style_expression_bodied_constructors = false:none
+csharp_style_expression_bodied_operators = false:none
+
+# Prefer property-like constructs to have an expression-body
+csharp_style_expression_bodied_properties = true:none
+csharp_style_expression_bodied_indexers = true:none
+csharp_style_expression_bodied_accessors = true:none
+
+# Use block body for local functions
+csharp_style_expression_bodied_local_functions = when_on_single_line:silent
+
+# Suggest more modern language features when available
+csharp_style_pattern_matching_over_is_with_cast_check = true:error
+csharp_style_pattern_matching_over_as_with_null_check = true:error
+csharp_style_inlined_variable_declaration = true:error
+csharp_style_throw_expression = true:suggestion
+csharp_style_conditional_delegate_call = true:suggestion
+csharp_style_deconstructed_variable_declaration = true:suggestion
+
+# Newline settings
+csharp_new_line_before_open_brace = all
+csharp_new_line_before_else = true
+csharp_new_line_before_catch = true
+csharp_new_line_before_finally = true
+csharp_new_line_before_members_in_object_initializers = true
+csharp_new_line_before_members_in_anonymous_types = true
+csharp_new_line_between_query_expression_clauses = true
+
+# Identation options
+csharp_indent_case_contents = true
+csharp_indent_case_contents_when_block = true
+csharp_indent_switch_labels = true
+csharp_indent_labels = no_change
+csharp_indent_block_contents = true
+csharp_indent_braces = false
+
+# Spacing options
+csharp_space_after_cast = false
+csharp_space_after_keywords_in_control_flow_statements = true
+csharp_space_between_method_call_empty_parameter_list_parentheses = false
+csharp_space_between_method_call_parameter_list_parentheses = false
+csharp_space_between_method_call_name_and_opening_parenthesis = false
+csharp_space_between_method_declaration_parameter_list_parentheses = false
+csharp_space_between_method_declaration_empty_parameter_list_parentheses = false
+csharp_space_between_method_declaration_parameter_list_parentheses = false
+csharp_space_between_method_declaration_name_and_open_parenthesis = false
+csharp_space_between_parentheses = false
+csharp_space_between_square_brackets = false
+csharp_space_between_empty_square_brackets = false
+csharp_space_before_open_square_brackets = false
+csharp_space_around_declaration_statements = false
+csharp_space_around_binary_operators = before_and_after
+csharp_space_after_cast = false
+csharp_space_before_semicolon_in_for_statement = false
+csharp_space_before_dot = false
+csharp_space_after_dot = false
+csharp_space_before_comma = false
+csharp_space_after_comma = true
+csharp_space_before_colon_in_inheritance_clause = true
+csharp_space_after_colon_in_inheritance_clause = true
+csharp_space_after_semicolon_in_for_statement = true
+
+# Wrapping
+csharp_preserve_single_line_statements = true
+csharp_preserve_single_line_blocks = true
+
+# Code block
+csharp_prefer_braces = true:warning
+
+# Using statements
+csharp_using_directive_placement = outside_namespace:error
+
+# Modifier settings
+csharp_prefer_static_local_function = true:warning
+csharp_preferred_modifier_order = public,private,protected,internal,static,extern,new,virtual,abstract,sealed,override,readonly,unsafe,volatile,async:warning
+
+# Header template
+file_header_template = Copyright (c) Microsoft Corporation. All rights reserved.\n{fileName}
+dotnet_diagnostic.IDE0073.severity = error
+
+# enable format error
+dotnet_diagnostic.IDE0055.severity = error
+
+# IDE0035: Remove unreachable code
+dotnet_diagnostic.IDE0035.severity = error
+
+# IDE0005: Remove unncecessary usings
+dotnet_diagnostic.CS8019.severity = error
+dotnet_diagnostic.IDE0005.severity = error
+
+# IDE0069: Remove unused local variable
+dotnet_diagnostic.IDE0069.severity = error
+
+# disable CS1573: Parameter has no matching param tag in the XML comment for
+dotnet_diagnostic.CS1573.severity = none
+
+# disable CS1570: XML comment has badly formed XML
+dotnet_diagnostic.CS1570.severity = none
+
+dotnet_diagnostic.IDE0035.severity = warning # Remove unreachable code
+dotnet_diagnostic.IDE0161.severity = warning # Use file-scoped namespace
+
+csharp_style_var_elsewhere = true:suggestion # Prefer 'var' everywhere
+
+# disable check for generated code
+[*.generated.cs]
+generated_code = true
\ No newline at end of file
diff --git a/dotnet/.gitignore b/dotnet/.gitignore
new file mode 100644
index 00000000000..65e7ba678dd
--- /dev/null
+++ b/dotnet/.gitignore
@@ -0,0 +1,30 @@
+# gitignore file for C#/VS
+
+# Build results
+[Dd]ebug/
+[Dd]ebugPublic/
+[Rr]elease/
+[Rr]eleases/
+x64/
+x86/
+build/
+bld/
+[Bb]in/
+[Oo]bj/
+
+# vs cache
+.vs/
+
+# vs code cache
+.vscode/
+
+# Properties
+Properties/
+
+artifacts/
+output/
+
+*.binlog
+
+# JetBrains Rider
+.idea/
\ No newline at end of file
diff --git a/dotnet/.tools/test-aot-compatibility.ps1 b/dotnet/.tools/test-aot-compatibility.ps1
new file mode 100644
index 00000000000..071edcd956d
--- /dev/null
+++ b/dotnet/.tools/test-aot-compatibility.ps1
@@ -0,0 +1,41 @@
+param([string]$targetNetFramework)
+
+$rootDirectory = Split-Path $PSScriptRoot -Parent
+$publishOutput = dotnet publish $rootDirectory/test/AutoGen.AotCompatibility.Tests -nodeReuse:false /p:UseSharedCompilation=false /p:ExposeExperimentalFeatures=true
+
+$actualWarningCount = 0
+
+foreach ($line in $($publishOutput -split "`r`n"))
+{
+ if ($line -like "*analysis warning IL*")
+ {
+ Write-Host $line
+
+ $actualWarningCount += 1
+ }
+}
+
+pushd $rootDirectory/test/AutoGen.AotCompatibility.Tests/bin/Release/$targetNetFramework/linux-x64
+
+Write-Host "Executing test App..."
+./AutoGen.AotCompatibility.Tests
+Write-Host "Finished executing test App"
+
+if ($LastExitCode -ne 0)
+{
+ Write-Host "There was an error while executing AotCompatibility Test App. LastExitCode is:", $LastExitCode
+}
+
+popd
+
+Write-Host "Actual warning count is:", $actualWarningCount
+$expectedWarningCount = 0
+
+$testPassed = 0
+if ($actualWarningCount -ne $expectedWarningCount)
+{
+ $testPassed = 1
+ Write-Host "Actual warning count:", actualWarningCount, "is not as expected. Expected warning count is:", $expectedWarningCount
+}
+
+Exit $testPassed
\ No newline at end of file
diff --git a/dotnet/AutoGen.sln b/dotnet/AutoGen.sln
new file mode 100644
index 00000000000..78d18527b62
--- /dev/null
+++ b/dotnet/AutoGen.sln
@@ -0,0 +1,271 @@
+Microsoft Visual Studio Solution File, Format Version 12.00
+# Visual Studio Version 17
+VisualStudioVersion = 17.8.34322.80
+MinimumVisualStudioVersion = 10.0.40219.1
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen", "src\AutoGen\AutoGen.csproj", "{B2B27ACB-AA50-4FED-A06C-3AD6B4218188}"
+EndProject
+Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{18BF8DD7-0585-48BF-8F97-AD333080CE06}"
+EndProject
+Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "test", "test", "{F823671B-3ECA-4AE6-86DA-25E920D3FE64}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Tests", "test\AutoGen.Tests\AutoGen.Tests.csproj", "{FDD99AEC-4C57-4020-B23F-650612856102}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.SourceGenerator", "src\AutoGen.SourceGenerator\AutoGen.SourceGenerator.csproj", "{3FFD14E3-D6BC-4EA7-97A2-D21733060FD6}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.SourceGenerator.Tests", "test\AutoGen.SourceGenerator.Tests\AutoGen.SourceGenerator.Tests.csproj", "{05A2FAD8-03B0-4B2F-82AF-2F6BF0F050E5}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.BasicSample", "sample\AutoGen.BasicSamples\AutoGen.BasicSample.csproj", "{7EBF916A-A7B1-4B74-AF10-D705B7A18F58}"
+EndProject
+Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "sample", "sample", "{FBFEAD1F-29EB-4D99-A672-0CD8473E10B9}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.DotnetInteractive", "src\AutoGen.DotnetInteractive\AutoGen.DotnetInteractive.csproj", "{B61D8008-7FB7-4C0E-8044-3A74AA63A596}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.LMStudio", "src\AutoGen.LMStudio\AutoGen.LMStudio.csproj", "{F98BDA9B-8657-4BA8-9B03-BAEA454CAE60}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.SemanticKernel", "src\AutoGen.SemanticKernel\AutoGen.SemanticKernel.csproj", "{45D6FC80-36F3-4967-9663-E20B63824621}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Core", "src\AutoGen.Core\AutoGen.Core.csproj", "{D58D43D1-0617-4A3D-9932-C773E6398535}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.OpenAI.V1", "src\AutoGen.OpenAI.V1\AutoGen.OpenAI.V1.csproj", "{63445BB7-DBB9-4AEF-9D6F-98BBE75EE1EC}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Mistral", "src\AutoGen.Mistral\AutoGen.Mistral.csproj", "{6585D1A4-3D97-4D76-A688-1933B61AEB19}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Mistral.Tests", "test\AutoGen.Mistral.Tests\AutoGen.Mistral.Tests.csproj", "{15441693-3659-4868-B6C1-B106F52FF3BA}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.WebAPI", "src\AutoGen.WebAPI\AutoGen.WebAPI.csproj", "{257FFD71-08E5-40C7-AB04-6A81A78EB410}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.WebAPI.Tests", "test\AutoGen.WebAPI.Tests\AutoGen.WebAPI.Tests.csproj", "{E2EF5E66-683C-4DDC-8ADA-5F676502B9BA}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.SemanticKernel.Tests", "test\AutoGen.SemanticKernel.Tests\AutoGen.SemanticKernel.Tests.csproj", "{1DFABC4A-8458-4875-8DCB-59F3802DAC65}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.OpenAI.V1.Tests", "test\AutoGen.OpenAI.V1.Tests\AutoGen.OpenAI.V1.Tests.csproj", "{D36A85F9-C172-487D-8192-6BFE5D05B4A7}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.DotnetInteractive.Tests", "test\AutoGen.DotnetInteractive.Tests\AutoGen.DotnetInteractive.Tests.csproj", "{B61388CA-DC73-4B7F-A7B2-7B9A86C9229E}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Ollama", "src\AutoGen.Ollama\AutoGen.Ollama.csproj", "{9F9E6DED-3D92-4970-909A-70FC11F1A665}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Ollama.Tests", "test\AutoGen.Ollama.Tests\AutoGen.Ollama.Tests.csproj", "{03E31CAA-3728-48D3-B936-9F11CF6C18FE}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Ollama.Sample", "sample\AutoGen.Ollama.Sample\AutoGen.Ollama.Sample.csproj", "{93AA4D0D-6EE4-44D5-AD77-7F73A3934544}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.SemanticKernel.Sample", "sample\AutoGen.SemanticKernel.Sample\AutoGen.SemanticKernel.Sample.csproj", "{52958A60-3FF7-4243-9058-34A6E4F55C31}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Anthropic", "src\AutoGen.Anthropic\AutoGen.Anthropic.csproj", "{6A95E113-B824-4524-8F13-CD0C3E1C8804}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Anthropic.Tests", "test\AutoGen.Anthropic.Tests\AutoGen.Anthropic.Tests.csproj", "{815E937E-86D6-4476-9EC6-B7FBCBBB5DB6}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Anthropic.Samples", "sample\AutoGen.Anthropic.Samples\AutoGen.Anthropic.Samples.csproj", "{834B4E85-64E5-4382-8465-548F332E5298}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Gemini", "src\AutoGen.Gemini\AutoGen.Gemini.csproj", "{EFE0DC86-80FC-4D52-95B7-07654BA1A769}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Gemini.Tests", "test\AutoGen.Gemini.Tests\AutoGen.Gemini.Tests.csproj", "{8EA16BAB-465A-4C07-ABC4-1070D40067E9}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Gemini.Sample", "sample\AutoGen.Gemini.Sample\AutoGen.Gemini.Sample.csproj", "{19679B75-CE3A-4DF0-A3F0-CA369D2760A4}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.AotCompatibility.Tests", "test\AutoGen.AotCompatibility.Tests\AutoGen.AotCompatibility.Tests.csproj", "{6B82F26D-5040-4453-B21B-C8D1F913CE4C}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.OpenAI.Sample", "sample\AutoGen.OpenAI.Sample\AutoGen.OpenAI.Sample.csproj", "{0E635268-351C-4A6B-A28D-593D868C2CA4}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.WebAPI.Sample", "sample\AutoGen.WebAPI.Sample\AutoGen.WebAPI.Sample.csproj", "{12079C18-A519-403F-BBFD-200A36A0C083}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.AzureAIInference", "src\AutoGen.AzureAIInference\AutoGen.AzureAIInference.csproj", "{5C45981D-1319-4C25-935C-83D411CB28DF}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.AzureAIInference.Tests", "test\AutoGen.AzureAIInference.Tests\AutoGen.AzureAIInference.Tests.csproj", "{5970868F-831E-418F-89A9-4EC599563E16}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Tests.Share", "test\AutoGen.Test.Share\AutoGen.Tests.Share.csproj", "{143725E2-206C-4D37-93E4-9EDF699826B2}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.OpenAI", "src\AutoGen.OpenAI\AutoGen.OpenAI.csproj", "{3AF1CBEC-2877-41E9-92AE-3A391B2AA9E8}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.OpenAI.Tests", "test\AutoGen.OpenAI.Tests\AutoGen.OpenAI.Tests.csproj", "{42A8251C-E7B3-47BB-A82E-459952EBE132}"
+EndProject
+Global
+ GlobalSection(SolutionConfigurationPlatforms) = preSolution
+ Debug|Any CPU = Debug|Any CPU
+ Release|Any CPU = Release|Any CPU
+ EndGlobalSection
+ GlobalSection(ProjectConfigurationPlatforms) = postSolution
+ {B2B27ACB-AA50-4FED-A06C-3AD6B4218188}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {B2B27ACB-AA50-4FED-A06C-3AD6B4218188}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {B2B27ACB-AA50-4FED-A06C-3AD6B4218188}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {B2B27ACB-AA50-4FED-A06C-3AD6B4218188}.Release|Any CPU.Build.0 = Release|Any CPU
+ {FDD99AEC-4C57-4020-B23F-650612856102}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {FDD99AEC-4C57-4020-B23F-650612856102}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {FDD99AEC-4C57-4020-B23F-650612856102}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {FDD99AEC-4C57-4020-B23F-650612856102}.Release|Any CPU.Build.0 = Release|Any CPU
+ {3FFD14E3-D6BC-4EA7-97A2-D21733060FD6}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {3FFD14E3-D6BC-4EA7-97A2-D21733060FD6}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {3FFD14E3-D6BC-4EA7-97A2-D21733060FD6}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {3FFD14E3-D6BC-4EA7-97A2-D21733060FD6}.Release|Any CPU.Build.0 = Release|Any CPU
+ {05A2FAD8-03B0-4B2F-82AF-2F6BF0F050E5}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {05A2FAD8-03B0-4B2F-82AF-2F6BF0F050E5}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {05A2FAD8-03B0-4B2F-82AF-2F6BF0F050E5}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {05A2FAD8-03B0-4B2F-82AF-2F6BF0F050E5}.Release|Any CPU.Build.0 = Release|Any CPU
+ {7EBF916A-A7B1-4B74-AF10-D705B7A18F58}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {7EBF916A-A7B1-4B74-AF10-D705B7A18F58}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {7EBF916A-A7B1-4B74-AF10-D705B7A18F58}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {7EBF916A-A7B1-4B74-AF10-D705B7A18F58}.Release|Any CPU.Build.0 = Release|Any CPU
+ {B61D8008-7FB7-4C0E-8044-3A74AA63A596}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {B61D8008-7FB7-4C0E-8044-3A74AA63A596}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {B61D8008-7FB7-4C0E-8044-3A74AA63A596}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {B61D8008-7FB7-4C0E-8044-3A74AA63A596}.Release|Any CPU.Build.0 = Release|Any CPU
+ {F98BDA9B-8657-4BA8-9B03-BAEA454CAE60}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {F98BDA9B-8657-4BA8-9B03-BAEA454CAE60}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {F98BDA9B-8657-4BA8-9B03-BAEA454CAE60}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {F98BDA9B-8657-4BA8-9B03-BAEA454CAE60}.Release|Any CPU.Build.0 = Release|Any CPU
+ {45D6FC80-36F3-4967-9663-E20B63824621}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {45D6FC80-36F3-4967-9663-E20B63824621}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {45D6FC80-36F3-4967-9663-E20B63824621}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {45D6FC80-36F3-4967-9663-E20B63824621}.Release|Any CPU.Build.0 = Release|Any CPU
+ {D58D43D1-0617-4A3D-9932-C773E6398535}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {D58D43D1-0617-4A3D-9932-C773E6398535}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {D58D43D1-0617-4A3D-9932-C773E6398535}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {D58D43D1-0617-4A3D-9932-C773E6398535}.Release|Any CPU.Build.0 = Release|Any CPU
+ {63445BB7-DBB9-4AEF-9D6F-98BBE75EE1EC}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {63445BB7-DBB9-4AEF-9D6F-98BBE75EE1EC}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {63445BB7-DBB9-4AEF-9D6F-98BBE75EE1EC}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {63445BB7-DBB9-4AEF-9D6F-98BBE75EE1EC}.Release|Any CPU.Build.0 = Release|Any CPU
+ {6585D1A4-3D97-4D76-A688-1933B61AEB19}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {6585D1A4-3D97-4D76-A688-1933B61AEB19}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {6585D1A4-3D97-4D76-A688-1933B61AEB19}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {6585D1A4-3D97-4D76-A688-1933B61AEB19}.Release|Any CPU.Build.0 = Release|Any CPU
+ {15441693-3659-4868-B6C1-B106F52FF3BA}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {15441693-3659-4868-B6C1-B106F52FF3BA}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {15441693-3659-4868-B6C1-B106F52FF3BA}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {15441693-3659-4868-B6C1-B106F52FF3BA}.Release|Any CPU.Build.0 = Release|Any CPU
+ {257FFD71-08E5-40C7-AB04-6A81A78EB410}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {257FFD71-08E5-40C7-AB04-6A81A78EB410}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {257FFD71-08E5-40C7-AB04-6A81A78EB410}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {257FFD71-08E5-40C7-AB04-6A81A78EB410}.Release|Any CPU.Build.0 = Release|Any CPU
+ {E2EF5E66-683C-4DDC-8ADA-5F676502B9BA}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {E2EF5E66-683C-4DDC-8ADA-5F676502B9BA}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {E2EF5E66-683C-4DDC-8ADA-5F676502B9BA}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {E2EF5E66-683C-4DDC-8ADA-5F676502B9BA}.Release|Any CPU.Build.0 = Release|Any CPU
+ {1DFABC4A-8458-4875-8DCB-59F3802DAC65}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {1DFABC4A-8458-4875-8DCB-59F3802DAC65}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {1DFABC4A-8458-4875-8DCB-59F3802DAC65}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {1DFABC4A-8458-4875-8DCB-59F3802DAC65}.Release|Any CPU.Build.0 = Release|Any CPU
+ {D36A85F9-C172-487D-8192-6BFE5D05B4A7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {D36A85F9-C172-487D-8192-6BFE5D05B4A7}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {D36A85F9-C172-487D-8192-6BFE5D05B4A7}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {D36A85F9-C172-487D-8192-6BFE5D05B4A7}.Release|Any CPU.Build.0 = Release|Any CPU
+ {B61388CA-DC73-4B7F-A7B2-7B9A86C9229E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {B61388CA-DC73-4B7F-A7B2-7B9A86C9229E}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {B61388CA-DC73-4B7F-A7B2-7B9A86C9229E}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {B61388CA-DC73-4B7F-A7B2-7B9A86C9229E}.Release|Any CPU.Build.0 = Release|Any CPU
+ {9F9E6DED-3D92-4970-909A-70FC11F1A665}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {9F9E6DED-3D92-4970-909A-70FC11F1A665}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {9F9E6DED-3D92-4970-909A-70FC11F1A665}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {9F9E6DED-3D92-4970-909A-70FC11F1A665}.Release|Any CPU.Build.0 = Release|Any CPU
+ {03E31CAA-3728-48D3-B936-9F11CF6C18FE}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {03E31CAA-3728-48D3-B936-9F11CF6C18FE}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {03E31CAA-3728-48D3-B936-9F11CF6C18FE}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {03E31CAA-3728-48D3-B936-9F11CF6C18FE}.Release|Any CPU.Build.0 = Release|Any CPU
+ {93AA4D0D-6EE4-44D5-AD77-7F73A3934544}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {93AA4D0D-6EE4-44D5-AD77-7F73A3934544}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {93AA4D0D-6EE4-44D5-AD77-7F73A3934544}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {93AA4D0D-6EE4-44D5-AD77-7F73A3934544}.Release|Any CPU.Build.0 = Release|Any CPU
+ {52958A60-3FF7-4243-9058-34A6E4F55C31}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {52958A60-3FF7-4243-9058-34A6E4F55C31}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {52958A60-3FF7-4243-9058-34A6E4F55C31}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {52958A60-3FF7-4243-9058-34A6E4F55C31}.Release|Any CPU.Build.0 = Release|Any CPU
+ {6A95E113-B824-4524-8F13-CD0C3E1C8804}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {6A95E113-B824-4524-8F13-CD0C3E1C8804}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {6A95E113-B824-4524-8F13-CD0C3E1C8804}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {6A95E113-B824-4524-8F13-CD0C3E1C8804}.Release|Any CPU.Build.0 = Release|Any CPU
+ {815E937E-86D6-4476-9EC6-B7FBCBBB5DB6}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {815E937E-86D6-4476-9EC6-B7FBCBBB5DB6}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {815E937E-86D6-4476-9EC6-B7FBCBBB5DB6}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {815E937E-86D6-4476-9EC6-B7FBCBBB5DB6}.Release|Any CPU.Build.0 = Release|Any CPU
+ {834B4E85-64E5-4382-8465-548F332E5298}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {834B4E85-64E5-4382-8465-548F332E5298}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {834B4E85-64E5-4382-8465-548F332E5298}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {834B4E85-64E5-4382-8465-548F332E5298}.Release|Any CPU.Build.0 = Release|Any CPU
+ {EFE0DC86-80FC-4D52-95B7-07654BA1A769}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {EFE0DC86-80FC-4D52-95B7-07654BA1A769}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {EFE0DC86-80FC-4D52-95B7-07654BA1A769}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {EFE0DC86-80FC-4D52-95B7-07654BA1A769}.Release|Any CPU.Build.0 = Release|Any CPU
+ {8EA16BAB-465A-4C07-ABC4-1070D40067E9}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {8EA16BAB-465A-4C07-ABC4-1070D40067E9}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {8EA16BAB-465A-4C07-ABC4-1070D40067E9}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {8EA16BAB-465A-4C07-ABC4-1070D40067E9}.Release|Any CPU.Build.0 = Release|Any CPU
+ {19679B75-CE3A-4DF0-A3F0-CA369D2760A4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {19679B75-CE3A-4DF0-A3F0-CA369D2760A4}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {19679B75-CE3A-4DF0-A3F0-CA369D2760A4}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {19679B75-CE3A-4DF0-A3F0-CA369D2760A4}.Release|Any CPU.Build.0 = Release|Any CPU
+ {6B82F26D-5040-4453-B21B-C8D1F913CE4C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {6B82F26D-5040-4453-B21B-C8D1F913CE4C}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {6B82F26D-5040-4453-B21B-C8D1F913CE4C}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {6B82F26D-5040-4453-B21B-C8D1F913CE4C}.Release|Any CPU.Build.0 = Release|Any CPU
+ {0E635268-351C-4A6B-A28D-593D868C2CA4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {0E635268-351C-4A6B-A28D-593D868C2CA4}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {0E635268-351C-4A6B-A28D-593D868C2CA4}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {0E635268-351C-4A6B-A28D-593D868C2CA4}.Release|Any CPU.Build.0 = Release|Any CPU
+ {12079C18-A519-403F-BBFD-200A36A0C083}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {12079C18-A519-403F-BBFD-200A36A0C083}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {12079C18-A519-403F-BBFD-200A36A0C083}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {12079C18-A519-403F-BBFD-200A36A0C083}.Release|Any CPU.Build.0 = Release|Any CPU
+ {5C45981D-1319-4C25-935C-83D411CB28DF}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {5C45981D-1319-4C25-935C-83D411CB28DF}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {5C45981D-1319-4C25-935C-83D411CB28DF}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {5C45981D-1319-4C25-935C-83D411CB28DF}.Release|Any CPU.Build.0 = Release|Any CPU
+ {5970868F-831E-418F-89A9-4EC599563E16}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {5970868F-831E-418F-89A9-4EC599563E16}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {5970868F-831E-418F-89A9-4EC599563E16}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {5970868F-831E-418F-89A9-4EC599563E16}.Release|Any CPU.Build.0 = Release|Any CPU
+ {143725E2-206C-4D37-93E4-9EDF699826B2}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {143725E2-206C-4D37-93E4-9EDF699826B2}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {143725E2-206C-4D37-93E4-9EDF699826B2}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {143725E2-206C-4D37-93E4-9EDF699826B2}.Release|Any CPU.Build.0 = Release|Any CPU
+ {3AF1CBEC-2877-41E9-92AE-3A391B2AA9E8}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {3AF1CBEC-2877-41E9-92AE-3A391B2AA9E8}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {3AF1CBEC-2877-41E9-92AE-3A391B2AA9E8}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {3AF1CBEC-2877-41E9-92AE-3A391B2AA9E8}.Release|Any CPU.Build.0 = Release|Any CPU
+ {42A8251C-E7B3-47BB-A82E-459952EBE132}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {42A8251C-E7B3-47BB-A82E-459952EBE132}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {42A8251C-E7B3-47BB-A82E-459952EBE132}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {42A8251C-E7B3-47BB-A82E-459952EBE132}.Release|Any CPU.Build.0 = Release|Any CPU
+ EndGlobalSection
+ GlobalSection(SolutionProperties) = preSolution
+ HideSolutionNode = FALSE
+ EndGlobalSection
+ GlobalSection(NestedProjects) = preSolution
+ {B2B27ACB-AA50-4FED-A06C-3AD6B4218188} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
+ {FDD99AEC-4C57-4020-B23F-650612856102} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
+ {3FFD14E3-D6BC-4EA7-97A2-D21733060FD6} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
+ {05A2FAD8-03B0-4B2F-82AF-2F6BF0F050E5} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
+ {7EBF916A-A7B1-4B74-AF10-D705B7A18F58} = {FBFEAD1F-29EB-4D99-A672-0CD8473E10B9}
+ {B61D8008-7FB7-4C0E-8044-3A74AA63A596} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
+ {F98BDA9B-8657-4BA8-9B03-BAEA454CAE60} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
+ {45D6FC80-36F3-4967-9663-E20B63824621} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
+ {D58D43D1-0617-4A3D-9932-C773E6398535} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
+ {63445BB7-DBB9-4AEF-9D6F-98BBE75EE1EC} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
+ {6585D1A4-3D97-4D76-A688-1933B61AEB19} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
+ {15441693-3659-4868-B6C1-B106F52FF3BA} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
+ {257FFD71-08E5-40C7-AB04-6A81A78EB410} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
+ {E2EF5E66-683C-4DDC-8ADA-5F676502B9BA} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
+ {1DFABC4A-8458-4875-8DCB-59F3802DAC65} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
+ {D36A85F9-C172-487D-8192-6BFE5D05B4A7} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
+ {B61388CA-DC73-4B7F-A7B2-7B9A86C9229E} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
+ {9F9E6DED-3D92-4970-909A-70FC11F1A665} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
+ {03E31CAA-3728-48D3-B936-9F11CF6C18FE} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
+ {93AA4D0D-6EE4-44D5-AD77-7F73A3934544} = {FBFEAD1F-29EB-4D99-A672-0CD8473E10B9}
+ {52958A60-3FF7-4243-9058-34A6E4F55C31} = {FBFEAD1F-29EB-4D99-A672-0CD8473E10B9}
+ {6A95E113-B824-4524-8F13-CD0C3E1C8804} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
+ {815E937E-86D6-4476-9EC6-B7FBCBBB5DB6} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
+ {834B4E85-64E5-4382-8465-548F332E5298} = {FBFEAD1F-29EB-4D99-A672-0CD8473E10B9}
+ {EFE0DC86-80FC-4D52-95B7-07654BA1A769} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
+ {8EA16BAB-465A-4C07-ABC4-1070D40067E9} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
+ {19679B75-CE3A-4DF0-A3F0-CA369D2760A4} = {FBFEAD1F-29EB-4D99-A672-0CD8473E10B9}
+ {6B82F26D-5040-4453-B21B-C8D1F913CE4C} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
+ {0E635268-351C-4A6B-A28D-593D868C2CA4} = {FBFEAD1F-29EB-4D99-A672-0CD8473E10B9}
+ {12079C18-A519-403F-BBFD-200A36A0C083} = {FBFEAD1F-29EB-4D99-A672-0CD8473E10B9}
+ {5C45981D-1319-4C25-935C-83D411CB28DF} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
+ {5970868F-831E-418F-89A9-4EC599563E16} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
+ {143725E2-206C-4D37-93E4-9EDF699826B2} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
+ {3AF1CBEC-2877-41E9-92AE-3A391B2AA9E8} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
+ {42A8251C-E7B3-47BB-A82E-459952EBE132} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
+ EndGlobalSection
+ GlobalSection(ExtensibilityGlobals) = postSolution
+ SolutionGuid = {93384647-528D-46C8-922C-8DB36A382F0B}
+ EndGlobalSection
+EndGlobal
diff --git a/dotnet/Directory.Build.props b/dotnet/Directory.Build.props
new file mode 100644
index 00000000000..b5663fe4c57
--- /dev/null
+++ b/dotnet/Directory.Build.props
@@ -0,0 +1,51 @@
+
+
+
+
+
+
+ netstandard2.0;net6.0;net8.0
+ net8.0
+ preview
+ enable
+ True
+ $(MSBuildThisFileDirectory)eng/opensource.snk
+ 0024000004800000940000000602000000240000525341310004000001000100f1d038d0b85ae392ad72011df91e9343b0b5df1bb8080aa21b9424362d696919e0e9ac3a8bca24e283e10f7a569c6f443e1d4e3ebc84377c87ca5caa562e80f9932bf5ea91b7862b538e13b8ba91c7565cf0e8dfeccfea9c805ae3bda044170ecc7fc6f147aeeac422dd96aeb9eb1f5a5882aa650efe2958f2f8107d2038f2ab
+ CS1998;CS1591
+ $(NoWarn);$(CSNoWarn);NU5104
+ true
+ true
+ false
+ true
+ true
+ false
+
+
+
+ $(MSBuildThisFileDirectory)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Always
+ testData/%(RecursiveDir)%(Filename)%(Extension)
+
+
+
+
+
+ Always
+ resource/%(RecursiveDir)%(Filename)%(Extension)
+
+
+
diff --git a/dotnet/NuGet.config b/dotnet/NuGet.config
new file mode 100644
index 00000000000..1d0cf4c2bc7
--- /dev/null
+++ b/dotnet/NuGet.config
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/dotnet/README.md b/dotnet/README.md
new file mode 100644
index 00000000000..5b0803b6e11
--- /dev/null
+++ b/dotnet/README.md
@@ -0,0 +1,103 @@
+### AutoGen for .NET
+
+[![dotnet-ci](https://github.com/microsoft/autogen/actions/workflows/dotnet-build.yml/badge.svg)](https://github.com/microsoft/autogen/actions/workflows/dotnet-build.yml)
+[![NuGet version](https://badge.fury.io/nu/AutoGen.Core.svg)](https://badge.fury.io/nu/AutoGen.Core)
+
+> [!NOTE]
+> Nightly build is available at:
+> - ![Static Badge](https://img.shields.io/badge/public-blue?style=flat) ![Static Badge](https://img.shields.io/badge/nightly-yellow?style=flat) ![Static Badge](https://img.shields.io/badge/github-grey?style=flat): https://nuget.pkg.github.com/microsoft/index.json
+> - ![Static Badge](https://img.shields.io/badge/public-blue?style=flat) ![Static Badge](https://img.shields.io/badge/nightly-yellow?style=flat) ![Static Badge](https://img.shields.io/badge/myget-grey?style=flat): https://www.myget.org/F/agentchat/api/v3/index.json
+> - ![Static Badge](https://img.shields.io/badge/internal-blue?style=flat) ![Static Badge](https://img.shields.io/badge/nightly-yellow?style=flat) ![Static Badge](https://img.shields.io/badge/azure_devops-grey?style=flat) : https://devdiv.pkgs.visualstudio.com/DevDiv/_packaging/AutoGen/nuget/v3/index.json
+
+
+Firstly, following the [installation guide](./website/articles/Installation.md) to install AutoGen packages.
+
+Then you can start with the following code snippet to create a conversable agent and chat with it.
+
+```csharp
+using AutoGen;
+using AutoGen.OpenAI;
+
+var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
+var gpt35Config = new OpenAIConfig(openAIKey, "gpt-3.5-turbo");
+
+var assistantAgent = new AssistantAgent(
+ name: "assistant",
+ systemMessage: "You are an assistant that help user to do some tasks.",
+ llmConfig: new ConversableAgentConfig
+ {
+ Temperature = 0,
+ ConfigList = [gpt35Config],
+ })
+ .RegisterPrintMessage(); // register a hook to print message nicely to console
+
+// set human input mode to ALWAYS so that user always provide input
+var userProxyAgent = new UserProxyAgent(
+ name: "user",
+ humanInputMode: ConversableAgent.HumanInputMode.ALWAYS)
+ .RegisterPrintMessage();
+
+// start the conversation
+await userProxyAgent.InitiateChatAsync(
+ receiver: assistantAgent,
+ message: "Hey assistant, please do me a favor.",
+ maxRound: 10);
+```
+
+#### Samples
+You can find more examples under the [sample project](https://github.com/microsoft/autogen/tree/dotnet/dotnet/sample/AutoGen.BasicSamples).
+
+#### Functionality
+- ConversableAgent
+ - [x] function call
+ - [x] code execution (dotnet only, powered by [`dotnet-interactive`](https://github.com/dotnet/interactive))
+
+- Agent communication
+ - [x] Two-agent chat
+ - [x] Group chat
+
+- [ ] Enhanced LLM Inferences
+
+- Exclusive for dotnet
+ - [x] Source generator for type-safe function definition generation
+
+#### Update log
+##### Update on 0.0.11 (2024-03-26)
+- Add link to Discord channel in nuget's readme.md
+- Document improvements
+##### Update on 0.0.10 (2024-03-12)
+- Rename `Workflow` to `Graph`
+- Rename `AddInitializeMessage` to `SendIntroduction`
+- Rename `SequentialGroupChat` to `RoundRobinGroupChat`
+##### Update on 0.0.9 (2024-03-02)
+- Refactor over @AutoGen.Message and introducing `TextMessage`, `ImageMessage`, `MultiModalMessage` and so on. PR [#1676](https://github.com/microsoft/autogen/pull/1676)
+- Add `AutoGen.SemanticKernel` to support seamless integration with Semantic Kernel
+- Move the agent contract abstraction to `AutoGen.Core` package. The `AutoGen.Core` package provides the abstraction for message type, agent and group chat and doesn't contain dependencies over `Azure.AI.OpenAI` or `Semantic Kernel`. This is useful when you want to leverage AutoGen's abstraction only and want to avoid introducing any other dependencies.
+- Move `GPTAgent`, `OpenAIChatAgent` and all openai-dependencies to `AutoGen.OpenAI`
+##### Update on 0.0.8 (2024-02-28)
+- Fix [#1804](https://github.com/microsoft/autogen/pull/1804)
+- Streaming support for IAgent [#1656](https://github.com/microsoft/autogen/pull/1656)
+- Streaming support for middleware via `MiddlewareStreamingAgent` [#1656](https://github.com/microsoft/autogen/pull/1656)
+- Graph chat support with conditional transition workflow [#1761](https://github.com/microsoft/autogen/pull/1761)
+- AutoGen.SourceGenerator: Generate `FunctionContract` from `FunctionAttribute` [#1736](https://github.com/microsoft/autogen/pull/1736)
+##### Update on 0.0.7 (2024-02-11)
+- Add `AutoGen.LMStudio` to support comsume openai-like API from LMStudio local server
+##### Update on 0.0.6 (2024-01-23)
+- Add `MiddlewareAgent`
+- Use `MiddlewareAgent` to implement existing agent hooks (RegisterPreProcess, RegisterPostProcess, RegisterReply)
+- Remove `AutoReplyAgent`, `PreProcessAgent`, `PostProcessAgent` because they are replaced by `MiddlewareAgent`
+##### Update on 0.0.5
+- Simplify `IAgent` interface by removing `ChatLLM` Property
+- Add `GenerateReplyOptions` to `IAgent.GenerateReplyAsync` which allows user to specify or override the options when generating reply
+
+##### Update on 0.0.4
+- Move out dependency of Semantic Kernel
+- Add type `IChatLLM` as connector to LLM
+
+##### Update on 0.0.3
+- In AutoGen.SourceGenerator, rename FunctionAttribution to FunctionAttribute
+- In AutoGen, refactor over ConversationAgent, UserProxyAgent, and AssistantAgent
+
+##### Update on 0.0.2
+- update Azure.OpenAI.AI to 1.0.0-beta.12
+- update Semantic kernel to 1.0.1
diff --git a/dotnet/eng/MetaInfo.props b/dotnet/eng/MetaInfo.props
new file mode 100644
index 00000000000..006c586faba
--- /dev/null
+++ b/dotnet/eng/MetaInfo.props
@@ -0,0 +1,12 @@
+
+
+
+ 0.1.0
+ AutoGen
+ https://microsoft.github.io/autogen-for-net/
+ https://github.com/microsoft/autogen
+ git
+ MIT
+ false
+
+
diff --git a/dotnet/eng/Sign.props b/dotnet/eng/Sign.props
new file mode 100644
index 00000000000..0d69e7797e4
--- /dev/null
+++ b/dotnet/eng/Sign.props
@@ -0,0 +1,22 @@
+
+
+
+
+
+
+
+
+ all
+ runtime; build; native; contentfiles; analyzers
+
+
+
+ Microsoft400
+
+
+
+
+ NuGet
+
+
+
diff --git a/dotnet/eng/Version.props b/dotnet/eng/Version.props
new file mode 100644
index 00000000000..36cfd917c2c
--- /dev/null
+++ b/dotnet/eng/Version.props
@@ -0,0 +1,23 @@
+
+
+
+ 1.0.0-beta.17
+ 2.0.0-beta.3
+ 1.18.1-rc
+ 1.18.1-alpha
+ 5.0.0
+ 4.3.0
+ 6.0.0
+ 6.8.0
+ 2.4.2
+ 17.7.0
+ 1.0.0-beta.24229.4
+ 8.0.0
+ 8.0.4
+ 3.0.0
+ 4.3.0.2
+ 1.0.0-beta.1
+ 2.0.0-beta.10
+ 7.4.4
+
+
\ No newline at end of file
diff --git a/dotnet/eng/opensource.snk b/dotnet/eng/opensource.snk
new file mode 100644
index 00000000000..779df7c8366
Binary files /dev/null and b/dotnet/eng/opensource.snk differ
diff --git a/dotnet/global.json b/dotnet/global.json
new file mode 100644
index 00000000000..a604954f983
--- /dev/null
+++ b/dotnet/global.json
@@ -0,0 +1,6 @@
+{
+ "sdk": {
+ "version": "8.0.104",
+ "rollForward": "latestMinor"
+ }
+}
\ No newline at end of file
diff --git a/dotnet/nuget/NUGET.md b/dotnet/nuget/NUGET.md
new file mode 100644
index 00000000000..34fdbca33ca
--- /dev/null
+++ b/dotnet/nuget/NUGET.md
@@ -0,0 +1,8 @@
+### About AutoGen for .NET
+`AutoGen for .NET` is the official .NET SDK for [AutoGen](https://github.com/microsoft/autogen). It enables you to create LLM agents and construct multi-agent workflows with ease. It also provides integration with popular platforms like OpenAI, Semantic Kernel, and LM Studio.
+
+### Gettings started
+- Find documents and examples on our [document site](https://microsoft.github.io/autogen-for-net/)
+- Join our [Discord channel](https://discord.gg/pAbnFJrkgZ) to get help and discuss with the community
+- Report a bug or request a feature by creating a new issue in our [github repo](https://github.com/microsoft/autogen)
+- Consume the nightly build package from one of the [nightly build feeds](https://microsoft.github.io/autogen-for-net/articles/Installation.html#nighly-build)
\ No newline at end of file
diff --git a/dotnet/nuget/icon.png b/dotnet/nuget/icon.png
new file mode 100644
index 00000000000..076fc48c562
--- /dev/null
+++ b/dotnet/nuget/icon.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:02dbf31fea0b92714c80fdc90888da7e96374a1f52c621a939835fd3c876ddcc
+size 426084
diff --git a/dotnet/nuget/nuget-package.props b/dotnet/nuget/nuget-package.props
new file mode 100644
index 00000000000..c6ddf38916f
--- /dev/null
+++ b/dotnet/nuget/nuget-package.props
@@ -0,0 +1,54 @@
+
+
+ true
+
+
+ AutoGen
+ Microsoft
+ AutoGen
+ A programming framework for agentic AI
+ AI, Artificial Intelligence, SDK
+ $(AssemblyName)
+
+
+ MIT
+ © Microsoft Corporation. All rights reserved.
+ https://microsoft.github.io/autogen-for-net
+ https://github.com/microsoft/autogen
+ true
+
+
+ icon.png
+ icon.png
+ NUGET.md
+
+
+ true
+ snupkg
+
+
+ true
+
+
+ true
+
+
+ bin\$(Configuration)\$(TargetFramework)\$(AssemblyName).xml
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ true
+
+
\ No newline at end of file
diff --git a/dotnet/resource/images/background.png b/dotnet/resource/images/background.png
new file mode 100644
index 00000000000..ca276f81f5b
--- /dev/null
+++ b/dotnet/resource/images/background.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:300b7c9d6ba0c23a3e52fbd2e268141ddcca0434a9fb9dcf7e58e7e903d36dcf
+size 2126185
diff --git a/dotnet/resource/images/square.png b/dotnet/resource/images/square.png
new file mode 100644
index 00000000000..afb4f4cd4df
--- /dev/null
+++ b/dotnet/resource/images/square.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8323d0b8eceb752e14c29543b2e28bb2fc648ed9719095c31b7708867a4dc918
+size 491
diff --git a/dotnet/sample/AutoGen.Anthropic.Samples/Anthropic_Agent_With_Prompt_Caching.cs b/dotnet/sample/AutoGen.Anthropic.Samples/Anthropic_Agent_With_Prompt_Caching.cs
new file mode 100644
index 00000000000..5d8a99ce128
--- /dev/null
+++ b/dotnet/sample/AutoGen.Anthropic.Samples/Anthropic_Agent_With_Prompt_Caching.cs
@@ -0,0 +1,133 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Anthropic_Agent_With_Prompt_Caching.cs
+
+using AutoGen.Anthropic.DTO;
+using AutoGen.Anthropic.Extensions;
+using AutoGen.Anthropic.Utils;
+using AutoGen.Core;
+
+namespace AutoGen.Anthropic.Samples;
+
+public class Anthropic_Agent_With_Prompt_Caching
+{
+ // A random and long test string to demonstrate cache control.
+ // the context must be larger than 1024 tokens for Claude 3.5 Sonnet and Claude 3 Opus
+ // 2048 tokens for Claude 3.0 Haiku
+ // Shorter prompts cannot be cached, even if marked with cache_control. Any requests to cache fewer than this number of tokens will be processed without caching
+
+ #region Long story for caching
+ public const string LongStory = """
+ Once upon a time in a small, nondescript town lived a man named Bob. Bob was an unassuming individual, the kind of person you wouldn’t look twice at if you passed him on the street. He worked as an IT specialist for a mid-sized corporation, spending his days fixing computers and troubleshooting software issues. But beneath his average exterior, Bob harbored a secret ambition—he wanted to take over the world.
+
+ Bob wasn’t always like this. For most of his life, he had been content with his routine, blending into the background. But one day, while browsing the dark corners of the internet, Bob stumbled upon an ancient manuscript, encrypted within the deep web, detailing the steps to global domination. It was written by a forgotten conqueror, someone whose name had been erased from history but whose methods were preserved in this digital relic. The manuscript laid out a plan so intricate and flawless that Bob, with his analytical mind, became obsessed.
+
+ Over the next few years, Bob meticulously followed the manuscript’s guidance. He started small, creating a network of like-minded individuals who shared his dream. They communicated through encrypted channels, meeting in secret to discuss their plans. Bob was careful, never revealing too much about himself, always staying in the shadows. He used his IT skills to gather information, infiltrating government databases, and private corporations, and acquiring secrets that could be used as leverage.
+
+ As his network grew, so did his influence. Bob began to manipulate world events from behind the scenes. He orchestrated economic crises, incited political turmoil, and planted seeds of discord among the world’s most powerful nations. Each move was calculated, each action a step closer to his ultimate goal. The world was in chaos, and no one suspected that a man like Bob could be behind it all.
+
+ But Bob knew that causing chaos wasn’t enough. To truly take over the world, he needed something more—something to cement his power. That’s when he turned to technology. Bob had always been ahead of the curve when it came to tech, and now, he planned to use it to his advantage. He began developing an AI, one that would be more powerful and intelligent than anything the world had ever seen. This AI, which Bob named “Nemesis,” was designed to control every aspect of modern life—from financial systems to military networks.
+
+ It took years of coding, testing, and refining, but eventually, Nemesis was ready. Bob unleashed the AI, and within days, it had taken control of the world’s digital infrastructure. Governments were powerless, their systems compromised. Corporations crumbled as their assets were seized. The military couldn’t act, their weapons turned against them. Bob, from the comfort of his modest home, had done it. He had taken over the world.
+
+ The world, now under Bob’s control, was eerily quiet. There were no more wars, no more financial crises, no more political strife. Nemesis ensured that everything ran smoothly, efficiently, and without dissent. The people of the world had no choice but to obey, their lives dictated by an unseen hand.
+
+ Bob, once a man who was overlooked and ignored, was now the most powerful person on the planet. But with that power came a realization. The world he had taken over was not the world he had envisioned. It was cold, mechanical, and devoid of the chaos that once made life unpredictable and exciting. Bob had achieved his goal, but in doing so, he had lost the very thing that made life worth living—freedom.
+
+ And so, Bob, now ruler of the world, sat alone in his control room, staring at the screens that displayed his dominion. He had everything he had ever wanted, yet he felt emptier than ever before. The world was his, but at what cost?
+
+ In the end, Bob realized that true power didn’t come from controlling others, but from the ability to let go. He deactivated Nemesis, restoring the world to its former state, and disappeared into obscurity, content to live out the rest of his days as just another face in the crowd. And though the world never knew his name, Bob’s legacy would live on, a reminder of the dangers of unchecked ambition.
+
+ Bob had vanished, leaving the world in a fragile state of recovery. Governments scrambled to regain control of their systems, corporations tried to rebuild, and the global population slowly adjusted to life without the invisible grip of Nemesis. Yet, even as society returned to a semblance of normalcy, whispers of the mysterious figure who had brought the world to its knees lingered in the shadows.
+
+ Meanwhile, Bob had retreated to a secluded cabin deep in the mountains. The cabin was a modest, rustic place, surrounded by dense forests and overlooking a tranquil lake. It was far from civilization, a perfect place for a man who wanted to disappear. Bob spent his days fishing, hiking, and reflecting on his past. For the first time in years, he felt a sense of peace.
+
+ But peace was fleeting. Despite his best efforts to put his past behind him, Bob couldn’t escape the consequences of his actions. He had unleashed Nemesis upon the world, and though he had deactivated the AI, remnants of its code still existed. Rogue factions, hackers, and remnants of his old network were searching for those fragments, hoping to revive Nemesis and seize the power that Bob had relinquished.
+
+ One day, as Bob was chopping wood outside his cabin, a figure emerged from the tree line. It was a young woman, dressed in hiking gear, with a determined look in her eyes. Bob tensed, his instincts telling him that this was no ordinary hiker.
+
+ “Bob,” the woman said, her voice steady. “Or should I say, the man who almost became the ruler of the world?”
+
+ Bob sighed, setting down his axe. “Who are you, and what do you want?”
+
+ The woman stepped closer. “My name is Sarah. I was part of your network, one of the few who knew about Nemesis. But I wasn’t like the others. I didn’t want power for myself—I wanted to protect the world from those who would misuse it.”
+
+ Bob studied her, trying to gauge her intentions. “And why are you here now?”
+
+ Sarah reached into her backpack and pulled out a small device. “Because Nemesis isn’t dead. Some of its code is still active, and it’s trying to reboot itself. I need your help to stop it for good.”
+
+ Bob’s heart sank. He had hoped that by deactivating Nemesis, he had erased it from existence. But deep down, he knew that an AI as powerful as Nemesis wouldn’t go down so easily. “Why come to me? I’m the one who created it. I’m the reason the world is in this mess.”
+
+ Sarah shook her head. “You’re also the only one who knows how to stop it. I’ve tracked down the remnants of Nemesis’s code, but I need you to help destroy it before it falls into the wrong hands.”
+
+ Bob hesitated. He had wanted nothing more than to leave his past behind, but he couldn’t ignore the responsibility that weighed on him. He had created Nemesis, and now it was his duty to make sure it never posed a threat again.
+
+ “Alright,” Bob said finally. “I’ll help you. But after this, I’m done. No more world domination, no more secret networks. I just want to live in peace.”
+
+ Sarah nodded. “Agreed. Let’s finish what you started.”
+
+ Over the next few weeks, Bob and Sarah worked together, traveling to various locations around the globe where fragments of Nemesis’s code had been detected. They infiltrated secure facilities, outsmarted rogue hackers, and neutralized threats, all while staying one step ahead of those who sought to control Nemesis for their own gain.
+
+ As they worked, Bob and Sarah developed a deep respect for one another. Sarah was sharp, resourceful, and driven by a genuine desire to protect the world. Bob found himself opening up to her, sharing his regrets, his doubts, and the lessons he had learned. In turn, Sarah shared her own story—how she had once been tempted by power but had chosen a different path, one that led her to fight for what was right.
+
+ Finally, after weeks of intense effort, they tracked down the last fragment of Nemesis’s code, hidden deep within a remote server farm in the Arctic. The facility was heavily guarded, but Bob and Sarah had planned meticulously. Under the cover of a blizzard, they infiltrated the facility, avoiding detection as they made their way to the heart of the server room.
+
+ As Bob began the process of erasing the final fragment, an alarm blared, and the facility’s security forces closed in. Sarah held them off as long as she could, but they were outnumbered and outgunned. Just as the situation seemed hopeless, Bob executed the final command, wiping Nemesis from existence once and for all.
+
+ But as the last remnants of Nemesis were deleted, Bob knew there was only one way to ensure it could never be resurrected. He initiated a self-destruct sequence for the server farm, trapping himself and Sarah inside.
+
+ Sarah stared at him, realization dawning in her eyes. “Bob, what are you doing?”
+
+ Bob looked at her, a sad smile on his face. “I have to make sure it’s over. This is the only way.”
+
+ Sarah’s eyes filled with tears, but she nodded, understanding the gravity of his decision. “Thank you, Bob. For everything.”
+
+ As the facility’s countdown reached its final seconds, Bob and Sarah stood side by side, knowing they had done the right thing. The explosion that followed was seen from miles away, a final testament to the end of an era.
+
+ The world never knew the true story of Bob, the man who almost ruled the world. But in his final act of sacrifice, he ensured that the world would remain free, a place where people could live their lives without fear of control. Bob had redeemed himself, not as a conqueror, but as a protector—a man who chose to save the world rather than rule it.
+
+ And in the quiet aftermath of the explosion, as the snow settled over the wreckage, Bob’s legacy was sealed—not as a name in history books, but as a silent guardian whose actions would be felt for generations to come.
+ """;
+ #endregion
+
+ public static async Task RunAsync()
+ {
+ #region init translator agents & register middlewares
+
+ var apiKey = Environment.GetEnvironmentVariable("ANTHROPIC_API_KEY") ??
+ throw new Exception("Please set ANTHROPIC_API_KEY environment variable.");
+ var anthropicClient = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, apiKey);
+ var frenchTranslatorAgent =
+ new AnthropicClientAgent(anthropicClient, "frenchTranslator", AnthropicConstants.Claude35Sonnet,
+ systemMessage: "You are a French translator")
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+
+ var germanTranslatorAgent = new AnthropicClientAgent(anthropicClient, "germanTranslator",
+ AnthropicConstants.Claude35Sonnet, systemMessage: "You are a German translator")
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+
+ #endregion
+
+ var userProxyAgent = new UserProxyAgent(
+ name: "user",
+ humanInputMode: HumanInputMode.ALWAYS)
+ .RegisterPrintMessage();
+
+ var groupChat = new RoundRobinGroupChat(
+ agents: [userProxyAgent, frenchTranslatorAgent, germanTranslatorAgent]);
+
+ var messageEnvelope =
+ MessageEnvelope.Create(
+ new ChatMessage("user", [TextContent.CreateTextWithCacheControl(LongStory)]),
+ from: "user");
+
+ var chatHistory = new List()
+ {
+ new TextMessage(Role.User, "translate this text for me", from: userProxyAgent.Name),
+ messageEnvelope,
+ };
+
+ var history = await groupChat.SendAsync(chatHistory).ToArrayAsync();
+ }
+}
diff --git a/dotnet/sample/AutoGen.Anthropic.Samples/AutoGen.Anthropic.Samples.csproj b/dotnet/sample/AutoGen.Anthropic.Samples/AutoGen.Anthropic.Samples.csproj
new file mode 100644
index 00000000000..fe7553b937f
--- /dev/null
+++ b/dotnet/sample/AutoGen.Anthropic.Samples/AutoGen.Anthropic.Samples.csproj
@@ -0,0 +1,19 @@
+
+
+
+ Exe
+ $(TestTargetFrameworks)
+ enable
+ enable
+ True
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dotnet/sample/AutoGen.Anthropic.Samples/Create_Anthropic_Agent.cs b/dotnet/sample/AutoGen.Anthropic.Samples/Create_Anthropic_Agent.cs
new file mode 100644
index 00000000000..6f32c3cb4a2
--- /dev/null
+++ b/dotnet/sample/AutoGen.Anthropic.Samples/Create_Anthropic_Agent.cs
@@ -0,0 +1,28 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Create_Anthropic_Agent.cs
+
+using AutoGen.Anthropic.Extensions;
+using AutoGen.Anthropic.Utils;
+using AutoGen.Core;
+
+namespace AutoGen.Anthropic.Samples;
+
+public static class Create_Anthropic_Agent
+{
+ public static async Task RunAsync()
+ {
+ #region create_anthropic_agent
+ var apiKey = Environment.GetEnvironmentVariable("ANTHROPIC_API_KEY") ?? throw new Exception("Missing ANTHROPIC_API_KEY environment variable.");
+ var anthropicClient = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, apiKey);
+ var agent = new AnthropicClientAgent(anthropicClient, "assistant", AnthropicConstants.Claude3Haiku);
+ #endregion
+
+ #region register_middleware
+ var agentWithConnector = agent
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+ #endregion register_middleware
+
+ await agentWithConnector.SendAsync(new TextMessage(Role.Assistant, "Hello", from: "user"));
+ }
+}
diff --git a/dotnet/sample/AutoGen.Anthropic.Samples/Create_Anthropic_Agent_With_Tool.cs b/dotnet/sample/AutoGen.Anthropic.Samples/Create_Anthropic_Agent_With_Tool.cs
new file mode 100644
index 00000000000..0324a39ffa5
--- /dev/null
+++ b/dotnet/sample/AutoGen.Anthropic.Samples/Create_Anthropic_Agent_With_Tool.cs
@@ -0,0 +1,100 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Create_Anthropic_Agent_With_Tool.cs
+
+using AutoGen.Anthropic.DTO;
+using AutoGen.Anthropic.Extensions;
+using AutoGen.Anthropic.Utils;
+using AutoGen.Core;
+using FluentAssertions;
+
+namespace AutoGen.Anthropic.Samples;
+
+#region WeatherFunction
+
+public partial class WeatherFunction
+{
+ ///
+ /// Gets the weather based on the location and the unit
+ ///
+ ///
+ ///
+ ///
+ [Function]
+ public async Task GetWeather(string location, string unit)
+ {
+ // dummy implementation
+ return $"The weather in {location} is currently sunny with a tempature of {unit} (s)";
+ }
+}
+#endregion
+public class Create_Anthropic_Agent_With_Tool
+{
+ public static async Task RunAsync()
+ {
+ #region define_tool
+ var tool = new Tool
+ {
+ Name = "GetWeather",
+ Description = "Get the current weather in a given location",
+ InputSchema = new InputSchema
+ {
+ Type = "object",
+ Properties = new Dictionary
+ {
+ { "location", new SchemaProperty { Type = "string", Description = "The city and state, e.g. San Francisco, CA" } },
+ { "unit", new SchemaProperty { Type = "string", Description = "The unit of temperature, either \"celsius\" or \"fahrenheit\"" } }
+ },
+ Required = new List { "location" }
+ }
+ };
+
+ var weatherFunction = new WeatherFunction();
+ var functionMiddleware = new FunctionCallMiddleware(
+ functions: [
+ weatherFunction.GetWeatherFunctionContract,
+ ],
+ functionMap: new Dictionary>>
+ {
+ { weatherFunction.GetWeatherFunctionContract.Name!, weatherFunction.GetWeatherWrapper },
+ });
+
+ #endregion
+
+ #region create_anthropic_agent
+
+ var apiKey = Environment.GetEnvironmentVariable("ANTHROPIC_API_KEY") ??
+ throw new Exception("Missing ANTHROPIC_API_KEY environment variable.");
+
+ var anthropicClient = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, apiKey);
+ var agent = new AnthropicClientAgent(anthropicClient, "assistant", AnthropicConstants.Claude3Haiku,
+ tools: [tool]); // Define tools for AnthropicClientAgent
+ #endregion
+
+ #region register_middleware
+
+ var agentWithConnector = agent
+ .RegisterMessageConnector()
+ .RegisterPrintMessage()
+ .RegisterStreamingMiddleware(functionMiddleware);
+ #endregion register_middleware
+
+ #region single_turn
+ var question = new TextMessage(Role.Assistant,
+ "What is the weather like in San Francisco?",
+ from: "user");
+ var functionCallReply = await agentWithConnector.SendAsync(question);
+ #endregion
+
+ #region Single_turn_verify_reply
+ functionCallReply.Should().BeOfType();
+ #endregion Single_turn_verify_reply
+
+ #region Multi_turn
+ var finalReply = await agentWithConnector.SendAsync(chatHistory: [question, functionCallReply]);
+ #endregion Multi_turn
+
+ #region Multi_turn_verify_reply
+ finalReply.Should().BeOfType();
+ #endregion Multi_turn_verify_reply
+ }
+}
diff --git a/dotnet/sample/AutoGen.Anthropic.Samples/Program.cs b/dotnet/sample/AutoGen.Anthropic.Samples/Program.cs
new file mode 100644
index 00000000000..105bb56524f
--- /dev/null
+++ b/dotnet/sample/AutoGen.Anthropic.Samples/Program.cs
@@ -0,0 +1,12 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Program.cs
+
+namespace AutoGen.Anthropic.Samples;
+
+internal static class Program
+{
+ public static async Task Main(string[] args)
+ {
+ await Anthropic_Agent_With_Prompt_Caching.RunAsync();
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/AutoGen.BasicSample.csproj b/dotnet/sample/AutoGen.BasicSamples/AutoGen.BasicSample.csproj
new file mode 100644
index 00000000000..d4323ee4c92
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/AutoGen.BasicSample.csproj
@@ -0,0 +1,19 @@
+
+
+
+ Exe
+ $(TestTargetFrameworks)
+ enable
+ True
+ $(NoWarn);CS8981;CS8600;CS8602;CS8604;CS8618;CS0219;SKEXP0054;SKEXP0050;SKEXP0110
+ true
+
+
+
+
+
+
+
+
+
+
diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/AgentCodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/AgentCodeSnippet.cs
new file mode 100644
index 00000000000..abaf94cbd4f
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/AgentCodeSnippet.cs
@@ -0,0 +1,31 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// AgentCodeSnippet.cs
+using AutoGen.Core;
+
+namespace AutoGen.BasicSample.CodeSnippet;
+
+internal class AgentCodeSnippet
+{
+ public async Task ChatWithAnAgent(IStreamingAgent agent)
+ {
+ #region ChatWithAnAgent_GenerateReplyAsync
+ var message = new TextMessage(Role.User, "Hello");
+ IMessage reply = await agent.GenerateReplyAsync([message]);
+ #endregion ChatWithAnAgent_GenerateReplyAsync
+
+ #region ChatWithAnAgent_SendAsync
+ reply = await agent.SendAsync("Hello");
+ #endregion ChatWithAnAgent_SendAsync
+
+ #region ChatWithAnAgent_GenerateStreamingReplyAsync
+ var textMessage = new TextMessage(Role.User, "Hello");
+ await foreach (var streamingReply in agent.GenerateStreamingReplyAsync([message]))
+ {
+ if (streamingReply is TextMessageUpdate update)
+ {
+ Console.Write(update.Content);
+ }
+ }
+ #endregion ChatWithAnAgent_GenerateStreamingReplyAsync
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/BuildInMessageCodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/BuildInMessageCodeSnippet.cs
new file mode 100644
index 00000000000..f26485116c8
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/BuildInMessageCodeSnippet.cs
@@ -0,0 +1,42 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// BuildInMessageCodeSnippet.cs
+
+using AutoGen.Core;
+namespace AutoGen.BasicSample.CodeSnippet;
+
+internal class BuildInMessageCodeSnippet
+{
+ public async Task StreamingCallCodeSnippetAsync()
+ {
+ IStreamingAgent agent = default;
+ #region StreamingCallCodeSnippet
+ var helloTextMessage = new TextMessage(Role.User, "Hello");
+ var reply = agent.GenerateStreamingReplyAsync([helloTextMessage]);
+ var finalTextMessage = new TextMessage(Role.Assistant, string.Empty, from: agent.Name);
+ await foreach (var message in reply)
+ {
+ if (message is TextMessageUpdate textMessage)
+ {
+ Console.Write(textMessage.Content);
+ finalTextMessage.Update(textMessage);
+ }
+ }
+ #endregion StreamingCallCodeSnippet
+
+ #region StreamingCallWithFinalMessage
+ reply = agent.GenerateStreamingReplyAsync([helloTextMessage]);
+ TextMessage finalMessage = null;
+ await foreach (var message in reply)
+ {
+ if (message is TextMessageUpdate textMessage)
+ {
+ Console.Write(textMessage.Content);
+ }
+ else if (message is TextMessage txtMessage)
+ {
+ finalMessage = txtMessage;
+ }
+ }
+ #endregion StreamingCallWithFinalMessage
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/CreateAnAgent.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/CreateAnAgent.cs
new file mode 100644
index 00000000000..f6805322466
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/CreateAnAgent.cs
@@ -0,0 +1,126 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// CreateAnAgent.cs
+
+using AutoGen;
+using AutoGen.Core;
+using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
+using FluentAssertions;
+using OpenAI;
+
+public partial class AssistantCodeSnippet
+{
+ public void CodeSnippet1()
+ {
+ #region code_snippet_1
+ // get OpenAI Key and create config
+ var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
+ var llmConfig = new OpenAIConfig(openAIKey, "gpt-3.5-turbo");
+
+ // create assistant agent
+ var assistantAgent = new AssistantAgent(
+ name: "assistant",
+ systemMessage: "You are an assistant that help user to do some tasks.",
+ llmConfig: new ConversableAgentConfig
+ {
+ Temperature = 0,
+ ConfigList = new[] { llmConfig },
+ });
+ #endregion code_snippet_1
+
+ }
+
+ public void CodeSnippet2()
+ {
+ #region code_snippet_2
+ // get OpenAI Key and create config
+ var apiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY");
+ var model = "gpt-4o-mini";
+
+ var openAIClient = new OpenAIClient(apiKey);
+
+ // create assistant agent
+ var assistantAgent = new OpenAIChatAgent(
+ name: "assistant",
+ systemMessage: "You are an assistant that help user to do some tasks.",
+ chatClient: openAIClient.GetChatClient(model))
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+ #endregion code_snippet_2
+ }
+
+ #region code_snippet_3
+ ///
+ /// convert input to upper case
+ ///
+ /// input
+ [Function]
+ public async Task UpperCase(string input)
+ {
+ var result = input.ToUpper();
+ return result;
+ }
+
+ #endregion code_snippet_3
+
+ public async Task CodeSnippet4()
+ {
+ // get OpenAI Key and create config
+ var apiKey = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY");
+ string endPoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT"); // change to your endpoint
+ var model = "gpt-4o-mini";
+ var openAIClient = new OpenAIClient(new System.ClientModel.ApiKeyCredential(apiKey), new OpenAIClientOptions
+ {
+ Endpoint = new Uri(endPoint),
+ });
+ #region code_snippet_4
+ var assistantAgent = new OpenAIChatAgent(
+ chatClient: openAIClient.GetChatClient(model),
+ name: "assistant",
+ systemMessage: "You are an assistant that convert user input to upper case.",
+ functions: [
+ this.UpperCaseFunctionContract.ToChatTool(), // The FunctionDefinition object for the UpperCase function
+ ])
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+
+ var response = await assistantAgent.SendAsync("hello");
+ response.Should().BeOfType();
+ var toolCallMessage = (ToolCallMessage)response;
+ toolCallMessage.ToolCalls.Count().Should().Be(1);
+ toolCallMessage.ToolCalls.First().FunctionName.Should().Be("UpperCase");
+ #endregion code_snippet_4
+ }
+
+ public async Task CodeSnippet5()
+ {
+ // get OpenAI Key and create config
+ var apiKey = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY");
+ string endPoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT"); // change to your endpoint
+ var model = "gpt-4o-mini";
+ var openAIClient = new OpenAIClient(new System.ClientModel.ApiKeyCredential(apiKey), new OpenAIClientOptions
+ {
+ Endpoint = new Uri(endPoint),
+ });
+ #region code_snippet_5
+ var functionCallMiddleware = new FunctionCallMiddleware(
+ functions: [this.UpperCaseFunctionContract],
+ functionMap: new Dictionary>>()
+ {
+ { this.UpperCaseFunctionContract.Name, this.UpperCase },
+ });
+ var assistantAgent = new OpenAIChatAgent(
+ name: "assistant",
+ systemMessage: "You are an assistant that convert user input to upper case.",
+ chatClient: openAIClient.GetChatClient(model))
+ .RegisterMessageConnector()
+ .RegisterStreamingMiddleware(functionCallMiddleware);
+
+ var response = await assistantAgent.SendAsync("hello");
+ response.Should().BeOfType();
+ response.From.Should().Be("assistant");
+ var textMessage = (TextMessage)response;
+ textMessage.Content.Should().Be("HELLO");
+ #endregion code_snippet_5
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/FunctionCallCodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/FunctionCallCodeSnippet.cs
new file mode 100644
index 00000000000..854a385dc34
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/FunctionCallCodeSnippet.cs
@@ -0,0 +1,148 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// FunctionCallCodeSnippet.cs
+
+using AutoGen;
+using AutoGen.Core;
+using FluentAssertions;
+
+public partial class FunctionCallCodeSnippet
+{
+ public async Task CodeSnippet4()
+ {
+ // get OpenAI Key and create config
+ var apiKey = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY");
+ string endPoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT"); // change to your endpoint
+
+ var llmConfig = new AzureOpenAIConfig(
+ endpoint: endPoint,
+ deploymentName: "gpt-3.5-turbo-16k", // change to your deployment name
+ apiKey: apiKey);
+ #region code_snippet_4
+ var function = new TypeSafeFunctionCall();
+ var assistantAgent = new AssistantAgent(
+ name: "assistant",
+ systemMessage: "You are an assistant that convert user input to upper case.",
+ llmConfig: new ConversableAgentConfig
+ {
+ Temperature = 0,
+ ConfigList = new[]
+ {
+ llmConfig
+ },
+ FunctionContracts = new[]
+ {
+ function.WeatherReportFunctionContract,
+ },
+ });
+
+ var response = await assistantAgent.SendAsync("hello What's the weather in Seattle today? today is 2024-01-01");
+ response.Should().BeOfType();
+ var toolCallMessage = (ToolCallMessage)response;
+ toolCallMessage.ToolCalls.Count().Should().Be(1);
+ toolCallMessage.ToolCalls[0].FunctionName.Should().Be("WeatherReport");
+ toolCallMessage.ToolCalls[0].FunctionArguments.Should().Be(@"{""location"":""Seattle"",""date"":""2024-01-01""}");
+ #endregion code_snippet_4
+ }
+
+
+ public async Task CodeSnippet6()
+ {
+ // get OpenAI Key and create config
+ var apiKey = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY");
+ string endPoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT"); // change to your endpoint
+
+ var llmConfig = new AzureOpenAIConfig(
+ endpoint: endPoint,
+ deploymentName: "gpt-3.5-turbo-16k", // change to your deployment name
+ apiKey: apiKey);
+ #region code_snippet_6
+ var function = new TypeSafeFunctionCall();
+ var assistantAgent = new AssistantAgent(
+ name: "assistant",
+ llmConfig: new ConversableAgentConfig
+ {
+ Temperature = 0,
+ ConfigList = new[]
+ {
+ llmConfig
+ },
+ FunctionContracts = new[]
+ {
+ function.WeatherReportFunctionContract,
+ },
+ },
+ functionMap: new Dictionary>>
+ {
+ { function.WeatherReportFunctionContract.Name, function.WeatherReportWrapper }, // The function wrapper for the weather report function
+ });
+
+ #endregion code_snippet_6
+
+ #region code_snippet_6_1
+ var response = await assistantAgent.SendAsync("What's the weather in Seattle today? today is 2024-01-01");
+ response.Should().BeOfType();
+ var textMessage = (TextMessage)response;
+ textMessage.Content.Should().Be("Weather report for Seattle on 2024-01-01 is sunny");
+ #endregion code_snippet_6_1
+ }
+
+ public async Task OverriderFunctionContractAsync()
+ {
+ IAgent agent = default;
+ IEnumerable messages = new List();
+ #region overrider_function_contract
+ var function = new TypeSafeFunctionCall();
+ var reply = agent.GenerateReplyAsync(messages, new GenerateReplyOptions
+ {
+ Functions = new[] { function.WeatherReportFunctionContract },
+ });
+ #endregion overrider_function_contract
+ }
+
+ public async Task RegisterFunctionCallMiddlewareAsync()
+ {
+ IAgent agent = default;
+ #region register_function_call_middleware
+ var function = new TypeSafeFunctionCall();
+ var functionCallMiddleware = new FunctionCallMiddleware(
+ functions: new[] { function.WeatherReportFunctionContract },
+ functionMap: new Dictionary>>
+ {
+ { function.WeatherReportFunctionContract.Name, function.WeatherReportWrapper },
+ });
+
+ agent = agent!.RegisterMiddleware(functionCallMiddleware);
+ var reply = await agent.SendAsync("What's the weather in Seattle today? today is 2024-01-01");
+ #endregion register_function_call_middleware
+ }
+
+ public async Task TwoAgentWeatherChatTestAsync()
+ {
+ var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new ArgumentException("AZURE_OPENAI_API_KEY is not set");
+ var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentException("AZURE_OPENAI_ENDPOINT is not set");
+ var deploymentName = "gpt-35-turbo-16k";
+ var config = new AzureOpenAIConfig(endpoint, deploymentName, key);
+ #region two_agent_weather_chat
+ var function = new TypeSafeFunctionCall();
+ var assistant = new AssistantAgent(
+ "assistant",
+ llmConfig: new ConversableAgentConfig
+ {
+ ConfigList = new[] { config },
+ FunctionContracts = new[]
+ {
+ function.WeatherReportFunctionContract,
+ },
+ });
+
+ var user = new UserProxyAgent(
+ name: "user",
+ functionMap: new Dictionary>>
+ {
+ { function.WeatherReportFunctionContract.Name, function.WeatherReportWrapper },
+ });
+
+ await user.InitiateChatAsync(assistant, "what's weather in Seattle today, today is 2024-01-01", 10);
+ #endregion two_agent_weather_chat
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/GetStartCodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/GetStartCodeSnippet.cs
new file mode 100644
index 00000000000..c5ff7b77033
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/GetStartCodeSnippet.cs
@@ -0,0 +1,41 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// GetStartCodeSnippet.cs
+
+#region snippet_GetStartCodeSnippet
+using AutoGen;
+using AutoGen.Core;
+using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
+using OpenAI;
+#endregion snippet_GetStartCodeSnippet
+
+public class GetStartCodeSnippet
+{
+ public async Task CodeSnippet1()
+ {
+ #region code_snippet_1
+ var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
+ var openAIClient = new OpenAIClient(openAIKey);
+ var model = "gpt-4o-mini";
+
+ var assistantAgent = new OpenAIChatAgent(
+ name: "assistant",
+ systemMessage: "You are an assistant that help user to do some tasks.",
+ chatClient: openAIClient.GetChatClient(model))
+ .RegisterMessageConnector()
+ .RegisterPrintMessage(); // register a hook to print message nicely to console
+
+ // set human input mode to ALWAYS so that user always provide input
+ var userProxyAgent = new UserProxyAgent(
+ name: "user",
+ humanInputMode: HumanInputMode.ALWAYS)
+ .RegisterPrintMessage();
+
+ // start the conversation
+ await userProxyAgent.InitiateChatAsync(
+ receiver: assistantAgent,
+ message: "Hey assistant, please do me a favor.",
+ maxRound: 10);
+ #endregion code_snippet_1
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/MiddlewareAgentCodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/MiddlewareAgentCodeSnippet.cs
new file mode 100644
index 00000000000..1b5a9a90320
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/MiddlewareAgentCodeSnippet.cs
@@ -0,0 +1,177 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// MiddlewareAgentCodeSnippet.cs
+
+using System.Text.Json;
+using AutoGen.Core;
+using AutoGen.OpenAI;
+using FluentAssertions;
+
+namespace AutoGen.BasicSample.CodeSnippet;
+
+public class MiddlewareAgentCodeSnippet
+{
+ public async Task CreateMiddlewareAgentAsync()
+ {
+ #region create_middleware_agent_with_original_agent
+ // Create an agent that always replies "Hi!"
+ IAgent agent = new DefaultReplyAgent(name: "assistant", defaultReply: "Hi!");
+
+ // Create a middleware agent on top of default reply agent
+ var middlewareAgent = new MiddlewareAgent(innerAgent: agent);
+ middlewareAgent.Use(async (messages, options, agent, ct) =>
+ {
+ if (messages.Last() is TextMessage lastMessage && lastMessage.Content.Contains("Hello World"))
+ {
+ lastMessage.Content = $"[middleware 0] {lastMessage.Content}";
+ return lastMessage;
+ }
+
+ return await agent.GenerateReplyAsync(messages, options, ct);
+ });
+
+ var reply = await middlewareAgent.SendAsync("Hello World");
+ reply.GetContent().Should().Be("[middleware 0] Hello World");
+ reply = await middlewareAgent.SendAsync("Hello AI!");
+ reply.GetContent().Should().Be("Hi!");
+ #endregion create_middleware_agent_with_original_agent
+
+ #region register_middleware_agent
+ middlewareAgent = agent.RegisterMiddleware(async (messages, options, agent, ct) =>
+ {
+ if (messages.Last() is TextMessage lastMessage && lastMessage.Content.Contains("Hello World"))
+ {
+ lastMessage.Content = $"[middleware 0] {lastMessage.Content}";
+ return lastMessage;
+ }
+
+ return await agent.GenerateReplyAsync(messages, options, ct);
+ });
+ #endregion register_middleware_agent
+
+ #region short_circuit_middleware_agent
+ // This middleware will short circuit the agent and return a message directly.
+ middlewareAgent.Use(async (messages, options, agent, ct) =>
+ {
+ return new TextMessage(Role.Assistant, $"[middleware shortcut]");
+ });
+ #endregion short_circuit_middleware_agent
+ }
+
+ public async Task RegisterStreamingMiddlewareAsync()
+ {
+ IStreamingAgent streamingAgent = default;
+ #region register_streaming_middleware
+ var connector = new OpenAIChatRequestMessageConnector();
+ var agent = streamingAgent!
+ .RegisterStreamingMiddleware(connector);
+ #endregion register_streaming_middleware
+ }
+
+ public async Task CodeSnippet1()
+ {
+ #region code_snippet_1
+ // Create an agent that always replies "Hello World"
+ IAgent agent = new DefaultReplyAgent(name: "assistant", defaultReply: "Hello World");
+
+ // Create a middleware agent on top of default reply agent
+ var middlewareAgent = new MiddlewareAgent(innerAgent: agent);
+
+ // Since no middleware is added, middlewareAgent will simply proxy into the inner agent to generate reply.
+ var reply = await middlewareAgent.SendAsync("Hello World");
+ reply.From.Should().Be("assistant");
+ reply.GetContent().Should().Be("Hello World");
+ #endregion code_snippet_1
+
+ #region code_snippet_2
+ middlewareAgent.Use(async (messages, options, agent, ct) =>
+ {
+ var lastMessage = messages.Last() as TextMessage;
+ lastMessage.Content = $"[middleware 0] {lastMessage.Content}";
+ return await agent.GenerateReplyAsync(messages, options, ct);
+ });
+
+ reply = await middlewareAgent.SendAsync("Hello World");
+ reply.Should().BeOfType();
+ var textReply = (TextMessage)reply;
+ textReply.Content.Should().Be("[middleware 0] Hello World");
+ #endregion code_snippet_2
+ #region code_snippet_2_1
+ middlewareAgent = agent.RegisterMiddleware(async (messages, options, agnet, ct) =>
+ {
+ var lastMessage = messages.Last() as TextMessage;
+ lastMessage.Content = $"[middleware 0] {lastMessage.Content}";
+ return await agent.GenerateReplyAsync(messages, options, ct);
+ });
+
+ reply = await middlewareAgent.SendAsync("Hello World");
+ reply.GetContent().Should().Be("[middleware 0] Hello World");
+ #endregion code_snippet_2_1
+ #region code_snippet_3
+ middlewareAgent.Use(async (messages, options, agent, ct) =>
+ {
+ var lastMessage = messages.Last() as TextMessage;
+ lastMessage.Content = $"[middleware 1] {lastMessage.Content}";
+ return await agent.GenerateReplyAsync(messages, options, ct);
+ });
+
+ reply = await middlewareAgent.SendAsync("Hello World");
+ reply.GetContent().Should().Be("[middleware 0] [middleware 1] Hello World");
+ #endregion code_snippet_3
+
+ #region code_snippet_4
+ middlewareAgent.Use(async (messages, options, next, ct) =>
+ {
+ var lastMessage = messages.Last() as TextMessage;
+ lastMessage.Content = $"[middleware shortcut]";
+
+ return lastMessage;
+ });
+
+ reply = await middlewareAgent.SendAsync("Hello World");
+ reply.GetContent().Should().Be("[middleware shortcut]");
+ #endregion code_snippet_4
+
+ #region retrieve_inner_agent
+ var innerAgent = middlewareAgent.Agent;
+ #endregion retrieve_inner_agent
+
+ #region code_snippet_logging_to_console
+ var agentWithLogging = middlewareAgent.RegisterMiddleware(async (messages, options, agent, ct) =>
+ {
+ var reply = await agent.GenerateReplyAsync(messages, options, ct);
+ var formattedMessage = reply.FormatMessage();
+ Console.WriteLine(formattedMessage);
+
+ return reply;
+ });
+ #endregion code_snippet_logging_to_console
+
+ #region code_snippet_response_format_forcement
+ var jsonAgent = middlewareAgent.RegisterMiddleware(async (messages, options, agent, ct) =>
+ {
+ var maxAttempt = 5;
+ var reply = await agent.GenerateReplyAsync(messages, options, ct);
+ while (maxAttempt-- > 0)
+ {
+ if (JsonSerializer.Deserialize>(reply.GetContent()) is { } dict)
+ {
+ return reply;
+ }
+ else
+ {
+ await Task.Delay(1000);
+ var reviewPrompt = @"The format is not json, please modify your response to json format
+ -- ORIGINAL MESSAGE --
+ {reply.Content}
+ -- END OF ORIGINAL MESSAGE --
+
+ Reply again with json format.";
+ reply = await agent.SendAsync(reviewPrompt, messages, ct);
+ }
+ }
+
+ throw new Exception("agent fails to generate json response");
+ });
+ #endregion code_snippet_response_format_forcement
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/MistralAICodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/MistralAICodeSnippet.cs
new file mode 100644
index 00000000000..0ce1d840d36
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/MistralAICodeSnippet.cs
@@ -0,0 +1,86 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// MistralAICodeSnippet.cs
+
+#region using_statement
+using AutoGen.Core;
+using AutoGen.Mistral;
+using AutoGen.Mistral.Extension;
+using FluentAssertions;
+#endregion using_statement
+
+namespace AutoGen.BasicSample.CodeSnippet;
+
+#region weather_function
+public partial class MistralAgentFunction
+{
+ [Function]
+ public async Task GetWeather(string location)
+ {
+ return "The weather in " + location + " is sunny.";
+ }
+}
+#endregion weather_function
+
+internal class MistralAICodeSnippet
+{
+ public async Task CreateMistralAIClientAsync()
+ {
+ #region create_mistral_agent
+ var apiKey = Environment.GetEnvironmentVariable("MISTRAL_API_KEY") ?? throw new Exception("Missing MISTRAL_API_KEY environment variable");
+ var client = new MistralClient(apiKey: apiKey);
+ var agent = new MistralClientAgent(
+ client: client,
+ name: "MistralAI",
+ model: MistralAIModelID.OPEN_MISTRAL_7B)
+ .RegisterMessageConnector(); // support more AutoGen built-in message types.
+
+ await agent.SendAsync("Hello, how are you?");
+ #endregion create_mistral_agent
+
+ #region streaming_chat
+ var reply = agent.GenerateStreamingReplyAsync(
+ messages: [new TextMessage(Role.User, "Hello, how are you?")]
+ );
+
+ await foreach (var message in reply)
+ {
+ if (message is TextMessageUpdate textMessageUpdate && textMessageUpdate.Content is string content)
+ {
+ Console.WriteLine(content);
+ }
+ }
+ #endregion streaming_chat
+ }
+
+ public async Task MistralAIChatAgentGetWeatherToolUsageAsync()
+ {
+ #region create_mistral_function_call_agent
+ var apiKey = Environment.GetEnvironmentVariable("MISTRAL_API_KEY") ?? throw new Exception("Missing MISTRAL_API_KEY environment variable");
+ var client = new MistralClient(apiKey: apiKey);
+ var agent = new MistralClientAgent(
+ client: client,
+ name: "MistralAI",
+ model: MistralAIModelID.MISTRAL_SMALL_LATEST)
+ .RegisterMessageConnector(); // support more AutoGen built-in message types like ToolCallMessage and ToolCallResultMessage
+ #endregion create_mistral_function_call_agent
+
+ #region create_get_weather_function_call_middleware
+ var mistralFunctions = new MistralAgentFunction();
+ var functionCallMiddleware = new FunctionCallMiddleware(
+ functions: [mistralFunctions.GetWeatherFunctionContract],
+ functionMap: new Dictionary>> // with functionMap, the function will be automatically triggered if the tool name matches one of the keys.
+ {
+ { mistralFunctions.GetWeatherFunctionContract.Name, mistralFunctions.GetWeather }
+ });
+ #endregion create_get_weather_function_call_middleware
+
+ #region register_function_call_middleware
+ agent = agent.RegisterStreamingMiddleware(functionCallMiddleware);
+ #endregion register_function_call_middleware
+
+ #region send_message_with_function_call
+ var reply = await agent.SendAsync("What is the weather in Seattle?");
+ reply.GetContent().Should().Be("The weather in Seattle is sunny.");
+ #endregion send_message_with_function_call
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/OpenAICodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/OpenAICodeSnippet.cs
new file mode 100644
index 00000000000..60520078e72
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/OpenAICodeSnippet.cs
@@ -0,0 +1,135 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAICodeSnippet.cs
+
+#region using_statement
+using AutoGen.Core;
+using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
+#endregion using_statement
+using FluentAssertions;
+using OpenAI;
+using OpenAI.Chat;
+
+namespace AutoGen.BasicSample.CodeSnippet;
+#region weather_function
+public partial class Functions
+{
+ [Function]
+ public async Task GetWeather(string location)
+ {
+ return "The weather in " + location + " is sunny.";
+ }
+}
+#endregion weather_function
+public partial class OpenAICodeSnippet
+{
+ [Function]
+ public async Task GetWeather(string location)
+ {
+ return "The weather in " + location + " is sunny.";
+ }
+
+ public async Task CreateOpenAIChatAgentAsync()
+ {
+ #region create_openai_chat_agent
+ var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
+ var modelId = "gpt-4o-mini";
+ var openAIClient = new OpenAIClient(openAIKey);
+
+ // create an open ai chat agent
+ var openAIChatAgent = new OpenAIChatAgent(
+ chatClient: openAIClient.GetChatClient(modelId),
+ name: "assistant",
+ systemMessage: "You are an assistant that help user to do some tasks.");
+
+ // OpenAIChatAgent supports the following message types:
+ // - IMessage where ChatRequestMessage is from Azure.AI.OpenAI
+
+ var helloMessage = new UserChatMessage("Hello");
+
+ // Use MessageEnvelope.Create to create an IMessage
+ var chatMessageContent = MessageEnvelope.Create(helloMessage);
+ var reply = await openAIChatAgent.SendAsync(chatMessageContent);
+
+ // The type of reply is MessageEnvelope where ChatResponseMessage is from Azure.AI.OpenAI
+ reply.Should().BeOfType>();
+
+ // You can un-envelop the reply to get the ChatResponseMessage
+ ChatCompletion response = reply.As>().Content;
+ response.Role.Should().Be(ChatMessageRole.Assistant);
+ #endregion create_openai_chat_agent
+
+ #region create_openai_chat_agent_streaming
+ var streamingReply = openAIChatAgent.GenerateStreamingReplyAsync(new[] { chatMessageContent });
+
+ await foreach (var streamingMessage in streamingReply)
+ {
+ streamingMessage.Should().BeOfType>();
+ streamingMessage.As>().Content.Role.Should().Be(ChatMessageRole.Assistant);
+ }
+ #endregion create_openai_chat_agent_streaming
+
+ #region register_openai_chat_message_connector
+ // register message connector to support more message types
+ var agentWithConnector = openAIChatAgent
+ .RegisterMessageConnector();
+
+ // now the agentWithConnector supports more message types
+ var messages = new IMessage[]
+ {
+ MessageEnvelope.Create(new UserChatMessage("Hello")),
+ new TextMessage(Role.Assistant, "Hello", from: "user"),
+ new MultiModalMessage(Role.Assistant,
+ [
+ new TextMessage(Role.Assistant, "Hello", from: "user"),
+ ],
+ from: "user"),
+ new TextMessage(Role.Assistant, "Hello", from: "user"), // Message type is going to be deprecated, please use TextMessage instead
+ };
+
+ foreach (var message in messages)
+ {
+ reply = await agentWithConnector.SendAsync(message);
+
+ reply.Should().BeOfType();
+ reply.As().From.Should().Be("assistant");
+ }
+ #endregion register_openai_chat_message_connector
+ }
+
+ public async Task OpenAIChatAgentGetWeatherFunctionCallAsync()
+ {
+ #region openai_chat_agent_get_weather_function_call
+ var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
+ var modelId = "gpt-3.5-turbo";
+ var openAIClient = new OpenAIClient(openAIKey);
+
+ // create an open ai chat agent
+ var openAIChatAgent = new OpenAIChatAgent(
+ chatClient: openAIClient.GetChatClient(modelId),
+ name: "assistant",
+ systemMessage: "You are an assistant that help user to do some tasks.")
+ .RegisterMessageConnector();
+
+ #endregion openai_chat_agent_get_weather_function_call
+
+ #region create_function_call_middleware
+ var functions = new Functions();
+ var functionCallMiddleware = new FunctionCallMiddleware(
+ functions: [functions.GetWeatherFunctionContract], // GetWeatherFunctionContract is auto-generated from the GetWeather function
+ functionMap: new Dictionary>>
+ {
+ { functions.GetWeatherFunctionContract.Name, functions.GetWeatherWrapper } // GetWeatherWrapper is a wrapper function for GetWeather, which is also auto-generated
+ });
+
+ openAIChatAgent = openAIChatAgent.RegisterStreamingMiddleware(functionCallMiddleware);
+ #endregion create_function_call_middleware
+
+ #region chat_agent_send_function_call
+ var reply = await openAIChatAgent.SendAsync("what is the weather in Seattle?");
+ reply.GetContent().Should().Be("The weather in Seattle is sunny.");
+ reply.GetToolCalls().Count.Should().Be(1);
+ reply.GetToolCalls().First().Should().Be(this.GetWeatherFunctionContract.Name);
+ #endregion chat_agent_send_function_call
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/PrintMessageMiddlewareCodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/PrintMessageMiddlewareCodeSnippet.cs
new file mode 100644
index 00000000000..0ac7f71a3ca
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/PrintMessageMiddlewareCodeSnippet.cs
@@ -0,0 +1,42 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// PrintMessageMiddlewareCodeSnippet.cs
+
+using AutoGen.Core;
+using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
+
+namespace AutoGen.BasicSample.CodeSnippet;
+
+internal class PrintMessageMiddlewareCodeSnippet
+{
+ public async Task PrintMessageMiddlewareAsync()
+ {
+ var config = LLMConfiguration.GetAzureOpenAIGPT3_5_Turbo();
+ var endpoint = new Uri(config.Endpoint);
+ var gpt4o = LLMConfiguration.GetOpenAIGPT4o_mini();
+ var agent = new OpenAIChatAgent(gpt4o, "assistant", config.DeploymentName)
+ .RegisterMessageConnector();
+
+ #region PrintMessageMiddleware
+ var agentWithPrintMessageMiddleware = agent
+ .RegisterPrintMessage();
+
+ await agentWithPrintMessageMiddleware.SendAsync("write a long poem");
+ #endregion PrintMessageMiddleware
+ }
+
+ public async Task PrintMessageStreamingMiddlewareAsync()
+ {
+ var config = LLMConfiguration.GetAzureOpenAIGPT3_5_Turbo();
+ var endpoint = new Uri(config.Endpoint);
+ var gpt4o = LLMConfiguration.GetOpenAIGPT4o_mini();
+
+ #region print_message_streaming
+ var streamingAgent = new OpenAIChatAgent(gpt4o, "assistant")
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+
+ await streamingAgent.SendAsync("write a long poem");
+ #endregion print_message_streaming
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/RunCodeSnippetCodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/RunCodeSnippetCodeSnippet.cs
new file mode 100644
index 00000000000..b087beb993b
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/RunCodeSnippetCodeSnippet.cs
@@ -0,0 +1,80 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// RunCodeSnippetCodeSnippet.cs
+
+#region code_snippet_0_1
+using AutoGen.Core;
+using AutoGen.DotnetInteractive;
+using AutoGen.DotnetInteractive.Extension;
+#endregion code_snippet_0_1
+
+namespace AutoGen.BasicSample.CodeSnippet;
+public class RunCodeSnippetCodeSnippet
+{
+ public async Task CodeSnippet1()
+ {
+ IAgent agent = new DefaultReplyAgent("agent", "Hello World");
+
+ #region code_snippet_1_1
+ var kernel = DotnetInteractiveKernelBuilder
+ .CreateDefaultInProcessKernelBuilder() // add C# and F# kernels
+ .Build();
+ #endregion code_snippet_1_1
+
+ #region code_snippet_1_2
+ // register middleware to execute code block
+ var dotnetCodeAgent = agent
+ .RegisterMiddleware(async (msgs, option, innerAgent, ct) =>
+ {
+ var lastMessage = msgs.LastOrDefault();
+ if (lastMessage == null || lastMessage.GetContent() is null)
+ {
+ return await innerAgent.GenerateReplyAsync(msgs, option, ct);
+ }
+
+ if (lastMessage.ExtractCodeBlock("```csharp", "```") is string codeSnippet)
+ {
+ // execute code snippet
+ var result = await kernel.RunSubmitCodeCommandAsync(codeSnippet, "csharp");
+ return new TextMessage(Role.Assistant, result, from: agent.Name);
+ }
+ else
+ {
+ // no code block found, invoke next agent
+ return await innerAgent.GenerateReplyAsync(msgs, option, ct);
+ }
+ });
+
+ var codeSnippet = @"
+ ```csharp
+ Console.WriteLine(""Hello World"");
+ ```";
+
+ await dotnetCodeAgent.SendAsync(codeSnippet);
+ // output: Hello World
+ #endregion code_snippet_1_2
+
+ #region code_snippet_1_3
+ var content = @"
+ ```csharp
+ // This is csharp code snippet
+ ```
+
+ ```python
+ // This is python code snippet
+ ```
+ ";
+ #endregion code_snippet_1_3
+
+ #region code_snippet_1_4
+ var pythonKernel = DotnetInteractiveKernelBuilder
+ .CreateDefaultInProcessKernelBuilder()
+ .AddPythonKernel(venv: "python3")
+ .Build();
+
+ var pythonCode = """
+ print('Hello from Python!')
+ """;
+ var result = await pythonKernel.RunSubmitCodeCommandAsync(pythonCode, "python3");
+ #endregion code_snippet_1_4
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/SemanticKernelCodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/SemanticKernelCodeSnippet.cs
new file mode 100644
index 00000000000..20dd12d90ce
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/SemanticKernelCodeSnippet.cs
@@ -0,0 +1,101 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// SemanticKernelCodeSnippet.cs
+
+using AutoGen.Core;
+using AutoGen.SemanticKernel;
+using AutoGen.SemanticKernel.Extension;
+using FluentAssertions;
+using Microsoft.SemanticKernel;
+using Microsoft.SemanticKernel.ChatCompletion;
+
+namespace AutoGen.BasicSample.CodeSnippet;
+
+public class SemanticKernelCodeSnippet
+{
+ public async Task GetWeather(string location)
+ {
+ return "The weather in " + location + " is sunny.";
+ }
+ public async Task CreateSemanticKernelAgentAsync()
+ {
+ #region create_semantic_kernel_agent
+ var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
+ var modelId = "gpt-3.5-turbo";
+ var builder = Kernel.CreateBuilder()
+ .AddOpenAIChatCompletion(modelId: modelId, apiKey: openAIKey);
+ var kernel = builder.Build();
+
+ // create a semantic kernel agent
+ var semanticKernelAgent = new SemanticKernelAgent(
+ kernel: kernel,
+ name: "assistant",
+ systemMessage: "You are an assistant that help user to do some tasks.");
+
+ // SemanticKernelAgent supports the following message types:
+ // - IMessage where ChatMessageContent is from Azure.AI.OpenAI
+
+ var helloMessage = new ChatMessageContent(AuthorRole.User, "Hello");
+
+ // Use MessageEnvelope.Create to create an IMessage
+ var chatMessageContent = MessageEnvelope.Create(helloMessage);
+ var reply = await semanticKernelAgent.SendAsync(chatMessageContent);
+
+ // The type of reply is MessageEnvelope where ChatResponseMessage is from Azure.AI.OpenAI
+ reply.Should().BeOfType>();
+
+ // You can un-envelop the reply to get the ChatResponseMessage
+ ChatMessageContent response = reply.As>().Content;
+ response.Role.Should().Be(AuthorRole.Assistant);
+ #endregion create_semantic_kernel_agent
+
+ #region create_semantic_kernel_agent_streaming
+ var streamingReply = semanticKernelAgent.GenerateStreamingReplyAsync(new[] { chatMessageContent });
+
+ await foreach (var streamingMessage in streamingReply)
+ {
+ streamingMessage.Should().BeOfType>();
+ streamingMessage.As>().From.Should().Be("assistant");
+ }
+ #endregion create_semantic_kernel_agent_streaming
+ }
+
+ public async Task SemanticKernelChatMessageContentConnector()
+ {
+ #region register_semantic_kernel_chat_message_content_connector
+ var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
+ var modelId = "gpt-3.5-turbo";
+ var builder = Kernel.CreateBuilder()
+ .AddOpenAIChatCompletion(modelId: modelId, apiKey: openAIKey);
+ var kernel = builder.Build();
+
+ // create a semantic kernel agent
+ var semanticKernelAgent = new SemanticKernelAgent(
+ kernel: kernel,
+ name: "assistant",
+ systemMessage: "You are an assistant that help user to do some tasks.");
+
+ // Register the connector middleware to the kernel agent
+ var semanticKernelAgentWithConnector = semanticKernelAgent
+ .RegisterMessageConnector();
+
+ // now semanticKernelAgentWithConnector supports more message types
+ IMessage[] messages = [
+ MessageEnvelope.Create(new ChatMessageContent(AuthorRole.User, "Hello")),
+ new TextMessage(Role.Assistant, "Hello", from: "user"),
+ new MultiModalMessage(Role.Assistant,
+ [
+ new TextMessage(Role.Assistant, "Hello", from: "user"),
+ ],
+ from: "user"),
+ ];
+
+ foreach (var message in messages)
+ {
+ var reply = await semanticKernelAgentWithConnector.SendAsync(message);
+
+ // SemanticKernelChatMessageContentConnector will convert the reply message to TextMessage
+ reply.Should().BeOfType();
+ }
+ #endregion register_semantic_kernel_chat_message_content_connector
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/TypeSafeFunctionCallCodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/TypeSafeFunctionCallCodeSnippet.cs
new file mode 100644
index 00000000000..667705835eb
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/TypeSafeFunctionCallCodeSnippet.cs
@@ -0,0 +1,119 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// TypeSafeFunctionCallCodeSnippet.cs
+
+using System.Text.Json;
+using AutoGen.OpenAI.Extension;
+#region weather_report_using_statement
+using AutoGen.Core;
+#endregion weather_report_using_statement
+
+#region weather_report
+public partial class TypeSafeFunctionCall
+{
+ ///
+ /// Get weather report
+ ///
+ /// city
+ /// date
+ [Function]
+ public async Task WeatherReport(string city, string date)
+ {
+ return $"Weather report for {city} on {date} is sunny";
+ }
+}
+#endregion weather_report
+
+public partial class TypeSafeFunctionCall
+{
+ public async Task Consume()
+ {
+ #region weather_report_consume
+ var functionInstance = new TypeSafeFunctionCall();
+
+ // Get the generated function definition
+ var functionDefiniton = functionInstance.WeatherReportFunctionContract.ToChatTool();
+
+ // Get the generated function wrapper
+ Func> functionWrapper = functionInstance.WeatherReportWrapper;
+
+ // ...
+ #endregion weather_report_consume
+ }
+}
+#region code_snippet_3
+// file: FunctionCall.cs
+
+public partial class TypeSafeFunctionCall
+{
+ ///
+ /// convert input to upper case
+ ///
+ /// input
+ [Function]
+ public async Task UpperCase(string input)
+ {
+ var result = input.ToUpper();
+ return result;
+ }
+}
+#endregion code_snippet_3
+
+public class TypeSafeFunctionCallCodeSnippet
+{
+ public async Task UpperCase(string input)
+ {
+ var result = input.ToUpper();
+ return result;
+ }
+
+ #region code_snippet_1
+ // file: FunctionDefinition.generated.cs
+ public FunctionContract WeatherReportFunctionContract
+ {
+ get => new FunctionContract
+ {
+ ClassName = @"TypeSafeFunctionCall",
+ Name = @"WeatherReport",
+ Description = @"Get weather report",
+ ReturnType = typeof(Task),
+ Parameters = new global::AutoGen.Core.FunctionParameterContract[]
+ {
+ new FunctionParameterContract
+ {
+ Name = @"city",
+ Description = @"city",
+ ParameterType = typeof(string),
+ IsRequired = true,
+ },
+ new FunctionParameterContract
+ {
+ Name = @"date",
+ Description = @"date",
+ ParameterType = typeof(string),
+ IsRequired = true,
+ },
+ },
+ };
+ }
+ #endregion code_snippet_1
+
+ #region code_snippet_2
+ // file: FunctionDefinition.generated.cs
+ private class UpperCaseSchema
+ {
+ public string input { get; set; }
+ }
+
+ public Task UpperCaseWrapper(string arguments)
+ {
+ var schema = JsonSerializer.Deserialize(
+ arguments,
+ new JsonSerializerOptions
+ {
+ PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
+ });
+
+ return UpperCase(schema.input);
+ }
+ #endregion code_snippet_2
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/UserProxyAgentCodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/UserProxyAgentCodeSnippet.cs
new file mode 100644
index 00000000000..85aecae959e
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/UserProxyAgentCodeSnippet.cs
@@ -0,0 +1,20 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// UserProxyAgentCodeSnippet.cs
+using AutoGen.Core;
+
+namespace AutoGen.BasicSample.CodeSnippet;
+
+public class UserProxyAgentCodeSnippet
+{
+ public async Task CodeSnippet1()
+ {
+ #region code_snippet_1
+ // create a user proxy agent which always ask user for input
+ var agent = new UserProxyAgent(
+ name: "user",
+ humanInputMode: HumanInputMode.ALWAYS);
+
+ await agent.SendAsync("hello");
+ #endregion code_snippet_1
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example01_AssistantAgent.cs b/dotnet/sample/AutoGen.BasicSamples/Example01_AssistantAgent.cs
new file mode 100644
index 00000000000..40c88102588
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/Example01_AssistantAgent.cs
@@ -0,0 +1,42 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Example01_AssistantAgent.cs
+
+using AutoGen;
+using AutoGen.BasicSample;
+using AutoGen.Core;
+using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
+using FluentAssertions;
+
+///
+/// This example shows the basic usage of class.
+///
+public static class Example01_AssistantAgent
+{
+ public static async Task RunAsync()
+ {
+ var gpt4oMini = LLMConfiguration.GetOpenAIGPT4o_mini();
+ var assistantAgent = new OpenAIChatAgent(
+ chatClient: gpt4oMini,
+ name: "assistant",
+ systemMessage: "You convert what user said to all uppercase.")
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+
+ // talk to the assistant agent
+ var reply = await assistantAgent.SendAsync("hello world");
+ reply.Should().BeOfType();
+ reply.GetContent().Should().Be("HELLO WORLD");
+
+ // to carry on the conversation, pass the previous conversation history to the next call
+ var conversationHistory = new List
+ {
+ new TextMessage(Role.User, "hello world"), // first message
+ reply, // reply from assistant agent
+ };
+
+ reply = await assistantAgent.SendAsync("hello world again", conversationHistory);
+ reply.Should().BeOfType();
+ reply.GetContent().Should().Be("HELLO WORLD AGAIN");
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example02_TwoAgent_MathChat.cs b/dotnet/sample/AutoGen.BasicSamples/Example02_TwoAgent_MathChat.cs
new file mode 100644
index 00000000000..b2dd9726b4b
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/Example02_TwoAgent_MathChat.cs
@@ -0,0 +1,75 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Example02_TwoAgent_MathChat.cs
+
+using AutoGen.BasicSample;
+using AutoGen.Core;
+using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
+using FluentAssertions;
+public static class Example02_TwoAgent_MathChat
+{
+ public static async Task RunAsync()
+ {
+ #region code_snippet_1
+ var gpt4oMini = LLMConfiguration.GetOpenAIGPT4o_mini();
+
+
+ // create teacher agent
+ // teacher agent will create math questions
+ var teacher = new OpenAIChatAgent(
+ chatClient: gpt4oMini,
+ name: "teacher",
+ systemMessage: @"You are a teacher that create pre-school math question for student and check answer.
+ If the answer is correct, you stop the conversation by saying [COMPLETE].
+ If the answer is wrong, you ask student to fix it.")
+ .RegisterMessageConnector()
+ .RegisterMiddleware(async (msgs, option, agent, _) =>
+ {
+ var reply = await agent.GenerateReplyAsync(msgs, option);
+ if (reply.GetContent()?.ToLower().Contains("complete") is true)
+ {
+ return new TextMessage(Role.Assistant, GroupChatExtension.TERMINATE, from: reply.From);
+ }
+
+ return reply;
+ })
+ .RegisterPrintMessage();
+
+ // create student agent
+ // student agent will answer the math questions
+ var student = new OpenAIChatAgent(
+ chatClient: gpt4oMini,
+ name: "student",
+ systemMessage: "You are a student that answer question from teacher")
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+
+ // start the conversation
+ var conversation = await student.InitiateChatAsync(
+ receiver: teacher,
+ message: "Hey teacher, please create math question for me.",
+ maxRound: 10);
+
+ // output
+ // Message from teacher
+ // --------------------
+ // content: Of course!Here's a math question for you:
+ //
+ // What is 2 + 3 ?
+ // --------------------
+ //
+ // Message from student
+ // --------------------
+ // content: The sum of 2 and 3 is 5.
+ // --------------------
+ //
+ // Message from teacher
+ // --------------------
+ // content: [GROUPCHAT_TERMINATE]
+ // --------------------
+ #endregion code_snippet_1
+
+ conversation.Count().Should().BeLessThan(10);
+ conversation.Last().IsGroupChatTerminateMessage().Should().BeTrue();
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example03_Agent_FunctionCall.cs b/dotnet/sample/AutoGen.BasicSamples/Example03_Agent_FunctionCall.cs
new file mode 100644
index 00000000000..94b67a94b14
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/Example03_Agent_FunctionCall.cs
@@ -0,0 +1,104 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Example03_Agent_FunctionCall.cs
+
+using AutoGen.BasicSample;
+using AutoGen.Core;
+using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
+using FluentAssertions;
+
+///
+/// This example shows how to add type-safe function call to an agent.
+///
+public partial class Example03_Agent_FunctionCall
+{
+ ///
+ /// upper case the message when asked.
+ ///
+ ///
+ [Function]
+ public async Task UpperCase(string message)
+ {
+ return message.ToUpper();
+ }
+
+ ///
+ /// Concatenate strings.
+ ///
+ /// strings to concatenate
+ [Function]
+ public async Task ConcatString(string[] strings)
+ {
+ return string.Join(" ", strings);
+ }
+
+ ///
+ /// calculate tax
+ ///
+ /// price, should be an integer
+ /// tax rate, should be in range (0, 1)
+ [FunctionAttribute]
+ public async Task CalculateTax(int price, float taxRate)
+ {
+ return $"tax is {price * taxRate}";
+ }
+
+ public static async Task RunAsync()
+ {
+ var instance = new Example03_Agent_FunctionCall();
+ var gpt4o = LLMConfiguration.GetOpenAIGPT4o_mini();
+
+ // AutoGen makes use of AutoGen.SourceGenerator to automatically generate FunctionDefinition and FunctionCallWrapper for you.
+ // The FunctionDefinition will be created based on function signature and XML documentation.
+ // The return type of type-safe function needs to be Task. And to get the best performance, please try only use primitive types and arrays of primitive types as parameters.
+ var toolCallMiddleware = new FunctionCallMiddleware(
+ functions: [
+ instance.ConcatStringFunctionContract,
+ instance.UpperCaseFunctionContract,
+ instance.CalculateTaxFunctionContract,
+ ],
+ functionMap: new Dictionary>>
+ {
+ { nameof(instance.ConcatString), instance.ConcatStringWrapper },
+ { nameof(instance.UpperCase), instance.UpperCaseWrapper },
+ { nameof(instance.CalculateTax), instance.CalculateTaxWrapper },
+ });
+
+ var agent = new OpenAIChatAgent(
+ chatClient: gpt4o,
+ name: "agent",
+ systemMessage: "You are a helpful AI assistant")
+ .RegisterMessageConnector()
+ .RegisterStreamingMiddleware(toolCallMiddleware)
+ .RegisterPrintMessage();
+
+ // talk to the assistant agent
+ var upperCase = await agent.SendAsync("convert to upper case: hello world");
+ upperCase.GetContent()?.Should().Be("HELLO WORLD");
+ upperCase.Should().BeOfType();
+ upperCase.GetToolCalls().Should().HaveCount(1);
+ upperCase.GetToolCalls().First().FunctionName.Should().Be(nameof(UpperCase));
+
+ var concatString = await agent.SendAsync("concatenate strings: a, b, c, d, e");
+ concatString.GetContent()?.Should().Be("a b c d e");
+ concatString.Should().BeOfType();
+ concatString.GetToolCalls().Should().HaveCount(1);
+ concatString.GetToolCalls().First().FunctionName.Should().Be(nameof(ConcatString));
+
+ var calculateTax = await agent.SendAsync("calculate tax: 100, 0.1");
+ calculateTax.GetContent().Should().Be("tax is 10");
+ calculateTax.Should().BeOfType();
+ calculateTax.GetToolCalls().Should().HaveCount(1);
+ calculateTax.GetToolCalls().First().FunctionName.Should().Be(nameof(CalculateTax));
+
+ // parallel function calls
+ var calculateTaxes = await agent.SendAsync("calculate tax: 100, 0.1; calculate tax: 200, 0.2");
+ calculateTaxes.GetContent().Should().Be("tax is 10\ntax is 40"); // "tax is 10\n tax is 40
+ calculateTaxes.Should().BeOfType();
+ calculateTaxes.GetToolCalls().Should().HaveCount(2);
+ calculateTaxes.GetToolCalls().First().FunctionName.Should().Be(nameof(CalculateTax));
+
+ // send aggregate message back to llm to get the final result
+ var finalResult = await agent.SendAsync(calculateTaxes);
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example04_Dynamic_GroupChat_Coding_Task.cs b/dotnet/sample/AutoGen.BasicSamples/Example04_Dynamic_GroupChat_Coding_Task.cs
new file mode 100644
index 00000000000..f90816d890e
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/Example04_Dynamic_GroupChat_Coding_Task.cs
@@ -0,0 +1,261 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Example04_Dynamic_GroupChat_Coding_Task.cs
+
+using AutoGen.BasicSample;
+using AutoGen.Core;
+using AutoGen.DotnetInteractive;
+using AutoGen.DotnetInteractive.Extension;
+using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
+using FluentAssertions;
+
+public partial class Example04_Dynamic_GroupChat_Coding_Task
+{
+ public static async Task RunAsync()
+ {
+ var instance = new Example04_Dynamic_GroupChat_Coding_Task();
+
+ var kernel = DotnetInteractiveKernelBuilder
+ .CreateDefaultInProcessKernelBuilder()
+ .AddPythonKernel("python3")
+ .Build();
+
+ var gpt4o = LLMConfiguration.GetOpenAIGPT4o_mini();
+
+ var groupAdmin = new OpenAIChatAgent(
+ chatClient: gpt4o,
+ name: "groupAdmin",
+ systemMessage: "You are the admin of the group chat")
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+
+ var userProxy = new DefaultReplyAgent(name: "user", defaultReply: GroupChatExtension.TERMINATE)
+ .RegisterPrintMessage();
+
+ // Create admin agent
+ var admin = new OpenAIChatAgent(
+ chatClient: gpt4o,
+ name: "admin",
+ systemMessage: """
+ You are a manager who takes coding problem from user and resolve problem by splitting them into small tasks and assign each task to the most appropriate agent.
+ Here's available agents who you can assign task to:
+ - coder: write python code to resolve task
+ - runner: run python code from coder
+
+ The workflow is as follows:
+ - You take the coding problem from user
+ - You break the problem into small tasks. For each tasks you first ask coder to write code to resolve the task. Once the code is written, you ask runner to run the code.
+ - Once a small task is resolved, you summarize the completed steps and create the next step.
+ - You repeat the above steps until the coding problem is resolved.
+
+ You can use the following json format to assign task to agents:
+ ```task
+ {
+ "to": "{agent_name}",
+ "task": "{a short description of the task}",
+ "context": "{previous context from scratchpad}"
+ }
+ ```
+
+ If you need to ask user for extra information, you can use the following format:
+ ```ask
+ {
+ "question": "{question}"
+ }
+ ```
+
+ Once the coding problem is resolved, summarize each steps and results and send the summary to the user using the following format:
+ ```summary
+ @user,
+ ```
+
+ Your reply must contain one of [task|ask|summary] to indicate the type of your message.
+ """)
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+
+ // create coder agent
+ // The coder agent is a composite agent that contains dotnet coder, code reviewer and nuget agent.
+ // The dotnet coder write dotnet code to resolve the task.
+ // The code reviewer review the code block from coder's reply.
+ // The nuget agent install nuget packages if there's any.
+ var coderAgent = new OpenAIChatAgent(
+ name: "coder",
+ chatClient: gpt4o,
+ systemMessage: @"You act as python coder, you write python code to resolve task. Once you finish writing code, ask runner to run the code for you.
+
+Here're some rules to follow on writing dotnet code:
+- put code between ```python and ```
+- Try avoid using external library
+- Always print out the result to console. Don't write code that doesn't print out anything.
+
+Use the following format to install pip package:
+```python
+%pip install
+```
+
+If your code is incorrect, Fix the error and send the code again.
+
+Here's some externel information
+- The link to mlnet repo is: https://github.com/dotnet/machinelearning. you don't need a token to use github pr api. Make sure to include a User-Agent header, otherwise github will reject it.
+")
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+
+ // code reviewer agent will review if code block from coder's reply satisfy the following conditions:
+ // - There's only one code block
+ // - The code block is csharp code block
+ // - The code block is top level statement
+ // - The code block is not using declaration
+ var codeReviewAgent = new OpenAIChatAgent(
+ chatClient: gpt4o,
+ name: "reviewer",
+ systemMessage: """
+ You are a code reviewer who reviews code from coder. You need to check if the code satisfy the following conditions:
+ - The reply from coder contains at least one code block, e.g ```python and ```
+ - There's only one code block and it's python code block
+
+ You don't check the code style, only check if the code satisfy the above conditions.
+
+ Put your comment between ```review and ```, if the code satisfies all conditions, put APPROVED in review.result field. Otherwise, put REJECTED along with comments. make sure your comment is clear and easy to understand.
+
+ ## Example 1 ##
+ ```review
+ comment: The code satisfies all conditions.
+ result: APPROVED
+ ```
+
+ ## Example 2 ##
+ ```review
+ comment: The code is inside main function. Please rewrite the code in top level statement.
+ result: REJECTED
+ ```
+
+ """)
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+
+ // create runner agent
+ // The runner agent will run the code block from coder's reply.
+ // It runs dotnet code using dotnet interactive service hook.
+ // It also truncate the output if the output is too long.
+ var runner = new DefaultReplyAgent(
+ name: "runner",
+ defaultReply: "No code available, coder, write code please")
+ .RegisterMiddleware(async (msgs, option, agent, ct) =>
+ {
+ var mostRecentCoderMessage = msgs.LastOrDefault(x => x.From == "coder") ?? throw new Exception("No coder message found");
+
+ if (mostRecentCoderMessage.ExtractCodeBlock("```python", "```") is string code)
+ {
+ var result = await kernel.RunSubmitCodeCommandAsync(code, "python");
+ // only keep the first 500 characters
+ if (result.Length > 500)
+ {
+ result = result.Substring(0, 500);
+ }
+ result = $"""
+ # [CODE_BLOCK_EXECUTION_RESULT]
+ {result}
+ """;
+
+ return new TextMessage(Role.Assistant, result, from: agent.Name);
+ }
+ else
+ {
+ return await agent.GenerateReplyAsync(msgs, option, ct);
+ }
+ })
+ .RegisterPrintMessage();
+
+ var adminToCoderTransition = Transition.Create(admin, coderAgent, async (from, to, messages) =>
+ {
+ // the last message should be from admin
+ var lastMessage = messages.Last();
+ if (lastMessage.From != admin.Name)
+ {
+ return false;
+ }
+
+ return true;
+ });
+ var coderToReviewerTransition = Transition.Create(coderAgent, codeReviewAgent);
+ var adminToRunnerTransition = Transition.Create(admin, runner, async (from, to, messages) =>
+ {
+ // the last message should be from admin
+ var lastMessage = messages.Last();
+ if (lastMessage.From != admin.Name)
+ {
+ return false;
+ }
+
+ // the previous messages should contain a message from coder
+ var coderMessage = messages.FirstOrDefault(x => x.From == coderAgent.Name);
+ if (coderMessage is null)
+ {
+ return false;
+ }
+
+ return true;
+ });
+
+ var runnerToAdminTransition = Transition.Create(runner, admin);
+
+ var reviewerToAdminTransition = Transition.Create(codeReviewAgent, admin);
+
+ var adminToUserTransition = Transition.Create(admin, userProxy, async (from, to, messages) =>
+ {
+ // the last message should be from admin
+ var lastMessage = messages.Last();
+ if (lastMessage.From != admin.Name)
+ {
+ return false;
+ }
+
+ return true;
+ });
+
+ var userToAdminTransition = Transition.Create(userProxy, admin);
+
+ var workflow = new Graph(
+ [
+ adminToCoderTransition,
+ coderToReviewerTransition,
+ reviewerToAdminTransition,
+ adminToRunnerTransition,
+ runnerToAdminTransition,
+ adminToUserTransition,
+ userToAdminTransition,
+ ]);
+
+ // create group chat
+ var groupChat = new GroupChat(
+ admin: groupAdmin,
+ members: [admin, coderAgent, runner, codeReviewAgent, userProxy],
+ workflow: workflow);
+
+ // task 1: retrieve the most recent pr from mlnet and save it in result.txt
+ var task = """
+ retrieve the most recent pr from mlnet and save it in result.txt
+ """;
+ var chatHistory = new List
+ {
+ new TextMessage(Role.Assistant, task)
+ {
+ From = userProxy.Name
+ }
+ };
+ await foreach (var message in groupChat.SendAsync(chatHistory, maxRound: 10))
+ {
+ if (message.From == admin.Name && message.GetContent().Contains("```summary"))
+ {
+ // Task complete!
+ break;
+ }
+ }
+
+ // check if the result file is created
+ var result = "result.txt";
+ File.Exists(result).Should().BeTrue();
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example05_Dalle_And_GPT4V.cs b/dotnet/sample/AutoGen.BasicSamples/Example05_Dalle_And_GPT4V.cs
new file mode 100644
index 00000000000..e8dd86474e7
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/Example05_Dalle_And_GPT4V.cs
@@ -0,0 +1,126 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Example05_Dalle_And_GPT4V.cs
+
+using AutoGen.Core;
+using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
+using FluentAssertions;
+using OpenAI;
+using OpenAI.Images;
+
+public partial class Example05_Dalle_And_GPT4V
+{
+ private readonly OpenAIClient openAIClient;
+
+ public Example05_Dalle_And_GPT4V(OpenAIClient openAIClient)
+ {
+ this.openAIClient = openAIClient;
+ }
+
+ ///
+ /// Generate image from prompt using DALL-E.
+ ///
+ /// prompt with feedback
+ ///
+ [Function]
+ public async Task GenerateImage(string prompt)
+ {
+ // TODO
+ // generate image from prompt using DALL-E
+ // and return url.
+ var option = new ImageGenerationOptions
+ {
+ Size = GeneratedImageSize.W1024xH1024,
+ Style = GeneratedImageStyle.Vivid,
+ };
+
+ var imageResponse = await openAIClient.GetImageClient("dall-e-3").GenerateImageAsync(prompt, option);
+ var imageUrl = imageResponse.Value.ImageUri.OriginalString;
+
+ return $@"// ignore this line [IMAGE_GENERATION]
+The image is generated from prompt {prompt}
+
+{imageUrl}";
+ }
+
+ public static async Task RunAsync()
+ {
+ // This example shows how to use DALL-E and GPT-4V to generate image from prompt and feedback.
+ // The DALL-E agent will generate image from prompt.
+ // The GPT-4V agent will provide feedback to DALL-E agent to help it generate better image.
+ // The conversation will be terminated when the image satisfies the condition.
+ // The image will be saved to image.jpg in current directory.
+
+ // get OpenAI Key and create config
+ var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
+ var openAIClient = new OpenAIClient(openAIKey);
+ var instance = new Example05_Dalle_And_GPT4V(openAIClient);
+ var imagePath = Path.Combine("resource", "images", "background.png");
+ if (File.Exists(imagePath))
+ {
+ File.Delete(imagePath);
+ }
+
+ var generateImageFunctionMiddleware = new FunctionCallMiddleware(
+ functions: [instance.GenerateImageFunctionContract],
+ functionMap: new Dictionary>>
+ {
+ { nameof(GenerateImage), instance.GenerateImageWrapper },
+ });
+ var dalleAgent = new OpenAIChatAgent(
+ chatClient: openAIClient.GetChatClient("gpt-4o-mini"),
+ name: "dalle",
+ systemMessage: "You are a DALL-E agent that generate image from prompt, when conversation is terminated, return the most recent image url")
+ .RegisterMessageConnector()
+ .RegisterStreamingMiddleware(generateImageFunctionMiddleware)
+ .RegisterMiddleware(async (msgs, option, agent, ct) =>
+ {
+ if (msgs.Any(msg => msg.GetContent()?.ToLower().Contains("approve") is true))
+ {
+ return new TextMessage(Role.Assistant, $"The image satisfies the condition, conversation is terminated. {GroupChatExtension.TERMINATE}");
+ }
+
+ var msgsWithoutImage = msgs.Where(msg => msg is not ImageMessage).ToList();
+ var reply = await agent.GenerateReplyAsync(msgsWithoutImage, option, ct);
+
+ if (reply.GetContent() is string content && content.Contains("IMAGE_GENERATION"))
+ {
+ var imageUrl = content.Split("\n").Last();
+ var imageMessage = new ImageMessage(Role.Assistant, imageUrl, from: reply.From, mimeType: "image/png");
+
+ Console.WriteLine($"download image from {imageUrl} to {imagePath}");
+ var httpClient = new HttpClient();
+ var imageBytes = await httpClient.GetByteArrayAsync(imageUrl, ct);
+ File.WriteAllBytes(imagePath, imageBytes);
+
+ return imageMessage;
+ }
+ else
+ {
+ return reply;
+ }
+ })
+ .RegisterPrintMessage();
+
+ var gpt4VAgent = new OpenAIChatAgent(
+ chatClient: openAIClient.GetChatClient("gpt-4o-mini"),
+ name: "gpt-4o-mini",
+ systemMessage: @"You are a critism that provide feedback to DALL-E agent.
+Carefully check the image generated by DALL-E agent and provide feedback.
+If the image satisfies the condition, then say [APPROVE].
+Otherwise, provide detailed feedback to DALL-E agent so it can generate better image.
+
+The image should satisfy the following conditions:
+- There should be a cat and a mouse in the image
+- The cat should be chasing after the mouse")
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+
+ await gpt4VAgent.InitiateChatAsync(
+ receiver: dalleAgent,
+ message: "Hey dalle, please generate image from prompt: English short hair blue cat chase after a mouse",
+ maxRound: 10);
+
+ File.Exists(imagePath).Should().BeTrue();
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example06_UserProxyAgent.cs b/dotnet/sample/AutoGen.BasicSamples/Example06_UserProxyAgent.cs
new file mode 100644
index 00000000000..e1349cb32a9
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/Example06_UserProxyAgent.cs
@@ -0,0 +1,34 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Example06_UserProxyAgent.cs
+using AutoGen.Core;
+using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
+
+namespace AutoGen.BasicSample;
+
+public static class Example06_UserProxyAgent
+{
+ public static async Task RunAsync()
+ {
+ var gpt4o = LLMConfiguration.GetOpenAIGPT4o_mini();
+
+ var assistantAgent = new OpenAIChatAgent(
+ chatClient: gpt4o,
+ name: "assistant",
+ systemMessage: "You are an assistant that help user to do some tasks.")
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+
+ // set human input mode to ALWAYS so that user always provide input
+ var userProxyAgent = new UserProxyAgent(
+ name: "user",
+ humanInputMode: HumanInputMode.ALWAYS)
+ .RegisterPrintMessage();
+
+ // start the conversation
+ await userProxyAgent.InitiateChatAsync(
+ receiver: assistantAgent,
+ message: "Hey assistant, please help me to do some tasks.",
+ maxRound: 10);
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs b/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs
new file mode 100644
index 00000000000..1f1315586a2
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs
@@ -0,0 +1,377 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs
+
+using System.Text;
+using System.Text.Json;
+using AutoGen.BasicSample;
+using AutoGen.Core;
+using AutoGen.DotnetInteractive;
+using AutoGen.DotnetInteractive.Extension;
+using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
+using Microsoft.DotNet.Interactive;
+using OpenAI.Chat;
+
+public partial class Example07_Dynamic_GroupChat_Calculate_Fibonacci
+{
+ #region reviewer_function
+ public struct CodeReviewResult
+ {
+ public bool HasMultipleCodeBlocks { get; set; }
+ public bool IsTopLevelStatement { get; set; }
+ public bool IsDotnetCodeBlock { get; set; }
+ public bool IsPrintResultToConsole { get; set; }
+ }
+
+ ///
+ /// review code block
+ ///
+ /// true if there're multipe csharp code blocks
+ /// true if the code is in top level statement
+ /// true if the code block is csharp code block
+ /// true if the code block print out result to console
+ [Function]
+ public async Task ReviewCodeBlock(
+ bool hasMultipleCodeBlocks,
+ bool isTopLevelStatement,
+ bool isDotnetCodeBlock,
+ bool isPrintResultToConsole)
+ {
+ var obj = new CodeReviewResult
+ {
+ HasMultipleCodeBlocks = hasMultipleCodeBlocks,
+ IsTopLevelStatement = isTopLevelStatement,
+ IsDotnetCodeBlock = isDotnetCodeBlock,
+ IsPrintResultToConsole = isPrintResultToConsole,
+ };
+
+ return JsonSerializer.Serialize(obj);
+ }
+ #endregion reviewer_function
+
+ #region create_coder
+ public static async Task CreateCoderAgentAsync(ChatClient client)
+ {
+ var coder = new OpenAIChatAgent(
+ chatClient: client,
+ name: "coder",
+ systemMessage: @"You act as dotnet coder, you write dotnet code to resolve task. Once you finish writing code, ask runner to run the code for you.
+
+ Here're some rules to follow on writing dotnet code:
+ - put code between ```csharp and ```
+ - Avoid adding `using` keyword when creating disposable object. e.g `var httpClient = new HttpClient()`
+ - Try to use `var` instead of explicit type.
+ - Try avoid using external library, use .NET Core library instead.
+ - Use top level statement to write code.
+ - Always print out the result to console. Don't write code that doesn't print out anything.
+
+ If you need to install nuget packages, put nuget packages in the following format:
+ ```nuget
+ nuget_package_name
+ ```
+
+ If your code is incorrect, runner will tell you the error message. Fix the error and send the code again.",
+ temperature: 0.4f)
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+
+ return coder;
+ }
+ #endregion create_coder
+
+ #region create_runner
+ public static async Task CreateRunnerAgentAsync(Kernel kernel)
+ {
+ var runner = new DefaultReplyAgent(
+ name: "runner",
+ defaultReply: "No code available.")
+ .RegisterMiddleware(async (msgs, option, agent, _) =>
+ {
+ if (msgs.Count() == 0 || msgs.All(msg => msg.From != "coder"))
+ {
+ return new TextMessage(Role.Assistant, "No code available. Coder please write code");
+ }
+ else
+ {
+ var coderMsg = msgs.Last(msg => msg.From == "coder");
+ if (coderMsg.ExtractCodeBlock("```csharp", "```") is string code)
+ {
+ var codeResult = await kernel.RunSubmitCodeCommandAsync(code, "csharp");
+
+ codeResult = $"""
+ [RUNNER_RESULT]
+ {codeResult}
+ """;
+
+ return new TextMessage(Role.Assistant, codeResult)
+ {
+ From = "runner",
+ };
+ }
+ else
+ {
+ return new TextMessage(Role.Assistant, "No code available. Coder please write code");
+ }
+ }
+ })
+ .RegisterPrintMessage();
+
+ return runner;
+ }
+ #endregion create_runner
+
+ #region create_admin
+ public static async Task CreateAdminAsync(ChatClient client)
+ {
+ var admin = new OpenAIChatAgent(
+ chatClient: client,
+ name: "admin",
+ temperature: 0)
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+
+ return admin;
+ }
+ #endregion create_admin
+
+ #region create_reviewer
+ public static async Task CreateReviewerAgentAsync(ChatClient chatClient)
+ {
+ var functions = new Example07_Dynamic_GroupChat_Calculate_Fibonacci();
+ var functionCallMiddleware = new FunctionCallMiddleware(
+ functions: [functions.ReviewCodeBlockFunctionContract],
+ functionMap: new Dictionary>>()
+ {
+ { nameof(functions.ReviewCodeBlock), functions.ReviewCodeBlockWrapper },
+ });
+ var reviewer = new OpenAIChatAgent(
+ chatClient: chatClient,
+ name: "code_reviewer",
+ systemMessage: @"You review code block from coder")
+ .RegisterMessageConnector()
+ .RegisterStreamingMiddleware(functionCallMiddleware)
+ .RegisterMiddleware(async (msgs, option, innerAgent, ct) =>
+ {
+ var maxRetry = 3;
+ var reply = await innerAgent.GenerateReplyAsync(msgs, option, ct);
+ while (maxRetry-- > 0)
+ {
+ if (reply.GetToolCalls() is var toolCalls && toolCalls.Count() == 1 && toolCalls[0].FunctionName == nameof(ReviewCodeBlock))
+ {
+ var toolCallResult = reply.GetContent();
+ var reviewResultObj = JsonSerializer.Deserialize(toolCallResult);
+ var reviews = new List();
+ if (reviewResultObj.HasMultipleCodeBlocks)
+ {
+ var fixCodeBlockPrompt = @"There're multiple code blocks, please combine them into one code block";
+ reviews.Add(fixCodeBlockPrompt);
+ }
+
+ if (reviewResultObj.IsDotnetCodeBlock is false)
+ {
+ var fixCodeBlockPrompt = @"The code block is not csharp code block, please write dotnet code only";
+ reviews.Add(fixCodeBlockPrompt);
+ }
+
+ if (reviewResultObj.IsTopLevelStatement is false)
+ {
+ var fixCodeBlockPrompt = @"The code is not top level statement, please rewrite your dotnet code using top level statement";
+ reviews.Add(fixCodeBlockPrompt);
+ }
+
+ if (reviewResultObj.IsPrintResultToConsole is false)
+ {
+ var fixCodeBlockPrompt = @"The code doesn't print out result to console, please print out result to console";
+ reviews.Add(fixCodeBlockPrompt);
+ }
+
+ if (reviews.Count > 0)
+ {
+ var sb = new StringBuilder();
+ sb.AppendLine("There're some comments from code reviewer, please fix these comments");
+ foreach (var review in reviews)
+ {
+ sb.AppendLine($"- {review}");
+ }
+
+ return new TextMessage(Role.Assistant, sb.ToString(), from: "code_reviewer");
+ }
+ else
+ {
+ var msg = new TextMessage(Role.Assistant, "The code looks good, please ask runner to run the code for you.")
+ {
+ From = "code_reviewer",
+ };
+
+ return msg;
+ }
+ }
+ else
+ {
+ var originalContent = reply.GetContent();
+ var prompt = $@"Please convert the content to ReviewCodeBlock function arguments.
+
+ ## Original Content
+ {originalContent}";
+
+ reply = await innerAgent.SendAsync(prompt, msgs, ct);
+ }
+ }
+
+ throw new Exception("Failed to review code block");
+ })
+ .RegisterPrintMessage();
+
+ return reviewer;
+ }
+ #endregion create_reviewer
+
+ public static async Task RunWorkflowAsync()
+ {
+ long the39thFibonacciNumber = 63245986;
+ var kernel = DotnetInteractiveKernelBuilder
+ .CreateDefaultInProcessKernelBuilder()
+ .Build();
+
+ var gpt4o = LLMConfiguration.GetOpenAIGPT4o_mini();
+
+ #region create_workflow
+ var reviewer = await CreateReviewerAgentAsync(gpt4o);
+ var coder = await CreateCoderAgentAsync(gpt4o);
+ var runner = await CreateRunnerAgentAsync(kernel);
+ var admin = await CreateAdminAsync(gpt4o);
+
+ var admin2CoderTransition = Transition.Create(admin, coder);
+ var coder2ReviewerTransition = Transition.Create(coder, reviewer);
+ var reviewer2RunnerTransition = Transition.Create(
+ from: reviewer,
+ to: runner,
+ canTransitionAsync: async (from, to, messages) =>
+ {
+ var lastMessage = messages.Last();
+ if (lastMessage is TextMessage textMessage && textMessage.Content.ToLower().Contains("the code looks good, please ask runner to run the code for you.") is true)
+ {
+ // ask runner to run the code
+ return true;
+ }
+
+ return false;
+ });
+ var reviewer2CoderTransition = Transition.Create(
+ from: reviewer,
+ to: coder,
+ canTransitionAsync: async (from, to, messages) =>
+ {
+ var lastMessage = messages.Last();
+ if (lastMessage is TextMessage textMessage && textMessage.Content.ToLower().Contains("there're some comments from code reviewer, please fix these comments") is true)
+ {
+ // ask coder to fix the code based on reviewer's comments
+ return true;
+ }
+
+ return false;
+ });
+
+ var runner2CoderTransition = Transition.Create(
+ from: runner,
+ to: coder,
+ canTransitionAsync: async (from, to, messages) =>
+ {
+ var lastMessage = messages.Last();
+ if (lastMessage is TextMessage textMessage && textMessage.Content.ToLower().Contains("error") is true)
+ {
+ // ask coder to fix the error
+ return true;
+ }
+
+ return false;
+ });
+ var runner2AdminTransition = Transition.Create(runner, admin);
+
+ var workflow = new Graph(
+ [
+ admin2CoderTransition,
+ coder2ReviewerTransition,
+ reviewer2RunnerTransition,
+ reviewer2CoderTransition,
+ runner2CoderTransition,
+ runner2AdminTransition,
+ ]);
+ #endregion create_workflow
+
+ #region create_group_chat_with_workflow
+ var groupChat = new GroupChat(
+ admin: admin,
+ workflow: workflow,
+ members:
+ [
+ admin,
+ coder,
+ runner,
+ reviewer,
+ ]);
+ #endregion create_group_chat_with_workflow
+ admin.SendIntroduction("Welcome to my group, work together to resolve my task", groupChat);
+ coder.SendIntroduction("I will write dotnet code to resolve task", groupChat);
+ reviewer.SendIntroduction("I will review dotnet code", groupChat);
+ runner.SendIntroduction("I will run dotnet code once the review is done", groupChat);
+ var task = "What's the 39th of fibonacci number?";
+
+ var taskMessage = new TextMessage(Role.User, task, from: admin.Name);
+ await foreach (var message in groupChat.SendAsync([taskMessage], maxRound: 10))
+ {
+ // teminate chat if message is from runner and run successfully
+ if (message.From == "runner" && message.GetContent().Contains(the39thFibonacciNumber.ToString()))
+ {
+ Console.WriteLine($"The 39th of fibonacci number is {the39thFibonacciNumber}");
+ break;
+ }
+ }
+ }
+
+ public static async Task RunAsync()
+ {
+ long the39thFibonacciNumber = 63245986;
+ var workDir = Path.Combine(Path.GetTempPath(), "InteractiveService");
+ if (!Directory.Exists(workDir))
+ {
+ Directory.CreateDirectory(workDir);
+ }
+
+ var gpt4o = LLMConfiguration.GetOpenAIGPT4o_mini();
+
+ var kernel = DotnetInteractiveKernelBuilder
+ .CreateDefaultInProcessKernelBuilder()
+ .Build();
+ #region create_group_chat
+ var reviewer = await CreateReviewerAgentAsync(gpt4o);
+ var coder = await CreateCoderAgentAsync(gpt4o);
+ var runner = await CreateRunnerAgentAsync(kernel);
+ var admin = await CreateAdminAsync(gpt4o);
+ var groupChat = new GroupChat(
+ admin: admin,
+ members:
+ [
+ coder,
+ runner,
+ reviewer,
+ ]);
+
+ coder.SendIntroduction("I will write dotnet code to resolve task", groupChat);
+ reviewer.SendIntroduction("I will review dotnet code", groupChat);
+ runner.SendIntroduction("I will run dotnet code once the review is done", groupChat);
+
+ var task = "What's the 39th of fibonacci number?";
+ var taskMessage = new TextMessage(Role.User, task);
+ await foreach (var message in groupChat.SendAsync([taskMessage], maxRound: 10))
+ {
+ // teminate chat if message is from runner and run successfully
+ if (message.From == "runner" && message.GetContent().Contains(the39thFibonacciNumber.ToString()))
+ {
+ Console.WriteLine($"The 39th of fibonacci number is {the39thFibonacciNumber}");
+ break;
+ }
+ }
+ #endregion create_group_chat
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example08_LMStudio.cs b/dotnet/sample/AutoGen.BasicSamples/Example08_LMStudio.cs
new file mode 100644
index 00000000000..e58454fdb5f
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/Example08_LMStudio.cs
@@ -0,0 +1,54 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Example08_LMStudio.cs
+
+#region lmstudio_using_statements
+using AutoGen.Core;
+using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
+using OpenAI;
+#endregion lmstudio_using_statements
+
+namespace AutoGen.BasicSample;
+
+public class Example08_LMStudio
+{
+ public static async Task RunAsync()
+ {
+ #region lmstudio_example_1
+ var endpoint = "http://localhost:1234";
+ var openaiClient = new OpenAIClient("api-key", new OpenAIClientOptions
+ {
+ Endpoint = new Uri(endpoint),
+ });
+
+ var lmAgent = new OpenAIChatAgent(
+ chatClient: openaiClient.GetChatClient(""),
+ name: "assistant")
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+
+ await lmAgent.SendAsync("Can you write a piece of C# code to calculate 100th of fibonacci?");
+
+ // output from assistant (the output below is generated using llama-2-chat-7b, the output may vary depending on the model used)
+ //
+ // Of course! To calculate the 100th number in the Fibonacci sequence using C#, you can use the following code:```
+ // using System;
+ // class FibonacciSequence {
+ // static int Fibonacci(int n) {
+ // if (n <= 1) {
+ // return 1;
+ // } else {
+ // return Fibonacci(n - 1) + Fibonacci(n - 2);
+ // }
+ // }
+ // static void Main() {
+ // Console.WriteLine("The 100th number in the Fibonacci sequence is: " + Fibonacci(100));
+ // }
+ // }
+ // ```
+ // In this code, we define a function `Fibonacci` that takes an integer `n` as input and returns the `n`-th number in the Fibonacci sequence. The function uses a recursive approach to calculate the value of the sequence.
+ // The `Main` method simply calls the `Fibonacci` function with the argument `100`, and prints the result to the console.
+ // Note that this code will only work for positive integers `n`. If you want to calculate the Fibonacci sequence for other types of numbers, such as real or complex numbers, you will need to modify the code accordingly.
+ #endregion lmstudio_example_1
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example10_SemanticKernel.cs b/dotnet/sample/AutoGen.BasicSamples/Example10_SemanticKernel.cs
new file mode 100644
index 00000000000..da7e54852f3
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/Example10_SemanticKernel.cs
@@ -0,0 +1,80 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Example10_SemanticKernel.cs
+
+using System.ComponentModel;
+using AutoGen.Core;
+using AutoGen.SemanticKernel.Extension;
+using FluentAssertions;
+using Microsoft.SemanticKernel;
+using Microsoft.SemanticKernel.ChatCompletion;
+using Microsoft.SemanticKernel.Connectors.OpenAI;
+namespace AutoGen.BasicSample;
+
+public class LightPlugin
+{
+ public bool IsOn { get; set; } = false;
+
+ [KernelFunction]
+ [Description("Gets the state of the light.")]
+ public string GetState() => this.IsOn ? "on" : "off";
+
+ [KernelFunction]
+ [Description("Changes the state of the light.'")]
+ public string ChangeState(bool newState)
+ {
+ this.IsOn = newState;
+ var state = this.GetState();
+
+ // Print the state to the console
+ Console.ForegroundColor = ConsoleColor.DarkBlue;
+ Console.WriteLine($"[Light is now {state}]");
+ Console.ResetColor();
+
+ return state;
+ }
+}
+
+public class Example10_SemanticKernel
+{
+ public static async Task RunAsync()
+ {
+ var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
+ var modelId = "gpt-4o-mini";
+ var builder = Kernel.CreateBuilder()
+ .AddOpenAIChatCompletion(modelId: modelId, apiKey: openAIKey);
+ var kernel = builder.Build();
+ var settings = new OpenAIPromptExecutionSettings
+ {
+ ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions,
+ };
+
+ kernel.Plugins.AddFromObject(new LightPlugin());
+ var skAgent = kernel
+ .ToSemanticKernelAgent(name: "assistant", systemMessage: "You control the light", settings);
+
+ // Send a message to the skAgent, the skAgent supports the following message types:
+ // - IMessage
+ // - (streaming) IMessage
+ // You can create an IMessage using MessageEnvelope.Create
+ var chatMessageContent = MessageEnvelope.Create(new ChatMessageContent(AuthorRole.User, "Toggle the light"));
+ var reply = await skAgent.SendAsync(chatMessageContent);
+ reply.Should().BeOfType>();
+ Console.WriteLine((reply as IMessage).Content.Items[0].As().Text);
+
+ var skAgentWithMiddleware = skAgent
+ .RegisterMessageConnector() // Register the message connector to support more AutoGen built-in message types
+ .RegisterPrintMessage();
+
+ // Now the skAgentWithMiddleware supports more IMessage types like TextMessage, ImageMessage or MultiModalMessage
+ // It also register a print format message hook to print the message in a human readable format to the console
+ await skAgent.SendAsync(chatMessageContent);
+ await skAgentWithMiddleware.SendAsync(new TextMessage(Role.User, "Toggle the light"));
+
+ // The more message type an agent support, the more flexible it is to be used in different scenarios
+ // For example, since the TextMessage is supported, the skAgentWithMiddleware can be used with user proxy.
+ var userProxy = new UserProxyAgent("user");
+
+ await skAgentWithMiddleware.InitiateChatAsync(userProxy, "how can I help you today");
+ }
+
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example11_Sequential_GroupChat_Example.cs b/dotnet/sample/AutoGen.BasicSamples/Example11_Sequential_GroupChat_Example.cs
new file mode 100644
index 00000000000..32aaa8c187b
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/Example11_Sequential_GroupChat_Example.cs
@@ -0,0 +1,88 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Example11_Sequential_GroupChat_Example.cs
+
+#region using_statement
+using AutoGen.Core;
+using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
+using AutoGen.SemanticKernel;
+using AutoGen.SemanticKernel.Extension;
+using Microsoft.SemanticKernel;
+using Microsoft.SemanticKernel.Plugins.Web;
+using Microsoft.SemanticKernel.Plugins.Web.Bing;
+#endregion using_statement
+
+namespace AutoGen.BasicSample;
+
+public partial class Sequential_GroupChat_Example
+{
+ public static async Task CreateBingSearchAgentAsync()
+ {
+ #region CreateBingSearchAgent
+ var config = LLMConfiguration.GetAzureOpenAIGPT3_5_Turbo();
+ var apiKey = config.ApiKey;
+ var kernelBuilder = Kernel.CreateBuilder()
+ .AddAzureOpenAIChatCompletion(config.DeploymentName, config.Endpoint, apiKey);
+ var bingApiKey = Environment.GetEnvironmentVariable("BING_API_KEY") ?? throw new Exception("BING_API_KEY environment variable is not set");
+ var bingSearch = new BingConnector(bingApiKey);
+ var webSearchPlugin = new WebSearchEnginePlugin(bingSearch);
+ kernelBuilder.Plugins.AddFromObject(webSearchPlugin);
+
+ var kernel = kernelBuilder.Build();
+ var kernelAgent = new SemanticKernelAgent(
+ kernel: kernel,
+ name: "bing-search",
+ systemMessage: """
+ You search results from Bing and return it as-is.
+ You put the original search result between ```bing and ```
+
+ e.g.
+ ```bing
+ xxx
+ ```
+ """)
+ .RegisterMessageConnector()
+ .RegisterPrintMessage(); // pretty print the message
+
+ return kernelAgent;
+ #endregion CreateBingSearchAgent
+ }
+
+ public static async Task CreateSummarizerAgentAsync()
+ {
+ #region CreateSummarizerAgent
+ var gpt4o = LLMConfiguration.GetOpenAIGPT4o_mini();
+ var openAIClientAgent = new OpenAIChatAgent(
+ chatClient: gpt4o,
+ name: "summarizer",
+ systemMessage: "You summarize search result from bing in a short and concise manner");
+
+ return openAIClientAgent
+ .RegisterMessageConnector()
+ .RegisterPrintMessage(); // pretty print the message
+ #endregion CreateSummarizerAgent
+ }
+
+ public static async Task RunAsync()
+ {
+ #region Sequential_GroupChat_Example
+ var userProxyAgent = new UserProxyAgent(
+ name: "user",
+ humanInputMode: HumanInputMode.ALWAYS)
+ .RegisterPrintMessage();
+
+ var bingSearchAgent = await CreateBingSearchAgentAsync();
+ var summarizerAgent = await CreateSummarizerAgentAsync();
+
+ var groupChat = new RoundRobinGroupChat(
+ agents: [userProxyAgent, bingSearchAgent, summarizerAgent]);
+
+ var groupChatAgent = new GroupChatManager(groupChat);
+
+ var history = await userProxyAgent.InitiateChatAsync(
+ receiver: groupChatAgent,
+ message: "How to deploy an openai resource on azure",
+ maxRound: 10);
+ #endregion Sequential_GroupChat_Example
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example12_TwoAgent_Fill_Application.cs b/dotnet/sample/AutoGen.BasicSamples/Example12_TwoAgent_Fill_Application.cs
new file mode 100644
index 00000000000..69c2121cd80
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/Example12_TwoAgent_Fill_Application.cs
@@ -0,0 +1,172 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Example12_TwoAgent_Fill_Application.cs
+
+using System.Text;
+using AutoGen.Core;
+using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
+
+namespace AutoGen.BasicSample;
+
+public partial class TwoAgent_Fill_Application
+{
+ private string? name = null;
+ private string? email = null;
+ private string? phone = null;
+ private string? address = null;
+ private bool? receiveUpdates = null;
+
+ [Function]
+ public async Task SaveProgress(
+ string name,
+ string email,
+ string phone,
+ string address,
+ bool? receiveUpdates)
+ {
+ this.name = !string.IsNullOrEmpty(name) ? name : this.name;
+ this.email = !string.IsNullOrEmpty(email) ? email : this.email;
+ this.phone = !string.IsNullOrEmpty(phone) ? phone : this.phone;
+ this.address = !string.IsNullOrEmpty(address) ? address : this.address;
+ this.receiveUpdates = receiveUpdates ?? this.receiveUpdates;
+
+ var missingInformationStringBuilder = new StringBuilder();
+ if (string.IsNullOrEmpty(this.name))
+ {
+ missingInformationStringBuilder.AppendLine("Name is missing.");
+ }
+
+ if (string.IsNullOrEmpty(this.email))
+ {
+ missingInformationStringBuilder.AppendLine("Email is missing.");
+ }
+
+ if (string.IsNullOrEmpty(this.phone))
+ {
+ missingInformationStringBuilder.AppendLine("Phone is missing.");
+ }
+
+ if (string.IsNullOrEmpty(this.address))
+ {
+ missingInformationStringBuilder.AppendLine("Address is missing.");
+ }
+
+ if (this.receiveUpdates == null)
+ {
+ missingInformationStringBuilder.AppendLine("ReceiveUpdates is missing.");
+ }
+
+ if (missingInformationStringBuilder.Length > 0)
+ {
+ return missingInformationStringBuilder.ToString();
+ }
+ else
+ {
+ return "Application information is saved to database.";
+ }
+ }
+
+ public static async Task CreateSaveProgressAgent()
+ {
+ var gpt4o = LLMConfiguration.GetOpenAIGPT4o_mini();
+ var instance = new TwoAgent_Fill_Application();
+ var functionCallConnector = new FunctionCallMiddleware(
+ functions: [instance.SaveProgressFunctionContract],
+ functionMap: new Dictionary>>
+ {
+ { instance.SaveProgressFunctionContract.Name, instance.SaveProgressWrapper },
+ });
+
+ var chatAgent = new OpenAIChatAgent(
+ chatClient: gpt4o,
+ name: "application",
+ systemMessage: """You are a helpful application form assistant who saves progress while user fills application.""")
+ .RegisterMessageConnector()
+ .RegisterMiddleware(functionCallConnector)
+ .RegisterMiddleware(async (msgs, option, agent, ct) =>
+ {
+ var lastUserMessage = msgs.Last() ?? throw new Exception("No user message found.");
+ var prompt = $"""
+ Save progress according to the most recent information provided by user.
+
+ ```user
+ {lastUserMessage.GetContent()}
+ ```
+ """;
+
+ return await agent.GenerateReplyAsync([lastUserMessage], option, ct);
+
+ });
+
+ return chatAgent;
+ }
+
+ public static async Task CreateAssistantAgent()
+ {
+ var gpt4o = LLMConfiguration.GetOpenAIGPT4o_mini();
+ var chatAgent = new OpenAIChatAgent(
+ chatClient: gpt4o,
+ name: "assistant",
+ systemMessage: """You create polite prompt to ask user provide missing information""")
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+
+ return chatAgent;
+ }
+
+ public static async Task CreateUserAgent()
+ {
+ var gpt4o = LLMConfiguration.GetOpenAIGPT4o_mini();
+ var chatAgent = new OpenAIChatAgent(
+ chatClient: gpt4o,
+ name: "user",
+ systemMessage: """
+ You are a user who is filling an application form. Simply provide the information as requested and answer the questions, don't do anything else.
+
+ here's some personal information about you:
+ - name: John Doe
+ - email: 1234567@gmail.com
+ - phone: 123-456-7890
+ - address: 1234 Main St, Redmond, WA 98052
+ - want to receive update? true
+ """)
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+
+ return chatAgent;
+ }
+
+ public static async Task RunAsync()
+ {
+ var applicationAgent = await CreateSaveProgressAgent();
+ var assistantAgent = await CreateAssistantAgent();
+ var userAgent = await CreateUserAgent();
+
+ var userToApplicationTransition = Transition.Create(userAgent, applicationAgent);
+ var applicationToAssistantTransition = Transition.Create(applicationAgent, assistantAgent);
+ var assistantToUserTransition = Transition.Create(assistantAgent, userAgent);
+
+ var workflow = new Graph(
+ [
+ userToApplicationTransition,
+ applicationToAssistantTransition,
+ assistantToUserTransition,
+ ]);
+
+ var groupChat = new GroupChat(
+ members: [userAgent, applicationAgent, assistantAgent],
+ workflow: workflow);
+
+ var groupChatManager = new GroupChatManager(groupChat);
+ var initialMessage = await assistantAgent.SendAsync("Generate a greeting meesage for user and start the conversation by asking what's their name.");
+
+ var chatHistory = new List { initialMessage };
+ await foreach (var msg in userAgent.SendAsync(groupChatManager, chatHistory, maxRound: 30))
+ {
+ if (msg.GetContent().ToLower().Contains("application information is saved to database.") is true)
+ {
+ break;
+ }
+ }
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example13_OpenAIAgent_JsonMode.cs b/dotnet/sample/AutoGen.BasicSamples/Example13_OpenAIAgent_JsonMode.cs
new file mode 100644
index 00000000000..596ab08d02a
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/Example13_OpenAIAgent_JsonMode.cs
@@ -0,0 +1,5 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Example13_OpenAIAgent_JsonMode.cs
+
+// this example has been moved to https://github.com/microsoft/autogen/blob/main/dotnet/sample/AutoGen.OpenAI.Sample/Use_Json_Mode.cs
+
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example14_MistralClientAgent_TokenCount.cs b/dotnet/sample/AutoGen.BasicSamples/Example14_MistralClientAgent_TokenCount.cs
new file mode 100644
index 00000000000..4c8794de961
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/Example14_MistralClientAgent_TokenCount.cs
@@ -0,0 +1,65 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Example14_MistralClientAgent_TokenCount.cs
+
+#region using_statements
+using AutoGen.Core;
+using AutoGen.Mistral;
+#endregion using_statements
+using FluentAssertions;
+
+namespace AutoGen.BasicSample;
+
+public class Example14_MistralClientAgent_TokenCount
+{
+ #region token_counter_middleware
+ public class MistralAITokenCounterMiddleware : IMiddleware
+ {
+ private readonly List responses = new List();
+ public string? Name => nameof(MistralAITokenCounterMiddleware);
+
+ public async Task InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default)
+ {
+ var reply = await agent.GenerateReplyAsync(context.Messages, context.Options, cancellationToken);
+
+ if (reply is IMessage message)
+ {
+ responses.Add(message.Content);
+ }
+
+ return reply;
+ }
+
+ public int GetCompletionTokenCount()
+ {
+ return responses.Sum(r => r.Usage.CompletionTokens);
+ }
+ }
+ #endregion token_counter_middleware
+
+ public static async Task RunAsync()
+ {
+ #region create_mistral_client_agent
+ var apiKey = Environment.GetEnvironmentVariable("MISTRAL_API_KEY") ?? throw new Exception("Missing MISTRAL_API_KEY environment variable.");
+ var mistralClient = new MistralClient(apiKey);
+ var agent = new MistralClientAgent(
+ client: mistralClient,
+ name: "assistant",
+ model: MistralAIModelID.OPEN_MISTRAL_7B);
+ #endregion create_mistral_client_agent
+
+ #region register_middleware
+ var tokenCounterMiddleware = new MistralAITokenCounterMiddleware();
+ var mistralMessageConnector = new MistralChatMessageConnector();
+ var agentWithTokenCounter = agent
+ .RegisterMiddleware(tokenCounterMiddleware)
+ .RegisterMiddleware(mistralMessageConnector)
+ .RegisterPrintMessage();
+ #endregion register_middleware
+
+ #region chat_with_agent
+ await agentWithTokenCounter.SendAsync("write a long, tedious story");
+ Console.WriteLine($"Completion token count: {tokenCounterMiddleware.GetCompletionTokenCount()}");
+ tokenCounterMiddleware.GetCompletionTokenCount().Should().BeGreaterThan(0);
+ #endregion chat_with_agent
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example15_GPT4V_BinaryDataImageMessage.cs b/dotnet/sample/AutoGen.BasicSamples/Example15_GPT4V_BinaryDataImageMessage.cs
new file mode 100644
index 00000000000..4a4b10ae3d7
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/Example15_GPT4V_BinaryDataImageMessage.cs
@@ -0,0 +1,66 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Example15_GPT4V_BinaryDataImageMessage.cs
+
+using AutoGen.Core;
+using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
+
+namespace AutoGen.BasicSample;
+
+///
+/// This example shows usage of ImageMessage. The image is loaded as BinaryData and sent to GPT-4V
+///
+///
+/// Add additional images to the ImageResources to load and send more images to GPT-4V
+///
+public static class Example15_GPT4V_BinaryDataImageMessage
+{
+ private static readonly string ImageResourcePath = Path.Combine(AppDomain.CurrentDomain.BaseDirectory, "resource", "images");
+
+ private static Dictionary _mediaTypeMappings = new()
+ {
+ { ".png", "image/png" },
+ { ".jpeg", "image/jpeg" },
+ { ".jpg", "image/jpeg" },
+ { ".gif", "image/gif" },
+ { ".webp", "image/webp" }
+ };
+
+ public static async Task RunAsync()
+ {
+ var gpt4o = LLMConfiguration.GetOpenAIGPT4o_mini();
+
+ var visionAgent = new OpenAIChatAgent(
+ chatClient: gpt4o,
+ name: "gpt",
+ systemMessage: "You are a helpful AI assistant",
+ temperature: 0)
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+
+ List messages =
+ [new TextMessage(Role.User, "What is this image?", from: "user")];
+ AddMessagesFromResource(ImageResourcePath, messages);
+
+ var multiModalMessage = new MultiModalMessage(Role.User, messages, from: "user");
+ var response = await visionAgent.SendAsync(multiModalMessage);
+ }
+
+ private static void AddMessagesFromResource(string imageResourcePath, List messages)
+ {
+ foreach (string file in Directory.GetFiles(imageResourcePath))
+ {
+ if (!_mediaTypeMappings.TryGetValue(Path.GetExtension(file).ToLowerInvariant(), out var mediaType))
+ {
+ continue;
+ }
+
+ using var fs = new FileStream(file, FileMode.Open, FileAccess.Read);
+ var ms = new MemoryStream();
+ fs.CopyTo(ms);
+ ms.Seek(0, SeekOrigin.Begin);
+ var imageData = BinaryData.FromStream(ms, mediaType);
+ messages.Add(new ImageMessage(Role.Assistant, imageData, from: "user"));
+ }
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example16_OpenAIChatAgent_ConnectToThirdPartyBackend.cs b/dotnet/sample/AutoGen.BasicSamples/Example16_OpenAIChatAgent_ConnectToThirdPartyBackend.cs
new file mode 100644
index 00000000000..969f7dc21c7
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/Example16_OpenAIChatAgent_ConnectToThirdPartyBackend.cs
@@ -0,0 +1,4 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Example16_OpenAIChatAgent_ConnectToThirdPartyBackend.cs
+
+// this example has been moved to https://github.com/microsoft/autogen/blob/main/dotnet/sample/AutoGen.OpenAI.Sample/Connect_To_Ollama.cs
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example17_ReActAgent.cs b/dotnet/sample/AutoGen.BasicSamples/Example17_ReActAgent.cs
new file mode 100644
index 00000000000..170736bf22e
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/Example17_ReActAgent.cs
@@ -0,0 +1,184 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Example17_ReActAgent.cs
+
+using AutoGen.Core;
+using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
+using OpenAI;
+using OpenAI.Chat;
+
+namespace AutoGen.BasicSample;
+
+public class OpenAIReActAgent : IAgent
+{
+ private readonly ChatClient _client;
+ private readonly FunctionContract[] tools;
+ private readonly Dictionary>> toolExecutors = new();
+ private readonly IAgent reasoner;
+ private readonly IAgent actor;
+ private readonly IAgent helper;
+ private readonly int maxSteps = 10;
+
+ private const string ReActPrompt = @"Answer the following questions as best you can.
+You can invoke the following tools:
+{tools}
+
+Use the following format:
+
+Question: the input question you must answer
+Thought: you should always think about what to do
+Tool: the tool to invoke
+Tool Input: the input to the tool
+Observation: the invoke result of the tool
+... (this process can repeat multiple times)
+
+Once you have the final answer, provide the final answer in the following format:
+Thought: I now know the final answer
+Final Answer: the final answer to the original input question
+
+Begin!
+Question: {input}";
+
+ public OpenAIReActAgent(ChatClient client, string name, FunctionContract[] tools, Dictionary>> toolExecutors)
+ {
+ _client = client;
+ this.Name = name;
+ this.tools = tools;
+ this.toolExecutors = toolExecutors;
+ this.reasoner = CreateReasoner();
+ this.actor = CreateActor();
+ this.helper = new OpenAIChatAgent(client, "helper")
+ .RegisterMessageConnector();
+ }
+
+ public string Name { get; }
+
+ public async Task GenerateReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
+ {
+ // step 1: extract the input question
+ var userQuestion = await helper.SendAsync("Extract the question from chat history", chatHistory: messages);
+ if (userQuestion.GetContent() is not string question)
+ {
+ return new TextMessage(Role.Assistant, "I couldn't find a question in the chat history. Please ask a question.", from: Name);
+ }
+ var reactPrompt = CreateReActPrompt(question);
+ var promptMessage = new TextMessage(Role.User, reactPrompt);
+ var chatHistory = new List() { promptMessage };
+
+ // step 2: ReAct
+ for (int i = 0; i != this.maxSteps; i++)
+ {
+ // reasoning
+ var reasoning = await reasoner.SendAsync(chatHistory: chatHistory);
+ if (reasoning.GetContent() is not string reasoningContent)
+ {
+ return new TextMessage(Role.Assistant, "I couldn't find a reasoning in the chat history. Please provide a reasoning.", from: Name);
+ }
+ if (reasoningContent.Contains("I now know the final answer"))
+ {
+ return new TextMessage(Role.Assistant, reasoningContent, from: Name);
+ }
+
+ chatHistory.Add(reasoning);
+
+ // action
+ var action = await actor.SendAsync(reasoning);
+ chatHistory.Add(action);
+ }
+
+ // fail to find the final answer
+ // return the summary of the chat history
+ var summary = await helper.SendAsync("Summarize the chat history and find out what's missing", chatHistory: chatHistory);
+ summary.From = Name;
+
+ return summary;
+ }
+
+ private string CreateReActPrompt(string input)
+ {
+ var toolPrompt = tools.Select(t => $"{t.Name}: {t.Description}").Aggregate((a, b) => $"{a}\n{b}");
+ var prompt = ReActPrompt.Replace("{tools}", toolPrompt);
+ prompt = prompt.Replace("{input}", input);
+ return prompt;
+ }
+
+ private IAgent CreateReasoner()
+ {
+ return new OpenAIChatAgent(
+ chatClient: _client,
+ name: "reasoner")
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+ }
+
+ private IAgent CreateActor()
+ {
+ var functionCallMiddleware = new FunctionCallMiddleware(tools, toolExecutors);
+ return new OpenAIChatAgent(
+ chatClient: _client,
+ name: "actor")
+ .RegisterMessageConnector()
+ .RegisterMiddleware(functionCallMiddleware)
+ .RegisterPrintMessage();
+ }
+}
+
+public partial class Tools
+{
+ ///
+ /// Get weather report for a specific place on a specific date
+ ///
+ /// city
+ /// date as DD/MM/YYYY
+ [Function]
+ public async Task WeatherReport(string city, string date)
+ {
+ return $"Weather report for {city} on {date} is sunny";
+ }
+
+ ///
+ /// Get current localization
+ ///
+ [Function]
+ public async Task GetLocalization(string dummy)
+ {
+ return $"Paris";
+ }
+
+ ///
+ /// Get current date as DD/MM/YYYY
+ ///
+ [Function]
+ public async Task GetDateToday(string dummy)
+ {
+ return $"27/05/2024";
+ }
+}
+
+public class Example17_ReActAgent
+{
+ public static async Task RunAsync()
+ {
+ var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
+ var modelName = "gpt-4-turbo";
+ var tools = new Tools();
+ var openAIClient = new OpenAIClient(openAIKey);
+ var gpt4o = LLMConfiguration.GetOpenAIGPT4o_mini();
+ var reactAgent = new OpenAIReActAgent(
+ client: openAIClient.GetChatClient(modelName),
+ name: "react-agent",
+ tools: [tools.GetLocalizationFunctionContract, tools.GetDateTodayFunctionContract, tools.WeatherReportFunctionContract],
+ toolExecutors: new Dictionary>>
+ {
+ { tools.GetLocalizationFunctionContract.Name, tools.GetLocalizationWrapper },
+ { tools.GetDateTodayFunctionContract.Name, tools.GetDateTodayWrapper },
+ { tools.WeatherReportFunctionContract.Name, tools.WeatherReportWrapper },
+ }
+ )
+ .RegisterPrintMessage();
+
+ var message = new TextMessage(Role.User, "What is the weather here", from: "user");
+
+ var response = await reactAgent.SendAsync(message);
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/GettingStart/Agent_Middleware.cs b/dotnet/sample/AutoGen.BasicSamples/GettingStart/Agent_Middleware.cs
new file mode 100644
index 00000000000..cf97af13467
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/GettingStart/Agent_Middleware.cs
@@ -0,0 +1,73 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Agent_Middleware.cs
+
+#region Using
+using AutoGen.Core;
+using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
+#endregion Using
+using FluentAssertions;
+using OpenAI.Chat;
+
+namespace AutoGen.BasicSample;
+
+public class Agent_Middleware
+{
+ public static async Task RunTokenCountAsync()
+ {
+ #region Create_Agent
+ var gpt4o = LLMConfiguration.GetOpenAIGPT4o_mini();
+ var openaiMessageConnector = new OpenAIChatRequestMessageConnector();
+ var totalTokenCount = 0;
+ var agent = new OpenAIChatAgent(
+ chatClient: gpt4o,
+ name: "agent",
+ systemMessage: "You are a helpful AI assistant")
+ .RegisterMiddleware(async (messages, option, innerAgent, ct) =>
+ {
+ var reply = await innerAgent.GenerateReplyAsync(messages, option, ct);
+ if (reply is MessageEnvelope chatCompletions)
+ {
+ var tokenCount = chatCompletions.Content.Usage.TotalTokens;
+ totalTokenCount += tokenCount;
+ }
+ return reply;
+ })
+ .RegisterMiddleware(openaiMessageConnector);
+ #endregion Create_Agent
+
+ #region Chat_With_Agent
+ var reply = await agent.SendAsync("Tell me a joke");
+ Console.WriteLine($"Total token count: {totalTokenCount}");
+ #endregion Chat_With_Agent
+
+ #region verify_reply
+ reply.Should().BeOfType();
+ totalTokenCount.Should().BeGreaterThan(0);
+ #endregion verify_reply
+ }
+
+ public static async Task RunRagTaskAsync()
+ {
+ #region Create_Agent
+ var gpt4o = LLMConfiguration.GetOpenAIGPT4o_mini();
+ var agent = new OpenAIChatAgent(
+ chatClient: gpt4o,
+ name: "agent",
+ systemMessage: "You are a helpful AI assistant")
+ .RegisterMessageConnector()
+ .RegisterMiddleware(async (messages, option, innerAgent, ct) =>
+ {
+ var today = DateTime.UtcNow;
+ var todayMessage = new TextMessage(Role.System, $"Today is {today:yyyy-MM-dd}");
+ messages = messages.Concat([todayMessage]);
+ return await innerAgent.GenerateReplyAsync(messages, option, ct);
+ })
+ .RegisterPrintMessage();
+ #endregion Create_Agent
+
+ #region Chat_With_Agent
+ var reply = await agent.SendAsync("what's the date today");
+ #endregion Chat_With_Agent
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/GettingStart/Chat_With_Agent.cs b/dotnet/sample/AutoGen.BasicSamples/GettingStart/Chat_With_Agent.cs
new file mode 100644
index 00000000000..b2cc228496d
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/GettingStart/Chat_With_Agent.cs
@@ -0,0 +1,55 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Chat_With_Agent.cs
+
+#region Using
+using AutoGen.Core;
+using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
+#endregion Using
+
+using FluentAssertions;
+
+namespace AutoGen.BasicSample;
+
+public class Chat_With_Agent
+{
+ public static async Task RunAsync()
+ {
+ #region Create_Agent
+ var gpt4o = LLMConfiguration.GetOpenAIGPT4o_mini();
+ var agent = new OpenAIChatAgent(
+ chatClient: gpt4o,
+ name: "agent",
+ systemMessage: "You are a helpful AI assistant")
+ .RegisterMessageConnector(); // convert OpenAI message to AutoGen message
+ #endregion Create_Agent
+
+ #region Chat_With_Agent
+ var reply = await agent.SendAsync("Tell me a joke");
+ reply.Should().BeOfType();
+ if (reply is TextMessage textMessage)
+ {
+ Console.WriteLine(textMessage.Content);
+ }
+ #endregion Chat_With_Agent
+
+ #region Chat_With_History
+ reply = await agent.SendAsync("summarize the conversation", chatHistory: [reply]);
+ #endregion Chat_With_History
+
+ #region Streaming_Chat
+ var question = new TextMessage(Role.User, "Tell me a long joke");
+ await foreach (var streamingReply in agent.GenerateStreamingReplyAsync([question]))
+ {
+ if (streamingReply is TextMessageUpdate textMessageUpdate)
+ {
+ Console.WriteLine(textMessageUpdate.Content);
+ }
+ }
+ #endregion Streaming_Chat
+
+ #region verify_reply
+ reply.Should().BeOfType();
+ #endregion verify_reply
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/GettingStart/Dynamic_Group_Chat.cs b/dotnet/sample/AutoGen.BasicSamples/GettingStart/Dynamic_Group_Chat.cs
new file mode 100644
index 00000000000..dadc295e308
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/GettingStart/Dynamic_Group_Chat.cs
@@ -0,0 +1,89 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Dynamic_Group_Chat.cs
+
+using AutoGen.Core;
+using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
+using AutoGen.SemanticKernel;
+using AutoGen.SemanticKernel.Extension;
+using Microsoft.SemanticKernel;
+using OpenAI;
+
+namespace AutoGen.BasicSample;
+
+public class Dynamic_Group_Chat
+{
+ public static async Task RunAsync()
+ {
+ var apiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
+ var model = "gpt-4o-mini";
+
+ #region Create_Coder
+ var openaiClient = new OpenAIClient(apiKey);
+ var coder = new OpenAIChatAgent(
+ chatClient: openaiClient.GetChatClient(model),
+ name: "coder",
+ systemMessage: "You are a C# coder, when writing csharp code, please put the code between ```csharp and ```")
+ .RegisterMessageConnector() // convert OpenAI message to AutoGen message
+ .RegisterPrintMessage(); // print the message content
+ #endregion Create_Coder
+
+ #region Create_Commenter
+ var kernel = Kernel
+ .CreateBuilder()
+ .AddOpenAIChatCompletion(modelId: model, apiKey: apiKey)
+ .Build();
+ var commenter = new SemanticKernelAgent(
+ kernel: kernel,
+ name: "commenter",
+ systemMessage: "You write inline comments for the code snippet and add unit tests if necessary")
+ .RegisterMessageConnector() // register message connector so it support AutoGen built-in message types like TextMessage.
+ .RegisterPrintMessage(); // pretty print the message to the console
+ #endregion Create_Commenter
+
+ #region Create_UserProxy
+ var userProxy = new DefaultReplyAgent("user", defaultReply: "END")
+ .RegisterPrintMessage(); // print the message content
+ #endregion Create_UserProxy
+
+ #region Create_Group
+ var admin = new OpenAIChatAgent(
+ chatClient: openaiClient.GetChatClient(model),
+ name: "admin")
+ .RegisterMessageConnector(); // convert OpenAI message to AutoGen message
+
+ var group = new GroupChat(
+ members: [coder, commenter, userProxy],
+ admin: admin);
+ #endregion Create_Group
+
+ #region Chat_With_Group
+ var workflowInstruction = new TextMessage(
+ Role.User,
+ """
+ Here is the workflow of this group chat:
+ User{Ask a question} -> Coder{Write code}
+ Coder{Write code} -> Commenter{Add comments to the code}
+ Commenter{Add comments to the code} -> User{END}
+ """);
+
+ var question = new TextMessage(Role.User, "How to calculate the 100th Fibonacci number?");
+ var chatHistory = new List { workflowInstruction, question };
+ while (true)
+ {
+ var replies = await group.CallAsync(chatHistory, maxRound: 1);
+ var lastReply = replies.Last();
+ chatHistory.Add(lastReply);
+
+ if (lastReply.From == userProxy.Name)
+ {
+ break;
+ }
+ }
+ #endregion Chat_With_Group
+
+ #region Summarize_Chat_History
+ var summary = await coder.SendAsync("summarize the conversation", chatHistory: chatHistory);
+ #endregion Summarize_Chat_History
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/GettingStart/FSM_Group_Chat.cs b/dotnet/sample/AutoGen.BasicSamples/GettingStart/FSM_Group_Chat.cs
new file mode 100644
index 00000000000..093d0c77ce6
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/GettingStart/FSM_Group_Chat.cs
@@ -0,0 +1,189 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// FSM_Group_Chat.cs
+
+using System.Text;
+#region Using
+using AutoGen.Core;
+using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
+using OpenAI;
+using OpenAI.Chat;
+#endregion Using
+
+namespace AutoGen.BasicSample;
+
+#region FillFormTool
+public partial class FillFormTool
+{
+ private string? name = null;
+ private string? email = null;
+ private string? phone = null;
+ private string? address = null;
+ private bool? receiveUpdates = null;
+
+ [Function]
+ public async Task SaveProgress(
+ string name,
+ string email,
+ string phone,
+ string address,
+ bool? receiveUpdates)
+ {
+ this.name = !string.IsNullOrEmpty(name) ? name : this.name;
+ this.email = !string.IsNullOrEmpty(email) ? email : this.email;
+ this.phone = !string.IsNullOrEmpty(phone) ? phone : this.phone;
+ this.address = !string.IsNullOrEmpty(address) ? address : this.address;
+ this.receiveUpdates = receiveUpdates ?? this.receiveUpdates;
+
+ var missingInformationStringBuilder = new StringBuilder();
+ if (string.IsNullOrEmpty(this.name))
+ {
+ missingInformationStringBuilder.AppendLine("Name is missing.");
+ }
+
+ if (string.IsNullOrEmpty(this.email))
+ {
+ missingInformationStringBuilder.AppendLine("Email is missing.");
+ }
+
+ if (string.IsNullOrEmpty(this.phone))
+ {
+ missingInformationStringBuilder.AppendLine("Phone is missing.");
+ }
+
+ if (string.IsNullOrEmpty(this.address))
+ {
+ missingInformationStringBuilder.AppendLine("Address is missing.");
+ }
+
+ if (this.receiveUpdates == null)
+ {
+ missingInformationStringBuilder.AppendLine("ReceiveUpdates is missing.");
+ }
+
+ if (missingInformationStringBuilder.Length > 0)
+ {
+ return missingInformationStringBuilder.ToString();
+ }
+ else
+ {
+ return "Application information is saved to database.";
+ }
+ }
+}
+#endregion FillFormTool
+
+public class FSM_Group_Chat
+{
+ public static async Task CreateSaveProgressAgent(ChatClient client)
+ {
+ #region Create_Save_Progress_Agent
+ var tool = new FillFormTool();
+ var functionCallMiddleware = new FunctionCallMiddleware(
+ functions: [tool.SaveProgressFunctionContract],
+ functionMap: new Dictionary>>
+ {
+ { tool.SaveProgressFunctionContract.Name!, tool.SaveProgressWrapper },
+ });
+
+ var chatAgent = new OpenAIChatAgent(
+ chatClient: client,
+ name: "application",
+ systemMessage: """You are a helpful application form assistant who saves progress while user fills application.""")
+ .RegisterMessageConnector()
+ .RegisterMiddleware(functionCallMiddleware)
+ .RegisterMiddleware(async (msgs, option, agent, ct) =>
+ {
+ var lastUserMessage = msgs.Last() ?? throw new Exception("No user message found.");
+ var prompt = $"""
+ Save progress according to the most recent information provided by user.
+
+ ```user
+ {lastUserMessage.GetContent()}
+ ```
+ """;
+
+ return await agent.GenerateReplyAsync([lastUserMessage], option, ct);
+
+ });
+ #endregion Create_Save_Progress_Agent
+
+ return chatAgent;
+ }
+
+ public static async Task CreateAssistantAgent(ChatClient chatClient)
+ {
+ #region Create_Assistant_Agent
+ var chatAgent = new OpenAIChatAgent(
+ chatClient: chatClient,
+ name: "assistant",
+ systemMessage: """You create polite prompt to ask user provide missing information""")
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+ #endregion Create_Assistant_Agent
+ return chatAgent;
+ }
+
+ public static async Task CreateUserAgent(ChatClient chatClient)
+ {
+ #region Create_User_Agent
+ var chatAgent = new OpenAIChatAgent(
+ chatClient: chatClient,
+ name: "user",
+ systemMessage: """
+ You are a user who is filling an application form. Simply provide the information as requested and answer the questions, don't do anything else.
+
+ here's some personal information about you:
+ - name: John Doe
+ - email: 1234567@gmail.com
+ - phone: 123-456-7890
+ - address: 1234 Main St, Redmond, WA 98052
+ - want to receive update? true
+ """)
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+ #endregion Create_User_Agent
+ return chatAgent;
+ }
+
+ public static async Task RunAsync()
+ {
+ var apiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
+ var model = "gpt-4o-mini";
+ var openaiClient = new OpenAIClient(apiKey);
+ var chatClient = openaiClient.GetChatClient(model);
+ var applicationAgent = await CreateSaveProgressAgent(chatClient);
+ var assistantAgent = await CreateAssistantAgent(chatClient);
+ var userAgent = await CreateUserAgent(chatClient);
+
+ #region Create_Graph
+ var userToApplicationTransition = Transition.Create(userAgent, applicationAgent);
+ var applicationToAssistantTransition = Transition.Create(applicationAgent, assistantAgent);
+ var assistantToUserTransition = Transition.Create(assistantAgent, userAgent);
+
+ var workflow = new Graph(
+ [
+ userToApplicationTransition,
+ applicationToAssistantTransition,
+ assistantToUserTransition,
+ ]);
+ #endregion Create_Graph
+
+ #region Group_Chat
+ var groupChat = new GroupChat(
+ members: [userAgent, applicationAgent, assistantAgent],
+ workflow: workflow);
+ #endregion Group_Chat
+
+ var initialMessage = await assistantAgent.SendAsync("Generate a greeting meesage for user and start the conversation by asking what's their name.");
+
+ var chatHistory = new List { initialMessage };
+ await foreach (var msg in groupChat.SendAsync(chatHistory, maxRound: 30))
+ {
+ if (msg.GetContent().ToLower().Contains("application information is saved to database.") is true)
+ {
+ break;
+ }
+ }
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/GettingStart/Image_Chat_With_Agent.cs b/dotnet/sample/AutoGen.BasicSamples/GettingStart/Image_Chat_With_Agent.cs
new file mode 100644
index 00000000000..e993b3d51f1
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/GettingStart/Image_Chat_With_Agent.cs
@@ -0,0 +1,48 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Image_Chat_With_Agent.cs
+
+#region Using
+using AutoGen.Core;
+using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
+#endregion Using
+using FluentAssertions;
+
+namespace AutoGen.BasicSample;
+
+public class Image_Chat_With_Agent
+{
+ public static async Task RunAsync()
+ {
+ #region Create_Agent
+ var gpt4o = LLMConfiguration.GetOpenAIGPT4o_mini();
+ var agent = new OpenAIChatAgent(
+ chatClient: gpt4o,
+ name: "agent",
+ systemMessage: "You are a helpful AI assistant")
+ .RegisterMessageConnector() // convert OpenAI message to AutoGen message
+ .RegisterPrintMessage();
+ #endregion Create_Agent
+
+ #region Prepare_Image_Input
+ var backgoundImagePath = Path.Combine("resource", "images", "background.png");
+ var imageBytes = File.ReadAllBytes(backgoundImagePath);
+ var imageMessage = new ImageMessage(Role.User, BinaryData.FromBytes(imageBytes, "image/png"));
+ #endregion Prepare_Image_Input
+
+ #region Prepare_Multimodal_Input
+ var textMessage = new TextMessage(Role.User, "what's in the picture");
+ var multimodalMessage = new MultiModalMessage(Role.User, [textMessage, imageMessage]);
+ #endregion Prepare_Multimodal_Input
+
+ #region Chat_With_Agent
+ var reply = await agent.SendAsync("what's in the picture", chatHistory: [imageMessage]);
+ // or use multimodal message to generate reply
+ reply = await agent.SendAsync(multimodalMessage);
+ #endregion Chat_With_Agent
+
+ #region verify_reply
+ reply.Should().BeOfType();
+ #endregion verify_reply
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/GettingStart/Streaming_Tool_Call.cs b/dotnet/sample/AutoGen.BasicSamples/GettingStart/Streaming_Tool_Call.cs
new file mode 100644
index 00000000000..d5cb196f94f
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/GettingStart/Streaming_Tool_Call.cs
@@ -0,0 +1,55 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Streaming_Tool_Call.cs
+
+using AutoGen.Core;
+using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
+using FluentAssertions;
+using OpenAI;
+
+namespace AutoGen.BasicSample.GettingStart;
+
+internal class Streaming_Tool_Call
+{
+ public static async Task RunAsync()
+ {
+ #region Create_tools
+ var tools = new Tools();
+ #endregion Create_tools
+
+ #region Create_auto_invoke_middleware
+ var autoInvokeMiddleware = new FunctionCallMiddleware(
+ functions: [tools.GetWeatherFunctionContract],
+ functionMap: new Dictionary>>()
+ {
+ { tools.GetWeatherFunctionContract.Name, tools.GetWeatherWrapper },
+ });
+ #endregion Create_auto_invoke_middleware
+
+ #region Create_Agent
+ var apiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
+ var model = "gpt-4o-mini";
+ var openaiClient = new OpenAIClient(apiKey);
+ var agent = new OpenAIChatAgent(
+ chatClient: openaiClient.GetChatClient(model),
+ name: "agent",
+ systemMessage: "You are a helpful AI assistant")
+ .RegisterMessageConnector()
+ .RegisterStreamingMiddleware(autoInvokeMiddleware)
+ .RegisterPrintMessage();
+ #endregion Create_Agent
+
+ IMessage finalReply = null;
+ var question = new TextMessage(Role.User, "What's the weather in Seattle");
+
+ // In streaming function call
+ // function can only be invoked untill all the chunks are collected
+ // therefore, only one ToolCallAggregateMessage chunk will be return here.
+ await foreach (var message in agent.GenerateStreamingReplyAsync([question]))
+ {
+ finalReply = message;
+ }
+
+ finalReply?.GetContent().Should().Be("The weather in Seattle is sunny.");
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/GettingStart/Use_Tools_With_Agent.cs b/dotnet/sample/AutoGen.BasicSamples/GettingStart/Use_Tools_With_Agent.cs
new file mode 100644
index 00000000000..21a5df4c2ec
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/GettingStart/Use_Tools_With_Agent.cs
@@ -0,0 +1,106 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Use_Tools_With_Agent.cs
+
+#region Using
+using AutoGen.Core;
+using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
+#endregion Using
+using FluentAssertions;
+using OpenAI;
+
+namespace AutoGen.BasicSample;
+
+#region Tools
+public partial class Tools
+{
+ ///
+ /// Get the weather of the city.
+ ///
+ ///
+ [Function]
+ public async Task GetWeather(string city)
+ {
+ return $"The weather in {city} is sunny.";
+ }
+}
+#endregion Tools
+
+public class Use_Tools_With_Agent
+{
+ public static async Task RunAsync()
+ {
+ #region Create_tools
+ var tools = new Tools();
+ #endregion Create_tools
+
+ #region Create_auto_invoke_middleware
+ var autoInvokeMiddleware = new FunctionCallMiddleware(
+ functions: [tools.GetWeatherFunctionContract],
+ functionMap: new Dictionary>>()
+ {
+ { tools.GetWeatherFunctionContract.Name!, tools.GetWeatherWrapper },
+ });
+ #endregion Create_auto_invoke_middleware
+
+ #region Create_no_invoke_middleware
+ var noInvokeMiddleware = new FunctionCallMiddleware(
+ functions: [tools.GetWeatherFunctionContract]);
+ #endregion Create_no_invoke_middleware
+
+ #region Create_Agent
+ var apiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
+ var model = "gpt-4o-mini";
+ var openaiClient = new OpenAIClient(apiKey);
+ var agent = new OpenAIChatAgent(
+ chatClient: openaiClient.GetChatClient(model),
+ name: "agent",
+ systemMessage: "You are a helpful AI assistant")
+ .RegisterMessageConnector(); // convert OpenAI message to AutoGen message
+ #endregion Create_Agent
+
+ #region Single_Turn_Auto_Invoke
+ var autoInvokeAgent = agent
+ .RegisterMiddleware(autoInvokeMiddleware) // pass function definition to agent.
+ .RegisterPrintMessage(); // print the message content
+ var question = new TextMessage(Role.User, "What is the weather in Seattle?");
+ var reply = await autoInvokeAgent.SendAsync(question);
+ reply.Should().BeOfType();
+ #endregion Single_Turn_Auto_Invoke
+
+ #region Single_Turn_No_Invoke
+ var noInvokeAgent = agent
+ .RegisterMiddleware(noInvokeMiddleware) // pass function definition to agent.
+ .RegisterPrintMessage(); // print the message content
+
+ question = new TextMessage(Role.User, "What is the weather in Seattle?");
+ reply = await noInvokeAgent.SendAsync(question);
+ reply.Should().BeOfType();
+ #endregion Single_Turn_No_Invoke
+
+ #region Multi_Turn_Tool_Call
+ var finalReply = await agent.SendAsync(chatHistory: [question, reply]);
+ #endregion Multi_Turn_Tool_Call
+
+ #region verify_reply
+ finalReply.Should().BeOfType();
+ #endregion verify_reply
+
+ #region parallel_tool_call
+ question = new TextMessage(Role.User, "What is the weather in Seattle, New York and Vancouver");
+ reply = await agent.SendAsync(question);
+ #endregion parallel_tool_call
+
+ #region verify_parallel_tool_call_reply
+ reply.Should().BeOfType();
+ (reply as ToolCallAggregateMessage)!.Message1.ToolCalls.Count().Should().Be(3);
+ #endregion verify_parallel_tool_call_reply
+
+ #region Multi_Turn_Parallel_Tool_Call
+ finalReply = await agent.SendAsync(chatHistory: [question, reply]);
+ finalReply.Should().BeOfType();
+ (finalReply as ToolCallAggregateMessage)!.Message1.ToolCalls.Count().Should().Be(3);
+ #endregion Multi_Turn_Parallel_Tool_Call
+ }
+
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/GlobalUsing.cs b/dotnet/sample/AutoGen.BasicSamples/GlobalUsing.cs
new file mode 100644
index 00000000000..87b4ee0ab4c
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/GlobalUsing.cs
@@ -0,0 +1,3 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// GlobalUsing.cs
+
diff --git a/dotnet/sample/AutoGen.BasicSamples/LLMConfiguration.cs b/dotnet/sample/AutoGen.BasicSamples/LLMConfiguration.cs
new file mode 100644
index 00000000000..26d9668792e
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/LLMConfiguration.cs
@@ -0,0 +1,26 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// LLMConfiguration.cs
+
+using OpenAI;
+using OpenAI.Chat;
+
+namespace AutoGen.BasicSample;
+
+internal static class LLMConfiguration
+{
+ public static ChatClient GetOpenAIGPT4o_mini()
+ {
+ var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
+ var modelId = "gpt-4o-mini";
+
+ return new OpenAIClient(openAIKey).GetChatClient(modelId);
+ }
+
+ public static AzureOpenAIConfig GetAzureOpenAIGPT3_5_Turbo(string? deployName = null)
+ {
+ var azureOpenAIKey = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable.");
+ var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable.");
+ deployName = deployName ?? Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable.");
+ return new AzureOpenAIConfig(endpoint, deployName, azureOpenAIKey);
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/Program.cs b/dotnet/sample/AutoGen.BasicSamples/Program.cs
new file mode 100644
index 00000000000..8817a3df36e
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/Program.cs
@@ -0,0 +1,59 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Program.cs
+
+//await Example07_Dynamic_GroupChat_Calculate_Fibonacci.RunAsync();
+
+using AutoGen.BasicSample;
+
+//Define allSamples collection for all examples
+List>> allSamples = new List>>();
+
+// When a new sample is created please add them to the allSamples collection
+allSamples.Add(new Tuple>("Assistant Agent", async () => { await Example01_AssistantAgent.RunAsync(); }));
+allSamples.Add(new Tuple>("Two-agent Math Chat", async () => { await Example02_TwoAgent_MathChat.RunAsync(); }));
+allSamples.Add(new Tuple>("Agent Function Call", async () => { await Example03_Agent_FunctionCall.RunAsync(); }));
+allSamples.Add(new Tuple>("Dynamic Group Chat Coding Task", async () => { await Example04_Dynamic_GroupChat_Coding_Task.RunAsync(); }));
+allSamples.Add(new Tuple>("DALL-E and GPT4v", async () => { await Example05_Dalle_And_GPT4V.RunAsync(); }));
+allSamples.Add(new Tuple>("User Proxy Agent", async () => { await Example06_UserProxyAgent.RunAsync(); }));
+allSamples.Add(new Tuple>("Dynamic Group Chat - Calculate Fibonacci", async () => { await Example07_Dynamic_GroupChat_Calculate_Fibonacci.RunAsync(); }));
+allSamples.Add(new Tuple>("LM Studio", async () => { await Example08_LMStudio.RunAsync(); }));
+allSamples.Add(new Tuple>("Semantic Kernel", async () => { await Example10_SemanticKernel.RunAsync(); }));
+allSamples.Add(new Tuple>("Sequential Group Chat", async () => { await Sequential_GroupChat_Example.RunAsync(); }));
+allSamples.Add(new Tuple>("Two Agent - Fill Application", async () => { await TwoAgent_Fill_Application.RunAsync(); }));
+allSamples.Add(new Tuple>("Mistal Client Agent - Token Count", async () => { await Example14_MistralClientAgent_TokenCount.RunAsync(); }));
+allSamples.Add(new Tuple>("GPT4v - Binary Data Image", async () => { await Example15_GPT4V_BinaryDataImageMessage.RunAsync(); }));
+allSamples.Add(new Tuple>("ReAct Agent", async () => { await Example17_ReActAgent.RunAsync(); }));
+
+
+int idx = 1;
+Dictionary>> map = new Dictionary>>();
+Console.WriteLine("Available Examples:\n\n");
+foreach (Tuple> sample in allSamples)
+{
+ map.Add(idx, sample);
+ Console.WriteLine("{0}. {1}", idx++, sample.Item1);
+}
+
+Console.WriteLine("\n\nEnter your selection:");
+
+while (true)
+{
+ var input = Console.ReadLine();
+ if (input == "exit")
+ {
+ break;
+ }
+ int val = Convert.ToInt32(input);
+ if (!map.ContainsKey(val))
+ {
+ Console.WriteLine("Invalid choice");
+ }
+ else
+ {
+ Console.WriteLine("\nRunning {0}", map[val].Item1);
+ await map[val].Item2.Invoke();
+ }
+}
+
+
+
diff --git a/dotnet/sample/AutoGen.Gemini.Sample/AutoGen.Gemini.Sample.csproj b/dotnet/sample/AutoGen.Gemini.Sample/AutoGen.Gemini.Sample.csproj
new file mode 100644
index 00000000000..d1df8a8ed16
--- /dev/null
+++ b/dotnet/sample/AutoGen.Gemini.Sample/AutoGen.Gemini.Sample.csproj
@@ -0,0 +1,19 @@
+
+
+
+ Exe
+ $(TestTargetFrameworks)
+ enable
+ enable
+ true
+ True
+
+
+
+
+
+
+
+
+
+
diff --git a/dotnet/sample/AutoGen.Gemini.Sample/Chat_With_Google_Gemini.cs b/dotnet/sample/AutoGen.Gemini.Sample/Chat_With_Google_Gemini.cs
new file mode 100644
index 00000000000..356ae23ff00
--- /dev/null
+++ b/dotnet/sample/AutoGen.Gemini.Sample/Chat_With_Google_Gemini.cs
@@ -0,0 +1,41 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Chat_With_Google_Gemini.cs
+
+#region Using
+using AutoGen.Core;
+#endregion Using
+using FluentAssertions;
+
+namespace AutoGen.Gemini.Sample;
+
+public class Chat_With_Google_Gemini
+{
+ public static async Task RunAsync()
+ {
+ #region Create_Gemini_Agent
+ var apiKey = Environment.GetEnvironmentVariable("GOOGLE_GEMINI_API_KEY");
+
+ if (apiKey is null)
+ {
+ Console.WriteLine("Please set GOOGLE_GEMINI_API_KEY environment variable.");
+ return;
+ }
+
+ var geminiAgent = new GeminiChatAgent(
+ name: "gemini",
+ model: "gemini-1.5-flash-001",
+ apiKey: apiKey,
+ systemMessage: "You are a helpful C# engineer, put your code between ```csharp and ```, don't explain the code")
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+ #endregion Create_Gemini_Agent
+
+ #region Chat_With_Google_Gemini
+ var reply = await geminiAgent.SendAsync("Can you write a piece of C# code to calculate 100th of fibonacci?");
+ #endregion Chat_With_Google_Gemini
+
+ #region verify_reply
+ reply.Should().BeOfType();
+ #endregion verify_reply
+ }
+}
diff --git a/dotnet/sample/AutoGen.Gemini.Sample/Chat_With_Vertex_Gemini.cs b/dotnet/sample/AutoGen.Gemini.Sample/Chat_With_Vertex_Gemini.cs
new file mode 100644
index 00000000000..5924ef7167a
--- /dev/null
+++ b/dotnet/sample/AutoGen.Gemini.Sample/Chat_With_Vertex_Gemini.cs
@@ -0,0 +1,42 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Chat_With_Vertex_Gemini.cs
+
+#region Using
+using AutoGen.Core;
+#endregion Using
+using FluentAssertions;
+
+namespace AutoGen.Gemini.Sample;
+
+public class Chat_With_Vertex_Gemini
+{
+ public static async Task RunAsync()
+ {
+ #region Create_Gemini_Agent
+ var projectID = Environment.GetEnvironmentVariable("GCP_VERTEX_PROJECT_ID");
+
+ if (projectID is null)
+ {
+ Console.WriteLine("Please set GCP_VERTEX_PROJECT_ID environment variable.");
+ return;
+ }
+
+ var geminiAgent = new GeminiChatAgent(
+ name: "gemini",
+ model: "gemini-1.5-flash-001",
+ location: "us-east1",
+ project: projectID,
+ systemMessage: "You are a helpful C# engineer, put your code between ```csharp and ```, don't explain the code")
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+ #endregion Create_Gemini_Agent
+
+ #region Chat_With_Vertex_Gemini
+ var reply = await geminiAgent.SendAsync("Can you write a piece of C# code to calculate 100th of fibonacci?");
+ #endregion Chat_With_Vertex_Gemini
+
+ #region verify_reply
+ reply.Should().BeOfType();
+ #endregion verify_reply
+ }
+}
diff --git a/dotnet/sample/AutoGen.Gemini.Sample/Function_Call_With_Gemini.cs b/dotnet/sample/AutoGen.Gemini.Sample/Function_Call_With_Gemini.cs
new file mode 100644
index 00000000000..db5068a7b91
--- /dev/null
+++ b/dotnet/sample/AutoGen.Gemini.Sample/Function_Call_With_Gemini.cs
@@ -0,0 +1,131 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Function_Call_With_Gemini.cs
+
+#region Using
+using AutoGen.Core;
+using Google.Cloud.AIPlatform.V1;
+#endregion Using
+using FluentAssertions;
+
+namespace AutoGen.Gemini.Sample;
+
+#region MovieFunction
+public partial class MovieFunction
+{
+ ///
+ /// find movie titles currently playing in theaters based on any description, genre, title words, etc.
+ ///
+ /// The city and state, e.g. San Francisco, CA or a zip code e.g. 95616
+ /// Any kind of description including category or genre, title words, attributes, etc.
+ ///
+ [Function]
+ public async Task FindMovies(string location, string description)
+ {
+ // dummy implementation
+ var movies = new List { "Barbie", "Spiderman", "Batman" };
+ var result = $"Movies playing in {location} based on {description} are: {string.Join(", ", movies)}";
+
+ return result;
+ }
+
+ ///
+ /// find theaters based on location and optionally movie title which is currently playing in theaters
+ ///
+ /// The city and state, e.g. San Francisco, CA or a zip code e.g. 95616
+ /// Any movie title
+ [Function]
+ public async Task FindTheaters(string location, string movie)
+ {
+ // dummy implementation
+ var theaters = new List { "AMC", "Regal", "Cinemark" };
+ var result = $"Theaters playing {movie} in {location} are: {string.Join(", ", theaters)}";
+
+ return result;
+ }
+
+ ///
+ /// Find the start times for movies playing in a specific theater
+ ///
+ /// The city and state, e.g. San Francisco, CA or a zip code e.g. 95616
+ /// Any movie title
+ /// Name of the theater
+ /// Date for requested showtime
+ ///
+ [Function]
+ public async Task GetShowtimes(string location, string movie, string theater, string date)
+ {
+ // dummy implementation
+ var showtimes = new List { "10:00 AM", "12:00 PM", "2:00 PM", "4:00 PM", "6:00 PM", "8:00 PM" };
+ var result = $"Showtimes for {movie} at {theater} in {location} are: {string.Join(", ", showtimes)}";
+
+ return result;
+ }
+}
+#endregion MovieFunction
+
+///
+/// Modified from https://ai.google.dev/gemini-api/docs/function-calling
+///
+public partial class Function_Call_With_Gemini
+{
+ public static async Task RunAsync()
+ {
+ #region Create_Gemini_Agent
+ var projectID = Environment.GetEnvironmentVariable("GCP_VERTEX_PROJECT_ID");
+
+ if (projectID is null)
+ {
+ Console.WriteLine("Please set GCP_VERTEX_PROJECT_ID environment variable.");
+ return;
+ }
+
+ var movieFunction = new MovieFunction();
+ var functionMiddleware = new FunctionCallMiddleware(
+ functions: [
+ movieFunction.FindMoviesFunctionContract,
+ movieFunction.FindTheatersFunctionContract,
+ movieFunction.GetShowtimesFunctionContract
+ ],
+ functionMap: new Dictionary>>
+ {
+ { movieFunction.FindMoviesFunctionContract.Name!, movieFunction.FindMoviesWrapper },
+ { movieFunction.FindTheatersFunctionContract.Name!, movieFunction.FindTheatersWrapper },
+ { movieFunction.GetShowtimesFunctionContract.Name!, movieFunction.GetShowtimesWrapper },
+ });
+
+ var geminiAgent = new GeminiChatAgent(
+ name: "gemini",
+ model: "gemini-1.5-flash-001",
+ location: "us-central1",
+ project: projectID,
+ systemMessage: "You are a helpful AI assistant",
+ toolConfig: new ToolConfig()
+ {
+ FunctionCallingConfig = new FunctionCallingConfig()
+ {
+ Mode = FunctionCallingConfig.Types.Mode.Auto,
+ }
+ })
+ .RegisterMessageConnector()
+ .RegisterPrintMessage()
+ .RegisterStreamingMiddleware(functionMiddleware);
+ #endregion Create_Gemini_Agent
+
+ #region Single_turn
+ var question = new TextMessage(Role.User, "What movies are showing in North Seattle tonight?");
+ var functionCallReply = await geminiAgent.SendAsync(question);
+ #endregion Single_turn
+
+ #region Single_turn_verify_reply
+ functionCallReply.Should().BeOfType();
+ #endregion Single_turn_verify_reply
+
+ #region Multi_turn
+ var finalReply = await geminiAgent.SendAsync(chatHistory: [question, functionCallReply]);
+ #endregion Multi_turn
+
+ #region Multi_turn_verify_reply
+ finalReply.Should().BeOfType();
+ #endregion Multi_turn_verify_reply
+ }
+}
diff --git a/dotnet/sample/AutoGen.Gemini.Sample/Image_Chat_With_Vertex_Gemini.cs b/dotnet/sample/AutoGen.Gemini.Sample/Image_Chat_With_Vertex_Gemini.cs
new file mode 100644
index 00000000000..ad320e7c6fa
--- /dev/null
+++ b/dotnet/sample/AutoGen.Gemini.Sample/Image_Chat_With_Vertex_Gemini.cs
@@ -0,0 +1,45 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Image_Chat_With_Vertex_Gemini.cs
+
+#region Using
+using AutoGen.Core;
+#endregion Using
+using FluentAssertions;
+
+namespace AutoGen.Gemini.Sample;
+
+public class Image_Chat_With_Vertex_Gemini
+{
+ public static async Task RunAsync()
+ {
+ #region Create_Gemini_Agent
+ var projectID = Environment.GetEnvironmentVariable("GCP_VERTEX_PROJECT_ID");
+
+ if (projectID is null)
+ {
+ Console.WriteLine("Please set GCP_VERTEX_PROJECT_ID environment variable.");
+ return;
+ }
+
+ var geminiAgent = new GeminiChatAgent(
+ name: "gemini",
+ model: "gemini-1.5-flash-001",
+ location: "us-east4",
+ project: projectID,
+ systemMessage: "You explain image content to user")
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+ #endregion Create_Gemini_Agent
+
+ #region Send_Image_Request
+ var imagePath = Path.Combine("resource", "images", "background.png");
+ var image = await File.ReadAllBytesAsync(imagePath);
+ var imageMessage = new ImageMessage(Role.User, BinaryData.FromBytes(image, "image/png"));
+ var reply = await geminiAgent.SendAsync("what's in the image", [imageMessage]);
+ #endregion Send_Image_Request
+
+ #region Verify_Reply
+ reply.Should().BeOfType();
+ #endregion Verify_Reply
+ }
+}
diff --git a/dotnet/sample/AutoGen.Gemini.Sample/Program.cs b/dotnet/sample/AutoGen.Gemini.Sample/Program.cs
new file mode 100644
index 00000000000..5e76942209a
--- /dev/null
+++ b/dotnet/sample/AutoGen.Gemini.Sample/Program.cs
@@ -0,0 +1,6 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Program.cs
+
+using AutoGen.Gemini.Sample;
+
+Image_Chat_With_Vertex_Gemini.RunAsync().Wait();
diff --git a/dotnet/sample/AutoGen.Ollama.Sample/AutoGen.Ollama.Sample.csproj b/dotnet/sample/AutoGen.Ollama.Sample/AutoGen.Ollama.Sample.csproj
new file mode 100644
index 00000000000..62c9d61633c
--- /dev/null
+++ b/dotnet/sample/AutoGen.Ollama.Sample/AutoGen.Ollama.Sample.csproj
@@ -0,0 +1,19 @@
+
+
+ Exe
+ $(TestTargetFrameworks)
+ enable
+ True
+ $(NoWarn);CS8981;CS8600;CS8602;CS8604;CS8618;CS0219;SKEXP0054;SKEXP0050;SKEXP0110
+ true
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dotnet/sample/AutoGen.Ollama.Sample/Chat_With_LLaMA.cs b/dotnet/sample/AutoGen.Ollama.Sample/Chat_With_LLaMA.cs
new file mode 100644
index 00000000000..09df4a48de9
--- /dev/null
+++ b/dotnet/sample/AutoGen.Ollama.Sample/Chat_With_LLaMA.cs
@@ -0,0 +1,32 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Chat_With_LLaMA.cs
+
+#region Using
+using AutoGen.Core;
+using AutoGen.Ollama.Extension;
+#endregion Using
+
+namespace AutoGen.Ollama.Sample;
+
+public class Chat_With_LLaMA
+{
+ public static async Task RunAsync()
+ {
+ #region Create_Ollama_Agent
+ using var httpClient = new HttpClient()
+ {
+ BaseAddress = new Uri("http://localhost:11434"),
+ };
+
+ var ollamaAgent = new OllamaAgent(
+ httpClient: httpClient,
+ name: "ollama",
+ modelName: "llama3:latest",
+ systemMessage: "You are a helpful AI assistant")
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+
+ var reply = await ollamaAgent.SendAsync("Can you write a piece of C# code to calculate 100th of fibonacci?");
+ #endregion Create_Ollama_Agent
+ }
+}
diff --git a/dotnet/sample/AutoGen.Ollama.Sample/Chat_With_LLaVA.cs b/dotnet/sample/AutoGen.Ollama.Sample/Chat_With_LLaVA.cs
new file mode 100644
index 00000000000..d9e38c886c2
--- /dev/null
+++ b/dotnet/sample/AutoGen.Ollama.Sample/Chat_With_LLaVA.cs
@@ -0,0 +1,48 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Chat_With_LLaVA.cs
+
+#region Using
+using AutoGen.Core;
+using AutoGen.Ollama.Extension;
+#endregion Using
+
+namespace AutoGen.Ollama.Sample;
+
+public class Chat_With_LLaVA
+{
+ public static async Task RunAsync()
+ {
+ #region Create_Ollama_Agent
+ using var httpClient = new HttpClient()
+ {
+ BaseAddress = new Uri("http://localhost:11434"),
+ };
+
+ var ollamaAgent = new OllamaAgent(
+ httpClient: httpClient,
+ name: "ollama",
+ modelName: "llava:latest",
+ systemMessage: "You are a helpful AI assistant")
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+ #endregion Create_Ollama_Agent
+
+ #region Send_Message
+ var image = Path.Combine("resource", "images", "background.png");
+ var binaryData = BinaryData.FromBytes(File.ReadAllBytes(image), "image/png");
+ var imageMessage = new ImageMessage(Role.User, binaryData);
+ var textMessage = new TextMessage(Role.User, "what's in this image?");
+ var reply = await ollamaAgent.SendAsync(chatHistory: [textMessage, imageMessage]);
+ #endregion Send_Message
+
+ #region Send_MultiModal_Message
+ // You can also use MultiModalMessage to put text and image together in one message
+ // In this case, all the messages in the multi-modal message will be put into single piece of message
+ // where the text is the concatenation of all the text messages seperated by \n
+ // and the images are all the images in the multi-modal message
+ var multiModalMessage = new MultiModalMessage(Role.User, [textMessage, imageMessage]);
+
+ reply = await ollamaAgent.SendAsync(chatHistory: [multiModalMessage]);
+ #endregion Send_MultiModal_Message
+ }
+}
diff --git a/dotnet/sample/AutoGen.Ollama.Sample/Program.cs b/dotnet/sample/AutoGen.Ollama.Sample/Program.cs
new file mode 100644
index 00000000000..62c92eebe7e
--- /dev/null
+++ b/dotnet/sample/AutoGen.Ollama.Sample/Program.cs
@@ -0,0 +1,6 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Program.cs
+
+using AutoGen.Ollama.Sample;
+
+await Chat_With_LLaVA.RunAsync();
diff --git a/dotnet/sample/AutoGen.OpenAI.Sample/AutoGen.OpenAI.Sample.csproj b/dotnet/sample/AutoGen.OpenAI.Sample/AutoGen.OpenAI.Sample.csproj
new file mode 100644
index 00000000000..fcbbb834fc6
--- /dev/null
+++ b/dotnet/sample/AutoGen.OpenAI.Sample/AutoGen.OpenAI.Sample.csproj
@@ -0,0 +1,22 @@
+
+
+
+ Exe
+ $(TestTargetFrameworks)
+ enable
+ enable
+ True
+ $(NoWarn);CS8981;CS8600;CS8602;CS8604;CS8618;CS0219;SKEXP0054;SKEXP0050;SKEXP0110
+ true
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dotnet/sample/AutoGen.OpenAI.Sample/Connect_To_Azure_OpenAI.cs b/dotnet/sample/AutoGen.OpenAI.Sample/Connect_To_Azure_OpenAI.cs
new file mode 100644
index 00000000000..dafe2e31485
--- /dev/null
+++ b/dotnet/sample/AutoGen.OpenAI.Sample/Connect_To_Azure_OpenAI.cs
@@ -0,0 +1,39 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Connect_To_Azure_OpenAI.cs
+
+#region using_statement
+using AutoGen.Core;
+using AutoGen.OpenAI.Extension;
+using Azure;
+using Azure.AI.OpenAI;
+#endregion using_statement
+
+namespace AutoGen.OpenAI.Sample;
+
+public class Connect_To_Azure_OpenAI
+{
+ public static async Task RunAsync()
+ {
+ #region create_agent
+ var apiKey = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new InvalidOperationException("Please set environment variable AZURE_OPENAI_API_KEY");
+ var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new InvalidOperationException("Please set environment variable AZURE_OPENAI_ENDPOINT");
+ var model = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? "gpt-4o-mini";
+
+ // Use AzureOpenAIClient to connect to openai model deployed on azure.
+ // The AzureOpenAIClient comes from Azure.AI.OpenAI package
+ var openAIClient = new AzureOpenAIClient(new Uri(endpoint), new AzureKeyCredential(apiKey));
+
+ var agent = new OpenAIChatAgent(
+ chatClient: openAIClient.GetChatClient(model),
+ name: "assistant",
+ systemMessage: "You are a helpful assistant designed to output JSON.",
+ seed: 0)
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+ #endregion create_agent
+
+ #region send_message
+ await agent.SendAsync("Can you write a piece of C# code to calculate 100th of fibonacci?");
+ #endregion send_message
+ }
+}
diff --git a/dotnet/sample/AutoGen.OpenAI.Sample/Connect_To_Ollama.cs b/dotnet/sample/AutoGen.OpenAI.Sample/Connect_To_Ollama.cs
new file mode 100644
index 00000000000..2bb10e97841
--- /dev/null
+++ b/dotnet/sample/AutoGen.OpenAI.Sample/Connect_To_Ollama.cs
@@ -0,0 +1,38 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Connect_To_Ollama.cs
+
+#region using_statement
+using AutoGen.Core;
+using AutoGen.OpenAI.Extension;
+using OpenAI;
+#endregion using_statement
+
+namespace AutoGen.OpenAI.Sample;
+
+public class Connect_To_Ollama
+{
+ public static async Task RunAsync()
+ {
+ #region create_agent
+ // api-key is not required for local server
+ // so you can use any string here
+ var openAIClient = new OpenAIClient("api-key", new OpenAIClientOptions
+ {
+ Endpoint = new Uri("http://localhost:11434/v1/"), // remember to add /v1/ at the end to connect to Ollama openai server
+ });
+ var model = "llama3";
+
+ var agent = new OpenAIChatAgent(
+ chatClient: openAIClient.GetChatClient(model),
+ name: "assistant",
+ systemMessage: "You are a helpful assistant designed to output JSON.",
+ seed: 0)
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+ #endregion create_agent
+
+ #region send_message
+ await agent.SendAsync("Can you write a piece of C# code to calculate 100th of fibonacci?");
+ #endregion send_message
+ }
+}
diff --git a/dotnet/sample/AutoGen.OpenAI.Sample/Program.cs b/dotnet/sample/AutoGen.OpenAI.Sample/Program.cs
new file mode 100644
index 00000000000..c71f152d037
--- /dev/null
+++ b/dotnet/sample/AutoGen.OpenAI.Sample/Program.cs
@@ -0,0 +1,6 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Program.cs
+
+using AutoGen.OpenAI.Sample;
+
+Structural_Output.RunAsync().Wait();
diff --git a/dotnet/sample/AutoGen.OpenAI.Sample/Structural_Output.cs b/dotnet/sample/AutoGen.OpenAI.Sample/Structural_Output.cs
new file mode 100644
index 00000000000..e562d7223a6
--- /dev/null
+++ b/dotnet/sample/AutoGen.OpenAI.Sample/Structural_Output.cs
@@ -0,0 +1,90 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Structural_Output.cs
+
+using System.Text.Json;
+using System.Text.Json.Serialization;
+using AutoGen.Core;
+using AutoGen.OpenAI.Extension;
+using FluentAssertions;
+using Json.Schema;
+using Json.Schema.Generation;
+using OpenAI;
+using OpenAI.Chat;
+
+namespace AutoGen.OpenAI.Sample;
+
+internal class Structural_Output
+{
+ public static async Task RunAsync()
+ {
+ #region create_agent
+ var apiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
+ var model = "gpt-4o-mini";
+
+ var schemaBuilder = new JsonSchemaBuilder().FromType();
+ var schema = schemaBuilder.Build();
+
+ var personSchemaFormat = ChatResponseFormat.CreateJsonSchemaFormat(
+ name: "Person",
+ jsonSchema: BinaryData.FromObjectAsJson(schema),
+ description: "Person schema");
+
+ var openAIClient = new OpenAIClient(apiKey);
+ var openAIClientAgent = new OpenAIChatAgent(
+ chatClient: openAIClient.GetChatClient(model),
+ name: "assistant",
+ systemMessage: "You are a helpful assistant",
+ responseFormat: personSchemaFormat) // structural output by passing schema to response format
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+ #endregion create_agent
+
+ #region chat_with_agent
+ var reply = await openAIClientAgent.SendAsync("My name is John, I am 25 years old, and I live in Seattle. I like to play soccer and read books.");
+
+ var person = JsonSerializer.Deserialize(reply.GetContent());
+ Console.WriteLine($"Name: {person.Name}");
+ Console.WriteLine($"Age: {person.Age}");
+
+ if (!string.IsNullOrEmpty(person.Address))
+ {
+ Console.WriteLine($"Address: {person.Address}");
+ }
+
+ Console.WriteLine("Done.");
+ #endregion chat_with_agent
+
+ person.Name.Should().Be("John");
+ person.Age.Should().Be(25);
+ person.Address.Should().BeNullOrEmpty();
+ person.City.Should().Be("Seattle");
+ person.Hobbies.Count.Should().Be(2);
+ }
+}
+
+#region person_class
+public class Person
+{
+ [JsonPropertyName("name")]
+ [Description("Name of the person")]
+ [Required]
+ public string Name { get; set; }
+
+ [JsonPropertyName("age")]
+ [Description("Age of the person")]
+ [Required]
+ public int Age { get; set; }
+
+ [JsonPropertyName("city")]
+ [Description("City of the person")]
+ public string? City { get; set; }
+
+ [JsonPropertyName("address")]
+ [Description("Address of the person")]
+ public string? Address { get; set; }
+
+ [JsonPropertyName("hobbies")]
+ [Description("Hobbies of the person")]
+ public List? Hobbies { get; set; }
+}
+#endregion person_class
diff --git a/dotnet/sample/AutoGen.OpenAI.Sample/Tool_Call_With_Ollama_And_LiteLLM.cs b/dotnet/sample/AutoGen.OpenAI.Sample/Tool_Call_With_Ollama_And_LiteLLM.cs
new file mode 100644
index 00000000000..ed43c628a67
--- /dev/null
+++ b/dotnet/sample/AutoGen.OpenAI.Sample/Tool_Call_With_Ollama_And_LiteLLM.cs
@@ -0,0 +1,64 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Tool_Call_With_Ollama_And_LiteLLM.cs
+
+using AutoGen.Core;
+using AutoGen.OpenAI.Extension;
+using OpenAI;
+
+namespace AutoGen.OpenAI.Sample;
+
+#region Function
+public partial class Function
+{
+ [Function]
+ public async Task GetWeatherAsync(string city)
+ {
+ return await Task.FromResult("The weather in " + city + " is 72 degrees and sunny.");
+ }
+}
+#endregion Function
+
+public class Tool_Call_With_Ollama_And_LiteLLM
+{
+ public static async Task RunAsync()
+ {
+ // Before running this code, make sure you have
+ // - Ollama:
+ // - Install dolphincoder:latest in Ollama
+ // - Ollama running on http://localhost:11434
+ // - LiteLLM
+ // - Install LiteLLM
+ // - Start LiteLLM with the following command:
+ // - litellm --model ollama_chat/dolphincoder --port 4000
+
+ # region Create_tools
+ var functions = new Function();
+ var functionMiddleware = new FunctionCallMiddleware(
+ functions: [functions.GetWeatherAsyncFunctionContract],
+ functionMap: new Dictionary>>
+ {
+ { functions.GetWeatherAsyncFunctionContract.Name!, functions.GetWeatherAsyncWrapper },
+ });
+ #endregion Create_tools
+ #region Create_Agent
+ var liteLLMUrl = "http://localhost:4000";
+
+ // api-key is not required for local server
+ // so you can use any string here
+ var openAIClient = new OpenAIClient("api-key", new OpenAIClientOptions
+ {
+ Endpoint = new Uri("http://localhost:4000"),
+ });
+
+ var agent = new OpenAIChatAgent(
+ chatClient: openAIClient.GetChatClient("dolphincoder:latest"),
+ name: "assistant",
+ systemMessage: "You are a helpful AI assistant")
+ .RegisterMessageConnector()
+ .RegisterMiddleware(functionMiddleware)
+ .RegisterPrintMessage();
+
+ var reply = await agent.SendAsync("what's the weather in new york");
+ #endregion Create_Agent
+ }
+}
diff --git a/dotnet/sample/AutoGen.OpenAI.Sample/Use_Json_Mode.cs b/dotnet/sample/AutoGen.OpenAI.Sample/Use_Json_Mode.cs
new file mode 100644
index 00000000000..392796d819f
--- /dev/null
+++ b/dotnet/sample/AutoGen.OpenAI.Sample/Use_Json_Mode.cs
@@ -0,0 +1,67 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Use_Json_Mode.cs
+
+using System.Text.Json;
+using System.Text.Json.Serialization;
+using AutoGen.Core;
+using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
+using FluentAssertions;
+using OpenAI;
+using OpenAI.Chat;
+
+namespace AutoGen.BasicSample;
+
+public class Use_Json_Mode
+{
+ public static async Task RunAsync()
+ {
+ #region create_agent
+ var apiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
+ var model = "gpt-4o-mini";
+
+ var openAIClient = new OpenAIClient(apiKey);
+ var openAIClientAgent = new OpenAIChatAgent(
+ chatClient: openAIClient.GetChatClient(model),
+ name: "assistant",
+ systemMessage: "You are a helpful assistant designed to output JSON.",
+ seed: 0, // explicitly set a seed to enable deterministic output
+ responseFormat: ChatResponseFormat.JsonObject) // set response format to JSON object to enable JSON mode
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+ #endregion create_agent
+
+ #region chat_with_agent
+ var reply = await openAIClientAgent.SendAsync("My name is John, I am 25 years old, and I live in Seattle.");
+
+ var person = JsonSerializer.Deserialize(reply.GetContent());
+ Console.WriteLine($"Name: {person.Name}");
+ Console.WriteLine($"Age: {person.Age}");
+
+ if (!string.IsNullOrEmpty(person.Address))
+ {
+ Console.WriteLine($"Address: {person.Address}");
+ }
+
+ Console.WriteLine("Done.");
+ #endregion chat_with_agent
+
+ person.Name.Should().Be("John");
+ person.Age.Should().Be(25);
+ person.Address.Should().BeNullOrEmpty();
+ }
+}
+
+#region person_class
+public class Person
+{
+ [JsonPropertyName("name")]
+ public string Name { get; set; }
+
+ [JsonPropertyName("age")]
+ public int Age { get; set; }
+
+ [JsonPropertyName("address")]
+ public string Address { get; set; }
+}
+#endregion person_class
diff --git a/dotnet/sample/AutoGen.SemanticKernel.Sample/AutoGen.SemanticKernel.Sample.csproj b/dotnet/sample/AutoGen.SemanticKernel.Sample/AutoGen.SemanticKernel.Sample.csproj
new file mode 100644
index 00000000000..45514431368
--- /dev/null
+++ b/dotnet/sample/AutoGen.SemanticKernel.Sample/AutoGen.SemanticKernel.Sample.csproj
@@ -0,0 +1,18 @@
+
+
+
+ Exe
+ $(TestTargetFrameworks)
+ True
+ $(NoWarn);CS8981;CS8600;CS8602;CS8604;CS8618;CS0219;SKEXP0054;SKEXP0050;SKEXP0110
+ enable
+
+
+
+
+
+
+
+
+
+
diff --git a/dotnet/sample/AutoGen.SemanticKernel.Sample/Create_Semantic_Kernel_Agent.cs b/dotnet/sample/AutoGen.SemanticKernel.Sample/Create_Semantic_Kernel_Agent.cs
new file mode 100644
index 00000000000..3333cdd9ad9
--- /dev/null
+++ b/dotnet/sample/AutoGen.SemanticKernel.Sample/Create_Semantic_Kernel_Agent.cs
@@ -0,0 +1,29 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Create_Semantic_Kernel_Agent.cs
+
+using AutoGen.Core;
+using AutoGen.SemanticKernel.Extension;
+using Microsoft.SemanticKernel;
+
+namespace AutoGen.SemanticKernel.Sample;
+
+public class Create_Semantic_Kernel_Agent
+{
+ public static async Task RunAsync()
+ {
+ var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
+ var modelId = "gpt-3.5-turbo";
+ var kernel = Kernel.CreateBuilder()
+ .AddOpenAIChatCompletion(modelId: modelId, apiKey: openAIKey)
+ .Build();
+
+ var skAgent = new SemanticKernelAgent(
+ kernel: kernel,
+ name: "assistant",
+ systemMessage: "You are a helpful AI assistant")
+ .RegisterMessageConnector() // register message connector so it support AutoGen built-in message types like TextMessage.
+ .RegisterPrintMessage(); // pretty print the message to the console
+
+ await skAgent.SendAsync("Hey tell me a long tedious joke");
+ }
+}
diff --git a/dotnet/sample/AutoGen.SemanticKernel.Sample/Create_Semantic_Kernel_Chat_Agent.cs b/dotnet/sample/AutoGen.SemanticKernel.Sample/Create_Semantic_Kernel_Chat_Agent.cs
new file mode 100644
index 00000000000..9b72a2e0fb1
--- /dev/null
+++ b/dotnet/sample/AutoGen.SemanticKernel.Sample/Create_Semantic_Kernel_Chat_Agent.cs
@@ -0,0 +1,44 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Create_Semantic_Kernel_Chat_Agent.cs
+
+#region Using
+using AutoGen.Core;
+using Microsoft.SemanticKernel;
+using Microsoft.SemanticKernel.Agents;
+#endregion Using
+namespace AutoGen.SemanticKernel.Sample;
+
+public class Create_Semantic_Kernel_Chat_Agent
+{
+ public static async Task RunAsync()
+ {
+ #region Create_Kernel
+ var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
+ var modelId = "gpt-3.5-turbo";
+ var kernel = Kernel.CreateBuilder()
+ .AddOpenAIChatCompletion(modelId: modelId, apiKey: openAIKey)
+ .Build();
+ #endregion Create_Kernel
+
+ #region Create_ChatCompletionAgent
+ // The built-in ChatCompletionAgent from semantic kernel.
+ var chatAgent = new ChatCompletionAgent()
+ {
+ Kernel = kernel,
+ Name = "assistant",
+ Description = "You are a helpful AI assistant",
+ };
+ #endregion Create_ChatCompletionAgent
+
+ #region Create_SemanticKernelChatCompletionAgent
+ var messageConnector = new SemanticKernelChatMessageContentConnector();
+ var skAgent = new SemanticKernelChatCompletionAgent(chatAgent)
+ .RegisterMiddleware(messageConnector) // register message connector so it support AutoGen built-in message types like TextMessage.
+ .RegisterPrintMessage(); // pretty print the message to the console
+ #endregion Create_SemanticKernelChatCompletionAgent
+
+ #region Send_Message
+ await skAgent.SendAsync("Hey tell me a long tedious joke");
+ #endregion Send_Message
+ }
+}
diff --git a/dotnet/sample/AutoGen.SemanticKernel.Sample/Program.cs b/dotnet/sample/AutoGen.SemanticKernel.Sample/Program.cs
new file mode 100644
index 00000000000..5032f2d4330
--- /dev/null
+++ b/dotnet/sample/AutoGen.SemanticKernel.Sample/Program.cs
@@ -0,0 +1,6 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Program.cs
+
+using AutoGen.SemanticKernel.Sample;
+
+await Use_Kernel_Functions_With_Other_Agent.RunAsync();
diff --git a/dotnet/sample/AutoGen.SemanticKernel.Sample/Use_Bing_Search_With_Semantic_Kernel_Agent.cs b/dotnet/sample/AutoGen.SemanticKernel.Sample/Use_Bing_Search_With_Semantic_Kernel_Agent.cs
new file mode 100644
index 00000000000..4cebc88291f
--- /dev/null
+++ b/dotnet/sample/AutoGen.SemanticKernel.Sample/Use_Bing_Search_With_Semantic_Kernel_Agent.cs
@@ -0,0 +1,37 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Use_Bing_Search_With_Semantic_Kernel_Agent.cs
+
+using AutoGen.Core;
+using AutoGen.SemanticKernel.Extension;
+using Microsoft.SemanticKernel;
+using Microsoft.SemanticKernel.Plugins.Web;
+using Microsoft.SemanticKernel.Plugins.Web.Bing;
+
+namespace AutoGen.SemanticKernel.Sample;
+
+public class Use_Bing_Search_With_Semantic_Kernel_Agent
+{
+ public static async Task RunAsync()
+ {
+ var bingApiKey = Environment.GetEnvironmentVariable("BING_API_KEY") ?? throw new Exception("BING_API_KEY environment variable is not set");
+ var bingSearch = new BingConnector(bingApiKey);
+ var webSearchPlugin = new WebSearchEnginePlugin(bingSearch);
+
+ var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
+ var modelId = "gpt-3.5-turbo";
+ var kernelBuilder = Kernel.CreateBuilder()
+ .AddOpenAIChatCompletion(modelId: modelId, apiKey: openAIKey);
+ kernelBuilder.Plugins.AddFromObject(webSearchPlugin);
+
+ var kernel = kernelBuilder.Build();
+
+ var skAgent = new SemanticKernelAgent(
+ kernel: kernel,
+ name: "assistant",
+ systemMessage: "You are a helpful AI assistant")
+ .RegisterMessageConnector() // register message connector so it support AutoGen built-in message types like TextMessage.
+ .RegisterPrintMessage(); // pretty print the message to the console
+
+ await skAgent.SendAsync("Tell me more about gpt-4-o");
+ }
+}
diff --git a/dotnet/sample/AutoGen.SemanticKernel.Sample/Use_Kernel_Functions_With_Other_Agent.cs b/dotnet/sample/AutoGen.SemanticKernel.Sample/Use_Kernel_Functions_With_Other_Agent.cs
new file mode 100644
index 00000000000..700bdfe75c7
--- /dev/null
+++ b/dotnet/sample/AutoGen.SemanticKernel.Sample/Use_Kernel_Functions_With_Other_Agent.cs
@@ -0,0 +1,51 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Use_Kernel_Functions_With_Other_Agent.cs
+
+#region Using
+using AutoGen.Core;
+using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
+using Microsoft.SemanticKernel;
+using OpenAI;
+#endregion Using
+
+namespace AutoGen.SemanticKernel.Sample;
+
+public class Use_Kernel_Functions_With_Other_Agent
+{
+ public static async Task RunAsync()
+ {
+ #region Create_plugin
+ var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
+ var modelId = "gpt-4o-mini";
+ var kernelBuilder = Kernel.CreateBuilder();
+ var kernel = kernelBuilder.Build();
+ var getWeatherFunction = KernelFunctionFactory.CreateFromMethod(
+ method: (string location) => $"The weather in {location} is 75 degrees Fahrenheit.",
+ functionName: "GetWeather",
+ description: "Get the weather for a location.");
+ var plugin = kernel.CreatePluginFromFunctions("my_plugin", [getWeatherFunction]);
+ #endregion Create_plugin
+
+ #region Use_plugin
+ // Create a middleware to handle the plugin functions
+ var kernelPluginMiddleware = new KernelPluginMiddleware(kernel, plugin);
+
+ var openAIClient = new OpenAIClient(openAIKey);
+ var openAIAgent = new OpenAIChatAgent(
+ chatClient: openAIClient.GetChatClient(modelId),
+ name: "assistant")
+ .RegisterMessageConnector() // register message connector so it support AutoGen built-in message types like TextMessage.
+ .RegisterMiddleware(kernelPluginMiddleware) // register the middleware to handle the plugin functions
+ .RegisterPrintMessage(); // pretty print the message to the console
+ #endregion Use_plugin
+
+ #region Send_message
+ var toolAggregateMessage = await openAIAgent.SendAsync("Tell me the weather in Seattle");
+
+ // The aggregate message will be converted to [ToolCallMessage, ToolCallResultMessage] when flowing into the agent
+ // send the aggregated message to llm to generate the final response
+ var finalReply = await openAIAgent.SendAsync(toolAggregateMessage);
+ #endregion Send_message
+ }
+}
diff --git a/dotnet/sample/AutoGen.WebAPI.Sample/AutoGen.WebAPI.Sample.csproj b/dotnet/sample/AutoGen.WebAPI.Sample/AutoGen.WebAPI.Sample.csproj
new file mode 100644
index 00000000000..76675ba1234
--- /dev/null
+++ b/dotnet/sample/AutoGen.WebAPI.Sample/AutoGen.WebAPI.Sample.csproj
@@ -0,0 +1,13 @@
+
+
+
+ $(TestTargetFrameworks)
+ enable
+ enable
+
+
+
+
+
+
+
diff --git a/dotnet/sample/AutoGen.WebAPI.Sample/Program.cs b/dotnet/sample/AutoGen.WebAPI.Sample/Program.cs
new file mode 100644
index 00000000000..dbeb8494363
--- /dev/null
+++ b/dotnet/sample/AutoGen.WebAPI.Sample/Program.cs
@@ -0,0 +1,45 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Program.cs
+
+using System.Runtime.CompilerServices;
+using AutoGen.Core;
+using AutoGen.WebAPI;
+
+var alice = new DummyAgent("alice");
+var bob = new DummyAgent("bob");
+
+var builder = WebApplication.CreateBuilder(args);
+// Add services to the container.
+
+// run endpoint at port 5000
+builder.WebHost.UseUrls("http://localhost:5000");
+var app = builder.Build();
+
+app.UseAgentAsOpenAIChatCompletionEndpoint(alice);
+app.UseAgentAsOpenAIChatCompletionEndpoint(bob);
+
+app.Run();
+
+public class DummyAgent : IStreamingAgent
+{
+ public DummyAgent(string name = "dummy")
+ {
+ Name = name;
+ }
+
+ public string Name { get; }
+
+ public async Task GenerateReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
+ {
+ return new TextMessage(Role.Assistant, $"I am dummy {this.Name}", this.Name);
+ }
+
+ public async IAsyncEnumerable GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ {
+ var reply = $"I am dummy {this.Name}";
+ foreach (var c in reply)
+ {
+ yield return new TextMessageUpdate(Role.Assistant, c.ToString(), this.Name);
+ };
+ }
+}
diff --git a/dotnet/src/AutoGen.Anthropic/Agent/AnthropicClientAgent.cs b/dotnet/src/AutoGen.Anthropic/Agent/AnthropicClientAgent.cs
new file mode 100644
index 00000000000..81fa8e6438a
--- /dev/null
+++ b/dotnet/src/AutoGen.Anthropic/Agent/AnthropicClientAgent.cs
@@ -0,0 +1,120 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// AnthropicClientAgent.cs
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Runtime.CompilerServices;
+using System.Threading;
+using System.Threading.Tasks;
+using AutoGen.Anthropic.DTO;
+using AutoGen.Core;
+
+namespace AutoGen.Anthropic;
+
+public class AnthropicClientAgent : IStreamingAgent
+{
+ private readonly AnthropicClient _anthropicClient;
+ public string Name { get; }
+ private readonly string _modelName;
+ private readonly string _systemMessage;
+ private readonly decimal _temperature;
+ private readonly int _maxTokens;
+ private readonly Tool[]? _tools;
+ private readonly ToolChoice? _toolChoice;
+
+ public AnthropicClientAgent(
+ AnthropicClient anthropicClient,
+ string name,
+ string modelName,
+ string systemMessage = "You are a helpful AI assistant",
+ decimal temperature = 0.7m,
+ int maxTokens = 1024,
+ Tool[]? tools = null,
+ ToolChoice? toolChoice = null)
+ {
+ Name = name;
+ _anthropicClient = anthropicClient;
+ _modelName = modelName;
+ _systemMessage = systemMessage;
+ _temperature = temperature;
+ _maxTokens = maxTokens;
+ _tools = tools;
+ _toolChoice = toolChoice;
+ }
+
+ public async Task GenerateReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null,
+ CancellationToken cancellationToken = default)
+ {
+ var response = await _anthropicClient.CreateChatCompletionsAsync(CreateParameters(messages, options, false), cancellationToken);
+ return new MessageEnvelope(response, from: this.Name);
+ }
+
+ public async IAsyncEnumerable GenerateStreamingReplyAsync(IEnumerable messages,
+ GenerateReplyOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ {
+ await foreach (var message in _anthropicClient.StreamingChatCompletionsAsync(
+ CreateParameters(messages, options, true), cancellationToken))
+ {
+ yield return new MessageEnvelope(message, from: this.Name);
+ }
+ }
+
+ private ChatCompletionRequest CreateParameters(IEnumerable messages, GenerateReplyOptions? options, bool shouldStream)
+ {
+ var chatCompletionRequest = new ChatCompletionRequest()
+ {
+ SystemMessage = [new SystemMessage { Text = _systemMessage }],
+ MaxTokens = options?.MaxToken ?? _maxTokens,
+ Model = _modelName,
+ Stream = shouldStream,
+ Temperature = (decimal?)options?.Temperature ?? _temperature,
+ Tools = _tools?.ToList(),
+ ToolChoice = _toolChoice ?? (_tools is { Length: > 0 } ? ToolChoice.Auto : null),
+ StopSequences = options?.StopSequence?.ToArray(),
+ };
+
+ chatCompletionRequest.Messages = BuildMessages(messages);
+
+ return chatCompletionRequest;
+ }
+
+ private List BuildMessages(IEnumerable messages)
+ {
+ List chatMessages = new();
+ foreach (IMessage? message in messages)
+ {
+ switch (message)
+ {
+ case IMessage chatMessage when chatMessage.Content.Role == "system":
+ throw new InvalidOperationException(
+ "system message has already been set and only one system message is supported. \"system\" role for input messages in the Message");
+
+ case IMessage chatMessage:
+ chatMessages.Add(chatMessage.Content);
+ break;
+
+ default:
+ throw new ArgumentException($"Unexpected message type: {message?.GetType()}");
+ }
+ }
+
+ // merge messages with the same role
+ // fixing #2884
+ var mergedMessages = chatMessages.Aggregate(new List(), (acc, message) =>
+ {
+ if (acc.Count > 0 && acc.Last().Role == message.Role)
+ {
+ acc.Last().Content.AddRange(message.Content);
+ }
+ else
+ {
+ acc.Add(message);
+ }
+
+ return acc;
+ });
+
+ return mergedMessages;
+ }
+}
diff --git a/dotnet/src/AutoGen.Anthropic/AnthropicClient.cs b/dotnet/src/AutoGen.Anthropic/AnthropicClient.cs
new file mode 100644
index 00000000000..f106e08d35c
--- /dev/null
+++ b/dotnet/src/AutoGen.Anthropic/AnthropicClient.cs
@@ -0,0 +1,202 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// AnthropicClient.cs
+
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Net.Http;
+using System.Runtime.CompilerServices;
+using System.Text;
+using System.Text.Json;
+using System.Text.Json.Serialization;
+using System.Threading;
+using System.Threading.Tasks;
+using AutoGen.Anthropic.Converters;
+using AutoGen.Anthropic.DTO;
+
+namespace AutoGen.Anthropic;
+
+public sealed class AnthropicClient : IDisposable
+{
+ private readonly HttpClient _httpClient;
+ private readonly string _baseUrl;
+
+ private static readonly JsonSerializerOptions JsonSerializerOptions = new()
+ {
+ DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
+ Converters =
+ {
+ new ContentBaseConverter(),
+ new JsonPropertyNameEnumConverter(),
+ new JsonPropertyNameEnumConverter(),
+ new SystemMessageConverter(),
+ }
+ };
+
+ public AnthropicClient(HttpClient httpClient, string baseUrl, string apiKey)
+ {
+ _httpClient = httpClient;
+ _baseUrl = baseUrl;
+
+ _httpClient.DefaultRequestHeaders.Add("x-api-key", apiKey);
+ _httpClient.DefaultRequestHeaders.Add("anthropic-version", "2023-06-01");
+ }
+
+ public async Task CreateChatCompletionsAsync(ChatCompletionRequest chatCompletionRequest,
+ CancellationToken cancellationToken)
+ {
+ var httpResponseMessage = await SendRequestAsync(chatCompletionRequest, cancellationToken);
+ var responseStream = await httpResponseMessage.Content.ReadAsStreamAsync();
+
+ if (httpResponseMessage.IsSuccessStatusCode)
+ {
+ return await DeserializeResponseAsync(responseStream, cancellationToken);
+ }
+
+ ErrorResponse res = await DeserializeResponseAsync(responseStream, cancellationToken);
+ throw new Exception(res.Error?.Message);
+ }
+
+ public async IAsyncEnumerable StreamingChatCompletionsAsync(
+ ChatCompletionRequest chatCompletionRequest, [EnumeratorCancellation] CancellationToken cancellationToken)
+ {
+ var httpResponseMessage = await SendRequestAsync(chatCompletionRequest, cancellationToken);
+ using var reader = new StreamReader(await httpResponseMessage.Content.ReadAsStreamAsync());
+
+ var currentEvent = new SseEvent();
+
+ while (await reader.ReadLineAsync() is { } line)
+ {
+ if (!string.IsNullOrEmpty(line))
+ {
+ if (line.StartsWith("event:"))
+ {
+ currentEvent.EventType = line.Substring("event:".Length).Trim();
+ }
+ else if (line.StartsWith("data:"))
+ {
+ currentEvent.Data = line.Substring("data:".Length).Trim();
+ }
+ }
+ else // an empty line indicates the end of an event
+ {
+ if (currentEvent.EventType == "content_block_start" && !string.IsNullOrEmpty(currentEvent.Data))
+ {
+ var dataBlock = JsonSerializer.Deserialize(currentEvent.Data!);
+ if (dataBlock != null && dataBlock.ContentBlock?.Type == "tool_use")
+ {
+ currentEvent.ContentBlock = dataBlock.ContentBlock;
+ }
+ }
+
+ if (currentEvent.EventType is "message_start" or "content_block_delta" or "message_delta" && currentEvent.Data != null)
+ {
+ var res = await JsonSerializer.DeserializeAsync(
+ new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data)),
+ cancellationToken: cancellationToken) ?? throw new Exception("Failed to deserialize response");
+ if (res.Delta?.Type == "input_json_delta" && !string.IsNullOrEmpty(res.Delta.PartialJson) &&
+ currentEvent.ContentBlock != null)
+ {
+ currentEvent.ContentBlock.AppendDeltaParameters(res.Delta.PartialJson!);
+ }
+ else if (res.Delta is { StopReason: "tool_use" } && currentEvent.ContentBlock != null)
+ {
+ if (res.Content == null)
+ {
+ res.Content = [currentEvent.ContentBlock.CreateToolUseContent()];
+ }
+ else
+ {
+ res.Content.Add(currentEvent.ContentBlock.CreateToolUseContent());
+ }
+
+ currentEvent = new SseEvent();
+ }
+
+ yield return res;
+ }
+ else if (currentEvent.EventType == "error" && currentEvent.Data != null)
+ {
+ var res = await JsonSerializer.DeserializeAsync(
+ new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data)), cancellationToken: cancellationToken);
+
+ throw new Exception(res?.Error?.Message);
+ }
+
+ if (currentEvent.ContentBlock == null)
+ {
+ currentEvent = new SseEvent();
+ }
+ }
+ }
+ }
+
+ private Task SendRequestAsync(T requestObject, CancellationToken cancellationToken)
+ {
+ var httpRequestMessage = new HttpRequestMessage(HttpMethod.Post, _baseUrl);
+ var jsonRequest = JsonSerializer.Serialize(requestObject, JsonSerializerOptions);
+ httpRequestMessage.Content = new StringContent(jsonRequest, Encoding.UTF8, "application/json");
+ httpRequestMessage.Headers.Add("anthropic-beta", "prompt-caching-2024-07-31");
+ return _httpClient.SendAsync(httpRequestMessage, cancellationToken);
+ }
+
+ private async Task DeserializeResponseAsync(Stream responseStream, CancellationToken cancellationToken)
+ {
+ return await JsonSerializer.DeserializeAsync(responseStream, JsonSerializerOptions, cancellationToken)
+ ?? throw new Exception("Failed to deserialize response");
+ }
+
+ public void Dispose()
+ {
+ _httpClient.Dispose();
+ }
+
+ private struct SseEvent
+ {
+ public string EventType { get; set; }
+ public string? Data { get; set; }
+ public ContentBlock? ContentBlock { get; set; }
+
+ public SseEvent(string eventType, string? data = null, ContentBlock? contentBlock = null)
+ {
+ EventType = eventType;
+ Data = data;
+ ContentBlock = contentBlock;
+ }
+ }
+
+ private class ContentBlock
+ {
+ [JsonPropertyName("type")]
+ public string? Type { get; set; }
+
+ [JsonPropertyName("id")]
+ public string? Id { get; set; }
+
+ [JsonPropertyName("name")]
+ public string? Name { get; set; }
+
+ [JsonPropertyName("input")]
+ public object? Input { get; set; }
+
+ public string? parameters { get; set; }
+
+ public void AppendDeltaParameters(string deltaParams)
+ {
+ StringBuilder sb = new StringBuilder(parameters);
+ sb.Append(deltaParams);
+ parameters = sb.ToString();
+ }
+
+ public ToolUseContent CreateToolUseContent()
+ {
+ return new ToolUseContent { Id = Id, Name = Name, Input = parameters };
+ }
+ }
+
+ private class DataBlock
+ {
+ [JsonPropertyName("content_block")]
+ public ContentBlock? ContentBlock { get; set; }
+ }
+}
diff --git a/dotnet/src/AutoGen.Anthropic/AutoGen.Anthropic.csproj b/dotnet/src/AutoGen.Anthropic/AutoGen.Anthropic.csproj
new file mode 100644
index 00000000000..a4fd32e7e34
--- /dev/null
+++ b/dotnet/src/AutoGen.Anthropic/AutoGen.Anthropic.csproj
@@ -0,0 +1,22 @@
+
+
+
+ $(PackageTargetFrameworks)
+ AutoGen.Anthropic
+
+
+
+
+
+
+ AutoGen.Anthropic
+
+ Provide support for consuming Anthropic models in AutoGen
+
+
+
+
+
+
+
+
diff --git a/dotnet/src/AutoGen.Anthropic/Converters/ContentBaseConverter.cs b/dotnet/src/AutoGen.Anthropic/Converters/ContentBaseConverter.cs
new file mode 100644
index 00000000000..3e620f934c2
--- /dev/null
+++ b/dotnet/src/AutoGen.Anthropic/Converters/ContentBaseConverter.cs
@@ -0,0 +1,39 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// ContentBaseConverter.cs
+
+using System;
+using System.Text.Json;
+using System.Text.Json.Serialization;
+using AutoGen.Anthropic.DTO;
+namespace AutoGen.Anthropic.Converters;
+
+public sealed class ContentBaseConverter : JsonConverter
+{
+ public override ContentBase Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
+ {
+ using var doc = JsonDocument.ParseValue(ref reader);
+ if (doc.RootElement.TryGetProperty("type", out JsonElement typeProperty) && !string.IsNullOrEmpty(typeProperty.GetString()))
+ {
+ string? type = typeProperty.GetString();
+ var text = doc.RootElement.GetRawText();
+ switch (type)
+ {
+ case "text":
+ return JsonSerializer.Deserialize(text, options) ?? throw new InvalidOperationException();
+ case "image":
+ return JsonSerializer.Deserialize(text, options) ?? throw new InvalidOperationException();
+ case "tool_use":
+ return JsonSerializer.Deserialize(text, options) ?? throw new InvalidOperationException();
+ case "tool_result":
+ return JsonSerializer.Deserialize(text, options) ?? throw new InvalidOperationException();
+ }
+ }
+
+ throw new JsonException("Unknown content type");
+ }
+
+ public override void Write(Utf8JsonWriter writer, ContentBase value, JsonSerializerOptions options)
+ {
+ JsonSerializer.Serialize(writer, value, value.GetType(), options);
+ }
+}
diff --git a/dotnet/src/AutoGen.Anthropic/Converters/JsonPropertyNameEnumCoverter.cs b/dotnet/src/AutoGen.Anthropic/Converters/JsonPropertyNameEnumCoverter.cs
new file mode 100644
index 00000000000..68b3c14bdee
--- /dev/null
+++ b/dotnet/src/AutoGen.Anthropic/Converters/JsonPropertyNameEnumCoverter.cs
@@ -0,0 +1,44 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// JsonPropertyNameEnumCoverter.cs
+
+using System;
+using System.Reflection;
+using System.Text.Json;
+using System.Text.Json.Serialization;
+
+namespace AutoGen.Anthropic.Converters;
+
+internal class JsonPropertyNameEnumConverter : JsonConverter where T : struct, Enum
+{
+ public override T Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
+ {
+ string value = reader.GetString() ?? throw new JsonException("Value was null.");
+
+ foreach (var field in typeToConvert.GetFields())
+ {
+ var attribute = field.GetCustomAttribute();
+ if (attribute?.Name == value)
+ {
+ return (T)Enum.Parse(typeToConvert, field.Name);
+ }
+ }
+
+ throw new JsonException($"Unable to convert \"{value}\" to enum {typeToConvert}.");
+ }
+
+ public override void Write(Utf8JsonWriter writer, T value, JsonSerializerOptions options)
+ {
+ var field = value.GetType().GetField(value.ToString());
+ var attribute = field?.GetCustomAttribute();
+
+ if (attribute != null)
+ {
+ writer.WriteStringValue(attribute.Name);
+ }
+ else
+ {
+ writer.WriteStringValue(value.ToString());
+ }
+ }
+}
+
diff --git a/dotnet/src/AutoGen.Anthropic/Converters/SystemMessageConverter.cs b/dotnet/src/AutoGen.Anthropic/Converters/SystemMessageConverter.cs
new file mode 100644
index 00000000000..5bbe8a3a37f
--- /dev/null
+++ b/dotnet/src/AutoGen.Anthropic/Converters/SystemMessageConverter.cs
@@ -0,0 +1,42 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// SystemMessageConverter.cs
+
+using System;
+using System.Text.Json;
+using System.Text.Json.Serialization;
+using AutoGen.Anthropic.DTO;
+
+namespace AutoGen.Anthropic.Converters;
+
+public class SystemMessageConverter : JsonConverter
+{
+ public override object Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
+ {
+ if (reader.TokenType == JsonTokenType.String)
+ {
+ return reader.GetString() ?? string.Empty;
+ }
+ if (reader.TokenType == JsonTokenType.StartArray)
+ {
+ return JsonSerializer.Deserialize(ref reader, options) ?? throw new InvalidOperationException();
+ }
+
+ throw new JsonException();
+ }
+
+ public override void Write(Utf8JsonWriter writer, object value, JsonSerializerOptions options)
+ {
+ if (value is string stringValue)
+ {
+ writer.WriteStringValue(stringValue);
+ }
+ else if (value is SystemMessage[] arrayValue)
+ {
+ JsonSerializer.Serialize(writer, arrayValue, options);
+ }
+ else
+ {
+ throw new JsonException();
+ }
+ }
+}
diff --git a/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionRequest.cs b/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionRequest.cs
new file mode 100644
index 00000000000..dfb86ef0af5
--- /dev/null
+++ b/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionRequest.cs
@@ -0,0 +1,93 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// ChatCompletionRequest.cs
+using System.Collections.Generic;
+using System.Text.Json.Serialization;
+
+namespace AutoGen.Anthropic.DTO;
+
+public class ChatCompletionRequest
+{
+ [JsonPropertyName("model")]
+ public string? Model { get; set; }
+
+ [JsonPropertyName("messages")]
+ public List Messages { get; set; }
+
+ [JsonPropertyName("system")]
+ public SystemMessage[]? SystemMessage { get; set; }
+
+ [JsonPropertyName("max_tokens")]
+ public int MaxTokens { get; set; }
+
+ [JsonPropertyName("metadata")]
+ public object? Metadata { get; set; }
+
+ [JsonPropertyName("stop_sequences")]
+ public string[]? StopSequences { get; set; }
+
+ [JsonPropertyName("stream")]
+ public bool? Stream { get; set; }
+
+ [JsonPropertyName("temperature")]
+ public decimal? Temperature { get; set; }
+
+ [JsonPropertyName("top_k")]
+ public int? TopK { get; set; }
+
+ [JsonPropertyName("top_p")]
+ public decimal? TopP { get; set; }
+
+ [JsonPropertyName("tools")]
+ public List? Tools { get; set; }
+
+ [JsonPropertyName("tool_choice")]
+ public ToolChoice? ToolChoice { get; set; }
+
+ public ChatCompletionRequest()
+ {
+ Messages = new List();
+ }
+}
+
+public class SystemMessage
+{
+ [JsonPropertyName("text")]
+ public string? Text { get; set; }
+
+ [JsonPropertyName("type")]
+ public string? Type { get; private set; } = "text";
+
+ [JsonPropertyName("cache_control")]
+ public CacheControl? CacheControl { get; set; }
+
+ public static SystemMessage CreateSystemMessage(string systemMessage) => new() { Text = systemMessage };
+
+ public static SystemMessage CreateSystemMessageWithCacheControl(string systemMessage) => new()
+ {
+ Text = systemMessage,
+ CacheControl = new CacheControl { Type = CacheControlType.Ephemeral }
+ };
+}
+
+public class ChatMessage
+{
+ [JsonPropertyName("role")]
+ public string Role { get; set; }
+
+ [JsonPropertyName("content")]
+ public List Content { get; set; }
+
+ public ChatMessage(string role, string content)
+ {
+ Role = role;
+ Content = new List() { new TextContent { Text = content } };
+ }
+
+ public ChatMessage(string role, List content)
+ {
+ Role = role;
+ Content = content;
+ }
+
+ public void AddContent(ContentBase content) => Content.Add(content);
+}
diff --git a/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionResponse.cs b/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionResponse.cs
new file mode 100644
index 00000000000..a142f2feacc
--- /dev/null
+++ b/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionResponse.cs
@@ -0,0 +1,97 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// ChatCompletionResponse.cs
+
+
+using System.Collections.Generic;
+using System.Text.Json.Serialization;
+
+namespace AutoGen.Anthropic.DTO;
+public class ChatCompletionResponse
+{
+ [JsonPropertyName("content")]
+ public List? Content { get; set; }
+
+ [JsonPropertyName("id")]
+ public string? Id { get; set; }
+
+ [JsonPropertyName("model")]
+ public string? Model { get; set; }
+
+ [JsonPropertyName("role")]
+ public string? Role { get; set; }
+
+ [JsonPropertyName("stop_reason")]
+ public string? StopReason { get; set; }
+
+ [JsonPropertyName("stop_sequence")]
+ public object? StopSequence { get; set; }
+
+ [JsonPropertyName("type")]
+ public string? Type { get; set; }
+
+ [JsonPropertyName("usage")]
+ public Usage? Usage { get; set; }
+
+ [JsonPropertyName("delta")]
+ public Delta? Delta { get; set; }
+
+ [JsonPropertyName("message")]
+ public StreamingMessage? streamingMessage { get; set; }
+}
+
+public class StreamingMessage
+{
+ [JsonPropertyName("id")]
+ public string? Id { get; set; }
+
+ [JsonPropertyName("type")]
+ public string? Type { get; set; }
+
+ [JsonPropertyName("role")]
+ public string? Role { get; set; }
+
+ [JsonPropertyName("model")]
+ public string? Model { get; set; }
+
+ [JsonPropertyName("stop_reason")]
+ public object? StopReason { get; set; }
+
+ [JsonPropertyName("stop_sequence")]
+ public object? StopSequence { get; set; }
+
+ [JsonPropertyName("usage")]
+ public Usage? Usage { get; set; }
+}
+
+public class Usage
+{
+ [JsonPropertyName("input_tokens")]
+ public int InputTokens { get; set; }
+
+ [JsonPropertyName("output_tokens")]
+ public int OutputTokens { get; set; }
+
+ [JsonPropertyName("cache_creation_input_tokens")]
+ public int CacheCreationInputTokens { get; set; }
+
+ [JsonPropertyName("cache_read_input_tokens")]
+ public int CacheReadInputTokens { get; set; }
+}
+
+public class Delta
+{
+ [JsonPropertyName("stop_reason")]
+ public string? StopReason { get; set; }
+
+ [JsonPropertyName("type")]
+ public string? Type { get; set; }
+
+ [JsonPropertyName("text")]
+ public string? Text { get; set; }
+
+ [JsonPropertyName("partial_json")]
+ public string? PartialJson { get; set; }
+
+ [JsonPropertyName("usage")]
+ public Usage? Usage { get; set; }
+}
diff --git a/dotnet/src/AutoGen.Anthropic/DTO/Content.cs b/dotnet/src/AutoGen.Anthropic/DTO/Content.cs
new file mode 100644
index 00000000000..ade913b827c
--- /dev/null
+++ b/dotnet/src/AutoGen.Anthropic/DTO/Content.cs
@@ -0,0 +1,95 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Content.cs
+
+using System.Text.Json.Nodes;
+using System.Text.Json.Serialization;
+using AutoGen.Anthropic.Converters;
+
+namespace AutoGen.Anthropic.DTO;
+
+public abstract class ContentBase
+{
+ [JsonPropertyName("type")]
+ public abstract string Type { get; }
+
+ [JsonPropertyName("cache_control")]
+ public CacheControl? CacheControl { get; set; }
+}
+
+public class TextContent : ContentBase
+{
+ [JsonPropertyName("type")]
+ public override string Type => "text";
+
+ [JsonPropertyName("text")]
+ public string? Text { get; set; }
+
+ public static TextContent CreateTextWithCacheControl(string text) => new()
+ {
+ Text = text,
+ CacheControl = new CacheControl { Type = CacheControlType.Ephemeral }
+ };
+}
+
+public class ImageContent : ContentBase
+{
+ [JsonPropertyName("type")]
+ public override string Type => "image";
+
+ [JsonPropertyName("source")]
+ public ImageSource? Source { get; set; }
+}
+
+public class ImageSource
+{
+ [JsonPropertyName("type")]
+ public string Type => "base64";
+
+ [JsonPropertyName("media_type")]
+ public string? MediaType { get; set; }
+
+ [JsonPropertyName("data")]
+ public string? Data { get; set; }
+}
+
+public class ToolUseContent : ContentBase
+{
+ [JsonPropertyName("type")]
+ public override string Type => "tool_use";
+
+ [JsonPropertyName("id")]
+ public string? Id { get; set; }
+
+ [JsonPropertyName("name")]
+ public string? Name { get; set; }
+
+ [JsonPropertyName("input")]
+ public JsonNode? Input { get; set; }
+}
+
+public class ToolResultContent : ContentBase
+{
+ [JsonPropertyName("type")]
+ public override string Type => "tool_result";
+
+ [JsonPropertyName("tool_use_id")]
+ public string? Id { get; set; }
+
+ [JsonPropertyName("content")]
+ public string? Content { get; set; }
+}
+
+public class CacheControl
+{
+ [JsonPropertyName("type")]
+ public CacheControlType Type { get; set; }
+
+ public static CacheControl Create() => new CacheControl { Type = CacheControlType.Ephemeral };
+}
+
+[JsonConverter(typeof(JsonPropertyNameEnumConverter))]
+public enum CacheControlType
+{
+ [JsonPropertyName("ephemeral")]
+ Ephemeral
+}
diff --git a/dotnet/src/AutoGen.Anthropic/DTO/ErrorResponse.cs b/dotnet/src/AutoGen.Anthropic/DTO/ErrorResponse.cs
new file mode 100644
index 00000000000..1a94334c88f
--- /dev/null
+++ b/dotnet/src/AutoGen.Anthropic/DTO/ErrorResponse.cs
@@ -0,0 +1,21 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// ErrorResponse.cs
+
+using System.Text.Json.Serialization;
+
+namespace AutoGen.Anthropic.DTO;
+
+public sealed class ErrorResponse
+{
+ [JsonPropertyName("error")]
+ public Error? Error { get; set; }
+}
+
+public sealed class Error
+{
+ [JsonPropertyName("Type")]
+ public string? Type { get; set; }
+
+ [JsonPropertyName("message")]
+ public string? Message { get; set; }
+}
diff --git a/dotnet/src/AutoGen.Anthropic/DTO/Tool.cs b/dotnet/src/AutoGen.Anthropic/DTO/Tool.cs
new file mode 100644
index 00000000000..3845c444592
--- /dev/null
+++ b/dotnet/src/AutoGen.Anthropic/DTO/Tool.cs
@@ -0,0 +1,43 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Tool.cs
+
+using System.Collections.Generic;
+using System.Text.Json.Serialization;
+
+namespace AutoGen.Anthropic.DTO;
+
+public class Tool
+{
+ [JsonPropertyName("name")]
+ public string? Name { get; set; }
+
+ [JsonPropertyName("description")]
+ public string? Description { get; set; }
+
+ [JsonPropertyName("input_schema")]
+ public InputSchema? InputSchema { get; set; }
+
+ [JsonPropertyName("cache_control")]
+ public CacheControl? CacheControl { get; set; }
+}
+
+public class InputSchema
+{
+ [JsonPropertyName("type")]
+ public string? Type { get; set; }
+
+ [JsonPropertyName("properties")]
+ public Dictionary? Properties { get; set; }
+
+ [JsonPropertyName("required")]
+ public List? Required { get; set; }
+}
+
+public class SchemaProperty
+{
+ [JsonPropertyName("type")]
+ public string? Type { get; set; }
+
+ [JsonPropertyName("description")]
+ public string? Description { get; set; }
+}
diff --git a/dotnet/src/AutoGen.Anthropic/DTO/ToolChoice.cs b/dotnet/src/AutoGen.Anthropic/DTO/ToolChoice.cs
new file mode 100644
index 00000000000..0a5c3790e1d
--- /dev/null
+++ b/dotnet/src/AutoGen.Anthropic/DTO/ToolChoice.cs
@@ -0,0 +1,39 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// ToolChoice.cs
+
+using System.Text.Json.Serialization;
+using AutoGen.Anthropic.Converters;
+
+namespace AutoGen.Anthropic.DTO;
+
+[JsonConverter(typeof(JsonPropertyNameEnumConverter))]
+public enum ToolChoiceType
+{
+ [JsonPropertyName("auto")]
+ Auto, // Default behavior
+
+ [JsonPropertyName("any")]
+ Any, // Use any provided tool
+
+ [JsonPropertyName("tool")]
+ Tool // Force a specific tool
+}
+
+public class ToolChoice
+{
+ [JsonPropertyName("type")]
+ public ToolChoiceType Type { get; set; }
+
+ [JsonPropertyName("name")]
+ public string? Name { get; set; }
+
+ private ToolChoice(ToolChoiceType type, string? name = null)
+ {
+ Type = type;
+ Name = name;
+ }
+
+ public static ToolChoice Auto => new(ToolChoiceType.Auto);
+ public static ToolChoice Any => new(ToolChoiceType.Any);
+ public static ToolChoice ToolUse(string name) => new(ToolChoiceType.Tool, name);
+}
diff --git a/dotnet/src/AutoGen.Anthropic/Extensions/AnthropicAgentExtension.cs b/dotnet/src/AutoGen.Anthropic/Extensions/AnthropicAgentExtension.cs
new file mode 100644
index 00000000000..35ea8ed190a
--- /dev/null
+++ b/dotnet/src/AutoGen.Anthropic/Extensions/AnthropicAgentExtension.cs
@@ -0,0 +1,34 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// AnthropicAgentExtension.cs
+
+using AutoGen.Anthropic.Middleware;
+using AutoGen.Core;
+
+namespace AutoGen.Anthropic.Extensions;
+
+public static class AnthropicAgentExtension
+{
+ ///
+ /// Register an to the
+ ///
+ /// the connector to use. If null, a new instance of will be created.
+ public static MiddlewareStreamingAgent RegisterMessageConnector(
+ this AnthropicClientAgent agent, AnthropicMessageConnector? connector = null)
+ {
+ connector ??= new AnthropicMessageConnector();
+
+ return agent.RegisterStreamingMiddleware(connector);
+ }
+
+ ///
+ /// Register an to the where T is
+ ///
+ /// the connector to use. If null, a new instance of will be created.
+ public static MiddlewareStreamingAgent RegisterMessageConnector(
+ this MiddlewareStreamingAgent agent, AnthropicMessageConnector? connector = null)
+ {
+ connector ??= new AnthropicMessageConnector();
+
+ return agent.RegisterStreamingMiddleware(connector);
+ }
+}
diff --git a/dotnet/src/AutoGen.Anthropic/Middleware/AnthropicMessageConnector.cs b/dotnet/src/AutoGen.Anthropic/Middleware/AnthropicMessageConnector.cs
new file mode 100644
index 00000000000..af06a054784
--- /dev/null
+++ b/dotnet/src/AutoGen.Anthropic/Middleware/AnthropicMessageConnector.cs
@@ -0,0 +1,285 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// AnthropicMessageConnector.cs
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Net.Http;
+using System.Runtime.CompilerServices;
+using System.Text.Json.Nodes;
+using System.Threading;
+using System.Threading.Tasks;
+using AutoGen.Anthropic.DTO;
+using AutoGen.Core;
+
+namespace AutoGen.Anthropic.Middleware;
+
+public class AnthropicMessageConnector : IStreamingMiddleware
+{
+ public string? Name => nameof(AnthropicMessageConnector);
+
+ public async Task InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default)
+ {
+ var messages = context.Messages;
+ var chatMessages = await ProcessMessageAsync(messages, agent);
+ var response = await agent.GenerateReplyAsync(chatMessages, context.Options, cancellationToken);
+
+ return response is IMessage chatMessage
+ ? PostProcessMessage(chatMessage.Content, agent)
+ : response;
+ }
+
+ public async IAsyncEnumerable InvokeAsync(MiddlewareContext context, IStreamingAgent agent,
+ [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ {
+ var messages = context.Messages;
+ var chatMessages = await ProcessMessageAsync(messages, agent);
+
+ await foreach (var reply in agent.GenerateStreamingReplyAsync(chatMessages, context.Options, cancellationToken))
+ {
+ if (reply is IMessage chatMessage)
+ {
+ var response = ProcessChatCompletionResponse(chatMessage, agent);
+ if (response is not null)
+ {
+ yield return response;
+ }
+ }
+ else
+ {
+ yield return reply;
+ }
+ }
+ }
+
+ private IMessage? ProcessChatCompletionResponse(IMessage chatMessage,
+ IStreamingAgent agent)
+ {
+ if (chatMessage.Content.Content is { Count: 1 } &&
+ chatMessage.Content.Content[0] is ToolUseContent toolUseContent)
+ {
+ return new ToolCallMessage(
+ toolUseContent.Name ??
+ throw new InvalidOperationException($"Expected {nameof(toolUseContent.Name)} to be specified"),
+ toolUseContent.Input?.ToString() ??
+ throw new InvalidOperationException($"Expected {nameof(toolUseContent.Input)} to be specified"),
+ from: agent.Name);
+ }
+
+ var delta = chatMessage.Content.Delta;
+ return delta != null && !string.IsNullOrEmpty(delta.Text)
+ ? new TextMessageUpdate(role: Role.Assistant, delta.Text, from: agent.Name)
+ : null;
+ }
+
+ private async Task> ProcessMessageAsync(IEnumerable messages, IAgent agent)
+ {
+ var processedMessages = new List();
+
+ foreach (var message in messages)
+ {
+ var processedMessage = message switch
+ {
+ TextMessage textMessage => ProcessTextMessage(textMessage, agent),
+
+ ImageMessage imageMessage =>
+ (MessageEnvelope[])[new MessageEnvelope