From 46b1c049ada9bac6cca711e086a374880afe72f3 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Thu, 13 Jun 2024 19:01:22 +0200 Subject: [PATCH 1/7] change sync dep direction --- packages/syft/src/syft/client/syncing.py | 2 +- packages/syft/src/syft/service/code/user_code.py | 11 ++++++++--- packages/syft/src/syft/service/request/request.py | 2 -- packages/syft/src/syft/service/sync/diff_state.py | 2 +- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/packages/syft/src/syft/client/syncing.py b/packages/syft/src/syft/client/syncing.py index 25dea439e8f..9f2cfafd85e 100644 --- a/packages/syft/src/syft/client/syncing.py +++ b/packages/syft/src/syft/client/syncing.py @@ -143,7 +143,7 @@ def handle_sync_batch( obj_diff_batch.decision = decision sync_instructions = [] - for diff in obj_diff_batch.get_dependents(include_roots=True): + for diff in obj_diff_batch.get_dependencies(include_roots=True): # figure out the right verify key to share to # in case of a job with user code, share to user code owner # without user code, share to job owner diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index d87363a0020..b5749b6b591 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -125,6 +125,11 @@ class UserCodeStatusCollection(SyncableSyftObject): __repr_attrs__ = ["approved", "status_dict"] + # if len(output_history): {uid: approved}, + # if denied string is somewhere: {uid: denied} + # else: {uid: pending} + # - the object is completely different for l2/l0 + # - the interface is different (because we need context in backend to get output_history) status_dict: dict[NodeIdentity, tuple[UserCodeStatus, str]] = {} user_code_link: LinkedObject @@ -411,7 +416,7 @@ def user(self) -> UserView | SyftError: ) return api.services.user.get_by_verify_key(self.user_verify_key) - def _status_from_output_history( + def _compute_status_from_output_history( self, context: AuthedServiceContext | None = None ) -> UserCodeStatusCollection | SyftError: if context is None: @@ -458,7 +463,7 @@ def status(self) -> UserCodeStatusCollection | SyftError: return SyftError( message="Encountered a low side UserCode object with a status_link." ) - return self._status_from_output_history() + return self._compute_status_from_output_history() if self.status_link is None: return SyftError( @@ -475,7 +480,7 @@ def get_status( return SyftError( message="Encountered a low side UserCode object with a status_link." ) - return self._status_from_output_history(context) + return self._compute_status_from_output_history(context) if self.status_link is None: return SyftError( message="This UserCode does not have a status. Please contact the Admin." diff --git a/packages/syft/src/syft/service/request/request.py b/packages/syft/src/syft/service/request/request.py index 2778e747225..4f6ab2c5ad5 100644 --- a/packages/syft/src/syft/service/request/request.py +++ b/packages/syft/src/syft/service/request/request.py @@ -866,11 +866,9 @@ def get_sync_dependencies( self, context: AuthedServiceContext ) -> list[UID] | SyftError: dependencies = [] - code_id = self.code_id if isinstance(code_id, SyftError): return code_id - dependencies.append(code_id) return dependencies diff --git a/packages/syft/src/syft/service/sync/diff_state.py b/packages/syft/src/syft/service/sync/diff_state.py index 8781435e220..bb3840cb0bf 100644 --- a/packages/syft/src/syft/service/sync/diff_state.py +++ b/packages/syft/src/syft/service/sync/diff_state.py @@ -817,7 +817,7 @@ def _build_hierarchy_helper( global_diffs=obj_uid_to_diff, global_roots=root_ids, hierarchy_levels=levels, - dependencies=batch_dependencies, + dependencies=obj_dependencies, root_diff=obj_uid_to_diff[root_uid], low_node_uid=low_node_uid, high_node_uid=high_node_uid, From 0ac566cfcd37606a1e75bfe88e1d0a97551131f6 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Fri, 14 Jun 2024 17:29:36 +0200 Subject: [PATCH 2/7] add new type filters, fix widget checkbox naming --- packages/syft/src/syft/client/syncing.py | 23 +++++-- .../syft/src/syft/service/sync/diff_state.py | 61 +++++++++++-------- .../src/syft/service/sync/resolve_widget.py | 18 +++--- 3 files changed, 65 insertions(+), 37 deletions(-) diff --git a/packages/syft/src/syft/client/syncing.py b/packages/syft/src/syft/client/syncing.py index 9f2cfafd85e..103d1434126 100644 --- a/packages/syft/src/syft/client/syncing.py +++ b/packages/syft/src/syft/client/syncing.py @@ -1,5 +1,8 @@ # stdlib +# stdlib +from collections.abc import Sequence + # relative from ..abstract_node import NodeSideType from ..node.credentials import SyftVerifyKey @@ -24,7 +27,9 @@ def compare_states( include_ignored: bool = False, include_same: bool = False, filter_by_email: str | None = None, - filter_by_type: str | type | None = None, + include_types: Sequence[str | type] | None = None, + exclude_types: Sequence[str | type] | None = None, + _hide_usercode: bool = True, ) -> NodeDiff | SyftError: # NodeDiff if ( @@ -45,6 +50,11 @@ def compare_states( return SyftError( "Invalid node side types: can only compare a high and low node" ) + + if _hide_usercode: + exclude_types = exclude_types or [] + exclude_types.append("usercode") + return NodeDiff.from_sync_state( low_state=low_state, high_state=high_state, @@ -52,7 +62,8 @@ def compare_states( include_ignored=include_ignored, include_same=include_same, filter_by_email=filter_by_email, - filter_by_type=filter_by_type, + include_types=include_types, + exclude_types=exclude_types, ) @@ -62,7 +73,9 @@ def compare_clients( include_ignored: bool = False, include_same: bool = False, filter_by_email: str | None = None, - filter_by_type: type | None = None, + include_types: Sequence[str | type] | None = None, + exclude_types: Sequence[str | type] | None = None, + _hide_usercode: bool = True, ) -> NodeDiff | SyftError: from_state = from_client.get_sync_state() if isinstance(from_state, SyftError): @@ -78,7 +91,9 @@ def compare_clients( include_ignored=include_ignored, include_same=include_same, filter_by_email=filter_by_email, - filter_by_type=filter_by_type, + include_types=include_types, + exclude_types=exclude_types, + _hide_usercode=_hide_usercode, ) diff --git a/packages/syft/src/syft/service/sync/diff_state.py b/packages/syft/src/syft/service/sync/diff_state.py index bb3840cb0bf..24d89af2fd6 100644 --- a/packages/syft/src/syft/service/sync/diff_state.py +++ b/packages/syft/src/syft/service/sync/diff_state.py @@ -1,5 +1,6 @@ # stdlib from collections.abc import Callable +from collections.abc import Collection from collections.abc import Iterable from dataclasses import dataclass import enum @@ -14,7 +15,6 @@ # third party from loguru import logger import pandas as pd -from pydantic import model_validator from rich import box from rich.console import Console from rich.console import Group @@ -676,7 +676,7 @@ def status(self) -> str: return "NEW" batch_statuses = [ - diff.status for diff in self.get_dependents(include_roots=False) + diff.status for diff in self.get_dependencies(include_roots=False) ] if all(status == "SAME" for status in batch_statuses): return "SAME" @@ -765,6 +765,7 @@ def from_dependencies( cls, root_uid: UID, obj_dependencies: dict[UID, list[UID]], + obj_dependents: dict[UID, list[UID]], obj_uid_to_diff: dict[UID, ObjectDiff], root_ids: list[UID], low_node_uid: UID, @@ -809,15 +810,13 @@ def _build_hierarchy_helper( levels = [level for _, level in batch_uids] batch_uids = {uid for uid, _ in batch_uids} # type: ignore - batch_dependencies = { - uid: [d for d in obj_dependencies.get(uid, []) if d in batch_uids] - for uid in batch_uids - } + return cls( global_diffs=obj_uid_to_diff, global_roots=root_ids, hierarchy_levels=levels, dependencies=obj_dependencies, + dependents=obj_dependents, root_diff=obj_uid_to_diff[root_uid], low_node_uid=low_node_uid, high_node_uid=high_node_uid, @@ -910,15 +909,6 @@ def visual_hierarchy(self) -> tuple[type, dict]: else: raise ValueError(f"Unknown root type: {self.root.obj_type}") - @model_validator(mode="after") - def make_dependents(self) -> Self: - dependents: dict = {} - for parent, children in self.dependencies.items(): - for child in children: - dependents[child] = dependents.get(child, []) + [parent] - self.dependents = dependents - return self - @property def root(self) -> ObjectDiff: return self.root_diff @@ -1195,7 +1185,8 @@ def from_sync_state( include_ignored: bool = False, include_same: bool = False, filter_by_email: str | None = None, - filter_by_type: type | None = None, + include_types: Collection[type | str] | None = None, + exclude_types: Collection[type | str] | None = None, _include_node_status: bool = False, ) -> "NodeDiff": obj_uid_to_diff = {} @@ -1235,8 +1226,9 @@ def from_sync_state( ) obj_uid_to_diff[diff.object_id] = diff + # TODO move static methods to NodeDiff __init__ obj_dependencies = NodeDiff.dependencies_from_states(low_state, high_state) - all_batches = NodeDiff.hierarchies( + all_batches = NodeDiff._create_batches( low_state, high_state, obj_dependencies, @@ -1265,9 +1257,10 @@ def from_sync_state( res._filter( user_email=filter_by_email, - obj_type=filter_by_type, + include_types=include_types, include_ignored=include_ignored, include_same=include_same, + exclude_types=exclude_types, inplace=True, ) @@ -1400,7 +1393,7 @@ def _sort_batches(hierarchies: list[ObjectDiffBatch]) -> list[ObjectDiffBatch]: return sorted_hierarchies @staticmethod - def hierarchies( + def _create_batches( low_sync_state: SyncState, high_sync_state: SyncState, obj_dependencies: dict[UID, list[UID]], @@ -1424,10 +1417,17 @@ def hierarchies( ): root_ids.append(diff.object_id) # type: ignore + # Dependents are the reverse edges of the dependency graph + obj_dependents = {} + for parent, children in obj_dependencies.items(): + for child in children: + obj_dependents[child] = obj_dependencies.get(child, []) + [parent] + for root_uid in root_ids: batch = ObjectDiffBatch.from_dependencies( root_uid, obj_dependencies, + obj_dependents, obj_uid_to_diff, root_ids, low_sync_state.node_uid, @@ -1483,9 +1483,10 @@ def _apply_filters( def _filter( self, user_email: str | None = None, - obj_type: str | type | None = None, include_ignored: bool = False, include_same: bool = False, + include_types: Collection[str | type] | None = None, + exclude_types: Collection[type | str] | None = None, inplace: bool = True, ) -> Self: new_filters = [] @@ -1493,12 +1494,6 @@ def _filter( new_filters.append( NodeDiffFilter(FilterProperty.USER, user_email, operator.eq) ) - if obj_type is not None: - if isinstance(obj_type, type): - obj_type = obj_type.__name__ - new_filters.append( - NodeDiffFilter(FilterProperty.TYPE, obj_type, operator.eq) - ) if not include_ignored: new_filters.append( NodeDiffFilter(FilterProperty.IGNORED, True, operator.ne) @@ -1507,6 +1502,20 @@ def _filter( new_filters.append( NodeDiffFilter(FilterProperty.STATUS, "SAME", operator.ne) ) + if include_types is not None: + include_types_ = { + t.__name__ if isinstance(t, type) else t for t in include_types + } + new_filters.append( + NodeDiffFilter(FilterProperty.TYPE, include_types_, operator.contains) + ) + if exclude_types: + for exclude_type in exclude_types: + if isinstance(exclude_type, type): + exclude_type = exclude_type.__name__ + new_filters.append( + NodeDiffFilter(FilterProperty.TYPE, exclude_type, operator.ne) + ) return self._apply_filters(new_filters, inplace=inplace) diff --git a/packages/syft/src/syft/service/sync/resolve_widget.py b/packages/syft/src/syft/service/sync/resolve_widget.py index b9dadcb319e..09f35d53fcc 100644 --- a/packages/syft/src/syft/service/sync/resolve_widget.py +++ b/packages/syft/src/syft/service/sync/resolve_widget.py @@ -331,7 +331,7 @@ def create_accordion_css( def build_accordion( self, - accordion_body: widgets.Widget, + accordion_body: MainObjectDiffWidget, show_sync_checkbox: bool = True, show_share_private_checkbox: bool = True, ) -> VBox: @@ -368,8 +368,12 @@ def build_accordion( layout=Layout(flex="1"), ) + if isinstance(self.diff.non_empty_object, ActionObject): + share_data_description = "Share real data and approve" + else: + share_data_description = "Share real data" share_private_data_checkbox = Checkbox( - description="Sync Real Data", + description=share_data_description, layout=Layout(width="auto", margin="0 2px 0 0"), ) sync_checkbox = Checkbox( @@ -485,20 +489,20 @@ def batch_diff_widgets(self) -> list[CollapsableObjectDiffWidget]: return dependent_diff_widgets @property - def dependent_batch_diff_widgets(self) -> list[CollapsableObjectDiffWidget]: + def dependent_root_diff_widgets(self) -> list[CollapsableObjectDiffWidget]: dependencies = self.obj_diff_batch.get_dependencies( include_roots=True, include_batch_root=False ) other_roots = [ d for d in dependencies if d.object_id in self.obj_diff_batch.global_roots ] - dependent_root_diff_widgets = [ + widgets = [ CollapsableObjectDiffWidget( diff, direction=self.obj_diff_batch.sync_direction ) for diff in other_roots ] - return dependent_root_diff_widgets + return widgets @property def main_object_diff_widget(self) -> MainObjectDiffWidget: @@ -536,7 +540,7 @@ def build(self) -> VBox: self.id2widget = {} batch_diff_widgets = self.batch_diff_widgets - dependent_batch_diff_widgets = self.dependent_batch_diff_widgets + dependent_batch_diff_widgets = self.dependent_root_diff_widgets main_object_diff_widget = self.main_object_diff_widget self.id2widget[main_object_diff_widget.diff.object_id] = main_object_diff_widget @@ -572,7 +576,7 @@ def build(self) -> VBox: def sync_button(self) -> Button: sync_button = Button( - description="Sync Selected Changes", + description="Apply Selected Changes", style={ "text_color": "#464A91", "button_color": "transparent", From 066de7d58a91df22e0402e4bec4e05aa12a6a56e Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Mon, 17 Jun 2024 10:44:36 +0200 Subject: [PATCH 3/7] remove outdated test --- .../service/sync/sync_resolve_single_test.py | 33 ------------------- 1 file changed, 33 deletions(-) diff --git a/packages/syft/tests/syft/service/sync/sync_resolve_single_test.py b/packages/syft/tests/syft/service/sync/sync_resolve_single_test.py index adc6346fd10..868f9f5203d 100644 --- a/packages/syft/tests/syft/service/sync/sync_resolve_single_test.py +++ b/packages/syft/tests/syft/service/sync/sync_resolve_single_test.py @@ -8,7 +8,6 @@ from syft.client.sync_decision import SyncDecision from syft.client.syncing import compare_clients from syft.client.syncing import resolve -from syft.service.code.user_code import UserCode from syft.service.job.job_stash import Job from syft.service.request.request import RequestStatus from syft.service.response import SyftError @@ -181,38 +180,6 @@ def compute() -> int: assert len(diff.all_batches) == 2 -def test_forget_usercode(low_worker, high_worker): - low_client = low_worker.root_client - client_low_ds = low_worker.guest_client - high_client = high_worker.root_client - - @sy.syft_function_single_use() - def compute() -> int: - print("computing...") - return 42 - - _ = client_low_ds.code.request_code_execution(compute) - - diff_before, diff_after = compare_and_resolve( - from_client=low_client, to_client=high_client - ) - - run_and_deposit_result(high_client) - - def skip_if_user_code(diff): - if diff.root_type is UserCode: - return SyncDecision.IGNORE - return SyncDecision.SKIP - - diff_before, diff_after = compare_and_resolve( - from_client=low_client, - to_client=high_client, - decision_callback=skip_if_user_code, - ) - assert not diff_before.is_same - assert len(diff_after.batches) == 1 - - def test_request_code_execution_multiple(low_worker, high_worker): low_client = low_worker.root_client client_low_ds = low_worker.guest_client From c3a8aa9f7af9422a181bc462e571c31a3740fb97 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Mon, 17 Jun 2024 11:22:40 +0200 Subject: [PATCH 4/7] fix typing --- packages/syft/src/syft/client/syncing.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/packages/syft/src/syft/client/syncing.py b/packages/syft/src/syft/client/syncing.py index 103d1434126..903bcc6e515 100644 --- a/packages/syft/src/syft/client/syncing.py +++ b/packages/syft/src/syft/client/syncing.py @@ -1,7 +1,7 @@ # stdlib # stdlib -from collections.abc import Sequence +from collections.abc import Collection # relative from ..abstract_node import NodeSideType @@ -27,8 +27,8 @@ def compare_states( include_ignored: bool = False, include_same: bool = False, filter_by_email: str | None = None, - include_types: Sequence[str | type] | None = None, - exclude_types: Sequence[str | type] | None = None, + include_types: Collection[str | type] | None = None, + exclude_types: Collection[str | type] | None = None, _hide_usercode: bool = True, ) -> NodeDiff | SyftError: # NodeDiff @@ -73,8 +73,8 @@ def compare_clients( include_ignored: bool = False, include_same: bool = False, filter_by_email: str | None = None, - include_types: Sequence[str | type] | None = None, - exclude_types: Sequence[str | type] | None = None, + include_types: Collection[str | type] | None = None, + exclude_types: Collection[str | type] | None = None, _hide_usercode: bool = True, ) -> NodeDiff | SyftError: from_state = from_client.get_sync_state() From d4355475385abe17eae1884ed2d983abe6d16fe2 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Mon, 17 Jun 2024 15:36:33 +0200 Subject: [PATCH 5/7] add warning message --- packages/syft/src/syft/client/syncing.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/packages/syft/src/syft/client/syncing.py b/packages/syft/src/syft/client/syncing.py index 903bcc6e515..342df11b6d8 100644 --- a/packages/syft/src/syft/client/syncing.py +++ b/packages/syft/src/syft/client/syncing.py @@ -16,6 +16,7 @@ from ..service.sync.sync_state import SyncState from ..types.uid import UID from ..util.decorators import deprecated +from ..util.util import prompt_warning_message from .domain_client import DomainClient from .sync_decision import SyncDecision from .sync_decision import SyncDirection @@ -29,7 +30,7 @@ def compare_states( filter_by_email: str | None = None, include_types: Collection[str | type] | None = None, exclude_types: Collection[str | type] | None = None, - _hide_usercode: bool = True, + hide_usercode: bool = True, ) -> NodeDiff | SyftError: # NodeDiff if ( @@ -51,7 +52,11 @@ def compare_states( "Invalid node side types: can only compare a high and low node" ) - if _hide_usercode: + if hide_usercode: + prompt_warning_message( + "User code is hidden by default, as they are also part of the Request." + " If you want to include them, set hide_usercode=False." + ) exclude_types = exclude_types or [] exclude_types.append("usercode") @@ -75,7 +80,7 @@ def compare_clients( filter_by_email: str | None = None, include_types: Collection[str | type] | None = None, exclude_types: Collection[str | type] | None = None, - _hide_usercode: bool = True, + hide_usercode: bool = True, ) -> NodeDiff | SyftError: from_state = from_client.get_sync_state() if isinstance(from_state, SyftError): @@ -93,7 +98,7 @@ def compare_clients( filter_by_email=filter_by_email, include_types=include_types, exclude_types=exclude_types, - _hide_usercode=_hide_usercode, + hide_usercode=hide_usercode, ) From 44a1808e8b68b710802ed57141c28125d3d3a9b8 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Mon, 17 Jun 2024 18:03:23 +0200 Subject: [PATCH 6/7] cleanup --- packages/syft/src/syft/service/code/user_code.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index b2612704b0f..f09c41e0768 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -124,12 +124,6 @@ class UserCodeStatusCollection(SyncableSyftObject): __version__ = SYFT_OBJECT_VERSION_1 __repr_attrs__ = ["approved", "status_dict"] - - # if len(output_history): {uid: approved}, - # if denied string is somewhere: {uid: denied} - # else: {uid: pending} - # - the object is completely different for l2/l0 - # - the interface is different (because we need context in backend to get output_history) status_dict: dict[NodeIdentity, tuple[UserCodeStatus, str]] = {} user_code_link: LinkedObject From db1e59bb70e8c122da3d4b8c5e547bcff57b4d78 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Tue, 18 Jun 2024 13:05:43 +0200 Subject: [PATCH 7/7] fix test --- .../tests/syft/service/sync/sync_resolve_single_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/packages/syft/tests/syft/service/sync/sync_resolve_single_test.py b/packages/syft/tests/syft/service/sync/sync_resolve_single_test.py index 868f9f5203d..0bd022ae604 100644 --- a/packages/syft/tests/syft/service/sync/sync_resolve_single_test.py +++ b/packages/syft/tests/syft/service/sync/sync_resolve_single_test.py @@ -156,7 +156,7 @@ def compute() -> int: _ = client_low_ds.code.request_code_execution(compute) - diff = compare_clients(low_client, high_client) + diff = compare_clients(low_client, high_client, hide_usercode=False) assert len(diff.batches) == 2 # Request + UserCode assert len(diff.ignored_batches) == 0 @@ -165,7 +165,7 @@ def compute() -> int: res = diff[0].ignore() assert isinstance(res, SyftSuccess) - diff = compare_clients(low_client, high_client) + diff = compare_clients(low_client, high_client, hide_usercode=False) assert len(diff.batches) == 0 assert len(diff.ignored_batches) == 2 assert len(diff.all_batches) == 2 @@ -174,7 +174,7 @@ def compute() -> int: res = diff.ignored_batches[0].unignore() assert isinstance(res, SyftSuccess) - diff = compare_clients(low_client, high_client) + diff = compare_clients(low_client, high_client, hide_usercode=False) assert len(diff.batches) == 1 assert len(diff.ignored_batches) == 1 assert len(diff.all_batches) == 2