diff --git a/notebooks/tutorials/data-scientist/03-working-with-private-datasets.ipynb b/notebooks/tutorials/data-scientist/03-working-with-private-datasets.ipynb index 895f59e2457..e8a339389c3 100644 --- a/notebooks/tutorials/data-scientist/03-working-with-private-datasets.ipynb +++ b/notebooks/tutorials/data-scientist/03-working-with-private-datasets.ipynb @@ -543,7 +543,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.2" + "version": "3.9.16" }, "toc": { "base_numbering": 1, diff --git a/notebooks/tutorials/enclaves/Enclave-single-notebook-DO-DS.ipynb b/notebooks/tutorials/enclaves/Enclave-single-notebook-DO-DS.ipynb index 0dfb05f035a..c5f4725257a 100644 --- a/notebooks/tutorials/enclaves/Enclave-single-notebook-DO-DS.ipynb +++ b/notebooks/tutorials/enclaves/Enclave-single-notebook-DO-DS.ipynb @@ -416,7 +416,7 @@ "metadata": {}, "outputs": [], "source": [ - "@sy.syft_function_single_use(canada_census_data=canada_census_data, italy_census_data=italy_census_data)\n", + "@sy.syft_function_single_use(canada_census_data=canada_census_data, italy_census_data=italy_census_data, share_results_with_owners=True)\n", "def compute_census_matches(canada_census_data, italy_census_data):\n", " import recordlinkage\n", " \n", @@ -543,10 +543,20 @@ "metadata": {}, "outputs": [], "source": [ - "for st in status.base_dict.values():\n", + "for st in status.status_dict.values():\n", " assert st == sy.service.request.request.UserCodeStatus.EXECUTE" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "43538640", + "metadata": {}, + "outputs": [], + "source": [ + "ds_enclave_proxy_client.code[-1].output_policy" + ] + }, { "cell_type": "code", "execution_count": null, @@ -601,10 +611,58 @@ "assert real_result == 813" ] }, + { + "cell_type": "markdown", + "id": "0c186d96", + "metadata": {}, + "source": [ + "# DO" + ] + }, + { + "cell_type": "markdown", + "id": "92a07f21", + "metadata": {}, + "source": [ + "## Can also get the result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0a0cc302", + "metadata": {}, + "outputs": [], + "source": [ + "request = do_ca_client.requests[0]\n", + "request" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bc567390", + "metadata": {}, + "outputs": [], + "source": [ + "result_ptr = request.get_results()\n", + "result_ptr" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c3715aa1", + "metadata": {}, + "outputs": [], + "source": [ + "assert result_ptr.syft_action_data == 813" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "c7e8c775-400a-46ce-ba3c-58ca0563621e", + "id": "1beca4ac", "metadata": {}, "outputs": [], "source": [] @@ -626,7 +684,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.4" + "version": "3.9.16" }, "toc": { "base_numbering": 1, diff --git a/packages/syft/src/syft/client/domain_client.py b/packages/syft/src/syft/client/domain_client.py index cb5ddf38005..ff23dbda1de 100644 --- a/packages/syft/src/syft/client/domain_client.py +++ b/packages/syft/src/syft/client/domain_client.py @@ -124,6 +124,8 @@ def data_subject_registry(self) -> Optional[APIModule]: @property def code(self) -> Optional[APIModule]: + # if self.api.refresh_api_callback is not None: + # self.api.refresh_api_callback() if self.api.has_service("code"): return self.api.services.code return None diff --git a/packages/syft/src/syft/client/enclave_client.py b/packages/syft/src/syft/client/enclave_client.py index 20c5609be66..5f3baae3592 100644 --- a/packages/syft/src/syft/client/enclave_client.py +++ b/packages/syft/src/syft/client/enclave_client.py @@ -122,7 +122,9 @@ def request_code_execution(self, code: SubmitUserCode): apis += [api] for api in apis: - api.services.code.request_code_execution(code=code) + res = api.services.code.request_code_execution(code=code) + if isinstance(res, SyftError): + return res # we are using the real method here, see the .code property getter _ = self.code diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index 999283604d7..1819b4d0ce4 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -33,7 +33,9 @@ from .action_object import ActionType from .action_object import AnyActionObject from .action_object import TwinMode +from .action_permissions import ActionObjectPermission from .action_permissions import ActionObjectREAD +from .action_permissions import ActionPermission from .action_store import ActionStore from .action_types import action_type_for_type from .numpy import NumpyArrayObject @@ -211,8 +213,18 @@ def _user_code_execute( syft_object=result_action_object, has_result_read_permission=True, ) + if set_result.is_err(): return set_result.err() + + if len(code_item.output_policy.output_readers) > 0: + self.store.add_permissions( + [ + ActionObjectPermission(result_id, ActionPermission.READ, x) + for x in code_item.output_policy.output_readers + ] + ) + return Ok(result_action_object) def execute_plan( diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index 8482e34b40d..c4464f134df 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -56,6 +56,8 @@ from ..policy.policy import load_policy_code from ..policy.policy_service import PolicyService from ..response import SyftError +from ..response import SyftNotReady +from ..response import SyftSuccess from .code_parse import GlobalsVisitor from .unparse import unparse @@ -100,15 +102,15 @@ def __hash__(self) -> int: # User Code status context for multiple approvals # To make nested dicts hashable for mongodb # as status is in attr_searchable -@serializable(attrs=["base_dict"]) -class UserCodeStatusContext(SyftHashableObject): - base_dict: Dict = {} +@serializable(attrs=["status_dict"]) +class UserCodeStatusCollection(SyftHashableObject): + status_dict: Dict[NodeIdentity, UserCodeStatus] = {} - def __init__(self, base_dict: Dict): - self.base_dict = base_dict + def __init__(self, status_dict: Dict): + self.status_dict = status_dict def __repr__(self): - return str(self.base_dict) + return str(self.status_dict) def _repr_html_(self): string = f""" @@ -119,7 +121,7 @@ def _repr_html_(self):
"""
- for node_identity, status in self.base_dict.items():
+ for node_identity, status in self.status_dict.items():
node_name_str = f"{node_identity.node_name}"
uid_str = f"{node_identity.node_id}"
status_str = f"{status.value}"
@@ -135,19 +137,34 @@ def _repr_html_(self):
def __repr_syft_nested__(self):
string = ""
- for node_identity, status in self.base_dict.items():
+ for node_identity, status in self.status_dict.items():
string += f"{node_identity.node_name}: {status}
"
return string
+ def get_status_message(self):
+ if self.approved:
+ return SyftSuccess(message=f"{type(self)} approved")
+ string = ""
+ for node_identity, status in self.status_dict.items():
+ string += f"Code status on node '{node_identity.node_name}' is '{status}'. "
+ if self.denied:
+ return SyftError(message=f"{type(self)} Your code cannot be run: {string}")
+ else:
+ return SyftNotReady(
+ message=f"{type(self)} Your code is waiting for approval. {string}"
+ )
+
@property
def approved(self) -> bool:
- # approved for this node only
- statuses = set(self.base_dict.values())
- return len(statuses) == 1 and UserCodeStatus.EXECUTE in statuses
+ return all([x == UserCodeStatus.EXECUTE for x in self.status_dict.values()])
+
+ @property
+ def denied(self) -> bool:
+ return UserCodeStatus.DENIED in self.status_dict.values()
- def for_context(self, context: AuthedServiceContext) -> UserCodeStatus:
+ def for_user_context(self, context: AuthedServiceContext) -> UserCodeStatus:
if context.node.node_type == NodeType.ENCLAVE:
- keys = set(self.base_dict.values())
+ keys = set(self.status_dict.values())
if len(keys) == 1 and UserCodeStatus.EXECUTE in keys:
return UserCodeStatus.EXECUTE
elif UserCodeStatus.SUBMITTED in keys and UserCodeStatus.DENIED not in keys:
@@ -163,8 +180,8 @@ def for_context(self, context: AuthedServiceContext) -> UserCodeStatus:
node_id=context.node.id,
verify_key=context.node.signing_key.verify_key,
)
- if node_identity in self.base_dict:
- return self.base_dict[node_identity]
+ if node_identity in self.status_dict:
+ return self.status_dict[node_identity]
else:
raise Exception(
f"Code Object does not contain {context.node.name} Domain's data"
@@ -180,10 +197,10 @@ def mutate(
node_identity = NodeIdentity(
node_name=node_name, node_id=node_id, verify_key=verify_key
)
- base_dict = self.base_dict
- if node_identity in base_dict:
- base_dict[node_identity] = value
- self.base_dict = base_dict
+ status_dict = self.status_dict
+ if node_identity in status_dict:
+ status_dict[node_identity] = value
+ self.status_dict = status_dict
return self
else:
return SyftError(
@@ -213,13 +230,13 @@ class UserCode(SyftObject):
user_unique_func_name: str
code_hash: str
signature: inspect.Signature
- status: UserCodeStatusContext
+ status: UserCodeStatusCollection
input_kwargs: List[str]
enclave_metadata: Optional[EnclaveMetadata] = None
__attr_searchable__ = ["user_verify_key", "status", "service_func_name"]
__attr_unique__ = ["code_hash", "user_unique_func_name"]
- __repr_attrs__ = ["status.approved", "service_func_name", "shareholders"]
+ __repr_attrs__ = ["status.approved", "service_func_name", "input_owners"]
def __setattr__(self, key: str, value: Any) -> None:
attr = getattr(type(self), key, None)
@@ -229,7 +246,7 @@ def __setattr__(self, key: str, value: Any) -> None:
return super().__setattr__(key, value)
def _coll_repr_(self) -> Dict[str, Any]:
- status = list(self.status.base_dict.values())[0].value
+ status = list(self.status.status_dict.values())[0].value
if status == UserCodeStatus.SUBMITTED.value:
badge_color = "badge-purple"
elif status == UserCodeStatus.EXECUTE.value:
@@ -249,12 +266,26 @@ def _coll_repr_(self) -> Dict[str, Any]:
}
@property
- def shareholders(self) -> List[str]:
- node_names_list = []
- nodes = self.input_policy_init_kwargs.keys()
- for node_identity in nodes:
- node_names_list.append(str(node_identity.node_name))
- return node_names_list
+ def is_enclave_code(self) -> bool:
+ return self.enclave_metadata is not None
+
+ @property
+ def input_owners(self) -> List[str]:
+ return [str(x.node_name) for x in self.input_policy_init_kwargs.keys()]
+
+ @property
+ def input_owner_verify_keys(self) -> List[SyftVerifyKey]:
+ return [x.verify_key for x in self.input_policy_init_kwargs.keys()]
+
+ @property
+ def output_reader_names(self) -> List[SyftVerifyKey]:
+ keys = self.output_policy_init_kwargs.get("output_readers", [])
+ inpkey2name = {x.verify_key: x.node_name for x in self.input_policy_init_kwargs}
+ return [inpkey2name[k] for k in keys if k in inpkey2name]
+
+ @property
+ def output_readers(self) -> List[SyftVerifyKey]:
+ return self.output_policy_init_kwargs.get("output_readers", [])
@property
def input_policy(self) -> Optional[InputPolicy]:
@@ -356,6 +387,13 @@ def output_policy(self, value: Any) -> None:
def byte_code(self) -> Optional[PyCodeObject]:
return compile_byte_code(self.parsed_code)
+ def get_results(self) -> Any:
+ # relative
+ from ...client.api import APIRegistry
+
+ api = APIRegistry.api_for(self.node_uid, self.syft_client_verify_key)
+ return api.services.code.get_results(self)
+
@property
def assets(self) -> List[Asset]:
# relative
@@ -418,11 +456,20 @@ def wrapper(*args: Any, **kwargs: Any) -> Callable:
return wrapper
def _repr_markdown_(self):
+ shared_with_line = ""
+ if len(self.output_readers) > 0:
+ owners_string = " and ".join([f"*{x}*" for x in self.output_reader_names])
+ shared_with_line += (
+ f"Custom Policy: "
+ f"outputs are *shared* with the owners of {owners_string} once computed"
+ )
+
md = f"""class UserCode
id: UID = {self.id}
status.approved: bool = {self.status.approved}
service_func_name: str = {self.service_func_name}
- shareholders: list = {self.shareholders}
+ shareholders: list = {self.input_owners}
+ {shared_with_line}
code:
{self.raw_code}"""
@@ -487,6 +534,10 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
else:
raise NotImplementedError
+ @property
+ def input_owner_verify_keys(self) -> List[str]:
+ return [x.verify_key for x in self.input_policy_init_kwargs.keys()]
+
class ArgumentType(Enum):
REAL = 1
@@ -507,16 +558,20 @@ def debox_asset(arg: Any) -> Any:
return deboxed_arg, ArgumentType.REAL
-def syft_function_single_use(*args: Any, **kwargs: Any):
+def syft_function_single_use(
+ *args: Any, share_results_with_owners=False, **kwargs: Any
+):
return syft_function(
input_policy=ExactMatch(*args, **kwargs),
output_policy=SingleExecutionExactOutput(),
+ share_results_with_owners=share_results_with_owners,
)
def syft_function(
input_policy: Union[InputPolicy, UID],
output_policy: Optional[Union[OutputPolicy, UID]] = None,
+ share_results_with_owners=False,
) -> SubmitUserCode:
if isinstance(input_policy, CustomInputPolicy):
input_policy_type = SubmitUserPolicy.from_obj(input_policy)
@@ -537,7 +592,7 @@ def decorator(f):
f"To add a code request, please create a project using `project = syft.Project(...)`, "
f"then use command `project.create_code_request`."
)
- return SubmitUserCode(
+ res = SubmitUserCode(
code=inspect.getsource(f),
func_name=f.__name__,
signature=inspect.signature(f),
@@ -549,6 +604,12 @@ def decorator(f):
input_kwargs=f.__code__.co_varnames[: f.__code__.co_argcount],
)
+ if share_results_with_owners:
+ res.output_policy_init_kwargs[
+ "output_readers"
+ ] = res.input_owner_verify_keys
+ return res
+
return decorator
@@ -697,8 +758,8 @@ def add_custom_status(context: TransformContext) -> TransformContext:
node_id=context.node.id,
verify_key=context.node.signing_key.verify_key,
)
- context.output["status"] = UserCodeStatusContext(
- base_dict={node_identity: UserCodeStatus.SUBMITTED}
+ context.output["status"] = UserCodeStatusCollection(
+ status_dict={node_identity: UserCodeStatus.SUBMITTED}
)
# if node_identity in input_keys or len(input_keys) == 0:
# context.output["status"] = UserCodeStatusContext(
@@ -707,8 +768,8 @@ def add_custom_status(context: TransformContext) -> TransformContext:
# else:
# raise ValueError(f"Invalid input keys: {input_keys} for {node_identity}")
elif context.node.node_type == NodeType.ENCLAVE:
- base_dict = {key: UserCodeStatus.SUBMITTED for key in input_keys}
- context.output["status"] = UserCodeStatusContext(base_dict=base_dict)
+ status_dict = {key: UserCodeStatus.SUBMITTED for key in input_keys}
+ context.output["status"] = UserCodeStatusCollection(status_dict=status_dict)
else:
raise NotImplementedError(
f"Invalid node type:{context.node.node_type} for code submission"
diff --git a/packages/syft/src/syft/service/code/user_code_service.py b/packages/syft/src/syft/service/code/user_code_service.py
index 3a171470f01..def326b3973 100644
--- a/packages/syft/src/syft/service/code/user_code_service.py
+++ b/packages/syft/src/syft/service/code/user_code_service.py
@@ -10,6 +10,8 @@
from result import Result
# relative
+from ...abstract_node import NodeType
+from ...client.enclave_client import EnclaveClient
from ...serde.serializable import serializable
from ...store.document_store import DocumentStore
from ...store.linked_obj import LinkedObject
@@ -17,13 +19,14 @@
from ...types.uid import UID
from ...util.telemetry import instrument
from ..action.action_object import ActionObject
+from ..action.action_permissions import ActionObjectPermission
+from ..action.action_permissions import ActionPermission
from ..context import AuthedServiceContext
-from ..policy.policy import OutputHistory
+from ..network.routes import route_to_connection
from ..request.request import SubmitRequest
from ..request.request import UserCodeStatusChange
from ..request.request_service import RequestService
from ..response import SyftError
-from ..response import SyftNotReady
from ..response import SyftSuccess
from ..service import AbstractService
from ..service import SERVICE_TO_TYPES
@@ -63,11 +66,23 @@ def _request_code_execution(
code: SubmitUserCode,
reason: Optional[str] = "",
):
- user_code = code.to(UserCode, context=context)
+ user_code: UserCode = code.to(UserCode, context=context)
+ if not all(
+ [x in user_code.input_owner_verify_keys for x in user_code.output_readers]
+ ):
+ raise ValueError("outputs can only be distributed to input owners")
result = self.stash.set(context.credentials, user_code)
if result.is_err():
return SyftError(message=str(result.err()))
+ # Users that have access to the output also have access to the code item
+ self.stash.add_permissions(
+ [
+ ActionObjectPermission(user_code.id, ActionPermission.READ, x)
+ for x in user_code.output_readers
+ ]
+ )
+
linked_obj = LinkedObject.from_obj(user_code, node_uid=context.node.id)
CODE_EXECUTE = UserCodeStatusChange(
@@ -145,77 +160,99 @@ def load_user_code(self, context: AuthedServiceContext) -> None:
user_code_items = result.ok()
load_approved_policy_code(user_code_items=user_code_items)
+ @service_method(path="code.get_results", name="get_results", roles=GUEST_ROLE_LEVEL)
+ def get_results(
+ self, context: AuthedServiceContext, inp: Union[UID, UserCode]
+ ) -> Union[List[UserCode], SyftError]:
+ uid = inp.id if isinstance(inp, UserCode) else inp
+ code_result = self.stash.get_by_uid(context.credentials, uid=uid)
+
+ if code_result.is_err():
+ return SyftError(message=code_result.err())
+ code = code_result.ok()
+
+ if code.is_enclave_code:
+ # if the current node is not the enclave
+ if not context.node.node_type == NodeType.ENCLAVE:
+ connection = route_to_connection(code.enclave_metadata.route)
+ enclave_client = EnclaveClient(
+ connection=connection,
+ credentials=context.node.signing_key,
+ )
+ return enclave_client.code.get_results(code.id)
+
+ # if the current node is the enclave
+ else:
+ if not code.status.approved:
+ return code.status.get_status_message()
+
+ if (output_policy := code.output_policy) is None:
+ return SyftError(message=f"Output policy not approved {code}")
+
+ if len(output_policy.output_history) > 0:
+ return resolve_outputs(
+ context=context, output_ids=output_policy.last_output_ids
+ )
+ else:
+ return SyftError(message="No results available")
+ else:
+ return SyftError(message="Endpoint only supported for enclave code")
+
@service_method(path="code.call", name="call", roles=GUEST_ROLE_LEVEL)
def call(
self, context: AuthedServiceContext, uid: UID, **kwargs: Any
) -> Union[SyftSuccess, SyftError]:
"""Call a User Code Function"""
try:
- filtered_kwargs = filter_kwargs(kwargs)
- result = self.stash.get_by_uid(context.credentials, uid=uid)
- if not result.is_ok():
- return SyftError(message=result.err())
-
# Unroll variables
- code_item = result.ok()
- code_status = code_item.status
+ kwarg2id = map_kwargs_to_id(kwargs)
- # Check if the user has permission to execute the code
- # They can execute if they are root user or if they are the user who submitted the code
- if not (
- context.credentials == context.node.verify_key
- or context.credentials == code_item.user_verify_key
- ):
- return SyftError(
- message=f"Code Execution Permission: {context.credentials} denied"
- )
+ # get code item
+ code_result = self.stash.get_by_uid(context.credentials, uid=uid)
+ if code_result.is_err():
+ return SyftError(message=code_result.err())
+ code: UserCode = code_result.ok()
- # Check if the code is approved
- if code_status.for_context(context) != UserCodeStatus.EXECUTE:
- if code_status.for_context(context) == UserCodeStatus.SUBMITTED:
- string = ""
- for node_identity, status in code_status.base_dict.items():
- string += f"Code status on node '{node_identity.node_name}' is '{status.value}'. "
- return SyftNotReady(
- message=f"{type(code_item)} Your code is waiting for approval. {string}"
- )
- return SyftError(
- message=f"{type(code_item)} Your code cannot be run: {code_status.for_context(context)}"
- )
+ if not code.status.approved:
+ return code.status.get_status_message()
- output_policy = code_item.output_policy
- if output_policy is None:
- raise Exception("Output policy not approved", code_item)
+ # Check if the user has permission to execute the code.
+ if not (has_code_permission := self.has_code_permission(code, context)):
+ return has_code_permission
- # Check if the OutputPolicy is valid
- is_valid = output_policy.valid
+ if (output_policy := code.output_policy) is None:
+ return SyftError("Output policy not approved", code)
- if not is_valid:
+ # Check if the OutputPolicy is valid
+ if not (is_valid := output_policy.valid):
if len(output_policy.output_history) > 0:
- result = get_outputs(
- context=context,
- output_history=output_policy.output_history[-1],
+ result = resolve_outputs(
+ context=context, output_ids=output_policy.last_output_ids
)
return result.as_empty()
return is_valid
# Execute the code item
action_service = context.node.get_service("actionservice")
- result: Result = action_service._user_code_execute(
- context, code_item, filtered_kwargs
- )
- if isinstance(result, str):
- return SyftError(message=result)
+
+ output_result: Result[
+ Union[ActionObject, TwinObject], str
+ ] = action_service._user_code_execute(context, code, kwarg2id)
+
+ if output_result.is_err():
+ return SyftError(message=output_result.err())
+ result = output_result.ok()
# Apply Output Policy to the results and update the OutputPolicyState
- result: Union[ActionObject, TwinObject] = result.ok()
output_policy.apply_output(context=context, outputs=result)
- code_item.output_policy = output_policy
- update_success = self.update_code_state(
- context=context, code_item=code_item
- )
- if not update_success:
+ code.output_policy = output_policy
+ if not (
+ update_success := self.update_code_state(
+ context=context, code_item=code
+ )
+ ):
return update_success
+
if isinstance(result, TwinObject):
return result.mock
else:
@@ -223,16 +260,29 @@ def call(
except Exception as e:
return SyftError(message=f"Failed to run. {e}")
+ def has_code_permission(self, code_item, context):
+ if not (
+ context.credentials == context.node.verify_key
+ or context.credentials == code_item.user_verify_key
+ ):
+ return SyftError(
+ message=f"Code Execution Permission: {context.credentials} denied"
+ )
+ return SyftSuccess(message="you have permission")
+
-def get_outputs(context: AuthedServiceContext, output_history: OutputHistory) -> Any:
+def resolve_outputs(
+ context: AuthedServiceContext,
+ output_ids: Optional[Union[List[UID], Dict[str, UID]]],
+) -> Any:
# relative
from ...service.action.action_object import TwinMode
- if isinstance(output_history.outputs, list):
- if len(output_history.outputs) == 0:
+ if isinstance(output_ids, list):
+ if len(output_ids) == 0:
return None
outputs = []
- for output_id in output_history.outputs:
+ for output_id in output_ids:
action_service = context.node.get_service("actionservice")
result = action_service.get(
context, uid=output_id, twin_mode=TwinMode.PRIVATE
@@ -247,7 +297,7 @@ def get_outputs(context: AuthedServiceContext, output_history: OutputHistory) ->
raise NotImplementedError
-def filter_kwargs(kwargs: Dict[str, Any]) -> Dict[str, Any]:
+def map_kwargs_to_id(kwargs: Dict[str, Any]) -> Dict[str, Any]:
# relative
from ...types.twin_object import TwinObject
from ..action.action_object import ActionObject
diff --git a/packages/syft/src/syft/service/policy/policy.py b/packages/syft/src/syft/service/policy/policy.py
index 39d371dbe53..092d08b9cb3 100644
--- a/packages/syft/src/syft/service/policy/policy.py
+++ b/packages/syft/src/syft/service/policy/policy.py
@@ -319,6 +319,7 @@ class OutputPolicy(Policy):
output_history: List[OutputHistory] = []
output_kwargs: List[str] = []
node_uid: Optional[UID]
+ output_readers: List[SyftVerifyKey] = []
def apply_output(
self,
@@ -340,6 +341,10 @@ def apply_output(
def outputs(self) -> List[str]:
return self.output_kwargs
+ @property
+ def last_output_ids(self) -> List[str]:
+ return self.output_history[-1].outputs
+
@serializable()
class OutputPolicyExecuteCount(OutputPolicy):
diff --git a/packages/syft/src/syft/service/request/request.py b/packages/syft/src/syft/service/request/request.py
index 3bf4cfc10c0..2337351d7da 100644
--- a/packages/syft/src/syft/service/request/request.py
+++ b/packages/syft/src/syft/service/request/request.py
@@ -191,6 +191,17 @@ def _repr_html_(self) -> Any:
self.node_uid,
self.syft_client_verify_key,
)
+ shared_with_line = ""
+ if self.code and len(self.code.output_readers) > 0:
+ # owner_names = ["canada", "US"]
+ owners_string = " and ".join(
+ [f"{x}" for x in self.code.output_reader_names]
+ )
+ shared_with_line += (
+ f"
Custom Policy: " + f"outputs are shared with the owners of {owners_string} once computed" + ) + metadata = api.services.metadata.get_metadata() admin_email = metadata.admin_email node_name = api.node_name.capitalize() if api.node_name is not None else "" @@ -203,6 +214,7 @@ def _repr_html_(self) -> Any:
Id: {self.id}
Request time: {self.request_time}
{updated_at_line} + {shared_with_line}Changes: {str_changes}
Status: {self.status}
Requested on: {node_name} of type \ @@ -240,6 +252,9 @@ def code(self) -> Any: message="This type of request does not have code associated with it." ) + def get_results(self) -> Any: + return self.code.get_results() + @property def current_change_state(self) -> Dict[UID, bool]: change_applied_map = {} @@ -742,8 +757,11 @@ def mutate(self, obj: UserCode, context: ChangeContext, undo: bool) -> Any: return obj return res - def is_enclave_request(self, req_enclave_metadata): - return req_enclave_metadata is not None and self.value == UserCodeStatus.EXECUTE + def is_enclave_request(self, user_code: UserCode): + return ( + user_code.is_enclave_code is not None + and self.value == UserCodeStatus.EXECUTE + ) def _run( self, context: ChangeContext, apply: bool @@ -766,7 +784,7 @@ def _run( from ..enclave.enclave_service import propagate_inputs_to_enclave user_code = res - if self.is_enclave_request(user_code.enclave_metadata): + if self.is_enclave_request(user_code): enclave_res = propagate_inputs_to_enclave( user_code=res, context=context ) diff --git a/packages/syft/src/syft/service/response.py b/packages/syft/src/syft/service/response.py index 25b1269619a..82c3544f752 100644 --- a/packages/syft/src/syft/service/response.py +++ b/packages/syft/src/syft/service/response.py @@ -56,6 +56,8 @@ def _repr_html_class_(self) -> str: @serializable() class SyftNotReady(SyftResponseMessage): + _bool: bool = False + @property def _repr_html_class_(self) -> str: return "alert-info" diff --git a/packages/syft/src/syft/store/document_store.py b/packages/syft/src/syft/store/document_store.py index 087830a4fc2..428193f8085 100644 --- a/packages/syft/src/syft/store/document_store.py +++ b/packages/syft/src/syft/store/document_store.py @@ -673,6 +673,9 @@ def get_by_uid( qks = QueryKeys(qks=[UIDPartitionKey.with_obj(uid)]) return self.query_one(credentials=credentials, qks=qks) + def add_permissions(self, permissions: List[ActionObjectPermission]) -> None: + self.partition.add_permissions(permissions) + def set( self, credentials: SyftVerifyKey,