Skip to content

Commit

Permalink
merge dev
Browse files Browse the repository at this point in the history
  • Loading branch information
koenvanderveen committed Jul 20, 2023
2 parents e2a43de + c2b73d7 commit 3e5e606
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 86 deletions.
12 changes: 6 additions & 6 deletions notebooks/api/0.8/01-submit-code.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
"source": [
"import syft as sy\n",
"sy.requires(SYFT_VERSION)\n",
"from syft.client.api import NodeView\n",
"from syft.client.api import NodeIdentity\n",
"from syft.service.request.request import RequestStatus\n",
"import pandas as pd"
]
Expand Down Expand Up @@ -358,10 +358,10 @@
"source": [
"# Tests\n",
"assert len(sum_trade_value_mil.kwargs) == 1\n",
"node_view = NodeView.from_api(jane_client.api)\n",
"assert node_view in sum_trade_value_mil.kwargs\n",
"assert \"trade_data\" in sum_trade_value_mil.kwargs[node_view]\n",
"assert sum_trade_value_mil.input_policy_init_kwargs[node_view][\"trade_data\"] == asset.action_id"
"node_identity = NodeIdentity.from_api(jane_client.api)\n",
"assert node_identity in sum_trade_value_mil.kwargs\n",
"assert \"trade_data\" in sum_trade_value_mil.kwargs[node_identity]\n",
"assert sum_trade_value_mil.input_policy_init_kwargs[node_identity][\"trade_data\"] == asset.action_id"
]
},
{
Expand Down Expand Up @@ -550,7 +550,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
"version": "3.10.11"
},
"toc": {
"base_numbering": 1,
Expand Down
16 changes: 7 additions & 9 deletions packages/syft/src/syft/client/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

# third party
from nacl.exceptions import BadSignatureError
from pydantic import BaseModel
from pydantic import EmailStr
from result import OkErr
from result import Result
Expand All @@ -43,6 +42,7 @@
from ..service.service import UserServiceConfigRegistry
from ..service.warnings import APIEndpointWarning
from ..service.warnings import WarningContext
from ..types.identity import Identity
from ..types.syft_object import SYFT_OBJECT_VERSION_1
from ..types.syft_object import SyftBaseObject
from ..types.syft_object import SyftObject
Expand Down Expand Up @@ -671,19 +671,14 @@ def monkey_patch_getdef(self, obj, oname="") -> Union[str, None]:


@serializable()
class NodeView(BaseModel):
class Config:
arbitrary_types_allowed = True

class NodeIdentity(Identity):
node_name: str
node_id: UID
verify_key: SyftVerifyKey

@staticmethod
def from_api(api: SyftAPI):
# stores the name root verify key of the domain node
node_metadata = api.connection.get_node_metadata(api.signing_key)
return NodeView(
return NodeIdentity(
node_name=node_metadata.name,
node_id=api.node_uid,
verify_key=SyftVerifyKey.from_string(node_metadata.verify_key),
Expand All @@ -698,7 +693,7 @@ def from_change_context(cls, context: ChangeContext):
)

def __eq__(self, other: Any) -> bool:
if not isinstance(other, NodeView):
if not isinstance(other, NodeIdentity):
return False
return (
self.node_name == other.node_name
Expand All @@ -709,6 +704,9 @@ def __eq__(self, other: Any) -> bool:
def __hash__(self) -> int:
return hash((self.node_name, self.verify_key))

def __repr__(self) -> str:
return f"NodeIdentity <name={self.node_name}, id={self.node_id.short()}, 🔑={str(self.verify_key)[0:8]}>"


def validate_callable_args_and_kwargs(args, kwargs, signature: Signature):
_valid_kwargs = {}
Expand Down
48 changes: 24 additions & 24 deletions packages/syft/src/syft/service/code/user_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

# relative
from ...abstract_node import NodeType
from ...client.api import NodeView
from ...client.api import NodeIdentity
from ...client.enclave_client import EnclaveMetadata
from ...node.credentials import SyftVerifyKey
from ...serde.deserialize import _deserialize
Expand Down Expand Up @@ -104,7 +104,7 @@ def __hash__(self) -> int:
# as status is in attr_searchable
@serializable(attrs=["status_dict"])
class UserCodeStatusCollection(SyftHashableObject):
status_dict: Dict[NodeView, UserCodeStatus] = {}
status_dict: Dict[NodeIdentity, UserCodeStatus] = {}

def __init__(self, status_dict: Dict):
self.status_dict = status_dict
Expand All @@ -121,9 +121,9 @@ def _repr_html_(self):
<h3 style="line-height: 25%; margin-top: 25px;">User Code Status</h3>
<p style="margin-left: 3px;">
"""
for node_view, status in self.status_dict.items():
node_name_str = f"{node_view.node_name}"
uid_str = f"{node_view.node_id}"
for node_identity, status in self.status_dict.items():
node_name_str = f"{node_identity.node_name}"
uid_str = f"{node_identity.node_id}"
status_str = f"{status.value}"

string += f"""
Expand All @@ -137,16 +137,16 @@ def _repr_html_(self):

def __repr_syft_nested__(self):
string = ""
for node_view, status in self.status_dict.items():
string += f"{node_view.node_name}: {status}<br>"
for node_identity, status in self.status_dict.items():
string += f"{node_identity.node_name}: {status}<br>"
return string

def get_status_message(self):
if self.approved:
return SyftSuccess(message=f"{type(self)} approved")
string = ""
for node_view, status in self.status_dict.items():
string += f"Code status on node '{node_view.node_name}' is '{status}'. "
for node_identity, status in self.status_dict.items():
string += f"Code status on node '{node_identity.node_name}' is '{status}'. "
if self.denied:
return SyftError(message=f"{type(self)} Your code cannot be run: {string}")
else:
Expand Down Expand Up @@ -175,13 +175,13 @@ def for_user_context(self, context: AuthedServiceContext) -> UserCodeStatus:
return Exception(f"Invalid types in {keys} for Code Submission")

elif context.node.node_type == NodeType.DOMAIN:
node_view = NodeView(
node_identity = NodeIdentity(
node_name=context.node.name,
node_id=context.node.id,
verify_key=context.node.signing_key.verify_key,
)
if node_view in self.status_dict:
return self.status_dict[node_view]
if node_identity in self.status_dict:
return self.status_dict[node_identity]
else:
raise Exception(
f"Code Object does not contain {context.node.name} Domain's data"
Expand All @@ -194,12 +194,12 @@ def for_user_context(self, context: AuthedServiceContext) -> UserCodeStatus:
def mutate(
self, value: UserCodeStatus, node_name: str, node_id, verify_key: SyftVerifyKey
) -> Union[SyftError, Self]:
node_view = NodeView(
node_identity = NodeIdentity(
node_name=node_name, node_id=node_id, verify_key=verify_key
)
status_dict = self.status_dict
if node_view in status_dict:
status_dict[node_view] = value
if node_identity in status_dict:
status_dict[node_identity] = value
self.status_dict = status_dict
return self
else:
Expand Down Expand Up @@ -300,7 +300,7 @@ def input_policy(self) -> Optional[InputPolicy]:
# TODO: Tech Debt here
node_view_workaround = False
for k, _ in self.input_policy_init_kwargs.items():
if isinstance(k, NodeView):
if isinstance(k, NodeIdentity):
node_view_workaround = True

if node_view_workaround:
Expand Down Expand Up @@ -405,8 +405,8 @@ def assets(self) -> List[Asset]:

inputs = (
uids
for node_view, uids in self.input_policy_init_kwargs.items()
if node_view.node_name == api.node_name
for node_identity, uids in self.input_policy_init_kwargs.items()
if node_identity.node_name == api.node_name
)
all_assets = []
for uid in itertools.chain.from_iterable(x.values() for x in inputs):
Expand Down Expand Up @@ -668,7 +668,7 @@ def new_check_code(context: TransformContext) -> TransformContext:
input_kwargs = context.output["input_policy_init_kwargs"]
node_view_workaround = False
for k in input_kwargs.keys():
if isinstance(k, NodeView):
if isinstance(k, NodeIdentity):
node_view_workaround = True

if not node_view_workaround:
Expand Down Expand Up @@ -753,20 +753,20 @@ def check_output_policy(context: TransformContext) -> TransformContext:
def add_custom_status(context: TransformContext) -> TransformContext:
input_keys = list(context.output["input_policy_init_kwargs"].keys())
if context.node.node_type == NodeType.DOMAIN:
node_view = NodeView(
node_identity = NodeIdentity(
node_name=context.node.name,
node_id=context.node.id,
verify_key=context.node.signing_key.verify_key,
)
context.output["status"] = UserCodeStatusCollection(
status_dict={node_view: UserCodeStatus.SUBMITTED}
status_dict={node_identity: UserCodeStatus.SUBMITTED}
)
# if node_view in input_keys or len(input_keys) == 0:
# if node_identity in input_keys or len(input_keys) == 0:
# context.output["status"] = UserCodeStatusContext(
# base_dict={node_view: UserCodeStatus.SUBMITTED}
# base_dict={node_identity: UserCodeStatus.SUBMITTED}
# )
# else:
# raise ValueError(f"Invalid input keys: {input_keys} for {node_view}")
# raise ValueError(f"Invalid input keys: {input_keys} for {node_identity}")
elif context.node.node_type == NodeType.ENCLAVE:
status_dict = {key: UserCodeStatus.SUBMITTED for key in input_keys}
context.output["status"] = UserCodeStatusCollection(status_dict=status_dict)
Expand Down
20 changes: 10 additions & 10 deletions packages/syft/src/syft/service/policy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

# relative
from ...abstract_node import NodeType
from ...client.api import NodeView
from ...client.api import NodeIdentity
from ...node.credentials import SyftVerifyKey
from ...serde.recursive_primitives import recursive_serde_register_type
from ...serde.serializable import serializable
Expand Down Expand Up @@ -131,7 +131,7 @@ class UserPolicyStatus(Enum):
def partition_by_node(kwargs: Dict[str, Any]) -> Dict[str, UID]:
# relative
from ...client.api import APIRegistry
from ...client.api import NodeView
from ...client.api import NodeIdentity
from ...types.twin_object import TwinObject
from ..action.action_object import ActionObject

Expand All @@ -153,11 +153,11 @@ def partition_by_node(kwargs: Dict[str, Any]) -> Dict[str, UID]:
_obj_exists = False
for api in api_list:
if api.services.action.exists(uid):
node_view = NodeView.from_api(api)
if node_view not in output_kwargs:
output_kwargs[node_view] = {k: uid}
node_identity = NodeIdentity.from_api(api)
if node_identity not in output_kwargs:
output_kwargs[node_identity] = {k: uid}
else:
output_kwargs[node_view].update({k: uid})
output_kwargs[node_identity].update({k: uid})

_obj_exists = True
break
Expand Down Expand Up @@ -187,11 +187,11 @@ def filter_kwargs(
raise NotImplementedError

@property
def inputs(self) -> Dict[NodeView, Any]:
def inputs(self) -> Dict[NodeIdentity, Any]:
return self.init_kwargs

def _inputs_for_context(self, context: ChangeContext):
user_node_view = NodeView.from_change_context(context)
user_node_view = NodeIdentity.from_change_context(context)
inputs = self.inputs[user_node_view]

action_service = context.node.get_service("actionservice")
Expand Down Expand Up @@ -251,12 +251,12 @@ def allowed_ids_only(
context: AuthedServiceContext,
) -> Dict[str, UID]:
if context.node.node_type == NodeType.DOMAIN:
node_view = NodeView(
node_identity = NodeIdentity(
node_name=context.node.name,
node_id=context.node.id,
verify_key=context.node.signing_key.verify_key,
)
allowed_inputs = allowed_inputs[node_view]
allowed_inputs = allowed_inputs[node_identity]
elif context.node.node_type == NodeType.ENCLAVE:
base_dict = {}
for key in allowed_inputs.values():
Expand Down
42 changes: 5 additions & 37 deletions packages/syft/src/syft/service/project/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing_extensions import Self

# relative
from ...client.api import NodeIdentity
from ...client.client import SyftClient
from ...client.client import SyftClientSessionCache
from ...node.credentials import SyftSigningKey
Expand All @@ -33,11 +34,13 @@
from ...service.metadata.node_metadata import NodeMetadata
from ...store.linked_obj import LinkedObject
from ...types.datetime import DateTime
from ...types.identity import Identity
from ...types.identity import UserIdentity
from ...types.syft_object import SYFT_OBJECT_VERSION_1
from ...types.syft_object import SyftObject
from ...types.syft_object import short_qual_name
from ...types.transforms import TransformContext
from ...types.transforms import keep
from ...types.transforms import rename
from ...types.transforms import transform
from ...types.uid import UID
from ...util import options
Expand All @@ -60,44 +63,9 @@ class EventAlreadyAddedException(SyftException):
pass


class Identity(SyftObject):
__canonical_name__ = "Identity"
__version__ = SYFT_OBJECT_VERSION_1

id: UID
verify_key: SyftVerifyKey

__repr_attrs__ = ["id", "verify_key"]

def __repr__(self) -> str:
verify_key_str = f"{self.verify_key}"
return f"<🔑 {verify_key_str[0:8]} @ 🟢 {self.id.short()}>"

@classmethod
def from_client(cls, client: SyftClient) -> Identity:
return cls(id=client.id, verify_key=client.credentials.verify_key)


@serializable()
class NodeIdentity(Identity):
"""This class is used to identify the node owner"""

__canonical_name__ = "NodeIdentity"
__version__ = SYFT_OBJECT_VERSION_1


# Used to Identity data scientist users of the node
@serializable()
class UserIdentity(Identity):
"""This class is used to identify the data scientist users of the node"""

__canonical_name__ = "UserIdentity"
__version__ = SYFT_OBJECT_VERSION_1


@transform(NodeMetadata, NodeIdentity)
def metadata_to_node_identity() -> List[Callable]:
return [keep(["id", "verify_key"])]
return [rename("id", "node_id"), rename("name", "node_name")]


class ProjectEvent(SyftObject):
Expand Down
Loading

0 comments on commit 3e5e606

Please sign in to comment.