Skip to content

Commit

Permalink
Merge pull request #8976 from OpenMined/eelco/codehash
Browse files Browse the repository at this point in the history
Make code hash specific to user
  • Loading branch information
eelcovdw authored Jun 26, 2024
2 parents a6ca826 + ee17d32 commit 71be405
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 5 deletions.
10 changes: 6 additions & 4 deletions packages/syft/src/syft/service/code/user_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,9 +1014,6 @@ def add_output_policy_ids(cls, values: Any) -> Any:
values["id"] = UID()
return values

def get_code_hash(self) -> str:
return hashlib.sha256(self.code.encode()).hexdigest()

@property
def kwargs(self) -> dict[Any, Any] | None:
return self.input_policy_init_kwargs
Expand Down Expand Up @@ -1172,6 +1169,11 @@ def input_owner_verify_keys(self) -> list[str] | None:
return None


def get_code_hash(code: str, user_verify_key: SyftVerifyKey) -> str:
full_str = f"{code}{user_verify_key}"
return hashlib.sha256(full_str.encode()).hexdigest()


def is_valid_usercode_name(func_name: str) -> Result[Any, str]:
if len(func_name) == 0:
return Err("Function name cannot be empty")
Expand Down Expand Up @@ -1448,7 +1450,7 @@ def hash_code(context: TransformContext) -> TransformContext:

code = context.output["code"]
context.output["raw_code"] = code
code_hash = context.obj.get_code_hash()
code_hash = get_code_hash(code, context.credentials)
context.output["code_hash"] = code_hash

return context
Expand Down
3 changes: 2 additions & 1 deletion packages/syft/src/syft/service/code/user_code_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from .user_code import UserCode
from .user_code import UserCodeStatus
from .user_code import UserCodeUpdate
from .user_code import get_code_hash
from .user_code import load_approved_policy_code
from .user_code_stash import UserCodeStash

Expand Down Expand Up @@ -89,7 +90,7 @@ def _submit(
"""
existing_code_or_err = self.stash.get_by_code_hash(
context.credentials,
code_hash=submit_code.get_code_hash(),
code_hash=get_code_hash(submit_code.code, context.credentials),
)

if existing_code_or_err.is_err():
Expand Down
45 changes: 45 additions & 0 deletions packages/syft/tests/syft/users/user_code_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,3 +457,48 @@ def my_func():

assert len(ds_client.code.get_all()) == 1
assert len(ds_client.requests.get_all()) == 1


def test_submit_existing_code_different_user(worker):
root_domain_client = worker.root_client

root_domain_client.register(
name="data-scientist",
email="[email protected]",
password="0000",
password_verify="0000",
)
ds_client_1 = root_domain_client.login(
email="[email protected]",
password="0000",
)

root_domain_client.register(
name="data-scientist-2",
email="[email protected]",
password="0000",
password_verify="0000",
)
ds_client_2 = root_domain_client.login(
email="[email protected]",
password="0000",
)

@sy.syft_function_single_use()
def my_func():
return 42

res_submit = ds_client_1.api.services.code.submit(my_func)
assert isinstance(res_submit, SyftSuccess)
res_resubmit = ds_client_1.api.services.code.submit(my_func)
assert isinstance(res_resubmit, SyftError)

# Resubmit with different user
res_submit = ds_client_2.api.services.code.submit(my_func)
assert isinstance(res_submit, SyftSuccess)
res_resubmit = ds_client_2.api.services.code.submit(my_func)
assert isinstance(res_resubmit, SyftError)

assert len(ds_client_1.code.get_all()) == 1
assert len(ds_client_2.code.get_all()) == 1
assert len(root_domain_client.code.get_all()) == 2

0 comments on commit 71be405

Please sign in to comment.