diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index ef65d8e8987..b8df87fc619 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -1014,9 +1014,6 @@ def add_output_policy_ids(cls, values: Any) -> Any: values["id"] = UID() return values - def get_code_hash(self) -> str: - return hashlib.sha256(self.code.encode()).hexdigest() - @property def kwargs(self) -> dict[Any, Any] | None: return self.input_policy_init_kwargs @@ -1172,6 +1169,11 @@ def input_owner_verify_keys(self) -> list[str] | None: return None +def get_code_hash(code: str, user_verify_key: SyftVerifyKey) -> str: + full_str = f"{code}{user_verify_key}" + return hashlib.sha256(full_str.encode()).hexdigest() + + def is_valid_usercode_name(func_name: str) -> Result[Any, str]: if len(func_name) == 0: return Err("Function name cannot be empty") @@ -1448,7 +1450,7 @@ def hash_code(context: TransformContext) -> TransformContext: code = context.output["code"] context.output["raw_code"] = code - code_hash = context.obj.get_code_hash() + code_hash = get_code_hash(code, context.credentials) context.output["code_hash"] = code_hash return context diff --git a/packages/syft/src/syft/service/code/user_code_service.py b/packages/syft/src/syft/service/code/user_code_service.py index cd947cf325a..41aba58575b 100644 --- a/packages/syft/src/syft/service/code/user_code_service.py +++ b/packages/syft/src/syft/service/code/user_code_service.py @@ -43,6 +43,7 @@ from .user_code import UserCode from .user_code import UserCodeStatus from .user_code import UserCodeUpdate +from .user_code import get_code_hash from .user_code import load_approved_policy_code from .user_code_stash import UserCodeStash @@ -89,7 +90,7 @@ def _submit( """ existing_code_or_err = self.stash.get_by_code_hash( context.credentials, - code_hash=submit_code.get_code_hash(), + code_hash=get_code_hash(submit_code.code, context.credentials), ) if existing_code_or_err.is_err(): diff --git a/packages/syft/tests/syft/users/user_code_test.py b/packages/syft/tests/syft/users/user_code_test.py index 3e9cb975580..69f69ab76d4 100644 --- a/packages/syft/tests/syft/users/user_code_test.py +++ b/packages/syft/tests/syft/users/user_code_test.py @@ -457,3 +457,48 @@ def my_func(): assert len(ds_client.code.get_all()) == 1 assert len(ds_client.requests.get_all()) == 1 + + +def test_submit_existing_code_different_user(worker): + root_domain_client = worker.root_client + + root_domain_client.register( + name="data-scientist", + email="test_user@openmined.org", + password="0000", + password_verify="0000", + ) + ds_client_1 = root_domain_client.login( + email="test_user@openmined.org", + password="0000", + ) + + root_domain_client.register( + name="data-scientist-2", + email="test_user_2@openmined.org", + password="0000", + password_verify="0000", + ) + ds_client_2 = root_domain_client.login( + email="test_user_2@openmined.org", + password="0000", + ) + + @sy.syft_function_single_use() + def my_func(): + return 42 + + res_submit = ds_client_1.api.services.code.submit(my_func) + assert isinstance(res_submit, SyftSuccess) + res_resubmit = ds_client_1.api.services.code.submit(my_func) + assert isinstance(res_resubmit, SyftError) + + # Resubmit with different user + res_submit = ds_client_2.api.services.code.submit(my_func) + assert isinstance(res_submit, SyftSuccess) + res_resubmit = ds_client_2.api.services.code.submit(my_func) + assert isinstance(res_resubmit, SyftError) + + assert len(ds_client_1.code.get_all()) == 1 + assert len(ds_client_2.code.get_all()) == 1 + assert len(root_domain_client.code.get_all()) == 2