Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(weave): Better Log Messages #3417

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions scripts/long_running_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import weave

client = weave.init("long_running_job")


@weave.op
def add_one(x: int) -> int:
return x + 1


with client.live_status(sec=1):
for i in range(100):
add_one(i)
print("done")
56 changes: 55 additions & 1 deletion weave/trace/weave_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
import platform
import re
import sys
from collections.abc import Iterator, Sequence
import threading
import time
from collections.abc import Generator, Iterator, Sequence
from concurrent.futures import Future
from contextlib import contextmanager
from functools import lru_cache
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -56,6 +59,7 @@
from weave.trace.table import Table
from weave.trace.util import deprecated, log_once
from weave.trace.vals import WeaveObject, WeaveTable, make_trace_obj
from weave.trace.weave_client_status import WeaveClientStatusState
from weave.trace_server.constants import MAX_DISPLAY_NAME_LENGTH, MAX_OBJECT_NAME_LENGTH
from weave.trace_server.ids import generate_id
from weave.trace_server.interface.feedback_types import RUNNABLE_FEEDBACK_TYPE_PREFIX
Expand Down Expand Up @@ -1277,6 +1281,55 @@ def query_costs(
)
return res.results

def get_status_state(self) -> WeaveClientStatusState:
call_processor_queue_size = 0
call_processor_worker_count = 0
remote_request_counter = {}

if isinstance(self.server, RemoteHTTPTraceServer):
if self.server.should_batch:
call_processor_queue_size = self.server.call_processor.queue.qsize()
call_processor_worker_count = 1
else:
call_processor_queue_size = 0
call_processor_worker_count = 1
remote_request_counter = self.server.remote_request_counter

return WeaveClientStatusState(
moment=datetime.datetime.now(datetime.timezone.utc),
future_exec_queue_size=len(self.future_executor._active_futures),
future_exec_worker_count=self.future_executor._executor._max_workers
if self.future_executor._executor
else 0,
call_processor_queue_size=call_processor_queue_size,
call_processor_worker_count=call_processor_worker_count,
remote_request_counter=remote_request_counter,
)

@contextmanager
def live_status(self, sec: int = 1) -> Generator[None, None, None]:
"""Context manager that prints client status every `sec` seconds in a background thread.

Args:
sec: Number of seconds between status updates
"""
stop_event = threading.Event()

def _print_status() -> None:
while not stop_event.is_set():
status = self.get_status_state()
print(status.model_dump_json(indent=2))
time.sleep(sec)

thread = threading.Thread(target=_print_status, daemon=True)
thread.start()

try:
yield
finally:
stop_event.set()
thread.join()

@trace_sentry.global_trace_sentry.watch()
def _send_score_call(
self,
Expand Down Expand Up @@ -1671,6 +1724,7 @@ def _ref_uri(self, name: str, version: str, path: str) -> str:
return ObjectRef(self.entity, self.project, name, version).uri()

def _flush(self) -> None:
# TODO: make this an env var
# Used to wait until all currently enqueued jobs are processed
if not self.future_executor._in_thread_context.get():
self.future_executor.flush()
Expand Down
12 changes: 12 additions & 0 deletions weave/trace/weave_client_status.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from datetime import datetime

from pydantic import BaseModel


class WeaveClientStatusState(BaseModel):
moment: datetime
future_exec_queue_size: int
future_exec_worker_count: int
call_processor_queue_size: int
call_processor_worker_count: int
remote_request_counter: dict[str, int]
14 changes: 10 additions & 4 deletions weave/trace_server_bindings/remote_http_trace_server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import io
import json
import logging
from collections import defaultdict
from collections.abc import Iterator
from typing import Any, Optional, Union, cast

Expand Down Expand Up @@ -100,6 +101,11 @@ def __init__(
self.call_processor = AsyncBatchProcessor(self._flush_calls)
self._auth: Optional[tuple[str, str]] = None
self.remote_request_bytes_limit = remote_request_bytes_limit
self.remote_request_counter: dict[str, int] = defaultdict(int)

def counted_post(self, url: str, *args: Any, **kwargs: Any) -> requests.Response:
self.remote_request_counter[url] += 1
return requests.post(url, *args, **kwargs)

def ensure_project_exists(
self, entity: str, project: str
Expand Down Expand Up @@ -157,7 +163,7 @@ def _flush_calls(
self._flush_calls(batch[split_idx:], _should_update_batch_size=False)
return

r = requests.post(
r = self.counted_post(
self.trace_server_url + "/call/upsert_batch",
data=encoded_data,
auth=self._auth,
Expand All @@ -184,7 +190,7 @@ def _generic_request_executor(
req: BaseModel,
stream: bool = False,
) -> requests.Response:
r = requests.post(
r = self.counted_post(
self.trace_server_url + url,
# `by_alias` is required since we have Mongo-style properties in the
# query models that are aliased to conform to start with `$`. Without
Expand Down Expand Up @@ -472,7 +478,7 @@ def refs_read_batch(
reraise=True,
)
def file_create(self, req: tsi.FileCreateReq) -> tsi.FileCreateRes:
r = requests.post(
r = self.counted_post(
self.trace_server_url + "/files/create",
auth=self._auth,
data={"project_id": req.project_id},
Expand All @@ -492,7 +498,7 @@ def file_create(self, req: tsi.FileCreateReq) -> tsi.FileCreateRes:
reraise=True,
)
def file_content_read(self, req: tsi.FileContentReadReq) -> tsi.FileContentReadRes:
r = requests.post(
r = self.counted_post(
self.trace_server_url + "/files/content",
json={"project_id": req.project_id, "digest": req.digest},
auth=self._auth,
Expand Down
Loading