diff --git a/packages/syft/src/syft/client/api.py b/packages/syft/src/syft/client/api.py index 9c8b244b129..8ebfac03c0b 100644 --- a/packages/syft/src/syft/client/api.py +++ b/packages/syft/src/syft/client/api.py @@ -23,6 +23,7 @@ from pydantic import TypeAdapter from result import OkErr from result import Result +from typeguard import TypeCheckError from typeguard import check_type # relative @@ -1385,7 +1386,7 @@ def validate_callable_args_and_kwargs( break # only need one to match else: check_type(arg, t) # raises Exception - except TypeError: + except TypeCheckError: t_arg = type(arg) if ( autoreload_enabled() @@ -1396,7 +1397,7 @@ def validate_callable_args_and_kwargs( pass else: _type_str = getattr(t, "__name__", str(t)) - msg = f"Arg: {arg} must be {_type_str} not {type(arg).__name__}" + msg = f"Arg is `{arg}`. \nIt must be of type `{_type_str}`, not `{type(arg).__name__}`" if msg: return SyftError(message=msg) diff --git a/packages/syft/src/syft/orchestra.py b/packages/syft/src/syft/orchestra.py index 1a08f594aa2..08672657762 100644 --- a/packages/syft/src/syft/orchestra.py +++ b/packages/syft/src/syft/orchestra.py @@ -8,6 +8,7 @@ from enum import Enum import getpass import inspect +import logging import os import sys from typing import Any @@ -24,6 +25,8 @@ from .service.response import SyftError from .util.util import get_random_available_port +logger = logging.getLogger(__name__) + DEFAULT_PORT = 8080 DEFAULT_URL = "http://localhost" @@ -174,7 +177,7 @@ def deploy_to_python( } if dev_mode: - print("Staging Protocol Changes...") + logger.debug("Staging Protocol Changes...") stage_protocol_changes() kwargs = { diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index 5c1dd137cd0..88b3d5109bd 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -97,6 +97,7 @@ from ..policy.policy import partition_by_node from ..policy.policy_service import PolicyService from ..response import SyftError +from ..response import SyftException from ..response import SyftInfo from ..response import SyftNotReady from ..response import SyftSuccess @@ -104,9 +105,10 @@ from ..service import ServiceConfigRegistry from ..user.user import UserView from ..user.user_roles import ServiceRole -from .code_parse import GlobalsVisitor from .code_parse import LaunchJobVisitor from .unparse import unparse +from .utils import check_for_global_vars +from .utils import parse_code from .utils import submit_subjobs_code if TYPE_CHECKING: @@ -1038,13 +1040,6 @@ def __call__( def local_call(self, *args: Any, **kwargs: Any) -> Any: # only run this on the client side if self.local_function: - source = dedent(inspect.getsource(self.local_function)) - tree = ast.parse(source) - - # check there are no globals - v = GlobalsVisitor() - v.visit(tree) - # filtered_args = [] filtered_kwargs = {} # for arg in args: @@ -1256,15 +1251,25 @@ def syft_function( else: output_policy_type = type(output_policy) - def decorator(f: Any) -> SubmitUserCode: + def decorator(f: Any) -> SubmitUserCode | SyftError: try: code = dedent(inspect.getsource(f)) + if name is not None: fname = name code = replace_func_name(code, fname) else: fname = f.__name__ + input_kwargs = f.__code__.co_varnames[: f.__code__.co_argcount] + + parse_user_code( + raw_code=code, + func_name=fname, + original_func_name=f.__name__, + function_input_kwargs=input_kwargs, + ) + res = SubmitUserCode( code=code, func_name=fname, @@ -1274,7 +1279,7 @@ def decorator(f: Any) -> SubmitUserCode: output_policy_type=output_policy_type, output_policy_init_kwargs=getattr(output_policy, "init_kwargs", {}), local_function=f, - input_kwargs=f.__code__.co_varnames[: f.__code__.co_argcount], + input_kwargs=input_kwargs, worker_pool_name=worker_pool_name, ) @@ -1287,6 +1292,11 @@ def decorator(f: Any) -> SubmitUserCode: display(err) return err + except SyftException as se: + err = SyftError(message=f"Error when parsing the code: {se}") + display(err) + return err + if share_results_with_owners and res.output_policy_init_kwargs is not None: res.output_policy_init_kwargs["output_readers"] = ( res.input_owner_verify_keys @@ -1318,26 +1328,23 @@ def generate_unique_func_name(context: TransformContext) -> TransformContext: return context -def process_code( - context: TransformContext, +def parse_user_code( raw_code: str, func_name: str, original_func_name: str, - policy_input_kwargs: list[str], function_input_kwargs: list[str], ) -> str: - tree = ast.parse(raw_code) - - # check there are no globals - v = GlobalsVisitor() - v.visit(tree) + # parse the code, check for syntax errors and if there are global variables + try: + tree: ast.Module = parse_code(raw_code=raw_code) + check_for_global_vars(code_tree=tree) + except SyftException as e: + raise SyftException(f"{e}") - f = tree.body[0] + f: ast.stmt = tree.body[0] f.decorator_list = [] call_args = function_input_kwargs - if "domain" in function_input_kwargs and context.output is not None: - context.output["uses_domain"] = True call_stmt_keywords = [ast.keyword(arg=i, value=[ast.Name(id=i)]) for i in call_args] call_stmt = ast.Assign( targets=[ast.Name(id="result")], @@ -1362,6 +1369,25 @@ def process_code( return unparse(wrapper_function) +def process_code( + context: TransformContext, + raw_code: str, + func_name: str, + original_func_name: str, + policy_input_kwargs: list[str], + function_input_kwargs: list[str], +) -> str: + if "domain" in function_input_kwargs and context.output is not None: + context.output["uses_domain"] = True + + return parse_user_code( + raw_code=raw_code, + func_name=func_name, + original_func_name=original_func_name, + function_input_kwargs=function_input_kwargs, + ) + + def new_check_code(context: TransformContext) -> TransformContext: # TODO: remove this tech debt hack if context.output is None: diff --git a/packages/syft/src/syft/service/code/utils.py b/packages/syft/src/syft/service/code/utils.py index fccc5314c43..a3d59cbe161 100644 --- a/packages/syft/src/syft/service/code/utils.py +++ b/packages/syft/src/syft/service/code/utils.py @@ -6,6 +6,9 @@ from IPython import get_ipython # relative +from ..response import SyftException +from ..response import SyftWarning +from .code_parse import GlobalsVisitor from .code_parse import LaunchJobVisitor @@ -36,3 +39,28 @@ def submit_subjobs_code(submit_user_code, ep_client) -> None: # type: ignore # fetch if specs["type_name"] == "SubmitUserCode": ep_client.code.submit(ipython.ev(call)) + + +def check_for_global_vars(code_tree: ast.Module) -> GlobalsVisitor | SyftWarning: + """ + Check that the code does not contain any global variables + """ + v = GlobalsVisitor() + try: + v.visit(code_tree) + except Exception: + raise SyftException( + "Your code contains (a) global variable(s), which is not allowed" + ) + return v + + +def parse_code(raw_code: str) -> ast.Module | SyftWarning: + """ + Parse the code into an AST tree and return a warning if there are syntax errors + """ + try: + tree = ast.parse(raw_code) + except SyntaxError as e: + raise SyftException(f"Your code contains syntax error: {e}") + return tree diff --git a/packages/syft/tests/syft/users/user_code_test.py b/packages/syft/tests/syft/users/user_code_test.py index c7a56550d55..4dd038f47e2 100644 --- a/packages/syft/tests/syft/users/user_code_test.py +++ b/packages/syft/tests/syft/users/user_code_test.py @@ -395,6 +395,26 @@ def valid_name_2(): client.code.submit(valid_name_2) +def test_submit_code_with_global_var(guest_client: DomainClient) -> None: + @sy.syft_function( + input_policy=sy.ExactMatch(), output_policy=sy.SingleExecutionExactOutput() + ) + def mock_syft_func_with_global(): + global x + return x + + res = guest_client.code.submit(mock_syft_func_with_global) + assert isinstance(res, SyftError) + + @sy.syft_function_single_use() + def mock_syft_func_single_use_with_global(): + global x + return x + + res = guest_client.code.submit(mock_syft_func_single_use_with_global) + assert isinstance(res, SyftError) + + def test_request_existing_usercodesubmit(worker) -> None: root_domain_client = worker.root_client