Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(weave): Add Remote Server to Tests #3440

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 41 additions & 149 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import base64
import contextlib
import logging
import os
Expand All @@ -14,11 +13,14 @@
from fastapi.testclient import TestClient

import weave
from tests.conftest_lib.http_trace_server import (
FastAPIServer,
build_minimal_blind_authenticating_trace_server,
)
from tests.trace.util import DummyTestException
from weave.trace import autopatch, weave_client, weave_init
from weave.trace_server import (
clickhouse_trace_server_batched,
external_to_internal_trace_server_adapter,
sqlite_trace_server,
)
from weave.trace_server import environment as ts_env
Expand Down Expand Up @@ -217,124 +219,6 @@ def server_healthy(num_retries=1):
return server_healthy(num_retries=30)


class TwoWayMapping:
def __init__(self):
self._ext_to_int_map = {}
self._int_to_ext_map = {}

# Useful for testing to ensure caching is working
self.stats = {
"ext_to_int": {
"hits": 0,
"misses": 0,
},
"int_to_ext": {
"hits": 0,
"misses": 0,
},
}

def ext_to_int(self, key, default=None):
if key not in self._ext_to_int_map:
if default is None:
raise ValueError(f"Key {key} not found")
if default in self._int_to_ext_map:
raise ValueError(f"Default {default} already in use")
self._ext_to_int_map[key] = default
self._int_to_ext_map[default] = key
self.stats["ext_to_int"]["misses"] += 1
else:
self.stats["ext_to_int"]["hits"] += 1
return self._ext_to_int_map[key]

def int_to_ext(self, key, default):
if key not in self._int_to_ext_map:
if default is None:
raise ValueError(f"Key {key} not found")
if default in self._ext_to_int_map:
raise ValueError(f"Default {default} already in use")
self._int_to_ext_map[key] = default
self._ext_to_int_map[default] = key
self.stats["int_to_ext"]["misses"] += 1
else:
self.stats["int_to_ext"]["hits"] += 1
return self._int_to_ext_map[key]


def b64(s: str) -> str:
# Base64 encode the string
return base64.b64encode(s.encode("ascii")).decode("ascii")


class DummyIdConverter(external_to_internal_trace_server_adapter.IdConverter):
def __init__(self):
self._project_map = TwoWayMapping()
self._run_map = TwoWayMapping()
self._user_map = TwoWayMapping()

def ext_to_int_project_id(self, project_id: str) -> str:
return self._project_map.ext_to_int(project_id, b64(project_id))

def int_to_ext_project_id(self, project_id: str) -> typing.Optional[str]:
return self._project_map.int_to_ext(project_id, b64(project_id))

def ext_to_int_run_id(self, run_id: str) -> str:
return self._run_map.ext_to_int(run_id, b64(run_id) + ":" + run_id)

def int_to_ext_run_id(self, run_id: str) -> str:
exp = run_id.split(":")[1]
return self._run_map.int_to_ext(run_id, exp)

def ext_to_int_user_id(self, user_id: str) -> str:
return self._user_map.ext_to_int(user_id, b64(user_id))

def int_to_ext_user_id(self, user_id: str) -> str:
return self._user_map.int_to_ext(user_id, b64(user_id))


class TestOnlyUserInjectingExternalTraceServer(
external_to_internal_trace_server_adapter.ExternalTraceServer
):
def __init__(
self,
internal_trace_server: tsi.TraceServerInterface,
id_converter: external_to_internal_trace_server_adapter.IdConverter,
user_id: str,
):
super().__init__(internal_trace_server, id_converter)
self._user_id = user_id

def call_start(self, req: tsi.CallStartReq) -> tsi.CallStartRes:
req.start.wb_user_id = self._user_id
return super().call_start(req)

def calls_delete(self, req: tsi.CallsDeleteReq) -> tsi.CallsDeleteRes:
req.wb_user_id = self._user_id
return super().calls_delete(req)

def call_update(self, req: tsi.CallUpdateReq) -> tsi.CallUpdateRes:
req.wb_user_id = self._user_id
return super().call_update(req)

def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes:
req.wb_user_id = self._user_id
return super().feedback_create(req)

def cost_create(self, req: tsi.CostCreateReq) -> tsi.CostCreateRes:
req.wb_user_id = self._user_id
return super().cost_create(req)

def actions_execute_batch(
self, req: tsi.ActionsExecuteBatchReq
) -> tsi.ActionsExecuteBatchRes:
req.wb_user_id = self._user_id
return super().actions_execute_batch(req)

def obj_create(self, req: tsi.ObjCreateReq) -> tsi.ObjCreateRes:
req.obj.wb_user_id = self._user_id
return super().obj_create(req)


# https://docs.pytest.org/en/7.1.x/example/simple.html#pytest-current-test-environment-variable
def get_test_name():
return os.environ.get("PYTEST_CURRENT_TEST", " ").split(" ")[0]
Expand Down Expand Up @@ -481,59 +365,71 @@ def __getattribute__(self, name):
return ServerRecorder(server)


from fastapi.testclient import TestClient


@contextlib.contextmanager
def create_client(
request, autopatch_settings: typing.Optional[autopatch.AutopatchSettings] = None
) -> weave_init.InitializedClient:
) -> typing.Generator[weave_init.InitializedClient, None, None]:
inited_client = None
webserver = None
weave_server_flag = request.config.getoption("--weave-server")
server: tsi.TraceServerInterface
entity = "shawn"
project = "test-project"

if weave_server_flag == "prod":
yield weave_init.init_weave("dev_testing")

url = ""
if weave_server_flag == "sqlite":
sqlite_server = sqlite_trace_server.SqliteTraceServer(
"file::memory:?cache=shared"
)
sqlite_server.drop_tables()
sqlite_server.setup_tables()
server = TestOnlyUserInjectingExternalTraceServer(
sqlite_server, DummyIdConverter(), entity
fast_api_app = build_minimal_blind_authenticating_trace_server(
sqlite_server, entity
)
webserver = FastAPIServer(fast_api_app)
webserver.start()
url = str(webserver.base_url)
elif weave_server_flag == "clickhouse":
ch_server = clickhouse_trace_server_batched.ClickHouseTraceServer.from_env()
ch_server.ch_client.command("DROP DATABASE IF EXISTS db_management")
ch_server.ch_client.command("DROP DATABASE IF EXISTS default")
ch_server._run_migrations()
server = TestOnlyUserInjectingExternalTraceServer(
ch_server, DummyIdConverter(), entity
fast_api_app = build_minimal_blind_authenticating_trace_server(
ch_server, entity
)
webserver = FastAPIServer(fast_api_app)
webserver.start()
url = str(webserver.base_url)
elif weave_server_flag.startswith("http"):
remote_server = remote_http_trace_server.RemoteHTTPTraceServer(
weave_server_flag
)
server = remote_server
elif weave_server_flag == ("prod"):
inited_client = weave_init.init_weave("dev_testing")

if inited_client is None:
client = TestOnlyFlushingWeaveClient(
entity, project, make_server_recorder(server)
)
inited_client = weave_init.InitializedClient(client)
autopatch.autopatch(autopatch_settings)
url = weave_server_flag
server = remote_http_trace_server.RemoteHTTPTraceServer(url)
client = TestOnlyFlushingWeaveClient(
entity, project, make_server_recorder(server), False
)
inited_client = weave_init.InitializedClient(client)
autopatch.autopatch(autopatch_settings)

return inited_client
try:
yield inited_client
finally:
if webserver:
webserver.stop()
inited_client.reset()
autopatch.reset_autopatch()


@pytest.fixture()
def client(request):
"""This is the standard fixture used everywhere in tests to test end to end
client functionality"""
inited_client = create_client(request)
try:
with create_client(request) as inited_client:
yield inited_client.client
finally:
inited_client.reset()
autopatch.reset_autopatch()


@pytest.fixture()
Expand All @@ -542,12 +438,8 @@ def client_creator(request):

@contextlib.contextmanager
def client(autopatch_settings: typing.Optional[autopatch.AutopatchSettings] = None):
inited_client = create_client(request, autopatch_settings)
try:
with create_client(request, autopatch_settings) as inited_client:
yield inited_client.client
finally:
inited_client.reset()
autopatch.reset_autopatch()

yield client

Expand Down
Loading
Loading