Skip to content

Commit

Permalink
Merge pull request #8910 from OpenMined/eelco/sync-dependencies-direc…
Browse files Browse the repository at this point in the history
…tion

Fix sync dependencies direction
  • Loading branch information
eelcovdw authored Jun 18, 2024
2 parents 6ad280b + 25a53f5 commit 724df08
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 78 deletions.
30 changes: 25 additions & 5 deletions packages/syft/src/syft/client/syncing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# stdlib

# stdlib
from collections.abc import Collection

# relative
from ..abstract_node import NodeSideType
from ..node.credentials import SyftVerifyKey
Expand All @@ -13,6 +16,7 @@
from ..service.sync.sync_state import SyncState
from ..types.uid import UID
from ..util.decorators import deprecated
from ..util.util import prompt_warning_message
from .domain_client import DomainClient
from .sync_decision import SyncDecision
from .sync_decision import SyncDirection
Expand All @@ -24,7 +28,9 @@ def compare_states(
include_ignored: bool = False,
include_same: bool = False,
filter_by_email: str | None = None,
filter_by_type: str | type | None = None,
include_types: Collection[str | type] | None = None,
exclude_types: Collection[str | type] | None = None,
hide_usercode: bool = True,
) -> NodeDiff | SyftError:
# NodeDiff
if (
Expand All @@ -45,14 +51,24 @@ def compare_states(
return SyftError(
"Invalid node side types: can only compare a high and low node"
)

if hide_usercode:
prompt_warning_message(
"User code is hidden by default, as they are also part of the Request."
" If you want to include them, set hide_usercode=False."
)
exclude_types = exclude_types or []
exclude_types.append("usercode")

return NodeDiff.from_sync_state(
low_state=low_state,
high_state=high_state,
direction=direction,
include_ignored=include_ignored,
include_same=include_same,
filter_by_email=filter_by_email,
filter_by_type=filter_by_type,
include_types=include_types,
exclude_types=exclude_types,
)


Expand All @@ -62,7 +78,9 @@ def compare_clients(
include_ignored: bool = False,
include_same: bool = False,
filter_by_email: str | None = None,
filter_by_type: type | None = None,
include_types: Collection[str | type] | None = None,
exclude_types: Collection[str | type] | None = None,
hide_usercode: bool = True,
) -> NodeDiff | SyftError:
from_state = from_client.get_sync_state()
if isinstance(from_state, SyftError):
Expand All @@ -78,7 +96,9 @@ def compare_clients(
include_ignored=include_ignored,
include_same=include_same,
filter_by_email=filter_by_email,
filter_by_type=filter_by_type,
include_types=include_types,
exclude_types=exclude_types,
hide_usercode=hide_usercode,
)


Expand Down Expand Up @@ -134,7 +154,7 @@ def handle_sync_batch(
obj_diff_batch.decision = decision

sync_instructions = []
for diff in obj_diff_batch.get_dependents(include_roots=True):
for diff in obj_diff_batch.get_dependencies(include_roots=True):
# figure out the right verify key to share to
# in case of a job with user code, share to user code owner
# without user code, share to job owner
Expand Down
1 change: 0 additions & 1 deletion packages/syft/src/syft/service/code/user_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ class UserCodeStatusCollection(SyncableSyftObject):
__version__ = SYFT_OBJECT_VERSION_1

__repr_attrs__ = ["approved", "status_dict"]

status_dict: dict[NodeIdentity, tuple[UserCodeStatus, str]] = {}
user_code_link: LinkedObject

Expand Down
2 changes: 0 additions & 2 deletions packages/syft/src/syft/service/request/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,11 +895,9 @@ def get_sync_dependencies(
self, context: AuthedServiceContext
) -> list[UID] | SyftError:
dependencies = []

code_id = self.code_id
if isinstance(code_id, SyftError):
return code_id

dependencies.append(code_id)

return dependencies
Expand Down
63 changes: 36 additions & 27 deletions packages/syft/src/syft/service/sync/diff_state.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# stdlib
from collections.abc import Callable
from collections.abc import Collection
from collections.abc import Iterable
from dataclasses import dataclass
import enum
Expand All @@ -14,7 +15,6 @@
# third party
from loguru import logger
import pandas as pd
from pydantic import model_validator
from rich import box
from rich.console import Console
from rich.console import Group
Expand Down Expand Up @@ -676,7 +676,7 @@ def status(self) -> str:
return "NEW"

batch_statuses = [
diff.status for diff in self.get_dependents(include_roots=False)
diff.status for diff in self.get_dependencies(include_roots=False)
]
if all(status == "SAME" for status in batch_statuses):
return "SAME"
Expand Down Expand Up @@ -765,6 +765,7 @@ def from_dependencies(
cls,
root_uid: UID,
obj_dependencies: dict[UID, list[UID]],
obj_dependents: dict[UID, list[UID]],
obj_uid_to_diff: dict[UID, ObjectDiff],
root_ids: list[UID],
low_node_uid: UID,
Expand Down Expand Up @@ -809,15 +810,13 @@ def _build_hierarchy_helper(
levels = [level for _, level in batch_uids]

batch_uids = {uid for uid, _ in batch_uids} # type: ignore
batch_dependencies = {
uid: [d for d in obj_dependencies.get(uid, []) if d in batch_uids]
for uid in batch_uids
}

return cls(
global_diffs=obj_uid_to_diff,
global_roots=root_ids,
hierarchy_levels=levels,
dependencies=batch_dependencies,
dependencies=obj_dependencies,
dependents=obj_dependents,
root_diff=obj_uid_to_diff[root_uid],
low_node_uid=low_node_uid,
high_node_uid=high_node_uid,
Expand Down Expand Up @@ -910,15 +909,6 @@ def visual_hierarchy(self) -> tuple[type, dict]:
else:
raise ValueError(f"Unknown root type: {self.root.obj_type}")

@model_validator(mode="after")
def make_dependents(self) -> Self:
dependents: dict = {}
for parent, children in self.dependencies.items():
for child in children:
dependents[child] = dependents.get(child, []) + [parent]
self.dependents = dependents
return self

@property
def root(self) -> ObjectDiff:
return self.root_diff
Expand Down Expand Up @@ -1195,7 +1185,8 @@ def from_sync_state(
include_ignored: bool = False,
include_same: bool = False,
filter_by_email: str | None = None,
filter_by_type: type | None = None,
include_types: Collection[type | str] | None = None,
exclude_types: Collection[type | str] | None = None,
_include_node_status: bool = False,
) -> "NodeDiff":
obj_uid_to_diff = {}
Expand Down Expand Up @@ -1235,8 +1226,9 @@ def from_sync_state(
)
obj_uid_to_diff[diff.object_id] = diff

# TODO move static methods to NodeDiff __init__
obj_dependencies = NodeDiff.dependencies_from_states(low_state, high_state)
all_batches = NodeDiff.hierarchies(
all_batches = NodeDiff._create_batches(
low_state,
high_state,
obj_dependencies,
Expand Down Expand Up @@ -1265,9 +1257,10 @@ def from_sync_state(

res._filter(
user_email=filter_by_email,
obj_type=filter_by_type,
include_types=include_types,
include_ignored=include_ignored,
include_same=include_same,
exclude_types=exclude_types,
inplace=True,
)

Expand Down Expand Up @@ -1400,7 +1393,7 @@ def _sort_batches(hierarchies: list[ObjectDiffBatch]) -> list[ObjectDiffBatch]:
return sorted_hierarchies

@staticmethod
def hierarchies(
def _create_batches(
low_sync_state: SyncState,
high_sync_state: SyncState,
obj_dependencies: dict[UID, list[UID]],
Expand All @@ -1424,10 +1417,17 @@ def hierarchies(
):
root_ids.append(diff.object_id) # type: ignore

# Dependents are the reverse edges of the dependency graph
obj_dependents = {}
for parent, children in obj_dependencies.items():
for child in children:
obj_dependents[child] = obj_dependencies.get(child, []) + [parent]

for root_uid in root_ids:
batch = ObjectDiffBatch.from_dependencies(
root_uid,
obj_dependencies,
obj_dependents,
obj_uid_to_diff,
root_ids,
low_sync_state.node_uid,
Expand Down Expand Up @@ -1483,22 +1483,17 @@ def _apply_filters(
def _filter(
self,
user_email: str | None = None,
obj_type: str | type | None = None,
include_ignored: bool = False,
include_same: bool = False,
include_types: Collection[str | type] | None = None,
exclude_types: Collection[type | str] | None = None,
inplace: bool = True,
) -> Self:
new_filters = []
if user_email is not None:
new_filters.append(
NodeDiffFilter(FilterProperty.USER, user_email, operator.eq)
)
if obj_type is not None:
if isinstance(obj_type, type):
obj_type = obj_type.__name__
new_filters.append(
NodeDiffFilter(FilterProperty.TYPE, obj_type, operator.eq)
)
if not include_ignored:
new_filters.append(
NodeDiffFilter(FilterProperty.IGNORED, True, operator.ne)
Expand All @@ -1507,6 +1502,20 @@ def _filter(
new_filters.append(
NodeDiffFilter(FilterProperty.STATUS, "SAME", operator.ne)
)
if include_types is not None:
include_types_ = {
t.__name__ if isinstance(t, type) else t for t in include_types
}
new_filters.append(
NodeDiffFilter(FilterProperty.TYPE, include_types_, operator.contains)
)
if exclude_types:
for exclude_type in exclude_types:
if isinstance(exclude_type, type):
exclude_type = exclude_type.__name__
new_filters.append(
NodeDiffFilter(FilterProperty.TYPE, exclude_type, operator.ne)
)

return self._apply_filters(new_filters, inplace=inplace)

Expand Down
18 changes: 11 additions & 7 deletions packages/syft/src/syft/service/sync/resolve_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def create_accordion_css(

def build_accordion(
self,
accordion_body: widgets.Widget,
accordion_body: MainObjectDiffWidget,
show_sync_checkbox: bool = True,
show_share_private_checkbox: bool = True,
) -> VBox:
Expand Down Expand Up @@ -368,8 +368,12 @@ def build_accordion(
layout=Layout(flex="1"),
)

if isinstance(self.diff.non_empty_object, ActionObject):
share_data_description = "Share real data and approve"
else:
share_data_description = "Share real data"
share_private_data_checkbox = Checkbox(
description="Sync Real Data",
description=share_data_description,
layout=Layout(width="auto", margin="0 2px 0 0"),
)
sync_checkbox = Checkbox(
Expand Down Expand Up @@ -485,20 +489,20 @@ def batch_diff_widgets(self) -> list[CollapsableObjectDiffWidget]:
return dependent_diff_widgets

@property
def dependent_batch_diff_widgets(self) -> list[CollapsableObjectDiffWidget]:
def dependent_root_diff_widgets(self) -> list[CollapsableObjectDiffWidget]:
dependencies = self.obj_diff_batch.get_dependencies(
include_roots=True, include_batch_root=False
)
other_roots = [
d for d in dependencies if d.object_id in self.obj_diff_batch.global_roots
]
dependent_root_diff_widgets = [
widgets = [
CollapsableObjectDiffWidget(
diff, direction=self.obj_diff_batch.sync_direction
)
for diff in other_roots
]
return dependent_root_diff_widgets
return widgets

@property
def main_object_diff_widget(self) -> MainObjectDiffWidget:
Expand Down Expand Up @@ -536,7 +540,7 @@ def build(self) -> VBox:
self.id2widget = {}

batch_diff_widgets = self.batch_diff_widgets
dependent_batch_diff_widgets = self.dependent_batch_diff_widgets
dependent_batch_diff_widgets = self.dependent_root_diff_widgets
main_object_diff_widget = self.main_object_diff_widget

self.id2widget[main_object_diff_widget.diff.object_id] = main_object_diff_widget
Expand Down Expand Up @@ -572,7 +576,7 @@ def build(self) -> VBox:

def sync_button(self) -> Button:
sync_button = Button(
description="Sync Selected Changes",
description="Apply Selected Changes",
style={
"text_color": "#464A91",
"button_color": "transparent",
Expand Down
Loading

0 comments on commit 724df08

Please sign in to comment.