diff --git a/packages/syft/src/syft/client/domain_client.py b/packages/syft/src/syft/client/domain_client.py index 9273ae15e5e..4d370783587 100644 --- a/packages/syft/src/syft/client/domain_client.py +++ b/packages/syft/src/syft/client/domain_client.py @@ -148,13 +148,8 @@ def upload_dataset(self, dataset: CreateDataset) -> SyftSuccess | SyftError: if isinstance(res, SyftWarning): logger.debug(res.message) - skip_save_to_blob_store = True - else: - skip_save_to_blob_store = False response = self.api.services.action.set( - twin, - ignore_detached_objs=contains_empty, - skip_save_to_blob_store=skip_save_to_blob_store, + twin, ignore_detached_objs=contains_empty ) if isinstance(response, SyftError): tqdm.write(f"Failed to upload asset: {asset.name}") diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index 44b028c30ab..d2ab6739c4c 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -430,6 +430,62 @@ "hash": "3117e16cbe4dbc344ab90fbbd36ba90dfb518e66f0fb07644bbe7864dcdeb309", "action": "add" } + }, + "ActionObject": { + "4": { + "version": 4, + "hash": "a4dd2949af0f516d0f640d28e0fdfa026ba8d55bb29eaa7844c926e467606892", + "action": "add" + } + }, + "AnyActionObject": { + "4": { + "version": 4, + "hash": "809bd7ffab211133a9be87e058facecf870a79cb2d4027616f5244323de27091", + "action": "add" + } + }, + "BlobFileOBject": { + "3": { + "version": 3, + "hash": "27901fcd545ad0607dbfcbfa0141ee03b0f0f4bee8d23f2d661a4b22011bfd37", + "action": "add" + } + }, + "NumpyArrayObject": { + "4": { + "version": 4, + "hash": "19e2ff3da78038d2164f86d1f9b0d1facc6008483be60d2852458e90202bb96b", + "action": "add" + } + }, + "NumpyScalarObject": { + "4": { + "version": 4, + "hash": "5101d00dd92ac4391cae77629eb48aa25401cc8c5ebb28a8a969cd5eba35fb67", + "action": "add" + } + }, + "NumpyBoolObject": { + "4": { + "version": 4, + "hash": "764cd93792c4dfe27b8952fde853626592fe58e1a341b5350b23f38ce474583f", + "action": "add" + } + }, + "PandasDataframeObject": { + "4": { + "version": 4, + "hash": "b70f4bb32ba9f3f5ea89552649bf882d927cf9085fb573cc6d4841b32d653f84", + "action": "add" + } + }, + "PandasSeriesObject": { + "4": { + "version": 4, + "hash": "6b0eb1f4dd80b729b713953bacaf9c0ea436a4d4eeb2dc0efbd8bff654d91f86", + "action": "add" + } } } } diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index b9ffd16ebf6..e878eb1bad2 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -43,11 +43,15 @@ from ...store.linked_obj import LinkedObject from ...types.base import SyftBaseModel from ...types.datetime import DateTime +from ...types.syft_migration import migrate from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SYFT_OBJECT_VERSION_3 +from ...types.syft_object import SYFT_OBJECT_VERSION_4 from ...types.syft_object import SyftBaseObject from ...types.syft_object import SyftObject from ...types.syncable_object import SyncableSyftObject +from ...types.transforms import drop +from ...types.transforms import make_set_default from ...types.uid import LineageID from ...types.uid import UID from ...util.util import prompt_warning_message @@ -527,13 +531,7 @@ def process_arg(arg: ActionObject | Asset | UID | Any) -> Any: print(r.message) if isinstance(r, SyftWarning): logger.debug(r.message) - skip_save_to_blob_store = True - else: - skip_save_to_blob_store = False - arg = api.services.action.set( - arg, - skip_save_to_blob_store=skip_save_to_blob_store, - ) + arg = api.services.action.set(arg) return arg arg_list = [process_arg(arg) for arg in args] if args else [] @@ -675,7 +673,7 @@ def truncate_str(string: str, length: int = 100) -> str: @serializable(without=["syft_pre_hooks__", "syft_post_hooks__"]) -class ActionObject(SyncableSyftObject): +class ActionObjectV3(SyncableSyftObject): """Action object for remote execution.""" __canonical_name__ = "ActionObject" @@ -710,6 +708,45 @@ class ActionObject(SyncableSyftObject): syft_created_at: DateTime | None = None syft_resolved: bool = True syft_action_data_node_id: UID | None = None + + +@serializable(without=["syft_pre_hooks__", "syft_post_hooks__"]) +class ActionObject(SyncableSyftObject): + """Action object for remote execution.""" + + __canonical_name__ = "ActionObject" + __version__ = SYFT_OBJECT_VERSION_4 + __private_sync_attr_mocks__: ClassVar[dict[str, Any]] = { + "syft_action_data_cache": None, + "syft_blob_storage_entry_id": None, + } + + __attr_searchable__: list[str] = [] # type: ignore[misc] + syft_action_data_cache: Any | None = None + syft_blob_storage_entry_id: UID | None = None + syft_pointer_type: ClassVar[type[ActionObjectPointer]] + + # Help with calculating history hash for code verification + syft_parent_hashes: int | list[int] | None = None + syft_parent_op: str | None = None + syft_parent_args: Any | None = None + syft_parent_kwargs: Any | None = None + syft_history_hash: int | None = None + syft_internal_type: ClassVar[type[Any]] + syft_node_uid: UID | None = None + syft_pre_hooks__: dict[str, list] = {} + syft_post_hooks__: dict[str, list] = {} + syft_twin_type: TwinMode = TwinMode.NONE + syft_passthrough_attrs: list[str] = BASE_PASSTHROUGH_ATTRS + syft_action_data_type: type | None = None + syft_action_data_repr_: str | None = None + syft_action_data_str_: str | None = None + syft_has_bool_attr: bool | None = None + syft_resolve_data: bool | None = None + syft_created_at: DateTime | None = None + syft_resolved: bool = True + syft_action_data_node_id: UID | None = None + syft_action_saved_to_blob_store: bool = True # syft_dont_wrap_attrs = ["shape"] def syft_get_diffs(self, ext_obj: Any) -> list[AttrDiff]: @@ -814,6 +851,7 @@ def _save_to_blob_storage_(self, data: Any) -> SyftError | SyftWarning | None: if get_metadata is not None and not can_upload_to_blob_storage( data, get_metadata() ): + self.syft_action_saved_to_blob_store = False return SyftWarning( message=f"The action object {self.id} was not saved to " f"the blob store but to memory cache since it is small." @@ -1247,13 +1285,9 @@ def _send( if isinstance(blob_storage_res, SyftWarning): logger.debug(blob_storage_res.message) - skip_save_to_blob_store = True - else: - skip_save_to_blob_store = False res = api.services.action.set( self, add_storage_permission=add_storage_permission, - skip_save_to_blob_store=skip_save_to_blob_store, ) if isinstance(res, ActionObject): self.syft_created_at = res.syft_created_at @@ -2189,7 +2223,7 @@ def __rrshift__(self, other: Any) -> Any: @serializable() -class AnyActionObject(ActionObject): +class AnyActionObjectV3(ActionObjectV3): """ This is a catch-all class for all objects that are not defined in the `action_types` dictionary. @@ -2203,6 +2237,22 @@ class AnyActionObject(ActionObject): syft_dont_wrap_attrs: list[str] = ["__str__", "__repr__", "syft_action_data_str_"] syft_action_data_str_: str = "" + +@serializable() +class AnyActionObject(ActionObject): + """ + This is a catch-all class for all objects that are not + defined in the `action_types` dictionary. + """ + + __canonical_name__ = "AnyActionObject" + __version__ = SYFT_OBJECT_VERSION_4 + + syft_internal_type: ClassVar[type[Any]] = NoneType # type: ignore + # syft_passthrough_attrs: List[str] = [] + syft_dont_wrap_attrs: list[str] = ["__str__", "__repr__", "syft_action_data_str_"] + syft_action_data_str_: str = "" + def __float__(self) -> float: return float(self.syft_action_data) @@ -2238,3 +2288,23 @@ def has_action_data_empty(args: Any, kwargs: Any) -> bool: if is_action_data_empty(a): return True return False + + +@migrate(ActionObjectV3, ActionObject) +def upgrade_action_object() -> list[Callable]: + return [make_set_default("syft_action_saved_to_blob_store", True)] + + +@migrate(ActionObject, ActionObjectV3) +def downgrade_action_object() -> list[Callable]: + return [drop("syft_action_saved_to_blob_store")] + + +@migrate(AnyActionObjectV3, AnyActionObject) +def upgrade_anyaction_object() -> list[Callable]: + return [make_set_default("syft_action_saved_to_blob_store", True)] + + +@migrate(AnyActionObject, AnyActionObjectV3) +def downgrade_anyaction_object() -> list[Callable]: + return [drop("syft_action_saved_to_blob_store")] diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index da273e2c12b..91a527569bb 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -75,15 +75,8 @@ def np_array(self, context: AuthedServiceContext, data: Any) -> Any: return blob_store_result if isinstance(blob_store_result, SyftWarning): logger.debug(blob_store_result.message) - skip_save_to_blob_store = True - else: - skip_save_to_blob_store = False - np_pointer = self._set( - context, - np_obj, - skip_save_to_blob_store=skip_save_to_blob_store, - ) + np_pointer = self._set(context, np_obj) return np_pointer @service_method( @@ -97,7 +90,6 @@ def set( action_object: ActionObject | TwinObject, add_storage_permission: bool = True, ignore_detached_objs: bool = False, - skip_save_to_blob_store: bool = False, ) -> ActionObject | SyftError: res = self._set( context, @@ -105,7 +97,6 @@ def set( has_result_read_permission=True, add_storage_permission=add_storage_permission, ignore_detached_objs=ignore_detached_objs, - skip_save_to_blob_store=skip_save_to_blob_store, ) if res.is_err(): return SyftError(message=res.value) @@ -123,14 +114,22 @@ def is_detached_obj( if ( isinstance(action_object, TwinObject) and ( - action_object.mock_obj.syft_blob_storage_entry_id is None - or action_object.private_obj.syft_blob_storage_entry_id is None + ( + action_object.mock_obj.syft_action_saved_to_blob_store + and action_object.mock_obj.syft_blob_storage_entry_id is None + ) + or ( + action_object.private_obj.syft_action_saved_to_blob_store + and action_object.private_obj.syft_blob_storage_entry_id is None + ) ) and not ignore_detached_obj ): return True if isinstance(action_object, ActionObject) and ( - action_object.syft_blob_storage_entry_id is None and not ignore_detached_obj + action_object.syft_action_saved_to_blob_store + and action_object.syft_blob_storage_entry_id is None + and not ignore_detached_obj ): return True return False @@ -142,12 +141,8 @@ def _set( has_result_read_permission: bool = False, add_storage_permission: bool = True, ignore_detached_objs: bool = False, - skip_save_to_blob_store: bool = False, ) -> Result[ActionObject, str]: - if ( - self.is_detached_obj(action_object, ignore_detached_objs) - and not skip_save_to_blob_store - ): + if self.is_detached_obj(action_object, ignore_detached_objs): return Err( "You uploaded an ActionObject that is not yet in the blob storage" ) @@ -156,14 +151,26 @@ def _set( if isinstance(action_object, ActionObject): action_object.syft_created_at = DateTime.now() - if not skip_save_to_blob_store: + ( action_object._clear_cache() + if action_object.syft_action_saved_to_blob_store + else None + ) else: # TwinObject action_object.private_obj.syft_created_at = DateTime.now() # type: ignore[unreachable] action_object.mock_obj.syft_created_at = DateTime.now() - if not skip_save_to_blob_store: + + # Clear cache if data is saved to blob storage + ( action_object.private_obj._clear_cache() + if action_object.private_obj.syft_action_saved_to_blob_store + else None + ) + ( action_object.mock_obj._clear_cache() + if action_object.mock_obj.syft_action_saved_to_blob_store + else None + ) # If either context or argument is True, has_result_read_permission is True has_result_read_permission = ( @@ -186,7 +193,9 @@ def _set( blob_storage_service: AbstractService = context.node.get_service( BlobStorageService ) - blob_storage_service.stash.add_permission(permission) + # if mock is saved to blob store, then add READ permission + if action_object.mock_obj.syft_action_saved_to_blob_store: + blob_storage_service.stash.add_permission(permission) if has_result_read_permission: action_object = action_object.private else: @@ -527,9 +536,6 @@ def set_result_to_store( return Err(blob_store_result.message) if isinstance(blob_store_result, SyftWarning): logger.debug(blob_store_result.message) - skip_save_to_blob_store = True - else: - skip_save_to_blob_store = False # IMPORTANT: DO THIS ONLY AFTER ._save_to_blob_storage if isinstance(result_action_object, TwinObject): @@ -545,7 +551,6 @@ def set_result_to_store( context, result_action_object, has_result_read_permission=True, - skip_save_to_blob_store=skip_save_to_blob_store, ) if set_result.is_err(): @@ -569,8 +574,9 @@ def blob_permission( store_permissions = [store_permission(x) for x in output_readers] self.store.add_permissions(store_permissions) - blob_permissions = [blob_permission(x) for x in output_readers] - blob_storage_service.stash.add_permissions(blob_permissions) + if result_blob_id is not None: + blob_permissions = [blob_permission(x) for x in output_readers] + blob_storage_service.stash.add_permissions(blob_permissions) return set_result @@ -814,13 +820,9 @@ def execute( } if isinstance(blob_store_result, SyftWarning): logger.debug(blob_store_result.message) - skip_save_to_blob_store = True - else: - skip_save_to_blob_store = False set_result = self._set( context, result_action_object, - skip_save_to_blob_store=skip_save_to_blob_store, ) if set_result.is_err(): return Err( diff --git a/packages/syft/src/syft/service/action/numpy.py b/packages/syft/src/syft/service/action/numpy.py index da8c8aecc05..f73f65f2cd7 100644 --- a/packages/syft/src/syft/service/action/numpy.py +++ b/packages/syft/src/syft/service/action/numpy.py @@ -1,4 +1,5 @@ # stdlib +from collections.abc import Callable from typing import Any from typing import ClassVar @@ -8,9 +9,14 @@ # relative from ...serde.serializable import serializable +from ...types.syft_migration import migrate from ...types.syft_object import SYFT_OBJECT_VERSION_3 +from ...types.syft_object import SYFT_OBJECT_VERSION_4 +from ...types.transforms import drop +from ...types.transforms import make_set_default from .action_object import ActionObject from .action_object import ActionObjectPointer +from .action_object import ActionObjectV3 from .action_object import BASE_PASSTHROUGH_ATTRS from .action_types import action_types @@ -41,7 +47,7 @@ def numpy_like_eq(left: Any, right: Any) -> bool: # 🔵 TODO 7: Map TPActionObjects and their 3rd Party types like numpy type to these # classes for bi-directional lookup. @serializable() -class NumpyArrayObject(ActionObject, np.lib.mixins.NDArrayOperatorsMixin): +class NumpyArrayObjectV3(ActionObjectV3, np.lib.mixins.NDArrayOperatorsMixin): __canonical_name__ = "NumpyArrayObject" __version__ = SYFT_OBJECT_VERSION_3 @@ -50,6 +56,17 @@ class NumpyArrayObject(ActionObject, np.lib.mixins.NDArrayOperatorsMixin): syft_passthrough_attrs: list[str] = BASE_PASSTHROUGH_ATTRS syft_dont_wrap_attrs: list[str] = ["dtype", "shape"] + +@serializable() +class NumpyArrayObject(ActionObject, np.lib.mixins.NDArrayOperatorsMixin): + __canonical_name__ = "NumpyArrayObject" + __version__ = SYFT_OBJECT_VERSION_4 + + syft_internal_type: ClassVar[type[Any]] = np.ndarray + syft_pointer_type: ClassVar[type[ActionObjectPointer]] = NumpyArrayObjectPointer + syft_passthrough_attrs: list[str] = BASE_PASSTHROUGH_ATTRS + syft_dont_wrap_attrs: list[str] = ["dtype", "shape"] + # def __eq__(self, other: Any) -> bool: # # 🟡 TODO 8: move __eq__ to a Data / Serdeable type interface on ActionObject # if isinstance(other, NumpyArrayObject): @@ -84,7 +101,7 @@ def __array_ufunc__( @serializable() -class NumpyScalarObject(ActionObject, np.lib.mixins.NDArrayOperatorsMixin): +class NumpyScalarObjectV3(ActionObjectV3, np.lib.mixins.NDArrayOperatorsMixin): __canonical_name__ = "NumpyScalarObject" __version__ = SYFT_OBJECT_VERSION_3 @@ -92,12 +109,22 @@ class NumpyScalarObject(ActionObject, np.lib.mixins.NDArrayOperatorsMixin): syft_passthrough_attrs: list[str] = BASE_PASSTHROUGH_ATTRS syft_dont_wrap_attrs: list[str] = ["dtype", "shape"] + +@serializable() +class NumpyScalarObject(ActionObject, np.lib.mixins.NDArrayOperatorsMixin): + __canonical_name__ = "NumpyScalarObject" + __version__ = SYFT_OBJECT_VERSION_4 + + syft_internal_type: ClassVar[type] = np.number + syft_passthrough_attrs: list[str] = BASE_PASSTHROUGH_ATTRS + syft_dont_wrap_attrs: list[str] = ["dtype", "shape"] + def __float__(self) -> float: return float(self.syft_action_data) @serializable() -class NumpyBoolObject(ActionObject, np.lib.mixins.NDArrayOperatorsMixin): +class NumpyBoolObjectV3(ActionObjectV3, np.lib.mixins.NDArrayOperatorsMixin): __canonical_name__ = "NumpyBoolObject" __version__ = SYFT_OBJECT_VERSION_3 @@ -106,6 +133,16 @@ class NumpyBoolObject(ActionObject, np.lib.mixins.NDArrayOperatorsMixin): syft_dont_wrap_attrs: list[str] = ["dtype", "shape"] +@serializable() +class NumpyBoolObject(ActionObject, np.lib.mixins.NDArrayOperatorsMixin): + __canonical_name__ = "NumpyBoolObject" + __version__ = SYFT_OBJECT_VERSION_4 + + syft_internal_type: ClassVar[type] = np.bool_ + syft_passthrough_attrs: list[str] = BASE_PASSTHROUGH_ATTRS + syft_dont_wrap_attrs: list[str] = ["dtype", "shape"] + + np_array = np.array([1, 2, 3]) action_types[type(np_array)] = NumpyArrayObject @@ -135,3 +172,33 @@ class NumpyBoolObject(ActionObject, np.lib.mixins.NDArrayOperatorsMixin): for scalar_type in SUPPORTED_INT_TYPES + SUPPORTED_FLOAT_TYPES: # type: ignore action_types[scalar_type] = NumpyScalarObject + + +@migrate(NumpyArrayObjectV3, NumpyArrayObject) +def upgrade_numpyarray_object() -> list[Callable]: + return [make_set_default("syft_action_saved_to_blob_store", True)] + + +@migrate(NumpyArrayObject, NumpyArrayObjectV3) +def downgrade_numpyarray_object() -> list[Callable]: + return [drop("syft_action_saved_to_blob_store")] + + +@migrate(NumpyBoolObjectV3, NumpyBoolObject) +def upgrade_numpybool_object() -> list[Callable]: + return [make_set_default("syft_action_saved_to_blob_store", True)] + + +@migrate(NumpyBoolObject, NumpyBoolObjectV3) +def downgrade_numpybool_object() -> list[Callable]: + return [drop("syft_action_saved_to_blob_store")] + + +@migrate(NumpyScalarObjectV3, NumpyScalarObject) +def upgrade_numpyscalar_object() -> list[Callable]: + return [make_set_default("syft_action_saved_to_blob_store", True)] + + +@migrate(NumpyScalarObject, NumpyScalarObjectV3) +def downgrade_numpyscalar_object() -> list[Callable]: + return [drop("syft_action_saved_to_blob_store")] diff --git a/packages/syft/src/syft/service/action/pandas.py b/packages/syft/src/syft/service/action/pandas.py index d16dec119b0..3238b4f53d6 100644 --- a/packages/syft/src/syft/service/action/pandas.py +++ b/packages/syft/src/syft/service/action/pandas.py @@ -1,4 +1,5 @@ # stdlib +from collections.abc import Callable from typing import Any from typing import ClassVar @@ -8,14 +9,19 @@ # relative from ...serde.serializable import serializable +from ...types.syft_migration import migrate from ...types.syft_object import SYFT_OBJECT_VERSION_3 +from ...types.syft_object import SYFT_OBJECT_VERSION_4 +from ...types.transforms import drop +from ...types.transforms import make_set_default from .action_object import ActionObject +from .action_object import ActionObjectV3 from .action_object import BASE_PASSTHROUGH_ATTRS from .action_types import action_types @serializable() -class PandasDataFrameObject(ActionObject): +class PandasDataFrameObjectV3(ActionObjectV3): __canonical_name__ = "PandasDataframeObject" __version__ = SYFT_OBJECT_VERSION_3 @@ -24,6 +30,17 @@ class PandasDataFrameObject(ActionObject): # this is added for instance checks for dataframes # syft_dont_wrap_attrs = ["shape"] + +@serializable() +class PandasDataFrameObject(ActionObject): + __canonical_name__ = "PandasDataframeObject" + __version__ = SYFT_OBJECT_VERSION_4 + + syft_internal_type: ClassVar[type] = DataFrame + syft_passthrough_attrs: list[str] = BASE_PASSTHROUGH_ATTRS + # this is added for instance checks for dataframes + # syft_dont_wrap_attrs = ["shape"] + def __dataframe__(self, *args: Any, **kwargs: Any) -> Any: return self.__dataframe__(*args, **kwargs) @@ -46,13 +63,22 @@ def __bool__(self) -> bool: @serializable() -class PandasSeriesObject(ActionObject): +class PandasSeriesObjectV3(ActionObjectV3): __canonical_name__ = "PandasSeriesObject" __version__ = SYFT_OBJECT_VERSION_3 syft_internal_type = Series syft_passthrough_attrs: list[str] = BASE_PASSTHROUGH_ATTRS + +@serializable() +class PandasSeriesObject(ActionObject): + __canonical_name__ = "PandasSeriesObject" + __version__ = SYFT_OBJECT_VERSION_4 + + syft_internal_type = Series + syft_passthrough_attrs: list[str] = BASE_PASSTHROUGH_ATTRS + # name: Optional[str] = None # syft_dont_wrap_attrs = ["shape"] @@ -73,3 +99,23 @@ def syft_is_property(self, obj: Any, method: str) -> bool: action_types[DataFrame] = PandasDataFrameObject action_types[Series] = PandasSeriesObject + + +@migrate(PandasSeriesObjectV3, PandasSeriesObject) +def upgrade_pandasseries_object() -> list[Callable]: + return [make_set_default("syft_action_saved_to_blob_store", True)] + + +@migrate(PandasSeriesObject, PandasSeriesObjectV3) +def downgrade_pandasseries_object() -> list[Callable]: + return [drop("syft_action_saved_to_blob_store")] + + +@migrate(PandasDataFrameObjectV3, PandasDataFrameObject) +def upgrade_pandasdataframe_object() -> list[Callable]: + return [make_set_default("syft_action_saved_to_blob_store", True)] + + +@migrate(PandasDataFrameObject, PandasDataFrameObjectV3) +def downgrade_pandasdataframe_object() -> list[Callable]: + return [drop("syft_action_saved_to_blob_store")] diff --git a/packages/syft/src/syft/service/dataset/dataset.py b/packages/syft/src/syft/service/dataset/dataset.py index c6d49799c6e..7a24a0175d7 100644 --- a/packages/syft/src/syft/service/dataset/dataset.py +++ b/packages/syft/src/syft/service/dataset/dataset.py @@ -817,9 +817,6 @@ def create_and_store_twin(context: TransformContext) -> TransformContext: raise ValueError(res.message) if isinstance(res, SyftWarning): logger.debug(res.message) - skip_save_to_blob_store = True - else: - skip_save_to_blob_store = False # TODO, upload to blob storage here if context.node is None: raise ValueError( @@ -829,7 +826,6 @@ def create_and_store_twin(context: TransformContext) -> TransformContext: result = action_service._set( context=context.to_node_context(), action_object=twin, - skip_save_to_blob_store=skip_save_to_blob_store, ) if result.is_err(): raise RuntimeError(f"Failed to create and store twin. Error: {result}") diff --git a/packages/syft/src/syft/service/request/request.py b/packages/syft/src/syft/service/request/request.py index 8c5687ac55e..bb01d0e8a4c 100644 --- a/packages/syft/src/syft/service/request/request.py +++ b/packages/syft/src/syft/service/request/request.py @@ -158,24 +158,35 @@ def _run( uid_blob = action_obj.private.syft_blob_storage_entry_id else: uid_blob = action_obj.syft_blob_storage_entry_id - requesting_permission_blob_obj = ActionObjectPermission( - uid=uid_blob, - credentials=context.requesting_user_credentials, - permission=self.apply_permission_type, + requesting_permission_blob_obj = ( + ActionObjectPermission( + uid=uid_blob, + credentials=context.requesting_user_credentials, + permission=self.apply_permission_type, + ) + if uid_blob + else None ) if apply: logger.debug( "ADDING PERMISSION", requesting_permission_action_obj, id_action ) action_store.add_permission(requesting_permission_action_obj) - blob_storage_service.stash.add_permission( - requesting_permission_blob_obj + ( + blob_storage_service.stash.add_permission( + requesting_permission_blob_obj + ) + if requesting_permission_blob_obj + else None ) else: if action_store.has_permission(requesting_permission_action_obj): action_store.remove_permission(requesting_permission_action_obj) - if blob_storage_service.stash.has_permission( + if ( requesting_permission_blob_obj + and blob_storage_service.stash.has_permission( + requesting_permission_blob_obj + ) ): blob_storage_service.stash.remove_permission( requesting_permission_blob_obj diff --git a/packages/syft/src/syft/service/sync/sync_service.py b/packages/syft/src/syft/service/sync/sync_service.py index 62885742c5b..d69889d9381 100644 --- a/packages/syft/src/syft/service/sync/sync_service.py +++ b/packages/syft/src/syft/service/sync/sync_service.py @@ -64,21 +64,24 @@ def add_actionobject_read_permissions( action_object: ActionObject, new_permissions: list[ActionObjectPermission], ) -> None: - blob_id = action_object.syft_blob_storage_entry_id - store_to = context.node.get_service("actionservice").store # type: ignore - store_to_blob = context.node.get_service("blobstorageservice").stash.partition # type: ignore - for permission in new_permissions: if permission.permission == ActionPermission.READ: store_to.add_permission(permission) - permission_blob = ActionObjectPermission( - uid=blob_id, - permission=permission.permission, - credentials=permission.credentials, - ) - store_to_blob.add_permission(permission_blob) + blob_id = action_object.syft_blob_storage_entry_id + if blob_id: + store_to_blob = context.node.get_service( + "blobstorageservice" + ).stash.partition # type: ignore + for permission in new_permissions: + if permission.permission == ActionPermission.READ: + permission_blob = ActionObjectPermission( + uid=blob_id, + permission=permission.permission, + credentials=permission.credentials, + ) + store_to_blob.add_permission(permission_blob) def set_obj_ids(self, context: AuthedServiceContext, x: Any) -> None: if hasattr(x, "__dict__") and isinstance(x, SyftObject): diff --git a/packages/syft/src/syft/types/blob_storage.py b/packages/syft/src/syft/types/blob_storage.py index a92134f26f7..016931b0901 100644 --- a/packages/syft/src/syft/types/blob_storage.py +++ b/packages/syft/src/syft/types/blob_storage.py @@ -27,15 +27,19 @@ from ..serde.serializable import serializable from ..service.action.action_object import ActionObject from ..service.action.action_object import ActionObjectPointer +from ..service.action.action_object import ActionObjectV3 from ..service.action.action_object import BASE_PASSTHROUGH_ATTRS from ..service.action.action_types import action_types from ..service.response import SyftError from ..service.response import SyftException from ..service.service import from_api_or_context from ..types.grid_url import GridURL +from ..types.transforms import drop from ..types.transforms import keep +from ..types.transforms import make_set_default from ..types.transforms import transform from .datetime import DateTime +from .syft_migration import migrate from .syft_object import SYFT_OBJECT_VERSION_2 from .syft_object import SYFT_OBJECT_VERSION_3 from .syft_object import SYFT_OBJECT_VERSION_4 @@ -192,7 +196,7 @@ class BlobFileObjectPointer(ActionObjectPointer): @serializable() -class BlobFileObject(ActionObject): +class BlobFileObjectV3(ActionObjectV3): __canonical_name__ = "BlobFileOBject" __version__ = SYFT_OBJECT_VERSION_2 @@ -201,6 +205,16 @@ class BlobFileObject(ActionObject): syft_passthrough_attrs: list[str] = BASE_PASSTHROUGH_ATTRS +@serializable() +class BlobFileObject(ActionObject): + __canonical_name__ = "BlobFileOBject" + __version__ = SYFT_OBJECT_VERSION_3 + + syft_internal_type: ClassVar[type[Any]] = BlobFile + syft_pointer_type: ClassVar[type[ActionObjectPointer]] = BlobFileObjectPointer + syft_passthrough_attrs: list[str] = BASE_PASSTHROUGH_ATTRS + + @serializable() class SecureFilePathLocation(SyftObject): __canonical_name__ = "SecureFilePathLocation" @@ -370,3 +384,13 @@ def storage_entry_to_metadata() -> list[Callable]: action_types[BlobFile] = BlobFileObject + + +@migrate(BlobFileObjectV3, BlobFileObject) +def upgrade_blobfile_object() -> list[Callable]: + return [make_set_default("syft_action_saved_to_blob_store", True)] + + +@migrate(BlobFileObject, BlobFileObjectV3) +def downgrade_blobfile_object() -> list[Callable]: + return [drop("syft_action_saved_to_blob_store")] diff --git a/packages/syft/src/syft/types/twin_object.py b/packages/syft/src/syft/types/twin_object.py index eae86e9cb5b..f2d2f020ef9 100644 --- a/packages/syft/src/syft/types/twin_object.py +++ b/packages/syft/src/syft/types/twin_object.py @@ -109,12 +109,8 @@ def send(self, client: SyftClient, add_storage_permission: bool = True) -> Any: blob_store_result = self._save_to_blob_storage() if isinstance(blob_store_result, SyftWarning): logger.debug(blob_store_result.message) - skip_save_to_blob_store = True - else: - skip_save_to_blob_store = False res = client.api.services.action.set( self, add_storage_permission=add_storage_permission, - skip_save_to_blob_store=skip_save_to_blob_store, ) return res