Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add linter and error for code submissions that use global scope variables / methods #8974

Merged
merged 15 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions packages/syft/src/syft/client/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pydantic import TypeAdapter
from result import OkErr
from result import Result
from typeguard import TypeCheckError
from typeguard import check_type

# relative
Expand Down Expand Up @@ -1385,7 +1386,7 @@ def validate_callable_args_and_kwargs(
break # only need one to match
else:
check_type(arg, t) # raises Exception
except TypeError:
except TypeCheckError:
t_arg = type(arg)
if (
autoreload_enabled()
Expand All @@ -1396,7 +1397,7 @@ def validate_callable_args_and_kwargs(
pass
else:
_type_str = getattr(t, "__name__", str(t))
msg = f"Arg: {arg} must be {_type_str} not {type(arg).__name__}"
msg = f"Arg is `{arg}`. \nIt must be of type `{_type_str}`, not `{type(arg).__name__}`"

if msg:
return SyftError(message=msg)
Expand Down
5 changes: 4 additions & 1 deletion packages/syft/src/syft/orchestra.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from enum import Enum
import getpass
import inspect
import logging
import os
import sys
from typing import Any
Expand All @@ -24,6 +25,8 @@
from .service.response import SyftError
from .util.util import get_random_available_port

logger = logging.getLogger(__name__)

DEFAULT_PORT = 8080
DEFAULT_URL = "http://localhost"

Expand Down Expand Up @@ -174,7 +177,7 @@ def deploy_to_python(
}

if dev_mode:
print("Staging Protocol Changes...")
logger.debug("Staging Protocol Changes...")
stage_protocol_changes()

kwargs = {
Expand Down
68 changes: 47 additions & 21 deletions packages/syft/src/syft/service/code/user_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,16 +97,18 @@
from ..policy.policy import partition_by_node
from ..policy.policy_service import PolicyService
from ..response import SyftError
from ..response import SyftException
from ..response import SyftInfo
from ..response import SyftNotReady
from ..response import SyftSuccess
from ..response import SyftWarning
from ..service import ServiceConfigRegistry
from ..user.user import UserView
from ..user.user_roles import ServiceRole
from .code_parse import GlobalsVisitor
from .code_parse import LaunchJobVisitor
from .unparse import unparse
from .utils import check_for_global_vars
from .utils import parse_code
from .utils import submit_subjobs_code

if TYPE_CHECKING:
Expand Down Expand Up @@ -1038,13 +1040,6 @@ def __call__(
def local_call(self, *args: Any, **kwargs: Any) -> Any:
# only run this on the client side
if self.local_function:
source = dedent(inspect.getsource(self.local_function))
tree = ast.parse(source)

# check there are no globals
v = GlobalsVisitor()
v.visit(tree)

# filtered_args = []
filtered_kwargs = {}
# for arg in args:
Expand Down Expand Up @@ -1256,15 +1251,25 @@ def syft_function(
else:
output_policy_type = type(output_policy)

def decorator(f: Any) -> SubmitUserCode:
def decorator(f: Any) -> SubmitUserCode | SyftError:
try:
code = dedent(inspect.getsource(f))

if name is not None:
fname = name
code = replace_func_name(code, fname)
else:
fname = f.__name__

input_kwargs = f.__code__.co_varnames[: f.__code__.co_argcount]

parse_user_code(
raw_code=code,
func_name=fname,
original_func_name=f.__name__,
function_input_kwargs=input_kwargs,
)

res = SubmitUserCode(
code=code,
func_name=fname,
Expand All @@ -1274,7 +1279,7 @@ def decorator(f: Any) -> SubmitUserCode:
output_policy_type=output_policy_type,
output_policy_init_kwargs=getattr(output_policy, "init_kwargs", {}),
local_function=f,
input_kwargs=f.__code__.co_varnames[: f.__code__.co_argcount],
input_kwargs=input_kwargs,
worker_pool_name=worker_pool_name,
)

Expand All @@ -1287,6 +1292,11 @@ def decorator(f: Any) -> SubmitUserCode:
display(err)
return err

except SyftException as se:
err = SyftError(message=f"Error when parsing the code: {se}")
display(err)
return err

if share_results_with_owners and res.output_policy_init_kwargs is not None:
res.output_policy_init_kwargs["output_readers"] = (
res.input_owner_verify_keys
Expand Down Expand Up @@ -1318,26 +1328,23 @@ def generate_unique_func_name(context: TransformContext) -> TransformContext:
return context


def process_code(
context: TransformContext,
def parse_user_code(
raw_code: str,
func_name: str,
original_func_name: str,
policy_input_kwargs: list[str],
function_input_kwargs: list[str],
) -> str:
tree = ast.parse(raw_code)

# check there are no globals
v = GlobalsVisitor()
v.visit(tree)
# parse the code, check for syntax errors and if there are global variables
try:
tree: ast.Module = parse_code(raw_code=raw_code)
check_for_global_vars(code_tree=tree)
except SyftException as e:
raise SyftException(f"{e}")

f = tree.body[0]
f: ast.stmt = tree.body[0]
f.decorator_list = []

call_args = function_input_kwargs
if "domain" in function_input_kwargs and context.output is not None:
context.output["uses_domain"] = True
call_stmt_keywords = [ast.keyword(arg=i, value=[ast.Name(id=i)]) for i in call_args]
call_stmt = ast.Assign(
targets=[ast.Name(id="result")],
Expand All @@ -1362,6 +1369,25 @@ def process_code(
return unparse(wrapper_function)


def process_code(
khoaguin marked this conversation as resolved.
Show resolved Hide resolved
context: TransformContext,
raw_code: str,
func_name: str,
original_func_name: str,
policy_input_kwargs: list[str],
function_input_kwargs: list[str],
) -> str:
if "domain" in function_input_kwargs and context.output is not None:
context.output["uses_domain"] = True

return parse_user_code(
raw_code=raw_code,
func_name=func_name,
original_func_name=original_func_name,
function_input_kwargs=function_input_kwargs,
)


def new_check_code(context: TransformContext) -> TransformContext:
# TODO: remove this tech debt hack
if context.output is None:
Expand Down
28 changes: 28 additions & 0 deletions packages/syft/src/syft/service/code/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from IPython import get_ipython

# relative
from ..response import SyftException
from ..response import SyftWarning
from .code_parse import GlobalsVisitor
from .code_parse import LaunchJobVisitor


Expand Down Expand Up @@ -36,3 +39,28 @@ def submit_subjobs_code(submit_user_code, ep_client) -> None: # type: ignore
# fetch
if specs["type_name"] == "SubmitUserCode":
ep_client.code.submit(ipython.ev(call))


def check_for_global_vars(code_tree: ast.Module) -> GlobalsVisitor | SyftWarning:
"""
Check that the code does not contain any global variables
"""
v = GlobalsVisitor()
try:
v.visit(code_tree)
except Exception:
raise SyftException(
"Your code contains (a) global variable(s), which is not allowed"
)
return v


def parse_code(raw_code: str) -> ast.Module | SyftWarning:
"""
Parse the code into an AST tree and return a warning if there are syntax errors
"""
try:
tree = ast.parse(raw_code)
except SyntaxError as e:
raise SyftException(f"Your code contains syntax error: {e}")
return tree
20 changes: 20 additions & 0 deletions packages/syft/tests/syft/users/user_code_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,26 @@ def valid_name_2():
client.code.submit(valid_name_2)


def test_submit_code_with_global_var(guest_client: DomainClient) -> None:
@sy.syft_function(
input_policy=sy.ExactMatch(), output_policy=sy.SingleExecutionExactOutput()
)
def mock_syft_func_with_global():
global x
return x

res = guest_client.code.submit(mock_syft_func_with_global)
assert isinstance(res, SyftError)

@sy.syft_function_single_use()
def mock_syft_func_single_use_with_global():
global x
return x

res = guest_client.code.submit(mock_syft_func_single_use_with_global)
assert isinstance(res, SyftError)


def test_request_existing_usercodesubmit(worker) -> None:
root_domain_client = worker.root_client

Expand Down
Loading