Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sync: ignore jobs created by custom endpoints #8836

Merged
merged 8 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions packages/syft/src/syft/node/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from ..service.job.job_service import JobService
from ..service.job.job_stash import Job
from ..service.job.job_stash import JobStash
from ..service.job.job_stash import JobType
from ..service.log.log_service import LogService
from ..service.metadata.metadata_service import MetadataService
from ..service.metadata.node_metadata import NodeMetadataV3
Expand Down Expand Up @@ -1288,10 +1289,10 @@ def add_api_endpoint_execution_to_queue(

action = Action.from_api_endpoint_execution()
return self.add_queueitem_to_queue(
queue_item,
credentials,
action,
None,
queue_item=queue_item,
credentials=credentials,
action=action,
job_type=JobType.TWINAPIJOB,
)

def get_worker_pool_ref_by_name(
Expand Down Expand Up @@ -1360,16 +1361,22 @@ def add_action_to_queue(
)

return self.add_queueitem_to_queue(
queue_item, credentials, action, parent_job_id, user_id
queue_item=queue_item,
credentials=credentials,
action=action,
parent_job_id=parent_job_id,
user_id=user_id,
)

def add_queueitem_to_queue(
self,
*,
queue_item: QueueItem,
credentials: SyftVerifyKey,
action: Action | None = None,
parent_job_id: UID | None = None,
user_id: UID | None = None,
job_type: JobType = JobType.JOB,
) -> Job | SyftError:
log_id = UID()
role = self.get_role_for_credentials(credentials=credentials)
Expand Down Expand Up @@ -1403,6 +1410,7 @@ def add_queueitem_to_queue(
parent_job_id=parent_job_id,
action=action,
requested_by=user_id,
job_type=job_type,
)

# 🟡 TODO 36: Needs distributed lock
Expand Down Expand Up @@ -1505,8 +1513,8 @@ def add_api_call_to_queue(
worker_pool=worker_pool_ref,
)
return self.add_queueitem_to_queue(
queue_item,
api_call.credentials,
queue_item=queue_item,
credentials=api_call.credentials,
action=None,
parent_job_id=parent_job_id,
)
Expand Down
2 changes: 1 addition & 1 deletion packages/syft/src/syft/protocol/protocol_version.json
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
},
"5": {
"version": 5,
"hash": "82ee08442b09797ed7a3710c31de633bb308b1d2215f51b58a3e01a4c201055d",
"hash": "95a2367bce2e4deb5f8c807561779876c1ec010dbf4d4f68abb526e4eca4487e",
"action": "add"
}
},
Expand Down
14 changes: 12 additions & 2 deletions packages/syft/src/syft/service/job/job_stash.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ...store.document_store import UIDPartitionKey
from ...types.datetime import DateTime
from ...types.syft_object import SYFT_OBJECT_VERSION_2
from ...types.syft_object import SYFT_OBJECT_VERSION_5
from ...types.syft_object import SYFT_OBJECT_VERSION_6
from ...types.syft_object import SyftObject
from ...types.syncable_object import SyncableSyftObject
from ...types.uid import UID
Expand Down Expand Up @@ -73,10 +73,19 @@ def center_content(text: Any) -> str:
return center_div


@serializable()
class JobType(str, Enum):
JOB = "job"
TWINAPIJOB = "twinapijob"

def __str__(self) -> str:
return self.value


@serializable()
class Job(SyncableSyftObject):
__canonical_name__ = "JobItem"
__version__ = SYFT_OBJECT_VERSION_5
__version__ = SYFT_OBJECT_VERSION_6

id: UID
node_uid: UID
Expand All @@ -94,6 +103,7 @@ class Job(SyncableSyftObject):
updated_at: DateTime | None = None
user_code_id: UID | None = None
requested_by: UID | None = None
job_type: JobType = JobType.JOB

__attr_searchable__ = ["parent_job_id", "job_worker_id", "status", "user_code_id"]
__repr_attrs__ = [
Expand Down
8 changes: 7 additions & 1 deletion packages/syft/src/syft/service/sync/diff_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from ..code.user_code import UserCode
from ..code.user_code import UserCodeStatusCollection
from ..job.job_stash import Job
from ..job.job_stash import JobType
from ..log.log import SyftLog
from ..output.output_service import ExecutionOutput
from ..request.request import Request
Expand Down Expand Up @@ -1288,7 +1289,12 @@ def hierarchies(
# TODO: Figure out nested user codes, do we even need that?

root_ids.append(diff.object_id) # type: ignore
elif isinstance(diff_obj, Job) and diff_obj.parent_job_id is None: # type: ignore
elif (
isinstance(diff_obj, Job) # type: ignore
and diff_obj.parent_job_id is None
# ignore Job objects created by TwinAPIEndpoint
and diff_obj.job_type != JobType.TWINAPIJOB
):
root_ids.append(diff.object_id) # type: ignore

for root_uid in root_ids:
Expand Down
2 changes: 2 additions & 0 deletions packages/syft/src/syft/types/syft_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,15 @@
SYFT_OBJECT_VERSION_3 = 3
SYFT_OBJECT_VERSION_4 = 4
SYFT_OBJECT_VERSION_5 = 5
SYFT_OBJECT_VERSION_6 = 6

supported_object_versions = [
SYFT_OBJECT_VERSION_1,
SYFT_OBJECT_VERSION_2,
SYFT_OBJECT_VERSION_3,
SYFT_OBJECT_VERSION_4,
SYFT_OBJECT_VERSION_5,
SYFT_OBJECT_VERSION_6,
]

HIGHEST_SYFT_OBJECT_VERSION = max(supported_object_versions)
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/local/twin_api_sync_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def compute(query):

# verify that ds cannot access private job
assert client_low_ds.api.services.job.get(private_job_id) is None
assert low_client.api.services.job.get(private_job_id) is not None
assert low_client.api.services.job.get(private_job_id) is None

# we only sync the mock function, we never sync the private function to the low side
mock_res = low_client.api.services.testapi.query.mock()
Expand Down
Loading