From 3dc46bf781c6aeff034a51abc7c9e1b565232f15 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Tue, 18 Jun 2024 12:55:09 +0200 Subject: [PATCH] add test --- .../syft/src/syft/service/code/user_code.py | 74 +++++++++++++++---- packages/syft/src/syft/service/response.py | 4 +- .../syft/tests/syft/users/user_code_test.py | 35 +++++++++ 3 files changed, 99 insertions(+), 14 deletions(-) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index 9ee7e09e7e9..f771cf9a9c5 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -12,6 +12,7 @@ import inspect from io import StringIO import itertools +import keyword import random import sys from textwrap import dedent @@ -26,8 +27,11 @@ # third party from IPython.display import display +from pydantic import ValidationError from pydantic import field_validator from result import Err +from result import Ok +from result import Result from typing_extensions import Self # relative @@ -90,6 +94,7 @@ from ..response import SyftNotReady from ..response import SyftSuccess from ..response import SyftWarning +from ..service import ServiceConfigRegistry from ..user.user import UserView from ..user.user_roles import ServiceRole from .code_parse import GlobalsVisitor @@ -363,6 +368,14 @@ class UserCode(SyncableSyftObject): "output_policy_state", ] + @field_validator("service_func_name", mode="after") + @classmethod + def service_func_name_is_valid(cls, value: str) -> str: + res = is_valid_usercode_name(value) + if res.is_err(): + raise ValueError(res.err_value) + return value + def __setattr__(self, key: str, value: Any) -> None: # Get the attribute from the class, it might be a descriptor or None attr = getattr(type(self), key, None) @@ -908,6 +921,14 @@ class SubmitUserCode(SyftObject): __repr_attrs__ = ["func_name", "code"] + @field_validator("func_name", mode="after") + @classmethod + def func_name_is_valid(cls, value: str) -> str: + res = is_valid_usercode_name(value) + if res.is_err(): + raise ValueError(res.err_value) + return value + @field_validator("output_policy_init_kwargs", mode="after") @classmethod def add_output_policy_ids(cls, values: Any) -> Any: @@ -1070,6 +1091,24 @@ def input_owner_verify_keys(self) -> list[str] | None: return None +def is_valid_usercode_name(func_name: str) -> Result[Any, str]: + if len(func_name) == 0: + return Err("Function name cannot be empty") + if func_name == "_": + return Err("Cannot use anonymous function as syft function") + if not str.isidentifier(func_name): + return Err("Function name must be a valid Python identifier") + if keyword.iskeyword(func_name): + return Err("Function name is a reserved python keyword") + + service_method_path = f"code.{func_name}" + if ServiceConfigRegistry.path_exists(service_method_path): + return Err( + f"Could not create syft function with name {func_name}: a service with the same name already exists" + ) + return Ok(None) + + class ArgumentType(Enum): REAL = 1 MOCK = 2 @@ -1128,19 +1167,28 @@ def syft_function( else: output_policy_type = type(output_policy) - def decorator(f: Any) -> SubmitUserCode: - res = SubmitUserCode( - code=dedent(inspect.getsource(f)), - func_name=f.__name__, - signature=inspect.signature(f), - input_policy_type=input_policy_type, - input_policy_init_kwargs=init_input_kwargs, - output_policy_type=output_policy_type, - output_policy_init_kwargs=getattr(output_policy, "init_kwargs", {}), - local_function=f, - input_kwargs=f.__code__.co_varnames[: f.__code__.co_argcount], - worker_pool_name=worker_pool_name, - ) + def decorator(f: Any) -> SubmitUserCode | SyftError: + try: + res = SubmitUserCode( + code=dedent(inspect.getsource(f)), + func_name=f.__name__, + signature=inspect.signature(f), + input_policy_type=input_policy_type, + input_policy_init_kwargs=init_input_kwargs, + output_policy_type=output_policy_type, + output_policy_init_kwargs=getattr(output_policy, "init_kwargs", {}), + local_function=f, + input_kwargs=f.__code__.co_varnames[: f.__code__.co_argcount], + worker_pool_name=worker_pool_name, + ) + except ValidationError as e: + errors = e.errors() + msg = "Failed to create syft function, encountered validation errors:\n" + for error in errors: + msg += f"\t{error['msg']}\n" + err = SyftError(message=msg) + display(err) + return err if share_results_with_owners and res.output_policy_init_kwargs is not None: res.output_policy_init_kwargs["output_readers"] = ( diff --git a/packages/syft/src/syft/service/response.py b/packages/syft/src/syft/service/response.py index 37227046c5c..723970cdfff 100644 --- a/packages/syft/src/syft/service/response.py +++ b/packages/syft/src/syft/service/response.py @@ -42,7 +42,9 @@ def _repr_html_class_(self) -> str: def _repr_html_(self) -> str: return ( f'
' - + f"{type(self).__name__}: {self.message}

" + f"{type(self).__name__}: " + f'
'
+            f"{self.message}

" ) diff --git a/packages/syft/tests/syft/users/user_code_test.py b/packages/syft/tests/syft/users/user_code_test.py index f006525097e..333d246e37f 100644 --- a/packages/syft/tests/syft/users/user_code_test.py +++ b/packages/syft/tests/syft/users/user_code_test.py @@ -4,6 +4,8 @@ # third party from faker import Faker import numpy as np +from pydantic import ValidationError +import pytest # syft absolute import syft as sy @@ -12,6 +14,7 @@ from syft.service.request.request import Request from syft.service.request.request import UserCodeStatusChange from syft.service.response import SyftError +from syft.service.response import SyftSuccess from syft.service.user.user import User @@ -331,3 +334,35 @@ def compute_sum(): result = ds_client.api.services.code.compute_sum() assert result, result assert result.get() == 1 + + +def test_submit_invalid_name(worker) -> None: + client = worker.root_client + + @sy.syft_function_single_use() + def valid_name(): + pass + + res = client.code.submit(valid_name) + assert isinstance(res, SyftSuccess) + + @sy.syft_function_single_use() + def get_all(): + pass + + assert isinstance(get_all, SyftError) + + @sy.syft_function_single_use() + def _(): + pass + + assert isinstance(_, SyftError) + + # overwrite valid function name before submit, fail on serde + @sy.syft_function_single_use() + def valid_name_2(): + pass + + valid_name_2.func_name = "get_all" + with pytest.raises(ValidationError): + client.code.submit(valid_name_2)