Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

moving recursive serde into the decorator #5973

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/syft-pr_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ jobs:
if: startsWith(runner.os, 'macos')
run: |
brew install libomp
pip uninstall python-dp -y

- name: Run normal tests
run: |
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/syft-version_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ jobs:
if: startsWith(runner.os, 'macos')
run: |
brew install libomp
pip uninstall python-dp -y

- name: Run supported library tests
run: |
Expand Down
5 changes: 3 additions & 2 deletions packages/syft/proto/core/common/recursive_serde.proto
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ syntax = "proto3";
package syft.core.common;

message RecursiveSerde {
bytes data = 1;
string fully_qualified_name = 2;
repeated string fields_name = 1;
repeated bytes fields_data = 2;
string fully_qualified_name = 3;
}
6 changes: 3 additions & 3 deletions packages/syft/setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ libs =
pandas
petlib
pillow>=8.3.1 # security-issues
python-dp
# python-dp
statsmodels
tenseal
xgboost>=1.4
Expand Down Expand Up @@ -161,7 +161,7 @@ ci-libs =
; petlib #install-custom-dependency
; pillow>=8.1.2,<=8.2.0 #install-custom-dependency
; pyarrow #install-custom-dependency
; python-dp #install-custom-dependency
# ; python-dp #install-custom-dependency
; statsmodels #install-custom-dependency
; tenseal #install-custom-dependency
; xgboost>=1.4 #install-custom-dependency
Expand All @@ -178,7 +178,7 @@ ci-grid =
flask_sqlalchemy
names
PyInquirer
python-dp
# python-dp
requests_toolbelt
scipy
sqlalchemy>=1.4
Expand Down
5 changes: 3 additions & 2 deletions packages/syft/src/syft/core/adp/idp_gaussian_mechanism.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numpy as np

# relative
from ..common.serde.recursive import RecursiveSerde
from ..common.serde.serializable import serializable


# methods serialize/deserialize np.int64 number
Expand Down Expand Up @@ -61,7 +61,8 @@ def individual_RDP_gaussian(params: Dict, alpha: float) -> np.float64:


# Example of a specific mechanism that inherits the Mechanism class
class iDPGaussianMechanism(Mechanism, RecursiveSerde):
@serializable(recursive_serde=True)
class iDPGaussianMechanism(Mechanism):
__attr_allowlist__ = [
"name",
"params",
Expand Down
8 changes: 5 additions & 3 deletions packages/syft/src/syft/core/adp/vm_private_scalar_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
import sympy as sp

# relative
from ..common.serde.recursive import RecursiveSerde
from ..common.serde.serializable import serializable
from .entity import Entity
from .scalar import GammaScalar


class PrimeFactory(RecursiveSerde):
@serializable(recursive_serde=True)
class PrimeFactory:

"""IMPORTANT: it's very important that two tensors be able to tell that
they are indeed referencing the EXACT same PrimeFactory. At present this is done
Expand All @@ -32,7 +33,8 @@ def next(self) -> int:
return self.prev_prime


class VirtualMachinePrivateScalarManager(RecursiveSerde):
@serializable(recursive_serde=True)
class VirtualMachinePrivateScalarManager:

__attr_allowlist__ = ["prime_factory", "prime2symbol"]

Expand Down
25 changes: 16 additions & 9 deletions packages/syft/src/syft/core/common/serde/deserialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,24 +78,31 @@ def _deserialize(
# There are serveral code paths that come through here and use different ways to
# match and overload protobuf -> deserialize type
obj_type = getattr(type(blob), "schema2type", None)
# relative
from .recursive import rs_get_protobuf_schema
from .recursive import rs_proto2object

if obj_type is None:
# TODO: This can probably be removed now we have lists of obj_types
obj_type = getattr(blob, "obj_type", None)
if isinstance(blob, rs_get_protobuf_schema()):
res = rs_proto2object(proto=blob)
if getattr(res, "temporary_box", False) and hasattr(res, "upcast"):
return res.upcast()
return res

if obj_type is None:
traceback_and_raise(deserialization_error)

obj_type = index_syft_by_module_name(fully_qualified_name=obj_type) # type: ignore
obj_type = getattr(obj_type, "_sy_serializable_wrapper_type", obj_type)
elif isinstance(obj_type, list):
# circular imports
# relative
from .recursive import RecursiveSerde

if RecursiveSerde in obj_type and isinstance(
blob, RecursiveSerde.get_protobuf_schema()
):
# this branch is for RecursiveSerde objects
obj_type = RecursiveSerde

if isinstance(blob, rs_get_protobuf_schema()):
res = rs_proto2object(proto=blob)
if getattr(res, "temporary_box", False) and hasattr(res, "upcast"):
return res.upcast()
return res
elif len(obj_type) == 1:
obj_type = obj_type[0]
else:
Expand Down
84 changes: 31 additions & 53 deletions packages/syft/src/syft/core/common/serde/recursive.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# stdlib
from typing import Any
from typing import Dict as DictType
from typing import List

# third party
from google.protobuf.reflection import GeneratedProtocolMessageType
Expand All @@ -10,67 +8,47 @@
import syft as sy

# relative
from ....lib.python import Dict
from ....proto.core.common.recursive_serde_pb2 import (
RecursiveSerde as RecursiveSerde_PB,
)
from ....util import get_fully_qualified_name
from ....util import index_syft_by_module_name
from .serializable import serializable


@serializable()
class RecursiveSerde:
"""If you subclass from this object and put that subclass in the syft classpath somehow, then
you'll be able to serialize it without having to create a custom protobuf. Be careful with this
though, because it's going to include all attributes by default (including private data if
it's there)."""
def rs_object2proto(self: Any) -> RecursiveSerde_PB:
# if __attr_allowlist__ then only include attrs from that list
msg = RecursiveSerde_PB(fully_qualified_name=get_fully_qualified_name(self))

# put attr names here - set this to None to include all attrs (not recommended)
__attr_allowlist__: List[str] = []
__serde_overrides__: DictType[Any, Any] = {}
if self.__attr_allowlist__ is None:
attribute_dict = self.__dict__.keys()
else:
attribute_dict = self.__attr_allowlist__

def _object2proto(self) -> RecursiveSerde_PB:

# if __attr_allowlist__ then only include attrs from that list
if self.__attr_allowlist__ is not None:
attrs = {}
for attr_name in self.__attr_allowlist__:
if hasattr(self, attr_name):
if self.__serde_overrides__.get(attr_name, None) is None:
attrs[attr_name] = getattr(self, attr_name)
else:
attrs[attr_name] = self.__serde_overrides__[attr_name][0](
getattr(self, attr_name)
)
# else include all attrs
for attr_name in attribute_dict:
if hasattr(self, attr_name):
msg.fields_name.append(attr_name)
transforms = self.__serde_overrides__.get(attr_name, None)
if transforms is None:
field_obj = getattr(self, attr_name)
else:
field_obj = transforms[0](getattr(self, attr_name))
msg.fields_data.append(sy.serialize(field_obj, to_bytes=True))
return msg


def rs_proto2object(proto: RecursiveSerde_PB) -> Any:
class_type = index_syft_by_module_name(proto.fully_qualified_name)
obj = class_type.__new__(class_type) # type: ignore
for attr_name, attr_bytes in zip(proto.fields_name, proto.fields_data):
attr_value = sy.deserialize(attr_bytes, from_bytes=True)
transforms = obj.__serde_overrides__.get(attr_name, None)
if transforms is None:
setattr(obj, attr_name, attr_value)
else:
attrs = self.__dict__ # type: ignore

return RecursiveSerde_PB(
data=sy.serialize(Dict(attrs), to_bytes=True),
fully_qualified_name=get_fully_qualified_name(self),
)

@staticmethod
def _proto2object(proto: RecursiveSerde_PB) -> "RecursiveSerde":
setattr(obj, attr_name, transforms[1](attr_value))

attrs = dict(sy.deserialize(proto.data, from_bytes=True))

class_type = index_syft_by_module_name(proto.fully_qualified_name)

obj = class_type.__new__(class_type) # type: ignore

for attr_name, attr_value in attrs.items():
if obj.__serde_overrides__.get(attr_name, None) is None:
setattr(obj, attr_name, attr_value)
else:
setattr(
obj, attr_name, obj.__serde_overrides__[attr_name][1](attr_value)
)
return obj

return obj

@staticmethod
def get_protobuf_schema() -> GeneratedProtocolMessageType:
return RecursiveSerde_PB
def rs_get_protobuf_schema() -> GeneratedProtocolMessageType:
return RecursiveSerde_PB
35 changes: 32 additions & 3 deletions packages/syft/src/syft/core/common/serde/serializable.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# stdlib
from typing import Any
from typing import Callable as CallableT
import warnings

# third party
from google.protobuf.reflection import GeneratedProtocolMessageType
Expand Down Expand Up @@ -87,7 +88,32 @@ def proto2object(proto: Any) -> Any:
)


def serializable(generate_wrapper: bool = False, protobuf_object: bool = False) -> Any:
def serializable(
generate_wrapper: bool = False,
protobuf_object: bool = False,
recursive_serde: bool = False,
) -> Any:
def rs_decorator(cls: Any) -> Any:
# relative
from .recursive import rs_get_protobuf_schema
from .recursive import rs_object2proto
from .recursive import rs_proto2object

if not hasattr(cls, "__attr_allowlist__"):
warnings.warn(
f"__attr_allowlist__ not defined for type {cls.__name__},"
" even if it uses recursive serde, defaulting on the empty list."
)
setattr(cls, "__attr_allowlist__", [])

if not hasattr(cls, "__serde_overrides__"):
setattr(cls, "__serde_overrides__", {})

setattr(cls, "_object2proto", rs_object2proto)
setattr(cls, "_proto2object", staticmethod(rs_proto2object))
setattr(cls, "get_protobuf_schema", staticmethod(rs_get_protobuf_schema))
return cls

def serializable_decorator(cls: Any) -> Any:
protobuf_schema = cls.get_protobuf_schema()
# overloading a protobuf by adding multiple classes and we will check the
Expand All @@ -103,5 +129,8 @@ def serializable_decorator(cls: Any) -> Any:

if generate_wrapper:
return GenerateWrapper
else:
return serializable_decorator

if recursive_serde:
return rs_decorator

return serializable_decorator
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
from .....common.message import ImmediateSyftMessageWithReply
from .....common.message import ImmediateSyftMessageWithoutReply
from .....common.serde.deserialize import _deserialize
from .....common.serde.recursive import RecursiveSerde
from .....common.serde.serializable import serializable
from .....common.uid import UID
from .....io.address import Address
from ....abstract.node import AbstractNode


class NodeRunnableMessageWithReply(RecursiveSerde):
@serializable(recursive_serde=True)
class NodeRunnableMessageWithReply:

__attr_allowlist__ = ["stuff"]

Expand Down
5 changes: 3 additions & 2 deletions packages/syft/src/syft/core/tensor/autodp/initial_gamma.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

# relative
from ...adp.vm_private_scalar_manager import VirtualMachinePrivateScalarManager
from ...common.serde.recursive import RecursiveSerde
from ...common.serde.serializable import serializable
from ...common.uid import UID
from ..passthrough import PassthroughTensor # type: ignore
from ..smpc.share_tensor import ShareTensor
Expand All @@ -27,7 +27,8 @@ def list2numpy(l_shape: Any) -> np.ndarray:
return np.array(list_length).reshape(shape)


class InitialGammaTensor(IntermediateGammaTensor, RecursiveSerde, ADPTensor):
@serializable(recursive_serde=True)
class InitialGammaTensor(IntermediateGammaTensor, ADPTensor):

__attr_allowlist__ = [
"uid",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
# relative
from ...adp.publish import publish
from ...adp.vm_private_scalar_manager import VirtualMachinePrivateScalarManager
from ...common.serde.recursive import RecursiveSerde
from ...common.serde.serializable import serializable
from ..passthrough import PassthroughTensor # type: ignore
from ..passthrough import is_acceptable_simple_type # type: ignore
from .adp_tensor import ADPTensor


class IntermediateGammaTensor(PassthroughTensor, RecursiveSerde, ADPTensor):
@serializable(recursive_serde=True)
class IntermediateGammaTensor(PassthroughTensor, ADPTensor):

__attr_allowlist__ = [
"term_tensor",
Expand Down
5 changes: 2 additions & 3 deletions packages/syft/src/syft/core/tensor/autodp/row_entity_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from ...adp.vm_private_scalar_manager import (
VirtualMachinePrivateScalarManager as TypeScalarManager,
)
from ...common.serde.recursive import RecursiveSerde
from ...common.serde.serializable import serializable
from ..passthrough import PassthroughTensor # type: ignore
from ..passthrough import implements # type: ignore
Expand All @@ -25,8 +24,8 @@
from .initial_gamma import InitialGammaTensor # type: ignore


@serializable()
class RowEntityPhiTensor(PassthroughTensor, RecursiveSerde, ADPTensor):
@serializable(recursive_serde=True)
class RowEntityPhiTensor(PassthroughTensor, ADPTensor):

__attr_allowlist__ = ["child"]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from ...adp.entity import Entity
from ...adp.vm_private_scalar_manager import VirtualMachinePrivateScalarManager
from ...common.serde.deserialize import _deserialize as deserialize
from ...common.serde.recursive import RecursiveSerde
from ...common.serde.serializable import serializable
from ...common.serde.serialize import _serialize as serialize
from ...common.uid import UID
Expand Down Expand Up @@ -265,10 +264,8 @@ def get_protobuf_schema() -> GeneratedProtocolMessageType:
return TensorWrappedSingleEntityPhiTensorPointer_PB


@serializable()
class SingleEntityPhiTensor(
PassthroughTensor, AutogradTensorAncestor, RecursiveSerde, ADPTensor
):
@serializable(recursive_serde=True)
class SingleEntityPhiTensor(PassthroughTensor, AutogradTensorAncestor, ADPTensor):

PointerClassOverride = TensorWrappedSingleEntityPhiTensorPointer

Expand Down
Loading