diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index 34c10dd9efe..64d10f92930 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -62,6 +62,7 @@ from ..service.job.job_service import JobService from ..service.job.job_stash import Job from ..service.job.job_stash import JobStash +from ..service.job.job_stash import JobType from ..service.log.log_service import LogService from ..service.metadata.metadata_service import MetadataService from ..service.metadata.node_metadata import NodeMetadataV3 @@ -1288,10 +1289,10 @@ def add_api_endpoint_execution_to_queue( action = Action.from_api_endpoint_execution() return self.add_queueitem_to_queue( - queue_item, - credentials, - action, - None, + queue_item=queue_item, + credentials=credentials, + action=action, + job_type=JobType.TWINAPIJOB, ) def get_worker_pool_ref_by_name( @@ -1360,16 +1361,22 @@ def add_action_to_queue( ) return self.add_queueitem_to_queue( - queue_item, credentials, action, parent_job_id, user_id + queue_item=queue_item, + credentials=credentials, + action=action, + parent_job_id=parent_job_id, + user_id=user_id, ) def add_queueitem_to_queue( self, + *, queue_item: QueueItem, credentials: SyftVerifyKey, action: Action | None = None, parent_job_id: UID | None = None, user_id: UID | None = None, + job_type: JobType = JobType.JOB, ) -> Job | SyftError: log_id = UID() role = self.get_role_for_credentials(credentials=credentials) @@ -1403,6 +1410,7 @@ def add_queueitem_to_queue( parent_job_id=parent_job_id, action=action, requested_by=user_id, + job_type=job_type, ) # 🟡 TODO 36: Needs distributed lock @@ -1505,8 +1513,8 @@ def add_api_call_to_queue( worker_pool=worker_pool_ref, ) return self.add_queueitem_to_queue( - queue_item, - api_call.credentials, + queue_item=queue_item, + credentials=api_call.credentials, action=None, parent_job_id=parent_job_id, ) diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index e30f48dfd5a..5ed629b5151 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -54,7 +54,7 @@ }, "5": { "version": 5, - "hash": "82ee08442b09797ed7a3710c31de633bb308b1d2215f51b58a3e01a4c201055d", + "hash": "95a2367bce2e4deb5f8c807561779876c1ec010dbf4d4f68abb526e4eca4487e", "action": "add" } }, diff --git a/packages/syft/src/syft/service/job/job_stash.py b/packages/syft/src/syft/service/job/job_stash.py index d7aa3aca00b..ec5dcfd19a1 100644 --- a/packages/syft/src/syft/service/job/job_stash.py +++ b/packages/syft/src/syft/service/job/job_stash.py @@ -29,7 +29,7 @@ from ...store.document_store import UIDPartitionKey from ...types.datetime import DateTime from ...types.syft_object import SYFT_OBJECT_VERSION_2 -from ...types.syft_object import SYFT_OBJECT_VERSION_5 +from ...types.syft_object import SYFT_OBJECT_VERSION_6 from ...types.syft_object import SyftObject from ...types.syncable_object import SyncableSyftObject from ...types.uid import UID @@ -73,10 +73,19 @@ def center_content(text: Any) -> str: return center_div +@serializable() +class JobType(str, Enum): + JOB = "job" + TWINAPIJOB = "twinapijob" + + def __str__(self) -> str: + return self.value + + @serializable() class Job(SyncableSyftObject): __canonical_name__ = "JobItem" - __version__ = SYFT_OBJECT_VERSION_5 + __version__ = SYFT_OBJECT_VERSION_6 id: UID node_uid: UID @@ -94,6 +103,7 @@ class Job(SyncableSyftObject): updated_at: DateTime | None = None user_code_id: UID | None = None requested_by: UID | None = None + job_type: JobType = JobType.JOB __attr_searchable__ = ["parent_job_id", "job_worker_id", "status", "user_code_id"] __repr_attrs__ = [ diff --git a/packages/syft/src/syft/service/sync/diff_state.py b/packages/syft/src/syft/service/sync/diff_state.py index 014e33f5bc8..cde79262c24 100644 --- a/packages/syft/src/syft/service/sync/diff_state.py +++ b/packages/syft/src/syft/service/sync/diff_state.py @@ -45,6 +45,7 @@ from ..code.user_code import UserCode from ..code.user_code import UserCodeStatusCollection from ..job.job_stash import Job +from ..job.job_stash import JobType from ..log.log import SyftLog from ..output.output_service import ExecutionOutput from ..request.request import Request @@ -1288,7 +1289,12 @@ def hierarchies( # TODO: Figure out nested user codes, do we even need that? root_ids.append(diff.object_id) # type: ignore - elif isinstance(diff_obj, Job) and diff_obj.parent_job_id is None: # type: ignore + elif ( + isinstance(diff_obj, Job) # type: ignore + and diff_obj.parent_job_id is None + # ignore Job objects created by TwinAPIEndpoint + and diff_obj.job_type != JobType.TWINAPIJOB + ): root_ids.append(diff.object_id) # type: ignore for root_uid in root_ids: diff --git a/packages/syft/src/syft/types/syft_object.py b/packages/syft/src/syft/types/syft_object.py index a290e4ff080..3ec9c073165 100644 --- a/packages/syft/src/syft/types/syft_object.py +++ b/packages/syft/src/syft/types/syft_object.py @@ -61,6 +61,7 @@ SYFT_OBJECT_VERSION_3 = 3 SYFT_OBJECT_VERSION_4 = 4 SYFT_OBJECT_VERSION_5 = 5 +SYFT_OBJECT_VERSION_6 = 6 supported_object_versions = [ SYFT_OBJECT_VERSION_1, @@ -68,6 +69,7 @@ SYFT_OBJECT_VERSION_3, SYFT_OBJECT_VERSION_4, SYFT_OBJECT_VERSION_5, + SYFT_OBJECT_VERSION_6, ] HIGHEST_SYFT_OBJECT_VERSION = max(supported_object_versions) diff --git a/tests/integration/local/twin_api_sync_test.py b/tests/integration/local/twin_api_sync_test.py index d39066ade9a..e09c82001d1 100644 --- a/tests/integration/local/twin_api_sync_test.py +++ b/tests/integration/local/twin_api_sync_test.py @@ -111,7 +111,7 @@ def compute(query): # verify that ds cannot access private job assert client_low_ds.api.services.job.get(private_job_id) is None - assert low_client.api.services.job.get(private_job_id) is not None + assert low_client.api.services.job.get(private_job_id) is None # we only sync the mock function, we never sync the private function to the low side mock_res = low_client.api.services.testapi.query.mock()