Skip to content

Commit

Permalink
Merge pull request #9021 from OpenMined/fix-blob-permission-small-data
Browse files Browse the repository at this point in the history
Skip blob permission check if small data
  • Loading branch information
eelcovdw committed Jul 5, 2024
2 parents 4c97cb4 + 91ea0bf commit e2bd35e
Show file tree
Hide file tree
Showing 11 changed files with 347 additions and 81 deletions.
7 changes: 1 addition & 6 deletions packages/syft/src/syft/client/domain_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,8 @@ def upload_dataset(self, dataset: CreateDataset) -> SyftSuccess | SyftError:

if isinstance(res, SyftWarning):
logger.debug(res.message)
skip_save_to_blob_store = True
else:
skip_save_to_blob_store = False
response = self.api.services.action.set(
twin,
ignore_detached_objs=contains_empty,
skip_save_to_blob_store=skip_save_to_blob_store,
twin, ignore_detached_objs=contains_empty
)
if isinstance(response, SyftError):
tqdm.write(f"Failed to upload asset: {asset.name}")
Expand Down
56 changes: 56 additions & 0 deletions packages/syft/src/syft/protocol/protocol_version.json
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,62 @@
"hash": "3117e16cbe4dbc344ab90fbbd36ba90dfb518e66f0fb07644bbe7864dcdeb309",
"action": "add"
}
},
"ActionObject": {
"4": {
"version": 4,
"hash": "a4dd2949af0f516d0f640d28e0fdfa026ba8d55bb29eaa7844c926e467606892",
"action": "add"
}
},
"AnyActionObject": {
"4": {
"version": 4,
"hash": "809bd7ffab211133a9be87e058facecf870a79cb2d4027616f5244323de27091",
"action": "add"
}
},
"BlobFileOBject": {
"3": {
"version": 3,
"hash": "27901fcd545ad0607dbfcbfa0141ee03b0f0f4bee8d23f2d661a4b22011bfd37",
"action": "add"
}
},
"NumpyArrayObject": {
"4": {
"version": 4,
"hash": "19e2ff3da78038d2164f86d1f9b0d1facc6008483be60d2852458e90202bb96b",
"action": "add"
}
},
"NumpyScalarObject": {
"4": {
"version": 4,
"hash": "5101d00dd92ac4391cae77629eb48aa25401cc8c5ebb28a8a969cd5eba35fb67",
"action": "add"
}
},
"NumpyBoolObject": {
"4": {
"version": 4,
"hash": "764cd93792c4dfe27b8952fde853626592fe58e1a341b5350b23f38ce474583f",
"action": "add"
}
},
"PandasDataframeObject": {
"4": {
"version": 4,
"hash": "b70f4bb32ba9f3f5ea89552649bf882d927cf9085fb573cc6d4841b32d653f84",
"action": "add"
}
},
"PandasSeriesObject": {
"4": {
"version": 4,
"hash": "6b0eb1f4dd80b729b713953bacaf9c0ea436a4d4eeb2dc0efbd8bff654d91f86",
"action": "add"
}
}
}
}
Expand Down
96 changes: 83 additions & 13 deletions packages/syft/src/syft/service/action/action_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,15 @@
from ...store.linked_obj import LinkedObject
from ...types.base import SyftBaseModel
from ...types.datetime import DateTime
from ...types.syft_migration import migrate
from ...types.syft_object import SYFT_OBJECT_VERSION_2
from ...types.syft_object import SYFT_OBJECT_VERSION_3
from ...types.syft_object import SYFT_OBJECT_VERSION_4
from ...types.syft_object import SyftBaseObject
from ...types.syft_object import SyftObject
from ...types.syncable_object import SyncableSyftObject
from ...types.transforms import drop
from ...types.transforms import make_set_default
from ...types.uid import LineageID
from ...types.uid import UID
from ...util.util import prompt_warning_message
Expand Down Expand Up @@ -527,13 +531,7 @@ def process_arg(arg: ActionObject | Asset | UID | Any) -> Any:
print(r.message)
if isinstance(r, SyftWarning):
logger.debug(r.message)
skip_save_to_blob_store = True
else:
skip_save_to_blob_store = False
arg = api.services.action.set(
arg,
skip_save_to_blob_store=skip_save_to_blob_store,
)
arg = api.services.action.set(arg)
return arg

arg_list = [process_arg(arg) for arg in args] if args else []
Expand Down Expand Up @@ -675,7 +673,7 @@ def truncate_str(string: str, length: int = 100) -> str:


@serializable(without=["syft_pre_hooks__", "syft_post_hooks__"])
class ActionObject(SyncableSyftObject):
class ActionObjectV3(SyncableSyftObject):
"""Action object for remote execution."""

__canonical_name__ = "ActionObject"
Expand Down Expand Up @@ -710,6 +708,45 @@ class ActionObject(SyncableSyftObject):
syft_created_at: DateTime | None = None
syft_resolved: bool = True
syft_action_data_node_id: UID | None = None


@serializable(without=["syft_pre_hooks__", "syft_post_hooks__"])
class ActionObject(SyncableSyftObject):
"""Action object for remote execution."""

__canonical_name__ = "ActionObject"
__version__ = SYFT_OBJECT_VERSION_4
__private_sync_attr_mocks__: ClassVar[dict[str, Any]] = {
"syft_action_data_cache": None,
"syft_blob_storage_entry_id": None,
}

__attr_searchable__: list[str] = [] # type: ignore[misc]
syft_action_data_cache: Any | None = None
syft_blob_storage_entry_id: UID | None = None
syft_pointer_type: ClassVar[type[ActionObjectPointer]]

# Help with calculating history hash for code verification
syft_parent_hashes: int | list[int] | None = None
syft_parent_op: str | None = None
syft_parent_args: Any | None = None
syft_parent_kwargs: Any | None = None
syft_history_hash: int | None = None
syft_internal_type: ClassVar[type[Any]]
syft_node_uid: UID | None = None
syft_pre_hooks__: dict[str, list] = {}
syft_post_hooks__: dict[str, list] = {}
syft_twin_type: TwinMode = TwinMode.NONE
syft_passthrough_attrs: list[str] = BASE_PASSTHROUGH_ATTRS
syft_action_data_type: type | None = None
syft_action_data_repr_: str | None = None
syft_action_data_str_: str | None = None
syft_has_bool_attr: bool | None = None
syft_resolve_data: bool | None = None
syft_created_at: DateTime | None = None
syft_resolved: bool = True
syft_action_data_node_id: UID | None = None
syft_action_saved_to_blob_store: bool = True
# syft_dont_wrap_attrs = ["shape"]

def syft_get_diffs(self, ext_obj: Any) -> list[AttrDiff]:
Expand Down Expand Up @@ -814,6 +851,7 @@ def _save_to_blob_storage_(self, data: Any) -> SyftError | SyftWarning | None:
if get_metadata is not None and not can_upload_to_blob_storage(
data, get_metadata()
):
self.syft_action_saved_to_blob_store = False
return SyftWarning(
message=f"The action object {self.id} was not saved to "
f"the blob store but to memory cache since it is small."
Expand Down Expand Up @@ -1247,13 +1285,9 @@ def _send(

if isinstance(blob_storage_res, SyftWarning):
logger.debug(blob_storage_res.message)
skip_save_to_blob_store = True
else:
skip_save_to_blob_store = False
res = api.services.action.set(
self,
add_storage_permission=add_storage_permission,
skip_save_to_blob_store=skip_save_to_blob_store,
)
if isinstance(res, ActionObject):
self.syft_created_at = res.syft_created_at
Expand Down Expand Up @@ -2189,7 +2223,7 @@ def __rrshift__(self, other: Any) -> Any:


@serializable()
class AnyActionObject(ActionObject):
class AnyActionObjectV3(ActionObjectV3):
"""
This is a catch-all class for all objects that are not
defined in the `action_types` dictionary.
Expand All @@ -2203,6 +2237,22 @@ class AnyActionObject(ActionObject):
syft_dont_wrap_attrs: list[str] = ["__str__", "__repr__", "syft_action_data_str_"]
syft_action_data_str_: str = ""


@serializable()
class AnyActionObject(ActionObject):
"""
This is a catch-all class for all objects that are not
defined in the `action_types` dictionary.
"""

__canonical_name__ = "AnyActionObject"
__version__ = SYFT_OBJECT_VERSION_4

syft_internal_type: ClassVar[type[Any]] = NoneType # type: ignore
# syft_passthrough_attrs: List[str] = []
syft_dont_wrap_attrs: list[str] = ["__str__", "__repr__", "syft_action_data_str_"]
syft_action_data_str_: str = ""

def __float__(self) -> float:
return float(self.syft_action_data)

Expand Down Expand Up @@ -2238,3 +2288,23 @@ def has_action_data_empty(args: Any, kwargs: Any) -> bool:
if is_action_data_empty(a):
return True
return False


@migrate(ActionObjectV3, ActionObject)
def upgrade_action_object() -> list[Callable]:
return [make_set_default("syft_action_saved_to_blob_store", True)]


@migrate(ActionObject, ActionObjectV3)
def downgrade_action_object() -> list[Callable]:
return [drop("syft_action_saved_to_blob_store")]


@migrate(AnyActionObjectV3, AnyActionObject)
def upgrade_anyaction_object() -> list[Callable]:
return [make_set_default("syft_action_saved_to_blob_store", True)]


@migrate(AnyActionObject, AnyActionObjectV3)
def downgrade_anyaction_object() -> list[Callable]:
return [drop("syft_action_saved_to_blob_store")]
64 changes: 33 additions & 31 deletions packages/syft/src/syft/service/action/action_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,8 @@ def np_array(self, context: AuthedServiceContext, data: Any) -> Any:
return blob_store_result
if isinstance(blob_store_result, SyftWarning):
logger.debug(blob_store_result.message)
skip_save_to_blob_store = True
else:
skip_save_to_blob_store = False

np_pointer = self._set(
context,
np_obj,
skip_save_to_blob_store=skip_save_to_blob_store,
)
np_pointer = self._set(context, np_obj)
return np_pointer

@service_method(
Expand All @@ -97,15 +90,13 @@ def set(
action_object: ActionObject | TwinObject,
add_storage_permission: bool = True,
ignore_detached_objs: bool = False,
skip_save_to_blob_store: bool = False,
) -> ActionObject | SyftError:
res = self._set(
context,
action_object,
has_result_read_permission=True,
add_storage_permission=add_storage_permission,
ignore_detached_objs=ignore_detached_objs,
skip_save_to_blob_store=skip_save_to_blob_store,
)
if res.is_err():
return SyftError(message=res.value)
Expand All @@ -123,14 +114,22 @@ def is_detached_obj(
if (
isinstance(action_object, TwinObject)
and (
action_object.mock_obj.syft_blob_storage_entry_id is None
or action_object.private_obj.syft_blob_storage_entry_id is None
(
action_object.mock_obj.syft_action_saved_to_blob_store
and action_object.mock_obj.syft_blob_storage_entry_id is None
)
or (
action_object.private_obj.syft_action_saved_to_blob_store
and action_object.private_obj.syft_blob_storage_entry_id is None
)
)
and not ignore_detached_obj
):
return True
if isinstance(action_object, ActionObject) and (
action_object.syft_blob_storage_entry_id is None and not ignore_detached_obj
action_object.syft_action_saved_to_blob_store
and action_object.syft_blob_storage_entry_id is None
and not ignore_detached_obj
):
return True
return False
Expand All @@ -142,12 +141,8 @@ def _set(
has_result_read_permission: bool = False,
add_storage_permission: bool = True,
ignore_detached_objs: bool = False,
skip_save_to_blob_store: bool = False,
) -> Result[ActionObject, str]:
if (
self.is_detached_obj(action_object, ignore_detached_objs)
and not skip_save_to_blob_store
):
if self.is_detached_obj(action_object, ignore_detached_objs):
return Err(
"You uploaded an ActionObject that is not yet in the blob storage"
)
Expand All @@ -156,14 +151,26 @@ def _set(

if isinstance(action_object, ActionObject):
action_object.syft_created_at = DateTime.now()
if not skip_save_to_blob_store:
(
action_object._clear_cache()
if action_object.syft_action_saved_to_blob_store
else None
)
else: # TwinObject
action_object.private_obj.syft_created_at = DateTime.now() # type: ignore[unreachable]
action_object.mock_obj.syft_created_at = DateTime.now()
if not skip_save_to_blob_store:

# Clear cache if data is saved to blob storage
(
action_object.private_obj._clear_cache()
if action_object.private_obj.syft_action_saved_to_blob_store
else None
)
(
action_object.mock_obj._clear_cache()
if action_object.mock_obj.syft_action_saved_to_blob_store
else None
)

# If either context or argument is True, has_result_read_permission is True
has_result_read_permission = (
Expand All @@ -186,7 +193,9 @@ def _set(
blob_storage_service: AbstractService = context.node.get_service(
BlobStorageService
)
blob_storage_service.stash.add_permission(permission)
# if mock is saved to blob store, then add READ permission
if action_object.mock_obj.syft_action_saved_to_blob_store:
blob_storage_service.stash.add_permission(permission)
if has_result_read_permission:
action_object = action_object.private
else:
Expand Down Expand Up @@ -527,9 +536,6 @@ def set_result_to_store(
return Err(blob_store_result.message)
if isinstance(blob_store_result, SyftWarning):
logger.debug(blob_store_result.message)
skip_save_to_blob_store = True
else:
skip_save_to_blob_store = False

# IMPORTANT: DO THIS ONLY AFTER ._save_to_blob_storage
if isinstance(result_action_object, TwinObject):
Expand All @@ -545,7 +551,6 @@ def set_result_to_store(
context,
result_action_object,
has_result_read_permission=True,
skip_save_to_blob_store=skip_save_to_blob_store,
)

if set_result.is_err():
Expand All @@ -569,8 +574,9 @@ def blob_permission(
store_permissions = [store_permission(x) for x in output_readers]
self.store.add_permissions(store_permissions)

blob_permissions = [blob_permission(x) for x in output_readers]
blob_storage_service.stash.add_permissions(blob_permissions)
if result_blob_id is not None:
blob_permissions = [blob_permission(x) for x in output_readers]
blob_storage_service.stash.add_permissions(blob_permissions)

return set_result

Expand Down Expand Up @@ -814,13 +820,9 @@ def execute(
}
if isinstance(blob_store_result, SyftWarning):
logger.debug(blob_store_result.message)
skip_save_to_blob_store = True
else:
skip_save_to_blob_store = False
set_result = self._set(
context,
result_action_object,
skip_save_to_blob_store=skip_save_to_blob_store,
)
if set_result.is_err():
return Err(
Expand Down
Loading

0 comments on commit e2bd35e

Please sign in to comment.