From c41f4c5785da4009a3d5f9c186160ba2f3ea382b Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Thu, 16 Jan 2025 12:06:10 -0800 Subject: [PATCH 1/5] init --- weave/trace/weave_init.py | 4 +- weave/trace_server/recording_trace_server.py | 90 ++++++++++++++++++++ 2 files changed, 93 insertions(+), 1 deletion(-) create mode 100644 weave/trace_server/recording_trace_server.py diff --git a/weave/trace/weave_init.py b/weave/trace/weave_init.py index f51d42d5018d..afdf24f7e8bf 100644 --- a/weave/trace/weave_init.py +++ b/weave/trace/weave_init.py @@ -3,6 +3,7 @@ from weave.trace import autopatch, errors, init_message, trace_sentry, weave_client from weave.trace.context import weave_client_context as weave_client_context from weave.trace_server import sqlite_trace_server +from weave.trace_server.recording_trace_server import RecordingTraceServer from weave.trace_server_bindings import remote_http_trace_server @@ -102,11 +103,12 @@ def init_weave( api_key = wandb_context.api_key remote_server = init_weave_get_server(api_key) + recording_server = RecordingTraceServer(remote_server) # from weave.trace_server.clickhouse_trace_server_batched import ClickHouseTraceServer # server = ClickHouseTraceServer(host="localhost") client = weave_client.WeaveClient( - entity_name, project_name, remote_server, ensure_project_exists + entity_name, project_name, recording_server, ensure_project_exists ) # If the project name was formatted by init, update the project name project_name = client.project diff --git a/weave/trace_server/recording_trace_server.py b/weave/trace_server/recording_trace_server.py new file mode 100644 index 000000000000..274f64cb368f --- /dev/null +++ b/weave/trace_server/recording_trace_server.py @@ -0,0 +1,90 @@ + + +import threading +import time +from collections import defaultdict +from datetime import datetime +from functools import wraps +from typing import Optional, TypedDict + +from weave.trace_server.trace_server_interface import TraceServerInterface + + +class LogRecord(TypedDict): + timestamp: datetime + name: str + duration: float + error: Optional[str] + +class RecordingTraceServer(TraceServerInterface): + _next_ts: TraceServerInterface + _log: list[LogRecord] + + def __init__(self, next_ts: TraceServerInterface): + self._next_ts = next_ts + self._log: list[LogRecord] = [] + self._log_lock = threading.Lock() + + def __getattribute__(self, name): + protected_names = ["_next_ts", "_log", "_log_lock", "_thread_safe_log", "get_log", "summarize_logs", "reset_log"] + if name in protected_names: + return super().__getattribute__(name) + attr = self._next_ts.__getattribute__(name) + + if name.startswith("_") or not callable(attr): + return attr + + @wraps(attr) + def wrapper(*args, **kwargs): + now = datetime.now() + start = time.perf_counter() + try: + if name == "file_create": + print(args[0].name) + res = attr(*args, **kwargs) + end = time.perf_counter() + self._thread_safe_log(LogRecord(timestamp=now, name=name, duration=end - start)) + except Exception as e: + end = time.perf_counter() + self._thread_safe_log(LogRecord(timestamp=now, name=name, duration=end - start, error=str(e))) + raise e + return res + + return wrapper + + def _thread_safe_log(self, log: LogRecord): + with self._log_lock: + self._log.append(log) + + def get_log(self) -> list[LogRecord]: + # if isinstance(self._next_ts, RecordingTraceServer): + # next_log = self._next_ts.get_log() + # else: + # next_log = {} + return self._log + # return { + # "name": self._name, + # "log": self._log, + # # "next": next_log, + # } + + def reset_log(self): + with self._log_lock: + self._log = [] + + def summarize_logs(self) -> dict: + log_groups = defaultdict(list) + for log in self._log: + log_groups[log["name"]].append(log) + groups = {} + for name, logs in log_groups.items(): + total_duration = sum(log["duration"] for log in logs) + count = len(logs) + error_count = sum(1 for log in logs if log["error"] is not None) + groups[name] = { + "total_duration": total_duration, + "count": count, + "average_duration": total_duration / count, + "error_count": error_count, + } + return groups From af0c62e6c8a9961b34fd03e666cdbfdde2cb7edd Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Thu, 16 Jan 2025 12:06:23 -0800 Subject: [PATCH 2/5] init --- weave/trace_server/recording_trace_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/weave/trace_server/recording_trace_server.py b/weave/trace_server/recording_trace_server.py index 274f64cb368f..ce27592e8e35 100644 --- a/weave/trace_server/recording_trace_server.py +++ b/weave/trace_server/recording_trace_server.py @@ -80,7 +80,7 @@ def summarize_logs(self) -> dict: for name, logs in log_groups.items(): total_duration = sum(log["duration"] for log in logs) count = len(logs) - error_count = sum(1 for log in logs if log["error"] is not None) + error_count = sum(1 for log in logs if log.get("error") is not None) groups[name] = { "total_duration": total_duration, "count": count, From b0d988fb3de575d20455c2c5821bed2339f8ef10 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Thu, 16 Jan 2025 13:37:30 -0800 Subject: [PATCH 3/5] adding some testing --- tests/trace/test_remote_caching.py | 117 +++++++++++++++++++++++++++++ 1 file changed, 117 insertions(+) create mode 100644 tests/trace/test_remote_caching.py diff --git a/tests/trace/test_remote_caching.py b/tests/trace/test_remote_caching.py new file mode 100644 index 000000000000..9ea14ddcb9a9 --- /dev/null +++ b/tests/trace/test_remote_caching.py @@ -0,0 +1,117 @@ +import random +import time +from pprint import pprint + +import PIL +import pytest + +import weave +from weave.trace_server.recording_trace_server import RecordingTraceServer + + +def make_random_image(width: int, height: int) -> PIL.Image: + image = PIL.Image.new("RGB", (width, height)) + image.putdata([(int(255*random.random()), int(255*random.random()), int(255*random.random())) for _ in range(width*height)]) + return image + + +@pytest.mark.asyncio +async def test_dataset_perf(): + timings = {} + clock = time.perf_counter() + + client = weave.init("test-dataset-perf") + timings["init"] = time.perf_counter() - clock + + # Create the dataset + rows = [{ + "id": i, + "image_0": make_random_image(1024, 1024), + "image_1": make_random_image(1024, 1024), + "truth": i % 2, + } for i in range(5)] + + clock = time.perf_counter() + dataset = weave.Dataset(rows=rows) + timings["dataset_create"] = time.perf_counter() - clock + + clock = time.perf_counter() + ref = weave.publish(dataset, "image_comparison_ds") + timings["publish"] = time.perf_counter() - clock + + # Next, load the dataset + uri_str = ref.uri() + ds_ref = weave.ref(uri_str) + clock = time.perf_counter() + ds = ds_ref.get() + timings["get"] = time.perf_counter() - clock + + # Next, construct the evaluation + class SimpleScorer(weave.Scorer): + @weave.op + def score(self, truth: int, output: int) -> int: + return truth == output + + # if isinstance(client.server, RecordingTraceServer): + # client._flush() + # pprint(client.server.get_log()) + # pprint(client.server.summarize_logs()) + # print("Pre-Eval CREATE; Resetting log") + # client.server.reset_log() + + # ds.rows = list(ds.rows) + + eval = weave.Evaluation( + dataset=ds, + scorers=[SimpleScorer()], + ) + + # Next, construct the model + class SimpleModel(weave.Model): + @weave.op() + async def invoke(self, image_0: PIL.Image, image_1: PIL.Image) -> dict: + # download images... + await self.play_matching_game( + image_a=image_0, image_b=image_1 + ) + + + @weave.op() + async def play_matching_game( + self, image_a: PIL.Image, image_b: PIL.Image + ) -> tuple[bool, str | None, str | None]: + model_a_num_similar = await self.get_similar_images(image_a) + + + async def get_similar_images(self, image: PIL.Image): + set_images = split_thumbail(thumbnail_image=image) + + + def split_thumbail(thumbnail_image: PIL.Image) -> list[PIL.Image]: + image = thumbnail_image.crop((10, 10, 10, 10)) + # @weave.op + # async def invoke(self, image_0: PIL.Image, image_1: PIL.Image) -> int: + # image_0.crop((0, 0, 10, 10)) + # image_1.crop((0, 0, 10, 10)) + # return 1 if random.random() > 0.5 else 0 + + + # if isinstance(client.server, RecordingTraceServer): + # client._flush() + # pprint(client.server.get_log()) + # pprint(client.server.summarize_logs()) + # print("Pre-Eval RUN; Resetting log") + # client.server.reset_log() + + # Next run the eval + clock = time.perf_counter() + res = await eval.evaluate(model=SimpleModel()) + timings["eval"] = time.perf_counter() - clock + pprint(res) + + if isinstance(client.server, RecordingTraceServer): + # pprint(client.server.get_log()) + pprint(client.server.summarize_logs()) + + print(timings) + assert False From e554ec493f7d73dc8b352f5dd363adfd85596ec5 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Thu, 16 Jan 2025 15:17:15 -0800 Subject: [PATCH 4/5] added caching layer --- pyproject.toml | 1 + tests/trace/test_remote_caching.py | 34 +-- weave/trace/weave_init.py | 6 +- weave/trace_server/recording_trace_server.py | 25 +- .../caching_middleware_trace_server.py | 222 ++++++++++++++++++ 5 files changed, 267 insertions(+), 21 deletions(-) create mode 100644 weave/trace_server_bindings/caching_middleware_trace_server.py diff --git a/pyproject.toml b/pyproject.toml index 78171a820193..adaa823c0a28 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ dependencies = [ "rich", # Used for special formatting of tables (should be made optional) "gql[aiohttp,requests]", # Used exclusively in wandb_api.py "jsonschema>=4.23.0", # Used by scorers for field validation + "diskcache==5.6.3", # Used for data caching ] [project.optional-dependencies] diff --git a/tests/trace/test_remote_caching.py b/tests/trace/test_remote_caching.py index 9ea14ddcb9a9..775668c9ed0a 100644 --- a/tests/trace/test_remote_caching.py +++ b/tests/trace/test_remote_caching.py @@ -11,7 +11,16 @@ def make_random_image(width: int, height: int) -> PIL.Image: image = PIL.Image.new("RGB", (width, height)) - image.putdata([(int(255*random.random()), int(255*random.random()), int(255*random.random())) for _ in range(width*height)]) + image.putdata( + [ + ( + int(255 * random.random()), + int(255 * random.random()), + int(255 * random.random()), + ) + for _ in range(width * height) + ] + ) return image @@ -24,12 +33,15 @@ async def test_dataset_perf(): timings["init"] = time.perf_counter() - clock # Create the dataset - rows = [{ - "id": i, - "image_0": make_random_image(1024, 1024), - "image_1": make_random_image(1024, 1024), - "truth": i % 2, - } for i in range(5)] + rows = [ + { + "id": i, + "image_0": make_random_image(1024, 1024), + "image_1": make_random_image(1024, 1024), + "truth": i % 2, + } + for i in range(5) + ] clock = time.perf_counter() dataset = weave.Dataset(rows=rows) @@ -71,10 +83,7 @@ class SimpleModel(weave.Model): @weave.op() async def invoke(self, image_0: PIL.Image, image_1: PIL.Image) -> dict: # download images... - await self.play_matching_game( - image_a=image_0, image_b=image_1 - ) - + await self.play_matching_game(image_a=image_0, image_b=image_1) @weave.op() async def play_matching_game( @@ -82,11 +91,9 @@ async def play_matching_game( ) -> tuple[bool, str | None, str | None]: model_a_num_similar = await self.get_similar_images(image_a) - async def get_similar_images(self, image: PIL.Image): set_images = split_thumbail(thumbnail_image=image) - def split_thumbail(thumbnail_image: PIL.Image) -> list[PIL.Image]: image = thumbnail_image.crop((10, 10, 10, 10)) # @weave.op @@ -95,7 +102,6 @@ def split_thumbail(thumbnail_image: PIL.Image) -> list[PIL.Image]: # image_1.crop((0, 0, 10, 10)) # return 1 if random.random() > 0.5 else 0 - # if isinstance(client.server, RecordingTraceServer): # client._flush() # pprint(client.server.get_log()) diff --git a/weave/trace/weave_init.py b/weave/trace/weave_init.py index afdf24f7e8bf..93898243100b 100644 --- a/weave/trace/weave_init.py +++ b/weave/trace/weave_init.py @@ -5,6 +5,9 @@ from weave.trace_server import sqlite_trace_server from weave.trace_server.recording_trace_server import RecordingTraceServer from weave.trace_server_bindings import remote_http_trace_server +from weave.trace_server_bindings.caching_middleware_trace_server import ( + CachingMiddlewareTraceServer, +) class InitializedClient: @@ -104,11 +107,12 @@ def init_weave( remote_server = init_weave_get_server(api_key) recording_server = RecordingTraceServer(remote_server) + caching_server = CachingMiddlewareTraceServer(recording_server) # from weave.trace_server.clickhouse_trace_server_batched import ClickHouseTraceServer # server = ClickHouseTraceServer(host="localhost") client = weave_client.WeaveClient( - entity_name, project_name, recording_server, ensure_project_exists + entity_name, project_name, caching_server, ensure_project_exists ) # If the project name was formatted by init, update the project name project_name = client.project diff --git a/weave/trace_server/recording_trace_server.py b/weave/trace_server/recording_trace_server.py index ce27592e8e35..829bb05f9f10 100644 --- a/weave/trace_server/recording_trace_server.py +++ b/weave/trace_server/recording_trace_server.py @@ -1,5 +1,3 @@ - - import threading import time from collections import defaultdict @@ -16,6 +14,7 @@ class LogRecord(TypedDict): duration: float error: Optional[str] + class RecordingTraceServer(TraceServerInterface): _next_ts: TraceServerInterface _log: list[LogRecord] @@ -26,7 +25,15 @@ def __init__(self, next_ts: TraceServerInterface): self._log_lock = threading.Lock() def __getattribute__(self, name): - protected_names = ["_next_ts", "_log", "_log_lock", "_thread_safe_log", "get_log", "summarize_logs", "reset_log"] + protected_names = [ + "_next_ts", + "_log", + "_log_lock", + "_thread_safe_log", + "get_log", + "summarize_logs", + "reset_log", + ] if name in protected_names: return super().__getattribute__(name) attr = self._next_ts.__getattribute__(name) @@ -43,10 +50,16 @@ def wrapper(*args, **kwargs): print(args[0].name) res = attr(*args, **kwargs) end = time.perf_counter() - self._thread_safe_log(LogRecord(timestamp=now, name=name, duration=end - start)) + self._thread_safe_log( + LogRecord(timestamp=now, name=name, duration=end - start) + ) except Exception as e: end = time.perf_counter() - self._thread_safe_log(LogRecord(timestamp=now, name=name, duration=end - start, error=str(e))) + self._thread_safe_log( + LogRecord( + timestamp=now, name=name, duration=end - start, error=str(e) + ) + ) raise e return res @@ -80,7 +93,7 @@ def summarize_logs(self) -> dict: for name, logs in log_groups.items(): total_duration = sum(log["duration"] for log in logs) count = len(logs) - error_count = sum(1 for log in logs if log.get("error") is not None) + error_count = sum(1 for log in logs if log.get("error") is not None) groups[name] = { "total_duration": total_duration, "count": count, diff --git a/weave/trace_server_bindings/caching_middleware_trace_server.py b/weave/trace_server_bindings/caching_middleware_trace_server.py new file mode 100644 index 000000000000..c93067efa4f4 --- /dev/null +++ b/weave/trace_server_bindings/caching_middleware_trace_server.py @@ -0,0 +1,222 @@ +from __future__ import annotations + +import logging +from collections.abc import Iterator +from pathlib import Path +from typing import Any, Callable + +import diskcache + +from weave.trace_server import trace_server_interface as tsi + +logger = logging.getLogger(__name__) + + +class CachingMiddlewareTraceServer(tsi.TraceServerInterface): + _next_trace_server: tsi.TraceServerInterface + _cache_prefix: str + + def __init__( + self, + next_trace_server: tsi.TraceServerInterface, + cache_dir: Path | None = None, # todo make this configurable + size_limit: int = 1_000_000_000, # 1GB - todo make this configurable + ): + self._next_trace_server = next_trace_server + + self._cache = diskcache.Cache(cache_dir, size_limit=size_limit) + + def _with_cache( + self, + namespace: str, + make_cache_key: Callable[[Any], str], + func: Callable[[Any], Any], + req: Any, + serialize: Callable[[Any], str], + deserialize: Callable[[Any], Any], + ) -> Any: + try: + cache_key = f"{namespace}_{make_cache_key(req)}" + except Exception as e: + logger.exception(f"Error creating cache key: {e}") + return func(req) + try: + cached_json_value = self._cache.get(cache_key) + if cached_json_value: + return deserialize(cached_json_value) + except Exception as e: + logger.exception(f"Error validating cached value: {e}") + self._cache.delete(cache_key) + res = func(req) + try: + json_value_to_cache = serialize(res) + self._cache.set(cache_key, json_value_to_cache) + except Exception as e: + logger.exception(f"Error caching value: {e}") + return res + + def _with_cache_generic(self, func, req, res_type: Type[tsi.BaseModel]): + return self._with_cache( + func.__name__, + lambda req: req.model_dump_json(), + func, + req, + lambda res: res.model_dump_json(), + lambda json_value: res_type.model_validate_json(json_value), + ) + + # Cacheable Methods: + def obj_read(self, req: tsi.ObjReadReq) -> tsi.ObjReadRes: + return self._with_cache_generic( + self._next_trace_server.obj_read, req, tsi.ObjReadRes + ) + + def table_query(self, req: tsi.TableQueryReq) -> tsi.TableQueryRes: + return self._with_cache_generic( + self._next_trace_server.table_query, req, tsi.TableQueryRes + ) + + def table_query_stream( + self, req: tsi.TableQueryReq + ) -> Iterator[tsi.TableRowSchema]: + # I am not sure the best way to cache the iterator here. TODO + return self._next_trace_server.table_query_stream(req) + + def table_query_stats(self, req: tsi.TableQueryStatsReq) -> tsi.TableQueryStatsRes: + return self._with_cache_generic( + self._next_trace_server.table_query_stats, req, tsi.TableQueryStatsRes + ) + + def refs_read_batch(self, req: tsi.RefsReadBatchReq) -> tsi.RefsReadBatchRes: + # This is a special case because we want to cache individual refs and only + # query for the ones that are not in the cache. + + # 1. Find the refs that are not in the cache + # 2. Query for the refs that are not in the cache + # 3. Cache the refs that are not in the cache + # 4. Return the re-composed response. + + final_results = [None] * len(req.refs) + needed_refs: list[str] = [] + needed_indices: list[int] = [] + + for i, ref in enumerate(req.refs): + existing_result = None + try: + existing_result = self._cache.get(ref) + except Exception as e: + logger.exception(f"Error getting cached value: {e}") + if existing_result: + final_results[i] = existing_result + else: + needed_refs.append(ref) + needed_indices.append(i) + + if needed_refs: + new_req = tsi.RefsReadBatchReq(refs=needed_refs) + needed_results = self._next_trace_server.refs_read_batch(new_req) + for i, val in zip(needed_indices, needed_results.vals): + final_results[i] = val + try: + self._cache.set(ref, val) + except Exception as e: + logger.exception(f"Error caching values: {e}") + + return tsi.RefsReadBatchRes(vals=final_results) + + def file_content_read(self, req: tsi.FileContentReadReq) -> tsi.FileContentReadRes: + return self._with_cache_generic( + self._next_trace_server.file_content_read, req, tsi.FileContentReadRes + ) + + # Remaining Un-cacheable Methods: + + # Call API + def call_start(self, req: tsi.CallStartReq) -> tsi.CallStartRes: + return self._next_trace_server.call_start(req) + + def call_end(self, req: tsi.CallEndReq) -> tsi.CallEndRes: + return self._next_trace_server.call_end(req) + + def call_read(self, req: tsi.CallReadReq) -> tsi.CallReadRes: + return self._next_trace_server.call_read(req) + + def calls_query(self, req: tsi.CallsQueryReq) -> tsi.CallsQueryRes: + return self._next_trace_server.calls_query(req) + + def calls_query_stream(self, req: tsi.CallsQueryReq) -> Iterator[tsi.CallSchema]: + return self._next_trace_server.calls_query_stream(req) + + def calls_delete(self, req: tsi.CallsDeleteReq) -> tsi.CallsDeleteRes: + return self._next_trace_server.calls_delete(req) + + def calls_query_stats(self, req: tsi.CallsQueryStatsReq) -> tsi.CallsQueryStatsRes: + return self._next_trace_server.calls_query_stats(req) + + def call_update(self, req: tsi.CallUpdateReq) -> tsi.CallUpdateRes: + return self._next_trace_server.call_update(req) + + # Op API + def op_create(self, req: tsi.OpCreateReq) -> tsi.OpCreateRes: + return self._next_trace_server.op_create(req) + + def op_read(self, req: tsi.OpReadReq) -> tsi.OpReadRes: + return self._next_trace_server.op_read(req) + + def ops_query(self, req: tsi.OpQueryReq) -> tsi.OpQueryRes: + return self._next_trace_server.ops_query(req) + + # Cost API + def cost_create(self, req: tsi.CostCreateReq) -> tsi.CostCreateRes: + return self._next_trace_server.cost_create(req) + + def cost_query(self, req: tsi.CostQueryReq) -> tsi.CostQueryRes: + return self._next_trace_server.cost_query(req) + + def cost_purge(self, req: tsi.CostPurgeReq) -> tsi.CostPurgeRes: + return self._next_trace_server.cost_purge(req) + + # Obj API + def obj_create(self, req: tsi.ObjCreateReq) -> tsi.ObjCreateRes: + return self._next_trace_server.obj_create(req) + + def objs_query(self, req: tsi.ObjQueryReq) -> tsi.ObjQueryRes: + return self._next_trace_server.objs_query(req) + + def obj_delete(self, req: tsi.ObjDeleteReq) -> tsi.ObjDeleteRes: + return self._next_trace_server.obj_delete(req) + + # Table API + def table_create(self, req: tsi.TableCreateReq) -> tsi.TableCreateRes: + return self._next_trace_server.table_create(req) + + def table_update(self, req: tsi.TableUpdateReq) -> tsi.TableUpdateRes: + return self._next_trace_server.table_update(req) + + # File API + def file_create(self, req: tsi.FileCreateReq) -> tsi.FileCreateRes: + return self._next_trace_server.file_create(req) + + def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes: + return self._next_trace_server.feedback_create(req) + + def feedback_query(self, req: tsi.FeedbackQueryReq) -> tsi.FeedbackQueryRes: + return self._next_trace_server.feedback_query(req) + + def feedback_purge(self, req: tsi.FeedbackPurgeReq) -> tsi.FeedbackPurgeRes: + return self._next_trace_server.feedback_purge(req) + + def feedback_replace(self, req: tsi.FeedbackReplaceReq) -> tsi.FeedbackReplaceRes: + return self._next_trace_server.feedback_replace(req) + + # Action API + def actions_execute_batch( + self, req: tsi.ActionsExecuteBatchReq + ) -> tsi.ActionsExecuteBatchRes: + return self._next_trace_server.actions_execute_batch(req) + + # Execute LLM API + def completions_create( + self, req: tsi.CompletionsCreateReq + ) -> tsi.CompletionsCreateRes: + return self._next_trace_server.completions_create(req) From 9b86079816f7ae0dd368862c767b249e1987aac3 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Thu, 16 Jan 2025 15:29:48 -0800 Subject: [PATCH 5/5] added caching layer 2 --- scripts/cache_test.py | 62 +++++++++++++++++++ .../caching_middleware_trace_server.py | 44 ++++++++++--- 2 files changed, 98 insertions(+), 8 deletions(-) create mode 100644 scripts/cache_test.py diff --git a/scripts/cache_test.py b/scripts/cache_test.py new file mode 100644 index 000000000000..c650cdcf8cd3 --- /dev/null +++ b/scripts/cache_test.py @@ -0,0 +1,62 @@ +import random +import time + +import PIL + +import weave + +weave.init("cache_test") + + +def make_random_image(width: int, height: int) -> PIL.Image: + image = PIL.Image.new("RGB", (width, height)) + image.putdata( + [ + ( + int(255 * random.random()), + int(255 * random.random()), + int(255 * random.random()), + ) + for _ in range(width * height) + ] + ) + return image + + +def make_random_dataset(num_rows: int) -> weave.Dataset: + # Create the dataset + rows = [ + { + "id": i, + "image_0": make_random_image(1024, 1024), + "truth": i % 2, + } + for i in range(5) + ] + + return weave.Dataset(rows=rows) + + +def do_experiment(): + ref = weave.publish(make_random_dataset(10), "test_dataset") + uri_str = ref.uri() + ds_ref = weave.ref(uri_str) + + for i in range(3): + clock = time.perf_counter() + ds = ds_ref.get() + print(f"Got dataset {i} in {time.perf_counter() - clock} seconds") + clock = time.perf_counter() + images = [r["image_0"] for r in ds.rows] + print(f"Got {len(images)} images in {time.perf_counter() - clock} seconds") + + +do_experiment() + +print("Changing") + +import os + +os.environ["WEAVE_USE_SERVER_CACHE"] = "false" + +do_experiment() diff --git a/weave/trace_server_bindings/caching_middleware_trace_server.py b/weave/trace_server_bindings/caching_middleware_trace_server.py index c93067efa4f4..31d18c46cbee 100644 --- a/weave/trace_server_bindings/caching_middleware_trace_server.py +++ b/weave/trace_server_bindings/caching_middleware_trace_server.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import os from collections.abc import Iterator from pathlib import Path from typing import Any, Callable @@ -26,6 +27,28 @@ def __init__( self._cache = diskcache.Cache(cache_dir, size_limit=size_limit) + def _safe_cache_get(self, key: str) -> Any: + try: + use_cache = os.getenv("WEAVE_USE_SERVER_CACHE", "true").lower() == "true" + if not use_cache: + return None + return self._cache.get(key) + except Exception as e: + logger.exception(f"Error getting cached value: {e}") + return None + + def _safe_cache_set(self, key: str, value: Any) -> None: + try: + return self._cache.set(key, value) + except Exception as e: + logger.exception(f"Error caching value: {e}") + + def _safe_cache_delete(self, key: str) -> None: + try: + self._cache.delete(key) + except Exception as e: + logger.exception(f"Error deleting cached value: {e}") + def _with_cache( self, namespace: str, @@ -41,21 +64,21 @@ def _with_cache( logger.exception(f"Error creating cache key: {e}") return func(req) try: - cached_json_value = self._cache.get(cache_key) + cached_json_value = self._safe_cache_get(cache_key) if cached_json_value: return deserialize(cached_json_value) except Exception as e: logger.exception(f"Error validating cached value: {e}") - self._cache.delete(cache_key) + self._safe_cache_delete(cache_key) res = func(req) try: json_value_to_cache = serialize(res) - self._cache.set(cache_key, json_value_to_cache) + self._safe_cache_set(cache_key, json_value_to_cache) except Exception as e: logger.exception(f"Error caching value: {e}") return res - def _with_cache_generic(self, func, req, res_type: Type[tsi.BaseModel]): + def _with_cache_generic(self, func, req, res_type: type[tsi.BaseModel]): return self._with_cache( func.__name__, lambda req: req.model_dump_json(), @@ -103,7 +126,7 @@ def refs_read_batch(self, req: tsi.RefsReadBatchReq) -> tsi.RefsReadBatchRes: for i, ref in enumerate(req.refs): existing_result = None try: - existing_result = self._cache.get(ref) + existing_result = self._safe_cache_get(ref) except Exception as e: logger.exception(f"Error getting cached value: {e}") if existing_result: @@ -118,15 +141,20 @@ def refs_read_batch(self, req: tsi.RefsReadBatchReq) -> tsi.RefsReadBatchRes: for i, val in zip(needed_indices, needed_results.vals): final_results[i] = val try: - self._cache.set(ref, val) + self._safe_cache_set(ref, val) except Exception as e: logger.exception(f"Error caching values: {e}") return tsi.RefsReadBatchRes(vals=final_results) def file_content_read(self, req: tsi.FileContentReadReq) -> tsi.FileContentReadRes: - return self._with_cache_generic( - self._next_trace_server.file_content_read, req, tsi.FileContentReadRes + return self._with_cache( + "file_content_read", + lambda req: req.model_dump_json(), + self._next_trace_server.file_content_read, + req, + lambda res: res.content, + lambda content: tsi.FileContentReadRes(content=content), ) # Remaining Un-cacheable Methods: