Skip to content

Commit

Permalink
Merge pull request #9039 from OpenMined/eelco/register-external-types
Browse files Browse the repository at this point in the history
migrations: add cannonical_name and version to all types in serde register
  • Loading branch information
koenvanderveen committed Jul 16, 2024
2 parents 7e831e8 + b5633fa commit d0e0ea4
Show file tree
Hide file tree
Showing 118 changed files with 730 additions and 291 deletions.
4 changes: 2 additions & 2 deletions packages/syft/src/syft/abstract_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .service.service import AbstractService


@serializable()
@serializable(canonical_name="ServerType", version=1)
class ServerType(str, Enum):
DATASITE = "datasite"
NETWORK = "network"
Expand All @@ -24,7 +24,7 @@ def __str__(self) -> str:
return self.value


@serializable()
@serializable(canonical_name="ServerSideType", version=1)
class ServerSideType(str, Enum):
LOW_SIDE = "low"
HIGH_SIDE = "high"
Expand Down
10 changes: 5 additions & 5 deletions packages/syft/src/syft/client/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def prepare_args_and_kwargs(
to_protocol=self.communication_protocol, args=args, kwargs=kwargs
)

return args, kwargs
return tuple(args), kwargs

def function_call(
self, path: str, *args: Any, cache_result: bool = True, **kwargs: Any
Expand Down Expand Up @@ -448,7 +448,7 @@ class RemoteUserCodeFunction(RemoteFunction):

def prepare_args_and_kwargs(
self, args: list | tuple, kwargs: dict[str, Any]
) -> SyftError | tuple[tuple, dict[str, Any]]:
) -> tuple[tuple, dict[str, Any]] | SyftError:
# relative
from ..service.action.action_object import convert_to_pointers

Expand Down Expand Up @@ -479,7 +479,7 @@ def prepare_args_and_kwargs(
to_protocol=self.communication_protocol, args=args, kwargs=kwargs
)

return args, kwargs
return tuple(args), kwargs

@property
def user_code_id(self) -> UID | None:
Expand Down Expand Up @@ -643,7 +643,7 @@ def _coll_repr_(self) -> dict[str, Any]:
return {"submodule": self.submodule, "endpoints": "\n".join(self.endpoints)}


@serializable()
@serializable(canonical_name="APIModule", version=1)
class APIModule:
_modules: list[str]
path: str
Expand Down Expand Up @@ -1275,7 +1275,7 @@ def monkey_patch_getdef(self: Any, obj: Any, oname: str = "") -> str | None:
pass # nosec


@serializable()
@serializable(canonical_name="ServerIdentity", version=1)
class ServerIdentity(Identity):
server_name: str

Expand Down
2 changes: 1 addition & 1 deletion packages/syft/src/syft/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ def get_client_type(self) -> type[SyftClient] | SyftError:


@instrument
@serializable()
@serializable(canonical_name="SyftClient", version=1)
class SyftClient:
connection: ServerConnection
metadata: ServerMetadataJSON | None
Expand Down
2 changes: 1 addition & 1 deletion packages/syft/src/syft/client/datasite_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def add_default_uploader(
return obj


@serializable()
@serializable(canonical_name="DatasiteClient", version=1)
class DatasiteClient(SyftClient):
def __repr__(self) -> str:
return f"<DatasiteClient: {self.name}>"
Expand Down
2 changes: 1 addition & 1 deletion packages/syft/src/syft/client/enclave_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class EnclaveMetadata(SyftObject):
route: ServerRouteType


@serializable()
@serializable(canonical_name="EnclaveClient", version=1)
class EnclaveClient(SyftClient):
# TODO: add widget repr for enclave client

Expand Down
2 changes: 1 addition & 1 deletion packages/syft/src/syft/client/gateway_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .connection import ServerConnection


@serializable()
@serializable(canonical_name="GatewayClient", version=1)
class GatewayClient(SyftClient):
# TODO: add widget repr for gateway client

Expand Down
2 changes: 1 addition & 1 deletion packages/syft/src/syft/client/sync_decision.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ..serde.serializable import serializable


@serializable()
@serializable(canonical_name="SyncDirection", version=1)
class SyncDirection(str, Enum):
LOW_TO_HIGH = "low_to_high"
HIGH_TO_LOW = "high_to_low"
Expand Down
6 changes: 3 additions & 3 deletions packages/syft/src/syft/custom_worker/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class WorkerConfig(SyftBaseModel):
pass


@serializable()
@serializable(canonical_name="CustomWorkerConfig", version=1)
class CustomWorkerConfig(WorkerConfig):
build: CustomBuildConfig
version: str = "1"
Expand All @@ -107,7 +107,7 @@ def get_signature(self) -> str:
return sha256(self.json(sort_keys=True).encode()).hexdigest()


@serializable()
@serializable(canonical_name="PrebuiltWorkerConfig", version=1)
class PrebuiltWorkerConfig(WorkerConfig):
# tag that is already built and pushed in some registry
tag: str
Expand All @@ -126,7 +126,7 @@ def __hash__(self) -> int:
return hash(self.tag)


@serializable()
@serializable(canonical_name="DockerWorkerConfig", version=1)
class DockerWorkerConfig(WorkerConfig):
dockerfile: str
file_name: str | None = None
Expand Down
37 changes: 36 additions & 1 deletion packages/syft/src/syft/serde/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from numpy import frombuffer

# relative
from ..types.syft_object import SYFT_OBJECT_VERSION_1
from .arrow import numpy_deserialize
from .arrow import numpy_serialize
from .recursive import recursive_serde_register
Expand Down Expand Up @@ -34,11 +35,17 @@
}

recursive_serde_register(
np.ndarray, serialize=numpy_serialize, deserialize=numpy_deserialize
np.ndarray,
serialize=numpy_serialize,
deserialize=numpy_deserialize,
canonical_name="numpy_ndarray",
version=SYFT_OBJECT_VERSION_1,
)

recursive_serde_register(
np._globals._NoValueType,
canonical_name="numpy_no_value",
version=SYFT_OBJECT_VERSION_1,
)
# serialize=numpy_serialize, deserialize=numpy_deserialize

Expand All @@ -47,84 +54,112 @@
np.bool_,
serialize=lambda x: x.tobytes(),
deserialize=lambda buffer: frombuffer(buffer, dtype=np.bool_)[0],
canonical_name="numpy_bool",
version=SYFT_OBJECT_VERSION_1,
)

recursive_serde_register(
np.int8,
serialize=lambda x: x.tobytes(),
deserialize=lambda buffer: frombuffer(buffer, dtype=np.int8)[0],
canonical_name="numpy_int8",
version=SYFT_OBJECT_VERSION_1,
)

recursive_serde_register(
np.int16,
serialize=lambda x: x.tobytes(),
deserialize=lambda buffer: frombuffer(buffer, dtype=np.int16)[0],
canonical_name="numpy_int16",
version=SYFT_OBJECT_VERSION_1,
)

recursive_serde_register(
np.int32,
serialize=lambda x: x.tobytes(),
deserialize=lambda buffer: frombuffer(buffer, dtype=np.int32)[0],
canonical_name="numpy_int32",
version=SYFT_OBJECT_VERSION_1,
)

recursive_serde_register(
np.int64,
serialize=lambda x: x.tobytes(),
deserialize=lambda buffer: frombuffer(buffer, dtype=np.int64)[0],
canonical_name="numpy_int64",
version=SYFT_OBJECT_VERSION_1,
)

recursive_serde_register(
np.uint8,
serialize=lambda x: x.tobytes(),
deserialize=lambda buffer: frombuffer(buffer, dtype=np.uint8)[0],
canonical_name="numpy_uint8",
version=SYFT_OBJECT_VERSION_1,
)

recursive_serde_register(
np.uint16,
serialize=lambda x: x.tobytes(),
deserialize=lambda buffer: frombuffer(buffer, dtype=np.uint16)[0],
canonical_name="numpy_uint16",
version=SYFT_OBJECT_VERSION_1,
)

recursive_serde_register(
np.uint32,
serialize=lambda x: x.tobytes(),
deserialize=lambda buffer: frombuffer(buffer, dtype=np.uint32)[0],
canonical_name="numpy_uint32",
version=SYFT_OBJECT_VERSION_1,
)

recursive_serde_register(
np.uint64,
serialize=lambda x: x.tobytes(),
deserialize=lambda buffer: frombuffer(buffer, dtype=np.uint64)[0],
canonical_name="numpy_uint64",
version=SYFT_OBJECT_VERSION_1,
)

recursive_serde_register(
np.single,
serialize=lambda x: x.tobytes(),
deserialize=lambda buffer: frombuffer(buffer, dtype=np.single)[0],
canonical_name="numpy_single",
version=SYFT_OBJECT_VERSION_1,
)

recursive_serde_register(
np.double,
serialize=lambda x: x.tobytes(),
deserialize=lambda buffer: frombuffer(buffer, dtype=np.double)[0],
canonical_name="numpy_double",
version=SYFT_OBJECT_VERSION_1,
)

recursive_serde_register(
np.float16,
serialize=lambda x: x.tobytes(),
deserialize=lambda buffer: frombuffer(buffer, dtype=np.float16)[0],
canonical_name="numpy_float16",
version=SYFT_OBJECT_VERSION_1,
)

recursive_serde_register(
np.float32,
serialize=lambda x: x.tobytes(),
deserialize=lambda buffer: frombuffer(buffer, dtype=np.float32)[0],
canonical_name="numpy_float32",
version=SYFT_OBJECT_VERSION_1,
)

recursive_serde_register(
np.float64,
serialize=lambda x: x.tobytes(),
deserialize=lambda buffer: frombuffer(buffer, dtype=np.float64)[0],
canonical_name="numpy_float64",
version=SYFT_OBJECT_VERSION_1,
)

# TODO: There is an incorrect mapping in looping,which makes it not work.
Expand Down
8 changes: 4 additions & 4 deletions packages/syft/src/syft/serde/lib_permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
from .serializable import serializable


@serializable()
@serializable(canonical_name="CMPCRUDPermission", version=1)
class CMPCRUDPermission(Enum):
NONE_EXECUTE = 1
ALL_EXECUTE = 2


@serializable()
@serializable(canonical_name="CMPPermission", version=1)
class CMPPermission:
@property
def permissions_string(self) -> str:
Expand All @@ -22,7 +22,7 @@ def __repr__(self) -> str:
return self.permission_string


@serializable()
@serializable(canonical_name="CMPUserPermission", version=1)
class CMPUserPermission(CMPPermission):
def __init__(self, user_id: UID, permission: CMPCRUDPermission):
self.user_id = user_id
Expand All @@ -36,7 +36,7 @@ def __repr__(self) -> str:
return self.permission_string


@serializable()
@serializable(canonical_name="CMPCompoundPermission", version=1)
class CMPCompoundPermission(CMPPermission):
def __init__(self, permission: CMPCRUDPermission):
self.permissions = permission
Expand Down
Loading

0 comments on commit d0e0ea4

Please sign in to comment.