Skip to content

Commit

Permalink
Adding File-Like object support in CSV Agent Toolkit (langchain-ai#10409
Browse files Browse the repository at this point in the history
)

If loading a CSV from a direct or temporary source, loading the
file-like object (subclass of IOBase) directly allows the agent creation
process to succeed, instead of throwing a ValueError.

Added an additional elif and tweaked value error message.
Added test to validate this functionality.

Pandas from_csv supports this natively but this current implementation
only accepts strings or paths to files.
https://pandas.pydata.org/docs/user_guide/io.html#io-read-csv-table

---------

Co-authored-by: Harrison Chase <[email protected]>
Co-authored-by: Bagatur <[email protected]>
  • Loading branch information
3 people authored Sep 11, 2023
1 parent 999163f commit 50128c8
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 5 deletions.
11 changes: 6 additions & 5 deletions libs/langchain/langchain/agents/agent_toolkits/csv/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from io import IOBase
from typing import Any, List, Optional, Union

from langchain.agents.agent import AgentExecutor
Expand All @@ -7,7 +8,7 @@

def create_csv_agent(
llm: BaseLanguageModel,
path: Union[str, List[str]],
path: Union[str, IOBase, List[Union[str, IOBase]]],
pandas_kwargs: Optional[dict] = None,
**kwargs: Any,
) -> AgentExecutor:
Expand All @@ -20,14 +21,14 @@ def create_csv_agent(
)

_kwargs = pandas_kwargs or {}
if isinstance(path, str):
if isinstance(path, (str, IOBase)):
df = pd.read_csv(path, **_kwargs)
elif isinstance(path, list):
df = []
for item in path:
if not isinstance(item, str):
raise ValueError(f"Expected str, got {type(path)}")
if not isinstance(item, (str, IOBase)):
raise ValueError(f"Expected str or file-like object, got {type(path)}")
df.append(pd.read_csv(item, **_kwargs))
else:
raise ValueError(f"Expected str or list, got {type(path)}")
raise ValueError(f"Expected str, list, or file-like object, got {type(path)}")
return create_pandas_dataframe_agent(llm, df, **kwargs)
19 changes: 19 additions & 0 deletions libs/langchain/tests/integration_tests/agent/test_csv_agent.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import io
import re

import numpy as np
Expand Down Expand Up @@ -34,6 +35,15 @@ def csv_list(tmp_path_factory: TempPathFactory) -> DataFrame:
return [filename1, filename2]


@pytest.fixture(scope="module")
def csv_file_like(tmp_path_factory: TempPathFactory) -> io.BytesIO:
random_data = np.random.rand(4, 4)
df = DataFrame(random_data, columns=["name", "age", "food", "sport"])
buffer = io.BytesIO()
df.to_pickle(buffer)
return buffer


def test_csv_agent_creation(csv: str) -> None:
agent = create_csv_agent(OpenAI(temperature=0), csv)
assert isinstance(agent, AgentExecutor)
Expand All @@ -55,3 +65,12 @@ def test_multi_csv(csv_list: list) -> None:
result = re.search(r".*(6).*", response)
assert result is not None
assert result.group(1) is not None


def test_file_like(file_like: io.BytesIO) -> None:
agent = create_csv_agent(OpenAI(temperature=0), file_like, verbose=True)
assert isinstance(agent, AgentExecutor)
response = agent.run("How many rows in the csv? Give me a number.")
result = re.search(r".*(4).*", response)
assert result is not None
assert result.group(1) is not None

0 comments on commit 50128c8

Please sign in to comment.