diff --git a/packages/syft/src/syft/client/syncing.py b/packages/syft/src/syft/client/syncing.py index 156866b26ff..c0b4dd8196e 100644 --- a/packages/syft/src/syft/client/syncing.py +++ b/packages/syft/src/syft/client/syncing.py @@ -1,5 +1,8 @@ # stdlib +# stdlib +from collections.abc import Collection + # relative from ..abstract_node import NodeSideType from ..node.credentials import SyftVerifyKey @@ -13,6 +16,7 @@ from ..service.sync.sync_state import SyncState from ..types.uid import UID from ..util.decorators import deprecated +from ..util.util import prompt_warning_message from .domain_client import DomainClient from .sync_decision import SyncDecision from .sync_decision import SyncDirection @@ -24,7 +28,9 @@ def compare_states( include_ignored: bool = False, include_same: bool = False, filter_by_email: str | None = None, - filter_by_type: str | type | None = None, + include_types: Collection[str | type] | None = None, + exclude_types: Collection[str | type] | None = None, + hide_usercode: bool = True, ) -> NodeDiff | SyftError: # NodeDiff if ( @@ -45,6 +51,15 @@ def compare_states( return SyftError( "Invalid node side types: can only compare a high and low node" ) + + if hide_usercode: + prompt_warning_message( + "User code is hidden by default, as they are also part of the Request." + " If you want to include them, set hide_usercode=False." + ) + exclude_types = exclude_types or [] + exclude_types.append("usercode") + return NodeDiff.from_sync_state( low_state=low_state, high_state=high_state, @@ -52,7 +67,8 @@ def compare_states( include_ignored=include_ignored, include_same=include_same, filter_by_email=filter_by_email, - filter_by_type=filter_by_type, + include_types=include_types, + exclude_types=exclude_types, ) @@ -62,7 +78,9 @@ def compare_clients( include_ignored: bool = False, include_same: bool = False, filter_by_email: str | None = None, - filter_by_type: type | None = None, + include_types: Collection[str | type] | None = None, + exclude_types: Collection[str | type] | None = None, + hide_usercode: bool = True, ) -> NodeDiff | SyftError: from_state = from_client.get_sync_state() if isinstance(from_state, SyftError): @@ -78,7 +96,9 @@ def compare_clients( include_ignored=include_ignored, include_same=include_same, filter_by_email=filter_by_email, - filter_by_type=filter_by_type, + include_types=include_types, + exclude_types=exclude_types, + hide_usercode=hide_usercode, ) @@ -134,7 +154,7 @@ def handle_sync_batch( obj_diff_batch.decision = decision sync_instructions = [] - for diff in obj_diff_batch.get_dependents(include_roots=True): + for diff in obj_diff_batch.get_dependencies(include_roots=True): # figure out the right verify key to share to # in case of a job with user code, share to user code owner # without user code, share to job owner diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index f771cf9a9c5..b8b1b74b848 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -130,7 +130,6 @@ class UserCodeStatusCollection(SyncableSyftObject): __version__ = SYFT_OBJECT_VERSION_1 __repr_attrs__ = ["approved", "status_dict"] - status_dict: dict[NodeIdentity, tuple[UserCodeStatus, str]] = {} user_code_link: LinkedObject diff --git a/packages/syft/src/syft/service/request/request.py b/packages/syft/src/syft/service/request/request.py index 6ec41fd1bdb..882dd243ec4 100644 --- a/packages/syft/src/syft/service/request/request.py +++ b/packages/syft/src/syft/service/request/request.py @@ -895,11 +895,9 @@ def get_sync_dependencies( self, context: AuthedServiceContext ) -> list[UID] | SyftError: dependencies = [] - code_id = self.code_id if isinstance(code_id, SyftError): return code_id - dependencies.append(code_id) return dependencies diff --git a/packages/syft/src/syft/service/sync/diff_state.py b/packages/syft/src/syft/service/sync/diff_state.py index 8781435e220..24d89af2fd6 100644 --- a/packages/syft/src/syft/service/sync/diff_state.py +++ b/packages/syft/src/syft/service/sync/diff_state.py @@ -1,5 +1,6 @@ # stdlib from collections.abc import Callable +from collections.abc import Collection from collections.abc import Iterable from dataclasses import dataclass import enum @@ -14,7 +15,6 @@ # third party from loguru import logger import pandas as pd -from pydantic import model_validator from rich import box from rich.console import Console from rich.console import Group @@ -676,7 +676,7 @@ def status(self) -> str: return "NEW" batch_statuses = [ - diff.status for diff in self.get_dependents(include_roots=False) + diff.status for diff in self.get_dependencies(include_roots=False) ] if all(status == "SAME" for status in batch_statuses): return "SAME" @@ -765,6 +765,7 @@ def from_dependencies( cls, root_uid: UID, obj_dependencies: dict[UID, list[UID]], + obj_dependents: dict[UID, list[UID]], obj_uid_to_diff: dict[UID, ObjectDiff], root_ids: list[UID], low_node_uid: UID, @@ -809,15 +810,13 @@ def _build_hierarchy_helper( levels = [level for _, level in batch_uids] batch_uids = {uid for uid, _ in batch_uids} # type: ignore - batch_dependencies = { - uid: [d for d in obj_dependencies.get(uid, []) if d in batch_uids] - for uid in batch_uids - } + return cls( global_diffs=obj_uid_to_diff, global_roots=root_ids, hierarchy_levels=levels, - dependencies=batch_dependencies, + dependencies=obj_dependencies, + dependents=obj_dependents, root_diff=obj_uid_to_diff[root_uid], low_node_uid=low_node_uid, high_node_uid=high_node_uid, @@ -910,15 +909,6 @@ def visual_hierarchy(self) -> tuple[type, dict]: else: raise ValueError(f"Unknown root type: {self.root.obj_type}") - @model_validator(mode="after") - def make_dependents(self) -> Self: - dependents: dict = {} - for parent, children in self.dependencies.items(): - for child in children: - dependents[child] = dependents.get(child, []) + [parent] - self.dependents = dependents - return self - @property def root(self) -> ObjectDiff: return self.root_diff @@ -1195,7 +1185,8 @@ def from_sync_state( include_ignored: bool = False, include_same: bool = False, filter_by_email: str | None = None, - filter_by_type: type | None = None, + include_types: Collection[type | str] | None = None, + exclude_types: Collection[type | str] | None = None, _include_node_status: bool = False, ) -> "NodeDiff": obj_uid_to_diff = {} @@ -1235,8 +1226,9 @@ def from_sync_state( ) obj_uid_to_diff[diff.object_id] = diff + # TODO move static methods to NodeDiff __init__ obj_dependencies = NodeDiff.dependencies_from_states(low_state, high_state) - all_batches = NodeDiff.hierarchies( + all_batches = NodeDiff._create_batches( low_state, high_state, obj_dependencies, @@ -1265,9 +1257,10 @@ def from_sync_state( res._filter( user_email=filter_by_email, - obj_type=filter_by_type, + include_types=include_types, include_ignored=include_ignored, include_same=include_same, + exclude_types=exclude_types, inplace=True, ) @@ -1400,7 +1393,7 @@ def _sort_batches(hierarchies: list[ObjectDiffBatch]) -> list[ObjectDiffBatch]: return sorted_hierarchies @staticmethod - def hierarchies( + def _create_batches( low_sync_state: SyncState, high_sync_state: SyncState, obj_dependencies: dict[UID, list[UID]], @@ -1424,10 +1417,17 @@ def hierarchies( ): root_ids.append(diff.object_id) # type: ignore + # Dependents are the reverse edges of the dependency graph + obj_dependents = {} + for parent, children in obj_dependencies.items(): + for child in children: + obj_dependents[child] = obj_dependencies.get(child, []) + [parent] + for root_uid in root_ids: batch = ObjectDiffBatch.from_dependencies( root_uid, obj_dependencies, + obj_dependents, obj_uid_to_diff, root_ids, low_sync_state.node_uid, @@ -1483,9 +1483,10 @@ def _apply_filters( def _filter( self, user_email: str | None = None, - obj_type: str | type | None = None, include_ignored: bool = False, include_same: bool = False, + include_types: Collection[str | type] | None = None, + exclude_types: Collection[type | str] | None = None, inplace: bool = True, ) -> Self: new_filters = [] @@ -1493,12 +1494,6 @@ def _filter( new_filters.append( NodeDiffFilter(FilterProperty.USER, user_email, operator.eq) ) - if obj_type is not None: - if isinstance(obj_type, type): - obj_type = obj_type.__name__ - new_filters.append( - NodeDiffFilter(FilterProperty.TYPE, obj_type, operator.eq) - ) if not include_ignored: new_filters.append( NodeDiffFilter(FilterProperty.IGNORED, True, operator.ne) @@ -1507,6 +1502,20 @@ def _filter( new_filters.append( NodeDiffFilter(FilterProperty.STATUS, "SAME", operator.ne) ) + if include_types is not None: + include_types_ = { + t.__name__ if isinstance(t, type) else t for t in include_types + } + new_filters.append( + NodeDiffFilter(FilterProperty.TYPE, include_types_, operator.contains) + ) + if exclude_types: + for exclude_type in exclude_types: + if isinstance(exclude_type, type): + exclude_type = exclude_type.__name__ + new_filters.append( + NodeDiffFilter(FilterProperty.TYPE, exclude_type, operator.ne) + ) return self._apply_filters(new_filters, inplace=inplace) diff --git a/packages/syft/src/syft/service/sync/resolve_widget.py b/packages/syft/src/syft/service/sync/resolve_widget.py index 45fe49e6f45..496fb7a65eb 100644 --- a/packages/syft/src/syft/service/sync/resolve_widget.py +++ b/packages/syft/src/syft/service/sync/resolve_widget.py @@ -331,7 +331,7 @@ def create_accordion_css( def build_accordion( self, - accordion_body: widgets.Widget, + accordion_body: MainObjectDiffWidget, show_sync_checkbox: bool = True, show_share_private_checkbox: bool = True, ) -> VBox: @@ -368,8 +368,12 @@ def build_accordion( layout=Layout(flex="1"), ) + if isinstance(self.diff.non_empty_object, ActionObject): + share_data_description = "Share real data and approve" + else: + share_data_description = "Share real data" share_private_data_checkbox = Checkbox( - description="Sync Real Data", + description=share_data_description, layout=Layout(width="auto", margin="0 2px 0 0"), ) sync_checkbox = Checkbox( @@ -485,20 +489,20 @@ def batch_diff_widgets(self) -> list[CollapsableObjectDiffWidget]: return dependent_diff_widgets @property - def dependent_batch_diff_widgets(self) -> list[CollapsableObjectDiffWidget]: + def dependent_root_diff_widgets(self) -> list[CollapsableObjectDiffWidget]: dependencies = self.obj_diff_batch.get_dependencies( include_roots=True, include_batch_root=False ) other_roots = [ d for d in dependencies if d.object_id in self.obj_diff_batch.global_roots ] - dependent_root_diff_widgets = [ + widgets = [ CollapsableObjectDiffWidget( diff, direction=self.obj_diff_batch.sync_direction ) for diff in other_roots ] - return dependent_root_diff_widgets + return widgets @property def main_object_diff_widget(self) -> MainObjectDiffWidget: @@ -536,7 +540,7 @@ def build(self) -> VBox: self.id2widget = {} batch_diff_widgets = self.batch_diff_widgets - dependent_batch_diff_widgets = self.dependent_batch_diff_widgets + dependent_batch_diff_widgets = self.dependent_root_diff_widgets main_object_diff_widget = self.main_object_diff_widget self.id2widget[main_object_diff_widget.diff.object_id] = main_object_diff_widget @@ -572,7 +576,7 @@ def build(self) -> VBox: def sync_button(self) -> Button: sync_button = Button( - description="Sync Selected Changes", + description="Apply Selected Changes", style={ "text_color": "#464A91", "button_color": "transparent", diff --git a/packages/syft/tests/syft/service/sync/sync_resolve_single_test.py b/packages/syft/tests/syft/service/sync/sync_resolve_single_test.py index adc6346fd10..0bd022ae604 100644 --- a/packages/syft/tests/syft/service/sync/sync_resolve_single_test.py +++ b/packages/syft/tests/syft/service/sync/sync_resolve_single_test.py @@ -8,7 +8,6 @@ from syft.client.sync_decision import SyncDecision from syft.client.syncing import compare_clients from syft.client.syncing import resolve -from syft.service.code.user_code import UserCode from syft.service.job.job_stash import Job from syft.service.request.request import RequestStatus from syft.service.response import SyftError @@ -157,7 +156,7 @@ def compute() -> int: _ = client_low_ds.code.request_code_execution(compute) - diff = compare_clients(low_client, high_client) + diff = compare_clients(low_client, high_client, hide_usercode=False) assert len(diff.batches) == 2 # Request + UserCode assert len(diff.ignored_batches) == 0 @@ -166,7 +165,7 @@ def compute() -> int: res = diff[0].ignore() assert isinstance(res, SyftSuccess) - diff = compare_clients(low_client, high_client) + diff = compare_clients(low_client, high_client, hide_usercode=False) assert len(diff.batches) == 0 assert len(diff.ignored_batches) == 2 assert len(diff.all_batches) == 2 @@ -175,44 +174,12 @@ def compute() -> int: res = diff.ignored_batches[0].unignore() assert isinstance(res, SyftSuccess) - diff = compare_clients(low_client, high_client) + diff = compare_clients(low_client, high_client, hide_usercode=False) assert len(diff.batches) == 1 assert len(diff.ignored_batches) == 1 assert len(diff.all_batches) == 2 -def test_forget_usercode(low_worker, high_worker): - low_client = low_worker.root_client - client_low_ds = low_worker.guest_client - high_client = high_worker.root_client - - @sy.syft_function_single_use() - def compute() -> int: - print("computing...") - return 42 - - _ = client_low_ds.code.request_code_execution(compute) - - diff_before, diff_after = compare_and_resolve( - from_client=low_client, to_client=high_client - ) - - run_and_deposit_result(high_client) - - def skip_if_user_code(diff): - if diff.root_type is UserCode: - return SyncDecision.IGNORE - return SyncDecision.SKIP - - diff_before, diff_after = compare_and_resolve( - from_client=low_client, - to_client=high_client, - decision_callback=skip_if_user_code, - ) - assert not diff_before.is_same - assert len(diff_after.batches) == 1 - - def test_request_code_execution_multiple(low_worker, high_worker): low_client = low_worker.root_client client_low_ds = low_worker.guest_client