Skip to content

Commit

Permalink
Merge branch 'dev' into eelco/deposit-result-l2
Browse files Browse the repository at this point in the history
  • Loading branch information
eelcovdw authored Aug 12, 2024
2 parents b79363e + f315b98 commit 716276e
Show file tree
Hide file tree
Showing 19 changed files with 459 additions and 142 deletions.
9 changes: 5 additions & 4 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
certifi>=2023.7.22 # not directly required, pinned by Snyk to avoid a vulnerability
certifi>=2024.7.4 # not directly required, pinned by Snyk to avoid a vulnerability
idna>=3.7 # not directly required, pinned by Snyk to avoid a vulnerability
ipython==8.10.0
jinja2>=3.1.3 # not directly required, pinned by Snyk to avoid a vulnerability
jinja2>=3.1.4 # not directly required, pinned by Snyk to avoid a vulnerability
markupsafe==2.0.1
pydata-sphinx-theme==0.7.2
pygments>=2.15.0 # not directly required, pinned by Snyk to avoid a vulnerability
requests>=2.32.0 # not directly required, pinned by Snyk to avoid a vulnerability
setuptools>=65.5.1 # not directly required, pinned by Snyk to avoid a vulnerability
requests>=2.32.2 # not directly required, pinned by Snyk to avoid a vulnerability
setuptools>=70.0.0 # not directly required, pinned by Snyk to avoid a vulnerability
sphinx==4.3.0
sphinx-autoapi==1.8.4
sphinx-code-include==1.1.1
Expand Down
65 changes: 37 additions & 28 deletions packages/syft/src/syft/client/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import types
from typing import Any
from typing import TYPE_CHECKING
from typing import _GenericAlias
from typing import cast
from typing import get_args
from typing import get_origin
Expand All @@ -19,12 +18,9 @@
from nacl.exceptions import BadSignatureError
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import EmailStr
from pydantic import TypeAdapter
from result import OkErr
from result import Result
from typeguard import TypeCheckError
from typeguard import check_type

# relative
from ..abstract_server import AbstractServer
Expand All @@ -47,6 +43,8 @@
from ..service.response import SyftSuccess
from ..service.service import UserLibConfigRegistry
from ..service.service import UserServiceConfigRegistry
from ..service.service import _format_signature
from ..service.service import _signature_error_message
from ..service.user.user_roles import ServiceRole
from ..service.warnings import APIEndpointWarning
from ..service.warnings import WarningContext
Expand Down Expand Up @@ -97,6 +95,23 @@ def _has_config_dict(t: Any) -> bool:
)


_config_dict = ConfigDict(arbitrary_types_allowed=True)


def _check_type(v: object, t: Any) -> Any:
# TypeAdapter only accepts `config` arg if `t` does not
# already contain a ConfigDict
# i.e model_config in BaseModel and __pydantic_config__ in
# other types.
type_adapter = (
TypeAdapter(t, config=_config_dict)
if not _has_config_dict(t)
else TypeAdapter(t)
)

return type_adapter.validate_python(v)


class APIRegistry:
__api_registry__: dict[tuple, SyftAPI] = OrderedDict()

Expand Down Expand Up @@ -1308,7 +1323,10 @@ def validate_callable_args_and_kwargs(
for key, value in kwargs.items():
if key not in signature.parameters:
return SyftError(
message=f"""Invalid parameter: `{key}`. Valid Parameters: {list(signature.parameters)}"""
message=(
f"Invalid parameter: `{key}`.\n"
f"{_signature_error_message(_format_signature(signature))}"
)
)
param = signature.parameters[key]
if isinstance(param.annotation, str):
Expand All @@ -1320,21 +1338,15 @@ def validate_callable_args_and_kwargs(

if t is not inspect.Parameter.empty:
try:
config_kw = (
{"config": ConfigDict(arbitrary_types_allowed=True)}
if not _has_config_dict(t)
else {}
)

# TypeAdapter only accepts `config` arg if `t` does not
# already contain a ConfigDict
# i.e model_config in BaseModel and __pydantic_config__ in
# other types.
TypeAdapter(t, **config_kw).validate_python(value)
except Exception:
_check_type(value, t)
except ValueError:
_type_str = getattr(t, "__name__", str(t))

return SyftError(
message=f"`{key}` must be of type `{_type_str}` not `{type(value).__name__}`"
message=(
f"`{key}` must be of type `{_type_str}` not `{type(value).__name__}`\n"
f"{_signature_error_message(_format_signature(signature))}"
)
)

_valid_kwargs[key] = value
Expand All @@ -1353,15 +1365,8 @@ def validate_callable_args_and_kwargs(
msg = None
try:
if t is not inspect.Parameter.empty:
if isinstance(t, _GenericAlias) and type(None) in t.__args__:
for v in t.__args__:
if issubclass(v, EmailStr):
v = str
check_type(arg, v) # raises Exception
break # only need one to match
else:
check_type(arg, t) # raises Exception
except TypeCheckError:
_check_type(arg, t)
except ValueError:
t_arg = type(arg)
if (
autoreload_enabled()
Expand All @@ -1372,7 +1377,11 @@ def validate_callable_args_and_kwargs(
pass
else:
_type_str = getattr(t, "__name__", str(t))
msg = f"Arg is `{arg}`. \nIt must be of type `{_type_str}`, not `{type(arg).__name__}`"

msg = (
f"Arg is `{arg}`. \nIt must be of type `{_type_str}`, not `{type(arg).__name__}`\n"
f"{_signature_error_message(_format_signature(signature))}"
)

if msg:
return SyftError(message=msg)
Expand Down
3 changes: 3 additions & 0 deletions packages/syft/src/syft/client/datasite_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from ..service.dataset.dataset import CreateDataset
from ..service.migration.object_migration_state import MigrationData
from ..service.response import SyftError
from ..service.response import SyftException
from ..service.response import SyftSuccess
from ..service.response import SyftWarning
from ..service.sync.diff_state import ResolvedSyncState
Expand Down Expand Up @@ -146,6 +147,8 @@ def upload_dataset(self, dataset: CreateDataset) -> SyftSuccess | SyftError:
res = twin._save_to_blob_storage(allow_empty=contains_empty)
if isinstance(res, SyftError):
return res
except SyftException as se:
return SyftError(message=f"{se}")
except Exception as e:
tqdm.write(f"Failed to create twin for {asset.name}. {e}")
return SyftError(message=f"Failed to create twin. {e}")
Expand Down
45 changes: 30 additions & 15 deletions packages/syft/src/syft/service/action/action_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from ...types.syft_object import SYFT_OBJECT_VERSION_1
from ...types.syft_object import SyftBaseObject
from ...types.syft_object import SyftObject
from ...types.syft_object_registry import SyftObjectRegistry
from ...types.syncable_object import SyncableSyftObject
from ...types.uid import LineageID
from ...types.uid import UID
Expand Down Expand Up @@ -913,21 +914,23 @@ def syft_lineage_id(self) -> LineageID:

@model_validator(mode="before")
@classmethod
def __check_action_data(cls, values: dict) -> dict:
v = values.get("syft_action_data_cache")
if values.get("syft_action_data_type", None) is None:
values["syft_action_data_type"] = type(v)
if not isinstance(v, ActionDataEmpty):
if inspect.isclass(v):
values["syft_action_data_repr_"] = truncate_str(repr_cls(v))
else:
values["syft_action_data_repr_"] = truncate_str(
v._repr_markdown_()
if v is not None and hasattr(v, "_repr_markdown_")
else v.__repr__()
)
values["syft_action_data_str_"] = truncate_str(str(v))
values["syft_has_bool_attr"] = hasattr(v, "__bool__")
def __check_action_data(cls, values: Any) -> dict:
if isinstance(values, dict):
v = values.get("syft_action_data_cache")
if values.get("syft_action_data_type", None) is None:
values["syft_action_data_type"] = type(v)
if not isinstance(v, ActionDataEmpty):
if inspect.isclass(v):
values["syft_action_data_repr_"] = truncate_str(repr_cls(v))
else:
values["syft_action_data_repr_"] = truncate_str(
v._repr_markdown_()
if v is not None and hasattr(v, "_repr_markdown_")
else v.__repr__()
)
values["syft_action_data_str_"] = truncate_str(str(v))
values["syft_has_bool_attr"] = hasattr(v, "__bool__")

return values

@property
Expand Down Expand Up @@ -1410,6 +1413,18 @@ def from_obj(
if id is not None and syft_lineage_id is not None and id != syft_lineage_id.id:
raise ValueError("UID and LineageID should match")

# check if the object's type is supported
try:
canonical_name, version = SyftObjectRegistry.get_canonical_name_version(
syft_action_data
)
except Exception:
obj_type = type(syft_action_data)
raise SyftException(
f"Error when creating action object for {syft_action_data}.\n"
f"Unsupported data type: '{obj_type.__module__}.{obj_type.__name__}'"
)

action_type = action_type_for_object(syft_action_data)
action_object = action_type(syft_action_data_cache=syft_action_data)
action_object.syft_blob_storage_entry_id = syft_blob_storage_entry_id
Expand Down
6 changes: 5 additions & 1 deletion packages/syft/src/syft/service/blob_storage/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any

# relative
from ...service.response import SyftException
from ...util.util import get_mb_serialized_size
from ..metadata.server_metadata import ServerMetadata
from ..metadata.server_metadata import ServerMetadataJSON
Expand All @@ -16,4 +17,7 @@ def min_size_for_blob_storage_upload(
def can_upload_to_blob_storage(
data: Any, metadata: ServerMetadata | ServerMetadataJSON
) -> bool:
return get_mb_serialized_size(data) >= min_size_for_blob_storage_upload(metadata)
serialized_size = get_mb_serialized_size(data)
if serialized_size.is_err():
raise SyftException(f"{serialized_size.err()}")
return serialized_size.ok() >= min_size_for_blob_storage_upload(metadata)
Loading

0 comments on commit 716276e

Please sign in to comment.