diff --git a/notebooks/api/0.8/01-submit-code.ipynb b/notebooks/api/0.8/01-submit-code.ipynb index 472e26392c3..c3ac7e55a35 100644 --- a/notebooks/api/0.8/01-submit-code.ipynb +++ b/notebooks/api/0.8/01-submit-code.ipynb @@ -37,7 +37,7 @@ "source": [ "import syft as sy\n", "sy.requires(SYFT_VERSION)\n", - "from syft.client.api import NodeView\n", + "from syft.client.api import NodeIdentity\n", "from syft.service.request.request import RequestStatus\n", "import pandas as pd" ] @@ -358,10 +358,10 @@ "source": [ "# Tests\n", "assert len(sum_trade_value_mil.kwargs) == 1\n", - "node_view = NodeView.from_api(jane_client.api)\n", - "assert node_view in sum_trade_value_mil.kwargs\n", - "assert \"trade_data\" in sum_trade_value_mil.kwargs[node_view]\n", - "assert sum_trade_value_mil.input_policy_init_kwargs[node_view][\"trade_data\"] == asset.action_id" + "node_identity = NodeIdentity.from_api(jane_client.api)\n", + "assert node_identity in sum_trade_value_mil.kwargs\n", + "assert \"trade_data\" in sum_trade_value_mil.kwargs[node_identity]\n", + "assert sum_trade_value_mil.input_policy_init_kwargs[node_identity][\"trade_data\"] == asset.action_id" ] }, { @@ -550,7 +550,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.7" + "version": "3.10.11" }, "toc": { "base_numbering": 1, diff --git a/packages/syft/src/syft/client/api.py b/packages/syft/src/syft/client/api.py index 4f465648a2d..226bd337de0 100644 --- a/packages/syft/src/syft/client/api.py +++ b/packages/syft/src/syft/client/api.py @@ -17,7 +17,6 @@ # third party from nacl.exceptions import BadSignatureError -from pydantic import BaseModel from pydantic import EmailStr from result import OkErr from result import Result @@ -43,6 +42,7 @@ from ..service.service import UserServiceConfigRegistry from ..service.warnings import APIEndpointWarning from ..service.warnings import WarningContext +from ..types.identity import Identity from ..types.syft_object import SYFT_OBJECT_VERSION_1 from ..types.syft_object import SyftBaseObject from ..types.syft_object import SyftObject @@ -671,19 +671,14 @@ def monkey_patch_getdef(self, obj, oname="") -> Union[str, None]: @serializable() -class NodeView(BaseModel): - class Config: - arbitrary_types_allowed = True - +class NodeIdentity(Identity): node_name: str - node_id: UID - verify_key: SyftVerifyKey @staticmethod def from_api(api: SyftAPI): # stores the name root verify key of the domain node node_metadata = api.connection.get_node_metadata(api.signing_key) - return NodeView( + return NodeIdentity( node_name=node_metadata.name, node_id=api.node_uid, verify_key=SyftVerifyKey.from_string(node_metadata.verify_key), @@ -698,7 +693,7 @@ def from_change_context(cls, context: ChangeContext): ) def __eq__(self, other: Any) -> bool: - if not isinstance(other, NodeView): + if not isinstance(other, NodeIdentity): return False return ( self.node_name == other.node_name @@ -709,6 +704,9 @@ def __eq__(self, other: Any) -> bool: def __hash__(self) -> int: return hash((self.node_name, self.verify_key)) + def __repr__(self) -> str: + return f"NodeIdentity " + def validate_callable_args_and_kwargs(args, kwargs, signature: Signature): _valid_kwargs = {} diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index 53d54319daf..c4464f134df 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -22,7 +22,7 @@ # relative from ...abstract_node import NodeType -from ...client.api import NodeView +from ...client.api import NodeIdentity from ...client.enclave_client import EnclaveMetadata from ...node.credentials import SyftVerifyKey from ...serde.deserialize import _deserialize @@ -104,7 +104,7 @@ def __hash__(self) -> int: # as status is in attr_searchable @serializable(attrs=["status_dict"]) class UserCodeStatusCollection(SyftHashableObject): - status_dict: Dict[NodeView, UserCodeStatus] = {} + status_dict: Dict[NodeIdentity, UserCodeStatus] = {} def __init__(self, status_dict: Dict): self.status_dict = status_dict @@ -121,9 +121,9 @@ def _repr_html_(self):

User Code Status

""" - for node_view, status in self.status_dict.items(): - node_name_str = f"{node_view.node_name}" - uid_str = f"{node_view.node_id}" + for node_identity, status in self.status_dict.items(): + node_name_str = f"{node_identity.node_name}" + uid_str = f"{node_identity.node_id}" status_str = f"{status.value}" string += f""" @@ -137,16 +137,16 @@ def _repr_html_(self): def __repr_syft_nested__(self): string = "" - for node_view, status in self.status_dict.items(): - string += f"{node_view.node_name}: {status}
" + for node_identity, status in self.status_dict.items(): + string += f"{node_identity.node_name}: {status}
" return string def get_status_message(self): if self.approved: return SyftSuccess(message=f"{type(self)} approved") string = "" - for node_view, status in self.status_dict.items(): - string += f"Code status on node '{node_view.node_name}' is '{status}'. " + for node_identity, status in self.status_dict.items(): + string += f"Code status on node '{node_identity.node_name}' is '{status}'. " if self.denied: return SyftError(message=f"{type(self)} Your code cannot be run: {string}") else: @@ -175,13 +175,13 @@ def for_user_context(self, context: AuthedServiceContext) -> UserCodeStatus: return Exception(f"Invalid types in {keys} for Code Submission") elif context.node.node_type == NodeType.DOMAIN: - node_view = NodeView( + node_identity = NodeIdentity( node_name=context.node.name, node_id=context.node.id, verify_key=context.node.signing_key.verify_key, ) - if node_view in self.status_dict: - return self.status_dict[node_view] + if node_identity in self.status_dict: + return self.status_dict[node_identity] else: raise Exception( f"Code Object does not contain {context.node.name} Domain's data" @@ -194,12 +194,12 @@ def for_user_context(self, context: AuthedServiceContext) -> UserCodeStatus: def mutate( self, value: UserCodeStatus, node_name: str, node_id, verify_key: SyftVerifyKey ) -> Union[SyftError, Self]: - node_view = NodeView( + node_identity = NodeIdentity( node_name=node_name, node_id=node_id, verify_key=verify_key ) status_dict = self.status_dict - if node_view in status_dict: - status_dict[node_view] = value + if node_identity in status_dict: + status_dict[node_identity] = value self.status_dict = status_dict return self else: @@ -300,7 +300,7 @@ def input_policy(self) -> Optional[InputPolicy]: # TODO: Tech Debt here node_view_workaround = False for k, _ in self.input_policy_init_kwargs.items(): - if isinstance(k, NodeView): + if isinstance(k, NodeIdentity): node_view_workaround = True if node_view_workaround: @@ -405,8 +405,8 @@ def assets(self) -> List[Asset]: inputs = ( uids - for node_view, uids in self.input_policy_init_kwargs.items() - if node_view.node_name == api.node_name + for node_identity, uids in self.input_policy_init_kwargs.items() + if node_identity.node_name == api.node_name ) all_assets = [] for uid in itertools.chain.from_iterable(x.values() for x in inputs): @@ -668,7 +668,7 @@ def new_check_code(context: TransformContext) -> TransformContext: input_kwargs = context.output["input_policy_init_kwargs"] node_view_workaround = False for k in input_kwargs.keys(): - if isinstance(k, NodeView): + if isinstance(k, NodeIdentity): node_view_workaround = True if not node_view_workaround: @@ -753,20 +753,20 @@ def check_output_policy(context: TransformContext) -> TransformContext: def add_custom_status(context: TransformContext) -> TransformContext: input_keys = list(context.output["input_policy_init_kwargs"].keys()) if context.node.node_type == NodeType.DOMAIN: - node_view = NodeView( + node_identity = NodeIdentity( node_name=context.node.name, node_id=context.node.id, verify_key=context.node.signing_key.verify_key, ) context.output["status"] = UserCodeStatusCollection( - status_dict={node_view: UserCodeStatus.SUBMITTED} + status_dict={node_identity: UserCodeStatus.SUBMITTED} ) - # if node_view in input_keys or len(input_keys) == 0: + # if node_identity in input_keys or len(input_keys) == 0: # context.output["status"] = UserCodeStatusContext( - # base_dict={node_view: UserCodeStatus.SUBMITTED} + # base_dict={node_identity: UserCodeStatus.SUBMITTED} # ) # else: - # raise ValueError(f"Invalid input keys: {input_keys} for {node_view}") + # raise ValueError(f"Invalid input keys: {input_keys} for {node_identity}") elif context.node.node_type == NodeType.ENCLAVE: status_dict = {key: UserCodeStatus.SUBMITTED for key in input_keys} context.output["status"] = UserCodeStatusCollection(status_dict=status_dict) diff --git a/packages/syft/src/syft/service/policy/policy.py b/packages/syft/src/syft/service/policy/policy.py index f0e0d836d2e..092d08b9cb3 100644 --- a/packages/syft/src/syft/service/policy/policy.py +++ b/packages/syft/src/syft/service/policy/policy.py @@ -25,7 +25,7 @@ # relative from ...abstract_node import NodeType -from ...client.api import NodeView +from ...client.api import NodeIdentity from ...node.credentials import SyftVerifyKey from ...serde.recursive_primitives import recursive_serde_register_type from ...serde.serializable import serializable @@ -131,7 +131,7 @@ class UserPolicyStatus(Enum): def partition_by_node(kwargs: Dict[str, Any]) -> Dict[str, UID]: # relative from ...client.api import APIRegistry - from ...client.api import NodeView + from ...client.api import NodeIdentity from ...types.twin_object import TwinObject from ..action.action_object import ActionObject @@ -153,11 +153,11 @@ def partition_by_node(kwargs: Dict[str, Any]) -> Dict[str, UID]: _obj_exists = False for api in api_list: if api.services.action.exists(uid): - node_view = NodeView.from_api(api) - if node_view not in output_kwargs: - output_kwargs[node_view] = {k: uid} + node_identity = NodeIdentity.from_api(api) + if node_identity not in output_kwargs: + output_kwargs[node_identity] = {k: uid} else: - output_kwargs[node_view].update({k: uid}) + output_kwargs[node_identity].update({k: uid}) _obj_exists = True break @@ -187,11 +187,11 @@ def filter_kwargs( raise NotImplementedError @property - def inputs(self) -> Dict[NodeView, Any]: + def inputs(self) -> Dict[NodeIdentity, Any]: return self.init_kwargs def _inputs_for_context(self, context: ChangeContext): - user_node_view = NodeView.from_change_context(context) + user_node_view = NodeIdentity.from_change_context(context) inputs = self.inputs[user_node_view] action_service = context.node.get_service("actionservice") @@ -251,12 +251,12 @@ def allowed_ids_only( context: AuthedServiceContext, ) -> Dict[str, UID]: if context.node.node_type == NodeType.DOMAIN: - node_view = NodeView( + node_identity = NodeIdentity( node_name=context.node.name, node_id=context.node.id, verify_key=context.node.signing_key.verify_key, ) - allowed_inputs = allowed_inputs[node_view] + allowed_inputs = allowed_inputs[node_identity] elif context.node.node_type == NodeType.ENCLAVE: base_dict = {} for key in allowed_inputs.values(): diff --git a/packages/syft/src/syft/service/project/project.py b/packages/syft/src/syft/service/project/project.py index c5e42665305..21a64f67b63 100644 --- a/packages/syft/src/syft/service/project/project.py +++ b/packages/syft/src/syft/service/project/project.py @@ -24,6 +24,7 @@ from typing_extensions import Self # relative +from ...client.api import NodeIdentity from ...client.client import SyftClient from ...client.client import SyftClientSessionCache from ...node.credentials import SyftSigningKey @@ -33,11 +34,13 @@ from ...service.metadata.node_metadata import NodeMetadata from ...store.linked_obj import LinkedObject from ...types.datetime import DateTime +from ...types.identity import Identity +from ...types.identity import UserIdentity from ...types.syft_object import SYFT_OBJECT_VERSION_1 from ...types.syft_object import SyftObject from ...types.syft_object import short_qual_name from ...types.transforms import TransformContext -from ...types.transforms import keep +from ...types.transforms import rename from ...types.transforms import transform from ...types.uid import UID from ...util import options @@ -60,44 +63,9 @@ class EventAlreadyAddedException(SyftException): pass -class Identity(SyftObject): - __canonical_name__ = "Identity" - __version__ = SYFT_OBJECT_VERSION_1 - - id: UID - verify_key: SyftVerifyKey - - __repr_attrs__ = ["id", "verify_key"] - - def __repr__(self) -> str: - verify_key_str = f"{self.verify_key}" - return f"<🔑 {verify_key_str[0:8]} @ 🟢 {self.id.short()}>" - - @classmethod - def from_client(cls, client: SyftClient) -> Identity: - return cls(id=client.id, verify_key=client.credentials.verify_key) - - -@serializable() -class NodeIdentity(Identity): - """This class is used to identify the node owner""" - - __canonical_name__ = "NodeIdentity" - __version__ = SYFT_OBJECT_VERSION_1 - - -# Used to Identity data scientist users of the node -@serializable() -class UserIdentity(Identity): - """This class is used to identify the data scientist users of the node""" - - __canonical_name__ = "UserIdentity" - __version__ = SYFT_OBJECT_VERSION_1 - - @transform(NodeMetadata, NodeIdentity) def metadata_to_node_identity() -> List[Callable]: - return [keep(["id", "verify_key"])] + return [rename("id", "node_id"), rename("name", "node_name")] class ProjectEvent(SyftObject): diff --git a/packages/syft/src/syft/types/identity.py b/packages/syft/src/syft/types/identity.py new file mode 100644 index 00000000000..a3359570ebb --- /dev/null +++ b/packages/syft/src/syft/types/identity.py @@ -0,0 +1,39 @@ +# future +from __future__ import annotations + +# stdlib +from typing import TYPE_CHECKING + +# third party +from typing_extensions import Self + +# relative +from ..node.credentials import SyftVerifyKey +from ..serde.serializable import serializable +from .base import SyftBaseModel +from .uid import UID + +if TYPE_CHECKING: + # relative + from ..client.client import SyftClient + + +class Identity(SyftBaseModel): + node_id: UID + verify_key: SyftVerifyKey + + __repr_attrs__ = ["id", "verify_key"] + + def __repr__(self) -> str: + return f"{self.__class__.__name__} " + + @classmethod + def from_client(cls, client: SyftClient) -> Self: + return cls(node_id=client.id, verify_key=client.credentials.verify_key) + + +@serializable() +class UserIdentity(Identity): + """This class is used to identify the data scientist users of the node""" + + pass