From 5f3c7694002f8b2b67f05a15a3f183a71067fcb5 Mon Sep 17 00:00:00 2001 From: khoaguin Date: Wed, 26 Jun 2024 08:25:38 +0700 Subject: [PATCH 1/7] - process code on the client side before submitting - add test for when the submit code contains `global` --- .../syft/src/syft/service/code/user_code.py | 18 ++++++++++++++++++ .../syft/tests/syft/users/user_code_test.py | 17 +++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index b71f5aa4cc6..16fb099df4c 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -1188,6 +1188,12 @@ def syft_function( def decorator(f: Any) -> SubmitUserCode: try: code = dedent(inspect.getsource(f)) + + res = process_code_client(code) + if isinstance(res, SyftError): + display(res) + return res + if name is not None: fname = name code = replace_func_name(code, fname) @@ -1233,6 +1239,18 @@ def decorator(f: Any) -> SubmitUserCode: return decorator +def process_code_client( + raw_code: str, +): + tree = ast.parse(raw_code) + # check there are no globals + v = GlobalsVisitor() + try: + v.visit(tree) + except Exception as e: + return SyftError(message=f"Failed to process code. {e}") + + def generate_unique_func_name(context: TransformContext) -> TransformContext: if context.output is not None: code_hash = context.output["code_hash"] diff --git a/packages/syft/tests/syft/users/user_code_test.py b/packages/syft/tests/syft/users/user_code_test.py index 333d246e37f..68fbf7922ea 100644 --- a/packages/syft/tests/syft/users/user_code_test.py +++ b/packages/syft/tests/syft/users/user_code_test.py @@ -366,3 +366,20 @@ def valid_name_2(): valid_name_2.func_name = "get_all" with pytest.raises(ValidationError): client.code.submit(valid_name_2) + + +def test_submit_code_with_global_var(guest_client: DomainClient) -> None: + @sy.syft_function( + input_policy=sy.ExactMatch(), output_policy=sy.SingleExecutionExactOutput() + ) + def mock_syft_func_with_global(): + global x + + def example_function(): + return 1 + x + + return example_function() + + res = guest_client.code.submit(mock_syft_func_with_global) + assert isinstance(res, SyftError) + assert "No Globals allowed!" in res.message From 412efc8e9cc155a7b7bf27f011e6a94ef82ff1c1 Mon Sep 17 00:00:00 2001 From: dk Date: Wed, 26 Jun 2024 11:16:39 +0700 Subject: [PATCH 2/7] [syft/user_code] add a `global` keyword check locally for client before submitting code --- packages/syft/src/syft/client/api.py | 5 +++-- packages/syft/src/syft/orchestra.py | 5 ++++- packages/syft/src/syft/service/code/user_code.py | 15 ++++++++------- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/packages/syft/src/syft/client/api.py b/packages/syft/src/syft/client/api.py index 9c8b244b129..8ebfac03c0b 100644 --- a/packages/syft/src/syft/client/api.py +++ b/packages/syft/src/syft/client/api.py @@ -23,6 +23,7 @@ from pydantic import TypeAdapter from result import OkErr from result import Result +from typeguard import TypeCheckError from typeguard import check_type # relative @@ -1385,7 +1386,7 @@ def validate_callable_args_and_kwargs( break # only need one to match else: check_type(arg, t) # raises Exception - except TypeError: + except TypeCheckError: t_arg = type(arg) if ( autoreload_enabled() @@ -1396,7 +1397,7 @@ def validate_callable_args_and_kwargs( pass else: _type_str = getattr(t, "__name__", str(t)) - msg = f"Arg: {arg} must be {_type_str} not {type(arg).__name__}" + msg = f"Arg is `{arg}`. \nIt must be of type `{_type_str}`, not `{type(arg).__name__}`" if msg: return SyftError(message=msg) diff --git a/packages/syft/src/syft/orchestra.py b/packages/syft/src/syft/orchestra.py index 1a08f594aa2..08672657762 100644 --- a/packages/syft/src/syft/orchestra.py +++ b/packages/syft/src/syft/orchestra.py @@ -8,6 +8,7 @@ from enum import Enum import getpass import inspect +import logging import os import sys from typing import Any @@ -24,6 +25,8 @@ from .service.response import SyftError from .util.util import get_random_available_port +logger = logging.getLogger(__name__) + DEFAULT_PORT = 8080 DEFAULT_URL = "http://localhost" @@ -174,7 +177,7 @@ def deploy_to_python( } if dev_mode: - print("Staging Protocol Changes...") + logger.debug("Staging Protocol Changes...") stage_protocol_changes() kwargs = { diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index 0c461a8ae84..fac45223721 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -1192,14 +1192,14 @@ def syft_function( else: output_policy_type = type(output_policy) - def decorator(f: Any) -> SubmitUserCode: + def decorator(f: Any) -> SubmitUserCode | SyftError: try: code = dedent(inspect.getsource(f)) - res = process_code_client(code) - if isinstance(res, SyftError): - display(res) - return res + global_check = _check_global(code) + if isinstance(global_check, SyftError): + display(global_check) + return global_check if name is not None: fname = name @@ -1246,9 +1246,9 @@ def decorator(f: Any) -> SubmitUserCode: return decorator -def process_code_client( +def _check_global( raw_code: str, -): +) -> None | SyftError: tree = ast.parse(raw_code) # check there are no globals v = GlobalsVisitor() @@ -1256,6 +1256,7 @@ def process_code_client( v.visit(tree) except Exception as e: return SyftError(message=f"Failed to process code. {e}") + return None def generate_unique_func_name(context: TransformContext) -> TransformContext: From be17b3bd0e5378165da53efe653c2d2dbb5857cb Mon Sep 17 00:00:00 2001 From: dk Date: Thu, 27 Jun 2024 11:22:00 +0700 Subject: [PATCH 3/7] [syft/user_code] checking global in `syft_function` and `SubmitUserCode.local_call` --- .../syft/src/syft/service/code/user_code.py | 90 ++++++++++++++++--- 1 file changed, 78 insertions(+), 12 deletions(-) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index fac45223721..ce82a26121f 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -980,11 +980,12 @@ def local_call(self, *args: Any, **kwargs: Any) -> Any: # only run this on the client side if self.local_function: source = dedent(inspect.getsource(self.local_function)) - tree = ast.parse(source) - # check there are no globals - v = GlobalsVisitor() - v.visit(tree) + v: GlobalsVisitor | SyftWarning = _check_global(raw_code=source) + if isinstance(v, SyftWarning): # the code contains "global" keyword + return SyftError( + message=f"Error when running function locally: {v.message}" + ) # filtered_args = [] filtered_kwargs = {} @@ -1196,10 +1197,19 @@ def decorator(f: Any) -> SubmitUserCode | SyftError: try: code = dedent(inspect.getsource(f)) + # check that there are no globals global_check = _check_global(code) - if isinstance(global_check, SyftError): + if isinstance(global_check, SyftWarning): display(global_check) - return global_check + # err = SyftError(message=global_check.message) + # display(err) + # return err + + lint_issues = _lint_code(code) + lint_warning_msg = "" + for issue in lint_issues: + lint_warning_msg += f"{issue}\n\t" + display(SyftWarning(message=lint_warning_msg)) if name is not None: fname = name @@ -1246,17 +1256,73 @@ def decorator(f: Any) -> SubmitUserCode | SyftError: return decorator -def _check_global( - raw_code: str, -) -> None | SyftError: +def _check_global(raw_code: str) -> GlobalsVisitor | SyftWarning: tree = ast.parse(raw_code) # check there are no globals v = GlobalsVisitor() try: v.visit(tree) - except Exception as e: - return SyftError(message=f"Failed to process code. {e}") - return None + except Exception: + return SyftWarning( + message="Your code contains (a) global variable(s), which is not allowed" + ) + return v + + +# Define a linter function +def _lint_code(code: str) -> list: + # Parse the code into an AST + tree = ast.parse(code) + + # Initialize a list to collect linting issues + issues = [] + + # Define a visitor class to walk the AST + class CodeVisitor(ast.NodeVisitor): + def __init__(self) -> None: + self.globals: set = set() + self.defined_names: set = set() + self.current_scope_defined_names: set = set() + + def visit_Global(self, node: Any) -> None: + # Collect global variable names + for name in node.names: + self.globals.add(name) + self.generic_visit(node) + + def visit_FunctionDef(self, node: Any) -> None: + # Collect defined function names and handle function scope + self.defined_names.add(node.name) + self.current_scope_defined_names = set() # New scope + self.generic_visit(node) + self.current_scope_defined_names.clear() # Clear scope after visiting + + def visit_Assign(self, node: Any) -> None: + # Collect assigned variable names + for target in node.targets: + if isinstance(target, ast.Name): + self.current_scope_defined_names.add(target.id) + self.generic_visit(node) + + def visit_Name(self, node: Any) -> None: + # Check if variables are used before being defined + if isinstance(node.ctx, ast.Load): + if ( + node.id not in self.current_scope_defined_names + and node.id not in self.defined_names + and node.id not in self.globals + ): + issues.append( + f"Variable '{node.id}' used at line {node.lineno} before being defined." + ) + self.generic_visit(node) + + # Create a visitor instance and visit the AST + visitor = CodeVisitor() + visitor.visit(tree) + + # Return the collected issues + return issues def generate_unique_func_name(context: TransformContext) -> TransformContext: From 9e8aaad09da0d2eab0567455a464c6594c17362a Mon Sep 17 00:00:00 2001 From: dk Date: Fri, 28 Jun 2024 11:13:04 +0700 Subject: [PATCH 4/7] [syft/user_code] remove unused variables linter - add a catch for syntax error when parsing the submitted code --- .../syft/src/syft/service/code/user_code.py | 76 ++----------------- 1 file changed, 8 insertions(+), 68 deletions(-) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index 31560cb9fb1..ad51750ed19 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -96,6 +96,7 @@ from ..policy.policy import partition_by_node from ..policy.policy_service import PolicyService from ..response import SyftError +from ..response import SyftException from ..response import SyftInfo from ..response import SyftNotReady from ..response import SyftSuccess @@ -1267,15 +1268,6 @@ def decorator(f: Any) -> SubmitUserCode | SyftError: global_check = _check_global(code) if isinstance(global_check, SyftWarning): display(global_check) - # err = SyftError(message=global_check.message) - # display(err) - # return err - - lint_issues = _lint_code(code) - lint_warning_msg = "" - for issue in lint_issues: - lint_warning_msg += f"{issue}\n\t" - display(SyftWarning(message=lint_warning_msg)) if name is not None: fname = name @@ -1335,62 +1327,6 @@ def _check_global(raw_code: str) -> GlobalsVisitor | SyftWarning: return v -# Define a linter function -def _lint_code(code: str) -> list: - # Parse the code into an AST - tree = ast.parse(code) - - # Initialize a list to collect linting issues - issues = [] - - # Define a visitor class to walk the AST - class CodeVisitor(ast.NodeVisitor): - def __init__(self) -> None: - self.globals: set = set() - self.defined_names: set = set() - self.current_scope_defined_names: set = set() - - def visit_Global(self, node: Any) -> None: - # Collect global variable names - for name in node.names: - self.globals.add(name) - self.generic_visit(node) - - def visit_FunctionDef(self, node: Any) -> None: - # Collect defined function names and handle function scope - self.defined_names.add(node.name) - self.current_scope_defined_names = set() # New scope - self.generic_visit(node) - self.current_scope_defined_names.clear() # Clear scope after visiting - - def visit_Assign(self, node: Any) -> None: - # Collect assigned variable names - for target in node.targets: - if isinstance(target, ast.Name): - self.current_scope_defined_names.add(target.id) - self.generic_visit(node) - - def visit_Name(self, node: Any) -> None: - # Check if variables are used before being defined - if isinstance(node.ctx, ast.Load): - if ( - node.id not in self.current_scope_defined_names - and node.id not in self.defined_names - and node.id not in self.globals - ): - issues.append( - f"Variable '{node.id}' used at line {node.lineno} before being defined." - ) - self.generic_visit(node) - - # Create a visitor instance and visit the AST - visitor = CodeVisitor() - visitor.visit(tree) - - # Return the collected issues - return issues - - def generate_unique_func_name(context: TransformContext) -> TransformContext: if context.output is not None: code_hash = context.output["code_hash"] @@ -1413,11 +1349,15 @@ def process_code( policy_input_kwargs: list[str], function_input_kwargs: list[str], ) -> str: - tree = ast.parse(raw_code) + try: + tree = ast.parse(raw_code) + except SyntaxError as e: + raise SyftException(f"Syntax error in code: {e}") # check there are no globals - v = GlobalsVisitor() - v.visit(tree) + v = _check_global(raw_code=tree) + if isinstance(v, SyftWarning): + raise SyftException(message=f"{v.message}") f = tree.body[0] f.decorator_list = [] From e65c81c20e0678f5ff3cf84132be008f4cfced06 Mon Sep 17 00:00:00 2001 From: dk Date: Mon, 1 Jul 2024 12:11:39 +0700 Subject: [PATCH 5/7] [syft/user_code] separate code parsing out of the 'global' keyworkd check function - add some type annotations - simplify unit test for the case --- .../syft/src/syft/service/code/user_code.py | 49 +++++++++++++------ .../syft/tests/syft/users/user_code_test.py | 13 +++-- 2 files changed, 42 insertions(+), 20 deletions(-) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index ad51750ed19..704f130b82d 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -1042,9 +1042,13 @@ def local_call(self, *args: Any, **kwargs: Any) -> Any: # only run this on the client side if self.local_function: source = dedent(inspect.getsource(self.local_function)) - - v: GlobalsVisitor | SyftWarning = _check_global(raw_code=source) - if isinstance(v, SyftWarning): # the code contains "global" keyword + tree: ast.Module | SyftWarning = _parse_code(source) + if isinstance(tree, SyftWarning): + return SyftError( + message=f"Error when running function locally: {tree.message}" + ) + v: GlobalsVisitor | SyftWarning = _check_global(code_tree=tree) + if isinstance(v, SyftWarning): return SyftError( message=f"Error when running function locally: {v.message}" ) @@ -1264,8 +1268,12 @@ def decorator(f: Any) -> SubmitUserCode | SyftError: try: code = dedent(inspect.getsource(f)) + tree: ast.Module | SyftWarning = _parse_code(raw_code=code) + if isinstance(tree, SyftWarning): + display(tree) + # check that there are no globals - global_check = _check_global(code) + global_check: GlobalsVisitor | SyftWarning = _check_global(code_tree=tree) if isinstance(global_check, SyftWarning): display(global_check) @@ -1314,12 +1322,13 @@ def decorator(f: Any) -> SubmitUserCode | SyftError: return decorator -def _check_global(raw_code: str) -> GlobalsVisitor | SyftWarning: - tree = ast.parse(raw_code) - # check there are no globals +def _check_global(code_tree: ast.Module) -> GlobalsVisitor | SyftWarning: + """ + Check that the code does not contain any global variables + """ v = GlobalsVisitor() try: - v.visit(tree) + v.visit(code_tree) except Exception: return SyftWarning( message="Your code contains (a) global variable(s), which is not allowed" @@ -1327,6 +1336,17 @@ def _check_global(raw_code: str) -> GlobalsVisitor | SyftWarning: return v +def _parse_code(raw_code: str) -> ast.Module | SyftWarning: + """ + Parse the code into an AST tree and return a warning if there are syntax errors + """ + try: + tree = ast.parse(raw_code) + except SyntaxError as e: + return SyftWarning(message=f"Your code contains syntax error: {e}") + return tree + + def generate_unique_func_name(context: TransformContext) -> TransformContext: if context.output is not None: code_hash = context.output["code_hash"] @@ -1349,17 +1369,16 @@ def process_code( policy_input_kwargs: list[str], function_input_kwargs: list[str], ) -> str: - try: - tree = ast.parse(raw_code) - except SyntaxError as e: - raise SyftException(f"Syntax error in code: {e}") + tree: ast.Module | SyftWarning = _parse_code(raw_code=raw_code) + if isinstance(tree, SyftWarning): + raise SyftException(f"{tree.message}") # check there are no globals - v = _check_global(raw_code=tree) + v: GlobalsVisitor | SyftWarning = _check_global(code_tree=tree) if isinstance(v, SyftWarning): - raise SyftException(message=f"{v.message}") + raise SyftException(f"{v.message}") - f = tree.body[0] + f: ast.stmt = tree.body[0] f.decorator_list = [] call_args = function_input_kwargs diff --git a/packages/syft/tests/syft/users/user_code_test.py b/packages/syft/tests/syft/users/user_code_test.py index 13caf8f9b07..5d7504c61fc 100644 --- a/packages/syft/tests/syft/users/user_code_test.py +++ b/packages/syft/tests/syft/users/user_code_test.py @@ -401,15 +401,18 @@ def test_submit_code_with_global_var(guest_client: DomainClient) -> None: ) def mock_syft_func_with_global(): global x + return x - def example_function(): - return 1 + x + res = guest_client.code.submit(mock_syft_func_with_global) + assert isinstance(res, SyftError) - return example_function() + @sy.syft_function_single_use() + def mock_syft_func_single_use_with_global(): + global x + return x - res = guest_client.code.submit(mock_syft_func_with_global) + res = guest_client.code.submit(mock_syft_func_single_use_with_global) assert isinstance(res, SyftError) - assert "No Globals allowed!" in res.message def test_request_existing_usercodesubmit(worker) -> None: From a589ad3c0b4ad66ff7d889afd5e477a31b109307 Mon Sep 17 00:00:00 2001 From: khoaguin Date: Mon, 1 Jul 2024 14:09:14 +0700 Subject: [PATCH 6/7] [syft/user_code] try to parse and unparse user code both on the client and server side - remove parsing user code in local call since it was already parsed before Co-authored-by: Shubham Gupta --- .../syft/src/syft/service/code/user_code.py | 110 +++++++++--------- 1 file changed, 58 insertions(+), 52 deletions(-) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index 704f130b82d..cad1255553f 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -1041,18 +1041,6 @@ def __call__( def local_call(self, *args: Any, **kwargs: Any) -> Any: # only run this on the client side if self.local_function: - source = dedent(inspect.getsource(self.local_function)) - tree: ast.Module | SyftWarning = _parse_code(source) - if isinstance(tree, SyftWarning): - return SyftError( - message=f"Error when running function locally: {tree.message}" - ) - v: GlobalsVisitor | SyftWarning = _check_global(code_tree=tree) - if isinstance(v, SyftWarning): - return SyftError( - message=f"Error when running function locally: {v.message}" - ) - # filtered_args = [] filtered_kwargs = {} # for arg in args: @@ -1268,21 +1256,21 @@ def decorator(f: Any) -> SubmitUserCode | SyftError: try: code = dedent(inspect.getsource(f)) - tree: ast.Module | SyftWarning = _parse_code(raw_code=code) - if isinstance(tree, SyftWarning): - display(tree) - - # check that there are no globals - global_check: GlobalsVisitor | SyftWarning = _check_global(code_tree=tree) - if isinstance(global_check, SyftWarning): - display(global_check) - if name is not None: fname = name code = replace_func_name(code, fname) else: fname = f.__name__ + input_kwargs = f.__code__.co_varnames[: f.__code__.co_argcount] + + parse_user_code( + raw_code=code, + func_name=fname, + original_func_name=f.__name__, + function_input_kwargs=input_kwargs, + ) + res = SubmitUserCode( code=code, func_name=fname, @@ -1292,7 +1280,7 @@ def decorator(f: Any) -> SubmitUserCode | SyftError: output_policy_type=output_policy_type, output_policy_init_kwargs=getattr(output_policy, "init_kwargs", {}), local_function=f, - input_kwargs=f.__code__.co_varnames[: f.__code__.co_argcount], + input_kwargs=input_kwargs, worker_pool_name=worker_pool_name, ) @@ -1305,6 +1293,11 @@ def decorator(f: Any) -> SubmitUserCode | SyftError: display(err) return err + except SyftException as se: + err = SyftError(message=f"Error when parsing the code: {se}") + display(err) + return err + if share_results_with_owners and res.output_policy_init_kwargs is not None: res.output_policy_init_kwargs["output_readers"] = ( res.input_owner_verify_keys @@ -1322,6 +1315,20 @@ def decorator(f: Any) -> SubmitUserCode | SyftError: return decorator +def generate_unique_func_name(context: TransformContext) -> TransformContext: + if context.output is not None: + code_hash = context.output["code_hash"] + service_func_name = context.output["func_name"] + context.output["service_func_name"] = service_func_name + func_name = f"user_func_{service_func_name}_{context.credentials}_{code_hash}" + user_unique_func_name = ( + f"user_func_{service_func_name}_{context.credentials}_{time.time()}" + ) + context.output["unique_func_name"] = func_name + context.output["user_unique_func_name"] = user_unique_func_name + return context + + def _check_global(code_tree: ast.Module) -> GlobalsVisitor | SyftWarning: """ Check that the code does not contain any global variables @@ -1330,8 +1337,8 @@ def _check_global(code_tree: ast.Module) -> GlobalsVisitor | SyftWarning: try: v.visit(code_tree) except Exception: - return SyftWarning( - message="Your code contains (a) global variable(s), which is not allowed" + raise SyftException( + "Your code contains (a) global variable(s), which is not allowed" ) return v @@ -1343,47 +1350,27 @@ def _parse_code(raw_code: str) -> ast.Module | SyftWarning: try: tree = ast.parse(raw_code) except SyntaxError as e: - return SyftWarning(message=f"Your code contains syntax error: {e}") + raise SyftException(f"Your code contains syntax error: {e}") return tree -def generate_unique_func_name(context: TransformContext) -> TransformContext: - if context.output is not None: - code_hash = context.output["code_hash"] - service_func_name = context.output["func_name"] - context.output["service_func_name"] = service_func_name - func_name = f"user_func_{service_func_name}_{context.credentials}_{code_hash}" - user_unique_func_name = ( - f"user_func_{service_func_name}_{context.credentials}_{time.time()}" - ) - context.output["unique_func_name"] = func_name - context.output["user_unique_func_name"] = user_unique_func_name - return context - - -def process_code( - context: TransformContext, +def parse_user_code( raw_code: str, func_name: str, original_func_name: str, - policy_input_kwargs: list[str], function_input_kwargs: list[str], ) -> str: - tree: ast.Module | SyftWarning = _parse_code(raw_code=raw_code) - if isinstance(tree, SyftWarning): - raise SyftException(f"{tree.message}") - - # check there are no globals - v: GlobalsVisitor | SyftWarning = _check_global(code_tree=tree) - if isinstance(v, SyftWarning): - raise SyftException(f"{v.message}") + # parse the code, check for syntax errors and if there are global variables + try: + tree: ast.Module = _parse_code(raw_code=raw_code) + _check_global(code_tree=tree) + except SyftException as e: + raise SyftException(f"{e}") f: ast.stmt = tree.body[0] f.decorator_list = [] call_args = function_input_kwargs - if "domain" in function_input_kwargs and context.output is not None: - context.output["uses_domain"] = True call_stmt_keywords = [ast.keyword(arg=i, value=[ast.Name(id=i)]) for i in call_args] call_stmt = ast.Assign( targets=[ast.Name(id="result")], @@ -1408,6 +1395,25 @@ def process_code( return unparse(wrapper_function) +def process_code( + context: TransformContext, + raw_code: str, + func_name: str, + original_func_name: str, + policy_input_kwargs: list[str], + function_input_kwargs: list[str], +) -> str: + if "domain" in function_input_kwargs and context.output is not None: + context.output["uses_domain"] = True + + return parse_user_code( + raw_code=raw_code, + func_name=func_name, + original_func_name=original_func_name, + function_input_kwargs=function_input_kwargs, + ) + + def new_check_code(context: TransformContext) -> TransformContext: # TODO: remove this tech debt hack if context.output is None: From 73625c877e9def0c1793d00248901748bb2e25d2 Mon Sep 17 00:00:00 2001 From: khoaguin Date: Wed, 3 Jul 2024 16:21:20 +0700 Subject: [PATCH 7/7] [syft/user_code] refactoring utility functions for parsing user code Co-authored-by: Shubham Gupta --- .../syft/src/syft/service/code/user_code.py | 32 +++---------------- packages/syft/src/syft/service/code/utils.py | 28 ++++++++++++++++ 2 files changed, 32 insertions(+), 28 deletions(-) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index 01114b49c0a..88b3d5109bd 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -105,9 +105,10 @@ from ..service import ServiceConfigRegistry from ..user.user import UserView from ..user.user_roles import ServiceRole -from .code_parse import GlobalsVisitor from .code_parse import LaunchJobVisitor from .unparse import unparse +from .utils import check_for_global_vars +from .utils import parse_code from .utils import submit_subjobs_code if TYPE_CHECKING: @@ -1327,31 +1328,6 @@ def generate_unique_func_name(context: TransformContext) -> TransformContext: return context -def _check_global(code_tree: ast.Module) -> GlobalsVisitor | SyftWarning: - """ - Check that the code does not contain any global variables - """ - v = GlobalsVisitor() - try: - v.visit(code_tree) - except Exception: - raise SyftException( - "Your code contains (a) global variable(s), which is not allowed" - ) - return v - - -def _parse_code(raw_code: str) -> ast.Module | SyftWarning: - """ - Parse the code into an AST tree and return a warning if there are syntax errors - """ - try: - tree = ast.parse(raw_code) - except SyntaxError as e: - raise SyftException(f"Your code contains syntax error: {e}") - return tree - - def parse_user_code( raw_code: str, func_name: str, @@ -1360,8 +1336,8 @@ def parse_user_code( ) -> str: # parse the code, check for syntax errors and if there are global variables try: - tree: ast.Module = _parse_code(raw_code=raw_code) - _check_global(code_tree=tree) + tree: ast.Module = parse_code(raw_code=raw_code) + check_for_global_vars(code_tree=tree) except SyftException as e: raise SyftException(f"{e}") diff --git a/packages/syft/src/syft/service/code/utils.py b/packages/syft/src/syft/service/code/utils.py index fccc5314c43..a3d59cbe161 100644 --- a/packages/syft/src/syft/service/code/utils.py +++ b/packages/syft/src/syft/service/code/utils.py @@ -6,6 +6,9 @@ from IPython import get_ipython # relative +from ..response import SyftException +from ..response import SyftWarning +from .code_parse import GlobalsVisitor from .code_parse import LaunchJobVisitor @@ -36,3 +39,28 @@ def submit_subjobs_code(submit_user_code, ep_client) -> None: # type: ignore # fetch if specs["type_name"] == "SubmitUserCode": ep_client.code.submit(ipython.ev(call)) + + +def check_for_global_vars(code_tree: ast.Module) -> GlobalsVisitor | SyftWarning: + """ + Check that the code does not contain any global variables + """ + v = GlobalsVisitor() + try: + v.visit(code_tree) + except Exception: + raise SyftException( + "Your code contains (a) global variable(s), which is not allowed" + ) + return v + + +def parse_code(raw_code: str) -> ast.Module | SyftWarning: + """ + Parse the code into an AST tree and return a warning if there are syntax errors + """ + try: + tree = ast.parse(raw_code) + except SyntaxError as e: + raise SyftException(f"Your code contains syntax error: {e}") + return tree