Skip to content

Commit

Permalink
[agent][property type] Change allowed_tools to Set as Duplicate doesn…
Browse files Browse the repository at this point in the history
…’t make sense (langchain-ai#3840)

- ActionAgent has a property called, `allowed_tools`, which is declared
as `List`. It stores all provided tools which is available to use during
agent action.
- This collection shouldn’t allow duplicates. The original datatype List
doesn’t make sense. Each tool should be unique. Even when there are
variants (assuming in the future), it would be named differently in
load_tools.


Test:
- confirm the functionality in an example by initializing an agent with
a list of 2 tools and confirm everything works.
```python3
def test_agent_chain_chat_bot():
	from langchain.agents import load_tools
	from langchain.agents import initialize_agent
	from langchain.agents import AgentType
	from langchain.chat_models import ChatOpenAI
	from langchain.llms import OpenAI
	from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper

	chat = ChatOpenAI(temperature=0)
	llm = OpenAI(temperature=0)
	tools = load_tools(["ddg-search", "llm-math"], llm=llm)

	agent = initialize_agent(tools, chat, agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True)
	agent.run("Who is Olivia Wilde's boyfriend? What is his current age raised to the 0.23 power?")
test_agent_chain_chat_bot()
```
Result:
<img width="863" alt="Screenshot 2023-05-01 at 7 58 11 PM"
src="https://user-images.githubusercontent.com/62768671/235572157-0937594c-ddfb-4760-acb2-aea4cacacd89.png">
  • Loading branch information
skcoirz authored May 2, 2023
1 parent c5cc09d commit ec21b71
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 13 deletions.
25 changes: 12 additions & 13 deletions langchain/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import time
from abc import abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union

import yaml
from pydantic import BaseModel, root_validator
Expand Down Expand Up @@ -46,8 +46,8 @@ def return_values(self) -> List[str]:
"""Return values of the agent."""
return ["output"]

def get_allowed_tools(self) -> Optional[List[str]]:
return None
def get_allowed_tools(self) -> Set[str]:
return set()

@abstractmethod
def plan(
Expand Down Expand Up @@ -178,8 +178,8 @@ def return_values(self) -> List[str]:
"""Return values of the agent."""
return ["output"]

def get_allowed_tools(self) -> Optional[List[str]]:
return None
def get_allowed_tools(self) -> Set[str]:
return set()

@abstractmethod
def plan(
Expand Down Expand Up @@ -372,9 +372,9 @@ class Agent(BaseSingleActionAgent):

llm_chain: LLMChain
output_parser: AgentOutputParser
allowed_tools: Optional[List[str]] = None
allowed_tools: Set[str] = set()

def get_allowed_tools(self) -> Optional[List[str]]:
def get_allowed_tools(self) -> Set[str]:
return self.allowed_tools

@property
Expand Down Expand Up @@ -607,12 +607,11 @@ def validate_tools(cls, values: Dict) -> Dict:
agent = values["agent"]
tools = values["tools"]
allowed_tools = agent.get_allowed_tools()
if allowed_tools is not None:
if set(allowed_tools) != set([tool.name for tool in tools]):
raise ValueError(
f"Allowed tools ({allowed_tools}) different than "
f"provided tools ({[tool.name for tool in tools]})"
)
if allowed_tools != set([tool.name for tool in tools]):
raise ValueError(
f"Allowed tools ({allowed_tools}) different than "
f"provided tools ({[tool.name for tool in tools]})"
)
return values

@root_validator()
Expand Down
1 change: 1 addition & 0 deletions tests/integration_tests/agent/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""All integration tests for agent."""
16 changes: 16 additions & 0 deletions tests/integration_tests/agent/test_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from langchain.agents.chat.base import ChatAgent
from langchain.llms.openai import OpenAI
from langchain.tools.ddg_search.tool import DuckDuckGoSearchRun


class TestAgent:
def test_agent_generation(self) -> None:
web_search = DuckDuckGoSearchRun()
tools = [web_search]
agent = ChatAgent.from_llm_and_tools(
ai_name="Tom",
ai_role="Assistant",
tools=tools,
llm=OpenAI(maxTokens=10),
)
assert agent.allowed_tools == set([web_search.name])

0 comments on commit ec21b71

Please sign in to comment.