Skip to content

Commit

Permalink
Merge pull request #8715 from OpenMined/eelco/serializing-large-objs
Browse files Browse the repository at this point in the history
[WIP] serializing large objs
  • Loading branch information
koenvanderveen committed Apr 18, 2024
2 parents 74b6410 + 91ab7a0 commit 891c550
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 39 deletions.
32 changes: 18 additions & 14 deletions packages/syft/src/syft/serde/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,24 @@


def arrow_serialize(obj: np.ndarray) -> bytes:
original_dtype = obj.dtype
apache_arrow = pa.Tensor.from_numpy(obj=obj)
sink = pa.BufferOutputStream()
pa.ipc.write_tensor(apache_arrow, sink)
buffer = sink.getvalue()
if flags.APACHE_ARROW_COMPRESSION is ApacheArrowCompression.NONE:
numpy_bytes = buffer.to_pybytes()
else:
numpy_bytes = pa.compress(
buffer, asbytes=True, codec=flags.APACHE_ARROW_COMPRESSION.value
)
dtype = original_dtype.name

return cast(bytes, _serialize((numpy_bytes, buffer.size, dtype), to_bytes=True))
# inner function to make sure variables go out of scope after this
def inner(obj: np.ndarray) -> tuple:
original_dtype = obj.dtype
apache_arrow = pa.Tensor.from_numpy(obj=obj)
sink = pa.BufferOutputStream()
pa.ipc.write_tensor(apache_arrow, sink)
buffer = sink.getvalue()
if flags.APACHE_ARROW_COMPRESSION is ApacheArrowCompression.NONE:
numpy_bytes = buffer.to_pybytes()
else:
numpy_bytes = pa.compress(
buffer, asbytes=True, codec=flags.APACHE_ARROW_COMPRESSION.value
)
dtype = original_dtype.name
return (numpy_bytes, buffer.size, dtype)

m = inner(obj)
return cast(bytes, _serialize(m, to_bytes=True))


def arrow_deserialize(
Expand Down
44 changes: 31 additions & 13 deletions packages/syft/src/syft/serde/recursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from enum import Enum
from enum import EnumMeta
import sys
import tempfile
import types
from typing import Any

Expand All @@ -22,6 +23,8 @@

recursive_scheme = get_capnp_schema("recursive_serde.capnp").RecursiveSerde

SPOOLED_FILE_MAX_SIZE_SERDE = 50 * (1024**2) # 50MB


def get_types(cls: type, keys: list[str] | None = None) -> list[type] | None:
if keys is None:
Expand Down Expand Up @@ -161,16 +164,28 @@ def recursive_serde_register(


def chunk_bytes(
data: bytes, field_name: str | int, builder: _DynamicStructBuilder
field_obj: Any,
ser_func: Callable,
field_name: str | int,
builder: _DynamicStructBuilder,
) -> None:
CHUNK_SIZE = int(5.12e8) # capnp max for a List(Data) field
list_size = len(data) // CHUNK_SIZE + 1
data_lst = builder.init(field_name, list_size)
END_INDEX = CHUNK_SIZE
for idx in range(list_size):
START_INDEX = idx * CHUNK_SIZE
END_INDEX = min(START_INDEX + CHUNK_SIZE, len(data))
data_lst[idx] = data[START_INDEX:END_INDEX]
data = ser_func(field_obj)
size_of_data = len(data)
with tempfile.SpooledTemporaryFile(
max_size=SPOOLED_FILE_MAX_SIZE_SERDE
) as tmp_file:
# Write data to a file to save RAM
tmp_file.write(data)
tmp_file.seek(0)
del data

CHUNK_SIZE = int(5.12e8) # capnp max for a List(Data) field
list_size = size_of_data // CHUNK_SIZE + 1
data_lst = builder.init(field_name, list_size)
for idx in range(list_size):
bytes_to_read = min(CHUNK_SIZE, size_of_data)
data_lst[idx] = tmp_file.read(bytes_to_read)
size_of_data -= CHUNK_SIZE


def combine_bytes(capnp_list: list[bytes]) -> bytes:
Expand All @@ -195,7 +210,6 @@ def rs_object2proto(self: Any, for_hashing: bool = False) -> _DynamicStructBuild
if fqn not in TYPE_BANK:
# third party
raise Exception(f"{fqn} not in TYPE_BANK")

msg.fullyQualifiedName = fqn
(
nonrecursive,
Expand All @@ -215,7 +229,7 @@ def rs_object2proto(self: Any, for_hashing: bool = False) -> _DynamicStructBuild
raise Exception(
f"Cant serialize {type(self)} nonrecursive without serialize."
)
chunk_bytes(serialize(self), "nonrecursiveBlob", msg)
chunk_bytes(self, serialize, "nonrecursiveBlob", msg)
return msg

if attribute_list is None:
Expand Down Expand Up @@ -248,9 +262,13 @@ def rs_object2proto(self: Any, for_hashing: bool = False) -> _DynamicStructBuild
if isinstance(field_obj, types.FunctionType):
continue

serialized = sy.serialize(field_obj, to_bytes=True, for_hashing=for_hashing)
msg.fieldsName[idx] = attr_name
chunk_bytes(serialized, idx, msg.fieldsData)
chunk_bytes(
field_obj,
lambda x: sy.serialize(x, to_bytes=True, for_hashing=for_hashing),
idx,
msg.fieldsData,
)

return msg

Expand Down
28 changes: 22 additions & 6 deletions packages/syft/src/syft/serde/recursive_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pathlib
from pathlib import PurePath
import sys
import tempfile
from types import MappingProxyType
from types import UnionType
from typing import Any
Expand All @@ -26,9 +27,11 @@

# relative
from .capnp import get_capnp_schema
from .recursive import SPOOLED_FILE_MAX_SIZE_SERDE
from .recursive import chunk_bytes
from .recursive import combine_bytes
from .recursive import recursive_serde_register
from .util import compatible_with_large_file_writes_capnp

iterable_schema = get_capnp_schema("iterable.capnp").Iterable
kv_iterable_schema = get_capnp_schema("kv_iterable.capnp").KVIterable
Expand All @@ -43,10 +46,23 @@ def serialize_iterable(iterable: Collection) -> bytes:
message.init("values", len(iterable))

for idx, it in enumerate(iterable):
serialized = _serialize(it, to_bytes=True)
chunk_bytes(serialized, idx, message.values)

return message.to_bytes()
# serialized = _serialize(it, to_bytes=True)
chunk_bytes(it, lambda x: _serialize(x, to_bytes=True), idx, message.values)

if compatible_with_large_file_writes_capnp():
with tempfile.SpooledTemporaryFile(
max_size=SPOOLED_FILE_MAX_SIZE_SERDE
) as tmp_file:
# Write data to a file to save RAM
message.write(tmp_file)
del message
tmp_file.seek(0)
res = tmp_file.read()
return res
else:
res = message.to_bytes()
del message
return res


def deserialize_iterable(iterable_type: type, blob: bytes) -> Collection:
Expand Down Expand Up @@ -80,8 +96,8 @@ def _serialize_kv_pairs(size: int, kv_pairs: Iterable[tuple[_KT, _VT]]) -> bytes

for index, (k, v) in enumerate(kv_pairs):
message.keys[index] = _serialize(k, to_bytes=True)
serialized = _serialize(v, to_bytes=True)
chunk_bytes(serialized, index, message.values)
# serialized = _serialize(v, to_bytes=True)
chunk_bytes(v, lambda x: _serialize(x, to_bytes=True), index, message.values)

return message.to_bytes()

Expand Down
22 changes: 20 additions & 2 deletions packages/syft/src/syft/serde/serialize.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# stdlib
import tempfile
from typing import Any

# relative
from .recursive import SPOOLED_FILE_MAX_SIZE_SERDE
from .util import compatible_with_large_file_writes_capnp


def _serialize(
obj: object,
Expand All @@ -12,9 +17,22 @@ def _serialize(
from .recursive import rs_object2proto

proto = rs_object2proto(obj, for_hashing=for_hashing)

if to_bytes:
return proto.to_bytes()
if compatible_with_large_file_writes_capnp():
with tempfile.SpooledTemporaryFile(
max_size=SPOOLED_FILE_MAX_SIZE_SERDE
) as tmp_file:
# Write data to a file to save RAM

proto.write(tmp_file)
# proto in memory, and bytes in file
del proto
# bytes in file
tmp_file.seek(0)
return tmp_file.read()
else:
res = proto.to_bytes()
return res

if to_proto:
return proto
6 changes: 6 additions & 0 deletions packages/syft/src/syft/serde/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# stdlib
from sys import platform


def compatible_with_large_file_writes_capnp() -> bool:
return platform not in ["darwin", "win32"]
29 changes: 25 additions & 4 deletions packages/syft/src/syft/service/action/action_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,21 @@ class ActionObjectPointer:
"__str__",
]

methods_to_check_in_cache = [
"_ipython_display_",
"_repr_mimebundle_",
"_repr_latex_",
"_repr_javascript_",
"_repr_html_",
"_repr_jpeg_",
"_repr_png_",
"_repr_svg_",
"_repr_pretty_",
"_repr_pdf_",
"_repr_json_",
"_repr_markdown_",
]


class PreHookContext(SyftBaseObject):
__canonical_name__ = "PreHookContext"
Expand Down Expand Up @@ -1577,8 +1592,6 @@ def _syft_get_attr_context(self, name: str) -> Any:
"""Find which instance - Syft ActionObject or the original object - has the requested attribute."""
defined_on_self = name in self.__dict__ or name in self.__private_attributes__

debug(">> ", name, ", defined_on_self = ", defined_on_self)

# use the custom defined version
context_self = self
if not defined_on_self:
Expand Down Expand Up @@ -1807,6 +1820,10 @@ def __getattribute__(self, name: str) -> Any:
name: str
The name of the attribute to access.
"""
# bypass ipython canary verification
if name == "_ipython_canary_method_should_not_exist_":
return None

# bypass certain attrs to prevent recursion issues
if name.startswith("_syft") or name.startswith("syft"):
return object.__getattribute__(self, name)
Expand All @@ -1817,13 +1834,17 @@ def __getattribute__(self, name: str) -> Any:
# third party
if name in self._syft_passthrough_attrs():
return object.__getattribute__(self, name)
context_self = self._syft_get_attr_context(name)

# Handle bool operator on nonbools
if name == "__bool__" and not self.syft_has_bool_attr:
return self._syft_wrap_attribute_for_bool_on_nonbools(name)

# check cache first
if name in methods_to_check_in_cache:
return getattr(self.syft_action_data_cache, name, None)

# Handle Properties
context_self = self._syft_get_attr_context(name)
if self.syft_is_property(context_self, name):
return self._syft_wrap_attribute_for_properties(name)

Expand Down Expand Up @@ -1880,7 +1901,7 @@ def _repr_markdown_(self, wrap_as_python: bool = True, indent: int = 0) -> str:
else self.syft_action_data_cache.__repr__()
)

return f"```python\n{res}\n```\n{data_repr_}"
return f"```python\n{res}\n{data_repr_}```\n"

def _data_repr(self) -> str | None:
if isinstance(self.syft_action_data_cache, ActionDataEmpty):
Expand Down

0 comments on commit 891c550

Please sign in to comment.