Skip to content

Commit

Permalink
Merge branch 'dev' into tqdm-progress-bar-green
Browse files Browse the repository at this point in the history
  • Loading branch information
khoaguin authored May 21, 2024
2 parents 904f9c9 + 9939eaa commit e3dbf3c
Show file tree
Hide file tree
Showing 10 changed files with 141 additions and 32 deletions.
22 changes: 15 additions & 7 deletions packages/syft/src/syft/node/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from ..service.job.job_service import JobService
from ..service.job.job_stash import Job
from ..service.job.job_stash import JobStash
from ..service.job.job_stash import JobType
from ..service.log.log_service import LogService
from ..service.metadata.metadata_service import MetadataService
from ..service.metadata.node_metadata import NodeMetadataV3
Expand Down Expand Up @@ -1288,10 +1289,10 @@ def add_api_endpoint_execution_to_queue(

action = Action.from_api_endpoint_execution()
return self.add_queueitem_to_queue(
queue_item,
credentials,
action,
None,
queue_item=queue_item,
credentials=credentials,
action=action,
job_type=JobType.TWINAPIJOB,
)

def get_worker_pool_ref_by_name(
Expand Down Expand Up @@ -1360,16 +1361,22 @@ def add_action_to_queue(
)

return self.add_queueitem_to_queue(
queue_item, credentials, action, parent_job_id, user_id
queue_item=queue_item,
credentials=credentials,
action=action,
parent_job_id=parent_job_id,
user_id=user_id,
)

def add_queueitem_to_queue(
self,
*,
queue_item: QueueItem,
credentials: SyftVerifyKey,
action: Action | None = None,
parent_job_id: UID | None = None,
user_id: UID | None = None,
job_type: JobType = JobType.JOB,
) -> Job | SyftError:
log_id = UID()
role = self.get_role_for_credentials(credentials=credentials)
Expand Down Expand Up @@ -1403,6 +1410,7 @@ def add_queueitem_to_queue(
parent_job_id=parent_job_id,
action=action,
requested_by=user_id,
job_type=job_type,
)

# 🟡 TODO 36: Needs distributed lock
Expand Down Expand Up @@ -1505,8 +1513,8 @@ def add_api_call_to_queue(
worker_pool=worker_pool_ref,
)
return self.add_queueitem_to_queue(
queue_item,
api_call.credentials,
queue_item=queue_item,
credentials=api_call.credentials,
action=None,
parent_job_id=parent_job_id,
)
Expand Down
6 changes: 3 additions & 3 deletions packages/syft/src/syft/protocol/protocol_version.json
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@
"hash": "6a7cc7c2bb4dd234c1508b0af4d3b403cd3b7b427578a775bf80dc36891923ed",
"action": "remove"
},
"5": {
"version": 5,
"hash": "82ee08442b09797ed7a3710c31de633bb308b1d2215f51b58a3e01a4c201055d",
"6": {
"version": 6,
"hash": "865a2ed791b8abd20d76e9a6bfae7ae7dad51b5ebfd8ff728aab25af93fa5570",
"action": "add"
}
},
Expand Down
13 changes: 13 additions & 0 deletions packages/syft/src/syft/service/code/user_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
from ..response import SyftNotReady
from ..response import SyftSuccess
from ..response import SyftWarning
from ..user.user import UserView
from .code_parse import GlobalsVisitor
from .code_parse import LaunchJobVisitor
from .unparse import unparse
Expand Down Expand Up @@ -348,6 +349,18 @@ def _coll_repr_(self) -> dict[str, Any]:
"Submit time": str(self.submit_time),
}

@property
def user(self) -> UserView | SyftError:
api = APIRegistry.api_for(
node_uid=self.syft_node_location,
user_verify_key=self.user_verify_key,
)
if api is None:
return SyftError(
message=f"Can't access Syft API. You must login to {self.syft_node_location}"
)
return api.services.user.get_current_user()

@property
def status(self) -> UserCodeStatusCollection | SyftError:
# Clientside only
Expand Down
33 changes: 18 additions & 15 deletions packages/syft/src/syft/service/job/job_stash.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# stdlib
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from enum import Enum
import random
from string import Template
Expand Down Expand Up @@ -28,8 +29,9 @@
from ...store.document_store import QueryKeys
from ...store.document_store import UIDPartitionKey
from ...types.datetime import DateTime
from ...types.datetime import format_timedelta
from ...types.syft_object import SYFT_OBJECT_VERSION_2
from ...types.syft_object import SYFT_OBJECT_VERSION_5
from ...types.syft_object import SYFT_OBJECT_VERSION_6
from ...types.syft_object import SyftObject
from ...types.syncable_object import SyncableSyftObject
from ...types.uid import UID
Expand Down Expand Up @@ -73,10 +75,19 @@ def center_content(text: Any) -> str:
return center_div


@serializable()
class JobType(str, Enum):
JOB = "job"
TWINAPIJOB = "twinapijob"

def __str__(self) -> str:
return self.value


@serializable()
class Job(SyncableSyftObject):
__canonical_name__ = "JobItem"
__version__ = SYFT_OBJECT_VERSION_5
__version__ = SYFT_OBJECT_VERSION_6

id: UID
node_uid: UID
Expand All @@ -87,13 +98,16 @@ class Job(SyncableSyftObject):
parent_job_id: UID | None = None
n_iters: int | None = 0
current_iter: int | None = None
creation_time: str | None = Field(default_factory=lambda: str(datetime.now()))
creation_time: str | None = Field(
default_factory=lambda: str(datetime.now(tz=timezone.utc))
)
action: Action | None = None
job_pid: int | None = None
job_worker_id: UID | None = None
updated_at: DateTime | None = None
user_code_id: UID | None = None
requested_by: UID | None = None
job_type: JobType = JobType.JOB

__attr_searchable__ = ["parent_job_id", "job_worker_id", "status", "user_code_id"]
__repr_attrs__ = [
Expand Down Expand Up @@ -191,18 +205,7 @@ def eta_string(self) -> str | None:
):
return None

def format_timedelta(local_timedelta: timedelta) -> str:
total_seconds = int(local_timedelta.total_seconds())
hours, leftover = divmod(total_seconds, 3600)
minutes, seconds = divmod(leftover, 60)

hours_string = f"{hours}:" if hours != 0 else ""
minutes_string = f"{minutes}:".zfill(3)
seconds_string = f"{seconds}".zfill(2)

return f"{hours_string}{minutes_string}{seconds_string}"

now = datetime.now()
now = datetime.now(tz=timezone.utc)
time_passed = now - datetime.fromisoformat(self.creation_time)
iter_duration_seconds: float = time_passed.total_seconds() / self.current_iter
iters_remaining = self.n_iters - self.current_iter
Expand Down
8 changes: 7 additions & 1 deletion packages/syft/src/syft/service/sync/diff_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from ..code.user_code import UserCode
from ..code.user_code import UserCodeStatusCollection
from ..job.job_stash import Job
from ..job.job_stash import JobType
from ..log.log import SyftLog
from ..output.output_service import ExecutionOutput
from ..request.request import Request
Expand Down Expand Up @@ -1288,7 +1289,12 @@ def hierarchies(
# TODO: Figure out nested user codes, do we even need that?

root_ids.append(diff.object_id) # type: ignore
elif isinstance(diff_obj, Job) and diff_obj.parent_job_id is None: # type: ignore
elif (
isinstance(diff_obj, Job) # type: ignore
and diff_obj.parent_job_id is None
# ignore Job objects created by TwinAPIEndpoint
and diff_obj.job_type != JobType.TWINAPIJOB
):
root_ids.append(diff.object_id) # type: ignore

for root_uid in root_ids:
Expand Down
32 changes: 32 additions & 0 deletions packages/syft/src/syft/types/datetime.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# stdlib
from datetime import datetime
from datetime import timedelta
from functools import total_ordering
import re
from typing import Any
Expand Down Expand Up @@ -57,3 +58,34 @@ def __eq__(self, other: Any) -> bool:

def __lt__(self, other: Self) -> bool:
return self.utc_timestamp < other.utc_timestamp

def timedelta(self, other: "DateTime") -> timedelta:
utc_timestamp_delta = self.utc_timestamp - other.utc_timestamp
return timedelta(seconds=utc_timestamp_delta)


def format_timedelta(local_timedelta: timedelta) -> str:
total_seconds = int(local_timedelta.total_seconds())
hours, leftover = divmod(total_seconds, 3600)
minutes, seconds = divmod(leftover, 60)

hours_string = f"{hours}:" if hours != 0 else ""
minutes_string = f"{minutes}:".zfill(3)
seconds_string = f"{seconds}".zfill(2)

return f"{hours_string}{minutes_string}{seconds_string}"


def format_timedelta_human_readable(local_timedelta: timedelta) -> str:
# Returns a human-readable string representing the timedelta
units = [("day", 86400), ("hour", 3600), ("minute", 60), ("second", 1)]
total_seconds = int(local_timedelta.total_seconds())

for unit_name, unit_seconds in units:
unit_value, total_seconds = divmod(total_seconds, unit_seconds)
if unit_value > 0:
if unit_value == 1:
return f"{unit_value} {unit_name}"
else:
return f"{unit_value} {unit_name}s"
return "0 seconds"
2 changes: 2 additions & 0 deletions packages/syft/src/syft/types/syft_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,15 @@
SYFT_OBJECT_VERSION_3 = 3
SYFT_OBJECT_VERSION_4 = 4
SYFT_OBJECT_VERSION_5 = 5
SYFT_OBJECT_VERSION_6 = 6

supported_object_versions = [
SYFT_OBJECT_VERSION_1,
SYFT_OBJECT_VERSION_2,
SYFT_OBJECT_VERSION_3,
SYFT_OBJECT_VERSION_4,
SYFT_OBJECT_VERSION_5,
SYFT_OBJECT_VERSION_6,
]

HIGHEST_SYFT_OBJECT_VERSION = max(supported_object_versions)
Expand Down
52 changes: 48 additions & 4 deletions packages/syft/src/syft/util/notebook_ui/components/sync.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# stdlib
import datetime
from typing import Any

# third party
Expand All @@ -9,6 +10,10 @@
from ....service.code.user_code import UserCode
from ....service.job.job_stash import Job
from ....service.request.request import Request
from ....service.response import SyftError
from ....service.user.user import UserView
from ....types.datetime import DateTime
from ....types.datetime import format_timedelta_human_readable
from ....types.syft_object import SYFT_OBJECT_VERSION_1
from ....types.syft_object import SyftObject
from ..icons import Icon
Expand Down Expand Up @@ -101,6 +106,43 @@ def get_status_str(self) -> str:
return status.value
return "" # type: ignore

def get_updated_by(self) -> str:
# TODO replace with centralized SyftObject created/updated by attribute
if isinstance(self.object, Request):
email = self.object.requesting_user_email
if email is not None:
return f"Requested by {email}"

user_view: UserView | SyftError | None = None
if isinstance(self.object, UserCode):
user_view = self.object.user

if isinstance(user_view, UserView):
return f"Created by {user_view.email}"
return ""

def get_updated_delta_str(self) -> str:
# TODO replace with centralized SyftObject created/updated by attribute
if isinstance(self.object, Job):
# NOTE Job is not using DateTime for creation_time, so we need to handle it separately
time_str = self.object.creation_time
if time_str is not None:
t = datetime.datetime.fromisoformat(time_str)
delta = datetime.datetime.now(datetime.timezone.utc) - t
return f"{format_timedelta_human_readable(delta)} ago"

dt: DateTime | None = None
if isinstance(self.object, Request):
dt = self.object.request_time
if isinstance(self.object, UserCode):
dt = self.object.submit_time
if dt is not None:
delta = DateTime.now().timedelta(dt)
delta_str = format_timedelta_human_readable(delta)
return f"{delta_str} ago"

return ""

def to_html(self) -> str:
type_html = TypeLabel(object=self.object).to_html()

Expand All @@ -110,10 +152,12 @@ def to_html(self) -> str:
copy_text=str(self.object.id.id), max_width=60
).to_html()

updated_delta_str = "29m ago"
updated_by = "[email protected]"
updated_delta_str = self.get_updated_delta_str()
updated_by = self.get_updated_by()
status_str = self.get_status_str()
status_seperator = " • " if len(status_str) else ""
status_row = " • ".join(
s for s in [status_str, updated_by, updated_delta_str] if s
)
summary_html = f"""
<div style="display: flex; gap: 8px; justify-content: space-between; width: 100%; overflow: hidden; align-items: center;">
<div style="display: flex; gap: 8px; justify-content: start; align-items: center;">
Expand All @@ -123,7 +167,7 @@ def to_html(self) -> str:
</div>
<div style="display: table-row">
<span class='syncstate-col-footer'>
{status_str}{status_seperator}Updated by {updated_by} {updated_delta_str}
{status_row}
</span>
</div>
""" # noqa: E501
Expand Down
3 changes: 2 additions & 1 deletion packages/syft/tests/syft/service/jobs/job_stash_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# stdlib
from datetime import datetime
from datetime import timedelta
from datetime import timezone

# third party
import pytest
Expand Down Expand Up @@ -33,7 +34,7 @@ def test_eta_string(current_iter, n_iters, status, creation_time_delta, expected
node_uid=UID(),
n_iters=n_iters,
current_iter=current_iter,
creation_time=(datetime.now() - creation_time_delta).isoformat(),
creation_time=(datetime.now(tz=timezone.utc) - creation_time_delta).isoformat(),
status=status,
)

Expand Down
2 changes: 1 addition & 1 deletion tests/integration/local/twin_api_sync_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def compute(query):

# verify that ds cannot access private job
assert client_low_ds.api.services.job.get(private_job_id) is None
assert low_client.api.services.job.get(private_job_id) is not None
assert low_client.api.services.job.get(private_job_id) is None

# we only sync the mock function, we never sync the private function to the low side
mock_res = low_client.api.services.testapi.query.mock()
Expand Down

0 comments on commit e3dbf3c

Please sign in to comment.