Skip to content

Commit

Permalink
Merge pull request #8926 from OpenMined/eelco/validate-usercode-funcname
Browse files Browse the repository at this point in the history
Add validator for UserCode names
  • Loading branch information
koenvanderveen authored Jun 18, 2024
2 parents 144bbbf + 3dc46bf commit 6ad280b
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 14 deletions.
74 changes: 61 additions & 13 deletions packages/syft/src/syft/service/code/user_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import inspect
from io import StringIO
import itertools
import keyword
import random
import sys
from textwrap import dedent
Expand All @@ -26,8 +27,11 @@

# third party
from IPython.display import display
from pydantic import ValidationError
from pydantic import field_validator
from result import Err
from result import Ok
from result import Result
from typing_extensions import Self

# relative
Expand Down Expand Up @@ -90,6 +94,7 @@
from ..response import SyftNotReady
from ..response import SyftSuccess
from ..response import SyftWarning
from ..service import ServiceConfigRegistry
from ..user.user import UserView
from ..user.user_roles import ServiceRole
from .code_parse import GlobalsVisitor
Expand Down Expand Up @@ -363,6 +368,14 @@ class UserCode(SyncableSyftObject):
"output_policy_state",
]

@field_validator("service_func_name", mode="after")
@classmethod
def service_func_name_is_valid(cls, value: str) -> str:
res = is_valid_usercode_name(value)
if res.is_err():
raise ValueError(res.err_value)
return value

def __setattr__(self, key: str, value: Any) -> None:
# Get the attribute from the class, it might be a descriptor or None
attr = getattr(type(self), key, None)
Expand Down Expand Up @@ -908,6 +921,14 @@ class SubmitUserCode(SyftObject):

__repr_attrs__ = ["func_name", "code"]

@field_validator("func_name", mode="after")
@classmethod
def func_name_is_valid(cls, value: str) -> str:
res = is_valid_usercode_name(value)
if res.is_err():
raise ValueError(res.err_value)
return value

@field_validator("output_policy_init_kwargs", mode="after")
@classmethod
def add_output_policy_ids(cls, values: Any) -> Any:
Expand Down Expand Up @@ -1070,6 +1091,24 @@ def input_owner_verify_keys(self) -> list[str] | None:
return None


def is_valid_usercode_name(func_name: str) -> Result[Any, str]:
if len(func_name) == 0:
return Err("Function name cannot be empty")
if func_name == "_":
return Err("Cannot use anonymous function as syft function")
if not str.isidentifier(func_name):
return Err("Function name must be a valid Python identifier")
if keyword.iskeyword(func_name):
return Err("Function name is a reserved python keyword")

service_method_path = f"code.{func_name}"
if ServiceConfigRegistry.path_exists(service_method_path):
return Err(
f"Could not create syft function with name {func_name}: a service with the same name already exists"
)
return Ok(None)


class ArgumentType(Enum):
REAL = 1
MOCK = 2
Expand Down Expand Up @@ -1128,19 +1167,28 @@ def syft_function(
else:
output_policy_type = type(output_policy)

def decorator(f: Any) -> SubmitUserCode:
res = SubmitUserCode(
code=dedent(inspect.getsource(f)),
func_name=f.__name__,
signature=inspect.signature(f),
input_policy_type=input_policy_type,
input_policy_init_kwargs=init_input_kwargs,
output_policy_type=output_policy_type,
output_policy_init_kwargs=getattr(output_policy, "init_kwargs", {}),
local_function=f,
input_kwargs=f.__code__.co_varnames[: f.__code__.co_argcount],
worker_pool_name=worker_pool_name,
)
def decorator(f: Any) -> SubmitUserCode | SyftError:
try:
res = SubmitUserCode(
code=dedent(inspect.getsource(f)),
func_name=f.__name__,
signature=inspect.signature(f),
input_policy_type=input_policy_type,
input_policy_init_kwargs=init_input_kwargs,
output_policy_type=output_policy_type,
output_policy_init_kwargs=getattr(output_policy, "init_kwargs", {}),
local_function=f,
input_kwargs=f.__code__.co_varnames[: f.__code__.co_argcount],
worker_pool_name=worker_pool_name,
)
except ValidationError as e:
errors = e.errors()
msg = "Failed to create syft function, encountered validation errors:\n"
for error in errors:
msg += f"\t{error['msg']}\n"
err = SyftError(message=msg)
display(err)
return err

if share_results_with_owners and res.output_policy_init_kwargs is not None:
res.output_policy_init_kwargs["output_readers"] = (
Expand Down
4 changes: 3 additions & 1 deletion packages/syft/src/syft/service/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def _repr_html_class_(self) -> str:
def _repr_html_(self) -> str:
return (
f'<div class="{self._repr_html_class_}" style="padding:5px;">'
+ f"<strong>{type(self).__name__}</strong>: {self.message}</div><br />"
f"<strong>{type(self).__name__}</strong>: "
f'<pre class="{self._repr_html_class_}" style="display:inline; font-family:inherit;">'
f"{self.message}</pre></div><br/>"
)


Expand Down
35 changes: 35 additions & 0 deletions packages/syft/tests/syft/users/user_code_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# third party
from faker import Faker
import numpy as np
from pydantic import ValidationError
import pytest

# syft absolute
import syft as sy
Expand All @@ -12,6 +14,7 @@
from syft.service.request.request import Request
from syft.service.request.request import UserCodeStatusChange
from syft.service.response import SyftError
from syft.service.response import SyftSuccess
from syft.service.user.user import User


Expand Down Expand Up @@ -331,3 +334,35 @@ def compute_sum():
result = ds_client.api.services.code.compute_sum()
assert result, result
assert result.get() == 1


def test_submit_invalid_name(worker) -> None:
client = worker.root_client

@sy.syft_function_single_use()
def valid_name():
pass

res = client.code.submit(valid_name)
assert isinstance(res, SyftSuccess)

@sy.syft_function_single_use()
def get_all():
pass

assert isinstance(get_all, SyftError)

@sy.syft_function_single_use()
def _():
pass

assert isinstance(_, SyftError)

# overwrite valid function name before submit, fail on serde
@sy.syft_function_single_use()
def valid_name_2():
pass

valid_name_2.func_name = "get_all"
with pytest.raises(ValidationError):
client.code.submit(valid_name_2)

0 comments on commit 6ad280b

Please sign in to comment.