Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(weave): Support creating a dataset from ops #3385

Open
wants to merge 11 commits into
base: andrew/ds-frame
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 69 additions & 1 deletion docs/docs/guides/core-types/datasets.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ This guide will show you how to:
- Download the latest version
- Iterate over examples

## Sample code
## Quickstart

<Tabs groupId="programming-language" queryString>
<TabItem value="python" label="Python" default>
Expand Down Expand Up @@ -68,3 +68,71 @@ This guide will show you how to:

</TabItem>
</Tabs>

## Alternate constructors

<Tabs groupId="programming-language" queryString>
<TabItem value="python" label="Python" default>
Datasets can also be constructed from common Weave objects like `Call`s, and popular python objects like `pandas.DataFrame`s.
<Tabs groupId="use-case">
<TabItem value="from-calls" label="From Calls">
This can be useful if you want to create an example from specific examples.

```python
@weave.op
def model(task: str) -> str:
return f"Now working on {task}"

res1, call1 = model.call(task="fetch")
res2, call2 = model.call(task="parse")

dataset = Dataset.from_calls([call1, call2])
# Now you can use the dataset to evaluate the model, etc.
```
</TabItem>

<TabItem value="from-op" label="From Op">
You can construct a `Dataset` using an `Op`. This will return a dataset containing all calls using that `Op`, which can be useful if you want to eval/monitor the `Op` over time.

```python
@weave.op
def model(task: str) -> str:
return f"Now working on {task}"

model(task="fetch")
model(task="parse")

dataset = Dataset.from_op(model) # Contains the two calls, fetch and parse
```
</TabItem>

<TabItem value="from-pandas" label="From Pandas">
You can also freely convert between `Dataset`s and `pandas.DataFrame`s.

```python
import pandas as pd

df = pd.DataFrame([
{'id': '0', 'sentence': "He no likes ice cream.", 'correction': "He doesn't like ice cream."},
{'id': '1', 'sentence': "She goed to the store.", 'correction': "She went to the store."},
{'id': '2', 'sentence': "They plays video games all day.", 'correction': "They play video games all day."}
])
dataset = Dataset.from_pandas(df)
df2 = dataset.to_pandas()

assert df.equals(df2)
```

</TabItem>

</Tabs>

</TabItem>
<TabItem value="typescript" label="TypeScript">

```typescript
This feature is not available in TypeScript yet. Stay tuned!
```

</TabItem>
</Tabs>
36 changes: 36 additions & 0 deletions tests/integrations/pandas-test/test_pandas.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pandas as pd

import weave
from weave import Dataset


def test_op_save_with_global_df(client):
Expand All @@ -20,3 +21,38 @@ def my_op(a: str) -> str:
call = list(my_op.calls())[0]
assert call.inputs == {"a": "d"}
assert call.output == "a"


def test_dataset(client):
rows = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}]
ds = Dataset(rows=rows)
df = ds.to_pandas()
assert df["a"].tolist() == [1, 3, 5]
assert df["b"].tolist() == [2, 4, 6]

df2 = pd.DataFrame(rows)
ds2 = Dataset.from_pandas(df2)
assert ds2.rows == rows
assert df.equals(df2)
assert ds.rows == ds2.rows


def test_calls_to_dataframe(client):
@weave.op
def greet(name: str, age: int) -> str:
return f"Hello, {name}! You are {age} years old."

greet("Alice", 30)
greet("Bob", 25)

calls = greet.calls()
dataset = Dataset.from_calls(calls)
df = dataset.to_pandas()
assert df["inputs"].tolist() == [
{"name": "Alice", "age": 30},
{"name": "Bob", "age": 25},
]
assert df["output"].tolist() == [
"Hello, Alice! You are 30 years old.",
"Hello, Bob! You are 25 years old.",
]
41 changes: 41 additions & 0 deletions tests/trace/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,44 @@ def test_dataset_iteration(client):
# Test that we can iterate multiple times
rows2 = list(dataset)
assert rows2 == rows


def test_dataset_from_calls(client):
@weave.op
def greet(name: str, age: int) -> str:
return f"Hello {name}, you are {age}!"

greet("Alice", 30)
greet("Bob", 25)

calls = client.get_calls()
dataset = weave.Dataset.from_calls(calls)
rows = list(dataset.rows)

assert len(rows) == 2
assert rows[0]["inputs"]["name"] == "Alice"
assert rows[0]["inputs"]["age"] == 30
assert rows[0]["output"] == "Hello Alice, you are 30!"
assert rows[1]["inputs"]["name"] == "Bob"
assert rows[1]["inputs"]["age"] == 25
assert rows[1]["output"] == "Hello Bob, you are 25!"


def test_dataset_from_op(client):
@weave.op
def greet(name: str, age: int) -> str:
return f"Hello {name}, you are {age}!"

greet("Alice", 30)
greet("Bob", 25)

dataset = weave.Dataset.from_op(greet)
rows = list(dataset.rows)

assert len(rows) == 2
assert rows[0]["inputs"]["name"] == "Alice"
assert rows[0]["inputs"]["age"] == 30
assert rows[0]["output"] == "Hello Alice, you are 30!"
assert rows[1]["inputs"]["name"] == "Bob"
assert rows[1]["inputs"]["age"] == 25
assert rows[1]["output"] == "Hello Bob, you are 25!"
33 changes: 31 additions & 2 deletions weave/flow/dataset.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from collections.abc import Iterator
from typing import Any
from collections.abc import Iterable, Iterator
from typing import TYPE_CHECKING, Any

from pydantic import field_validator
from typing_extensions import Self

import weave
from weave.flow.obj import Object
from weave.trace.op import Op
from weave.trace.vals import WeaveTable
from weave.trace.weave_client import Call

if TYPE_CHECKING:
import pandas as pd


def short_str(obj: Any, limit: int = 25) -> str:
Expand Down Expand Up @@ -42,6 +48,29 @@ class Dataset(Object):

rows: weave.Table

@classmethod
def from_op(cls, op: Op) -> Self:
calls = op.calls()
return cls.from_calls(calls)

@classmethod
def from_calls(cls, calls: Iterable[Call]) -> Self:
rows = [call.to_dict() for call in calls]
return cls(rows=rows)

@classmethod
def from_pandas(cls, df: "pd.DataFrame") -> Self:
rows = df.to_dict(orient="records")
return cls(rows=rows)

def to_pandas(self) -> "pd.DataFrame":
try:
import pandas as pd
except ImportError:
raise ImportError("pandas is required to use this method")

return pd.DataFrame(self.rows)

@field_validator("rows", mode="before")
def convert_to_table(cls, rows: Any) -> weave.Table:
if not isinstance(rows, weave.Table):
Expand Down
7 changes: 7 additions & 0 deletions weave/trace/weave_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,13 @@ class ApplyScorerSuccess:
wc._send_score_call(self, score_call, scorer_ref_uri)
return apply_scorer_result

def to_dict(self) -> dict:
d = {k: v for k, v in dataclasses.asdict(self).items() if not k.startswith("_")}
d["op_name"] = self.op_name
d["display_name"] = self.display_name

return d


def make_client_call(
entity: str, project: str, server_call: CallSchema, server: TraceServerInterface
Expand Down
Loading