Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

draft: Tim/cache remote #3430

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ dependencies = [
"rich", # Used for special formatting of tables (should be made optional)
"gql[aiohttp,requests]", # Used exclusively in wandb_api.py
"jsonschema>=4.23.0", # Used by scorers for field validation
"diskcache==5.6.3", # Used for data caching
]

[project.optional-dependencies]
Expand Down
62 changes: 62 additions & 0 deletions scripts/cache_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import random
import time

import PIL

import weave

weave.init("cache_test")


def make_random_image(width: int, height: int) -> PIL.Image:
image = PIL.Image.new("RGB", (width, height))
image.putdata(
[
(
int(255 * random.random()),
int(255 * random.random()),
int(255 * random.random()),
)
for _ in range(width * height)
]
)
return image


def make_random_dataset(num_rows: int) -> weave.Dataset:
# Create the dataset
rows = [
{
"id": i,
"image_0": make_random_image(1024, 1024),
"truth": i % 2,
}
for i in range(5)
]

return weave.Dataset(rows=rows)


def do_experiment():
ref = weave.publish(make_random_dataset(10), "test_dataset")
uri_str = ref.uri()
ds_ref = weave.ref(uri_str)

for i in range(3):
clock = time.perf_counter()
ds = ds_ref.get()
print(f"Got dataset {i} in {time.perf_counter() - clock} seconds")
clock = time.perf_counter()
images = [r["image_0"] for r in ds.rows]
print(f"Got {len(images)} images in {time.perf_counter() - clock} seconds")


do_experiment()

print("Changing")

import os

os.environ["WEAVE_USE_SERVER_CACHE"] = "false"

do_experiment()
123 changes: 123 additions & 0 deletions tests/trace/test_remote_caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import random
import time
from pprint import pprint

import PIL
import pytest

import weave
from weave.trace_server.recording_trace_server import RecordingTraceServer


def make_random_image(width: int, height: int) -> PIL.Image:
image = PIL.Image.new("RGB", (width, height))
image.putdata(
[
(
int(255 * random.random()),
int(255 * random.random()),
int(255 * random.random()),
)
for _ in range(width * height)
]
)
return image


@pytest.mark.asyncio
async def test_dataset_perf():
timings = {}
clock = time.perf_counter()

client = weave.init("test-dataset-perf")
timings["init"] = time.perf_counter() - clock

# Create the dataset
rows = [
{
"id": i,
"image_0": make_random_image(1024, 1024),
"image_1": make_random_image(1024, 1024),
"truth": i % 2,
}
for i in range(5)
]

clock = time.perf_counter()
dataset = weave.Dataset(rows=rows)
timings["dataset_create"] = time.perf_counter() - clock

clock = time.perf_counter()
ref = weave.publish(dataset, "image_comparison_ds")
timings["publish"] = time.perf_counter() - clock

# Next, load the dataset
uri_str = ref.uri()
ds_ref = weave.ref(uri_str)
clock = time.perf_counter()
ds = ds_ref.get()
timings["get"] = time.perf_counter() - clock

# Next, construct the evaluation
class SimpleScorer(weave.Scorer):
@weave.op
def score(self, truth: int, output: int) -> int:
return truth == output

# if isinstance(client.server, RecordingTraceServer):
# client._flush()
# pprint(client.server.get_log())
# pprint(client.server.summarize_logs())
# print("Pre-Eval CREATE; Resetting log")
# client.server.reset_log()

# ds.rows = list(ds.rows)

eval = weave.Evaluation(
dataset=ds,
scorers=[SimpleScorer()],
)

# Next, construct the model
class SimpleModel(weave.Model):
@weave.op()
async def invoke(self, image_0: PIL.Image, image_1: PIL.Image) -> dict:
# download images...
await self.play_matching_game(image_a=image_0, image_b=image_1)

@weave.op()
async def play_matching_game(
self, image_a: PIL.Image, image_b: PIL.Image
) -> tuple[bool, str | None, str | None]:
model_a_num_similar = await self.get_similar_images(image_a)

async def get_similar_images(self, image: PIL.Image):
set_images = split_thumbail(thumbnail_image=image)

def split_thumbail(thumbnail_image: PIL.Image) -> list[PIL.Image]:
image = thumbnail_image.crop((10, 10, 10, 10))
# @weave.op
# async def invoke(self, image_0: PIL.Image, image_1: PIL.Image) -> int:
# image_0.crop((0, 0, 10, 10))
# image_1.crop((0, 0, 10, 10))
# return 1 if random.random() > 0.5 else 0

# if isinstance(client.server, RecordingTraceServer):
# client._flush()
# pprint(client.server.get_log())
# pprint(client.server.summarize_logs())
# print("Pre-Eval RUN; Resetting log")
# client.server.reset_log()

# Next run the eval
clock = time.perf_counter()
res = await eval.evaluate(model=SimpleModel())
timings["eval"] = time.perf_counter() - clock
pprint(res)

if isinstance(client.server, RecordingTraceServer):
# pprint(client.server.get_log())
pprint(client.server.summarize_logs())

print(timings)
assert False
8 changes: 7 additions & 1 deletion weave/trace/weave_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
from weave.trace import autopatch, errors, init_message, trace_sentry, weave_client
from weave.trace.context import weave_client_context as weave_client_context
from weave.trace_server import sqlite_trace_server
from weave.trace_server.recording_trace_server import RecordingTraceServer
from weave.trace_server_bindings import remote_http_trace_server
from weave.trace_server_bindings.caching_middleware_trace_server import (
CachingMiddlewareTraceServer,
)


class InitializedClient:
Expand Down Expand Up @@ -102,11 +106,13 @@ def init_weave(
api_key = wandb_context.api_key

remote_server = init_weave_get_server(api_key)
recording_server = RecordingTraceServer(remote_server)
caching_server = CachingMiddlewareTraceServer(recording_server)
# from weave.trace_server.clickhouse_trace_server_batched import ClickHouseTraceServer

# server = ClickHouseTraceServer(host="localhost")
client = weave_client.WeaveClient(
entity_name, project_name, remote_server, ensure_project_exists
entity_name, project_name, caching_server, ensure_project_exists
)
# If the project name was formatted by init, update the project name
project_name = client.project
Expand Down
103 changes: 103 additions & 0 deletions weave/trace_server/recording_trace_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import threading
import time
from collections import defaultdict
from datetime import datetime
from functools import wraps
from typing import Optional, TypedDict

from weave.trace_server.trace_server_interface import TraceServerInterface


class LogRecord(TypedDict):
timestamp: datetime
name: str
duration: float
error: Optional[str]


class RecordingTraceServer(TraceServerInterface):
_next_ts: TraceServerInterface
_log: list[LogRecord]

def __init__(self, next_ts: TraceServerInterface):
self._next_ts = next_ts
self._log: list[LogRecord] = []
self._log_lock = threading.Lock()

def __getattribute__(self, name):
protected_names = [
"_next_ts",
"_log",
"_log_lock",
"_thread_safe_log",
"get_log",
"summarize_logs",
"reset_log",
]
if name in protected_names:
return super().__getattribute__(name)
attr = self._next_ts.__getattribute__(name)

if name.startswith("_") or not callable(attr):
return attr

@wraps(attr)
def wrapper(*args, **kwargs):
now = datetime.now()
start = time.perf_counter()
try:
if name == "file_create":
print(args[0].name)
res = attr(*args, **kwargs)
end = time.perf_counter()
self._thread_safe_log(
LogRecord(timestamp=now, name=name, duration=end - start)
)
except Exception as e:
end = time.perf_counter()
self._thread_safe_log(
LogRecord(
timestamp=now, name=name, duration=end - start, error=str(e)
)
)
raise e
return res

return wrapper

def _thread_safe_log(self, log: LogRecord):
with self._log_lock:
self._log.append(log)

def get_log(self) -> list[LogRecord]:
# if isinstance(self._next_ts, RecordingTraceServer):
# next_log = self._next_ts.get_log()
# else:
# next_log = {}
return self._log
# return {
# "name": self._name,
# "log": self._log,
# # "next": next_log,
# }

def reset_log(self):
with self._log_lock:
self._log = []

def summarize_logs(self) -> dict:
log_groups = defaultdict(list)
for log in self._log:
log_groups[log["name"]].append(log)
groups = {}
for name, logs in log_groups.items():
total_duration = sum(log["duration"] for log in logs)
count = len(logs)
error_count = sum(1 for log in logs if log.get("error") is not None)
groups[name] = {
"total_duration": total_duration,
"count": count,
"average_duration": total_duration / count,
"error_count": error_count,
}
return groups
Loading
Loading