Skip to content

Commit

Permalink
Merge pull request #9027 from OpenMined/fix-saving-nested-action-obj-…
Browse files Browse the repository at this point in the history
…queue

Move Flattening of Nested Action Objects to User code exec
  • Loading branch information
teo-milea committed Jul 15, 2024
2 parents 1790644 + 43a5ebd commit 077fc56
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 81 deletions.
81 changes: 81 additions & 0 deletions packages/syft/src/syft/service/action/action_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,10 @@ def _user_code_execute(
input_policy = code_item.get_input_policy(context)
output_policy = code_item.get_output_policy(context)

# Unwrap nested ActionObjects
for _k, arg in kwargs.items():
self.flatten_action_arg(context, arg) if isinstance(arg, UID) else None

if not override_execution_permission:
if input_policy is None:
if not code_item.is_output_policy_approved(context):
Expand Down Expand Up @@ -751,6 +755,83 @@ def call_method(
else:
return execute_object(self, context, resolved_self, action) # type:ignore[unreachable]

def unwrap_nested_actionobjects(
self, context: AuthedServiceContext, data: Any
) -> Any:
"""recursively unwraps nested action objects"""

if isinstance(data, list):
return [self.unwrap_nested_actionobjects(context, obj) for obj in data]
if isinstance(data, dict):
return {
key: self.unwrap_nested_actionobjects(context, obj)
for key, obj in data.items()
}
if isinstance(data, ActionObject):
res = self.get(context=context, uid=data.id)
res = res.ok() if res.is_ok() else res.err()
if not isinstance(res, ActionObject):
return SyftError(message=f"{res}")
else:
nested_res = res.syft_action_data
if isinstance(nested_res, ActionObject):
raise ValueError(
"More than double nesting of ActionObjects is currently not supported"
)
return nested_res
return data

def contains_nested_actionobjects(self, data: Any) -> bool:
"""
returns if this is a list/set/dict that contains ActionObjects
"""

def unwrap_collection(col: set | dict | list) -> [Any]: # type: ignore
return_values = []
if isinstance(col, dict):
values = list(col.values()) + list(col.keys())
else:
values = list(col)
for v in values:
if isinstance(v, list | dict | set):
return_values += unwrap_collection(v)
else:
return_values.append(v)
return return_values

if isinstance(data, list | dict | set):
values = unwrap_collection(data)
has_action_object = any(isinstance(x, ActionObject) for x in values)
return has_action_object
elif isinstance(data, ActionObject):
return True
return False

def flatten_action_arg(self, context: AuthedServiceContext, arg: UID) -> UID | None:
""" "If the argument is a collection (of collections) of ActionObjects,
We want to flatten the collection and upload a new ActionObject that contains
its values. E.g. [[ActionObject1, ActionObject2],[ActionObject3, ActionObject4]]
-> [[value1, value2],[value3, value4]]
"""
res = self.get(context=context, uid=arg)
if res.is_err():
return arg

action_object = res.ok()
data = action_object.syft_action_data

if self.contains_nested_actionobjects(data):
new_data = self.unwrap_nested_actionobjects(context, data)
# Update existing action object with the new flattened data
action_object.syft_action_data_cache = new_data
action_object._save_to_blob_storage()
res = self._set(
context=context,
action_object=action_object,
)

return None

@service_method(path="action.execute", name="execute", roles=GUEST_ROLE_LEVEL)
def execute(
self, context: AuthedServiceContext, action: Action
Expand Down
80 changes: 0 additions & 80 deletions packages/syft/src/syft/service/queue/zmq_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,82 +261,6 @@ def contains_unresolved_action_objects(self, arg: Any, recursion: int = 0) -> bo
logger.exception("Failed to resolve action objects.", exc_info=e)
return True

def unwrap_nested_actionobjects(self, data: Any) -> Any:
"""recursively unwraps nested action objects"""

if isinstance(data, list):
return [self.unwrap_nested_actionobjects(obj) for obj in data]
if isinstance(data, dict):
return {
key: self.unwrap_nested_actionobjects(obj) for key, obj in data.items()
}
if isinstance(data, ActionObject):
res = self.action_service.get(self.auth_context, data.id)
res = res.ok() if res.is_ok() else res.err()
if not isinstance(res, ActionObject):
return SyftError(message=f"{res}")
else:
nested_res = res.syft_action_data
if isinstance(nested_res, ActionObject):
raise ValueError(
"More than double nesting of ActionObjects is currently not supported"
)
return nested_res
return data

def contains_nested_actionobjects(self, data: Any) -> bool:
"""
returns if this is a list/set/dict that contains ActionObjects
"""

def unwrap_collection(col: set | dict | list) -> [Any]: # type: ignore
return_values = []
if isinstance(col, dict):
values = list(col.values()) + list(col.keys())
else:
values = list(col)
for v in values:
if isinstance(v, list | dict | set):
return_values += unwrap_collection(v)
else:
return_values.append(v)
return return_values

if isinstance(data, list | dict | set):
values = unwrap_collection(data)
has_action_object = any(isinstance(x, ActionObject) for x in values)
return has_action_object
elif isinstance(data, ActionObject):
return True
return False

def preprocess_action_arg(self, arg: UID) -> UID | None:
""" "If the argument is a collection (of collections) of ActionObjects,
We want to flatten the collection and upload a new ActionObject that contains
its values. E.g. [[ActionObject1, ActionObject2],[ActionObject3, ActionObject4]]
-> [[value1, value2],[value3, value4]]
"""
res = self.action_service.get(context=self.auth_context, uid=arg)
if res.is_err():
return arg
action_object = res.ok()
data = action_object.syft_action_data
if self.contains_nested_actionobjects(data):
new_data = self.unwrap_nested_actionobjects(data)

new_action_object = ActionObject.from_obj(
new_data,
id=action_object.id,
syft_blob_storage_entry_id=action_object.syft_blob_storage_entry_id,
syft_server_location=action_object.syft_server_location,
syft_client_verify_key=action_object.syft_client_verify_key,
)
new_action_object._save_to_blob_storage()
res = self.action_service._set(
context=self.auth_context, action_object=new_action_object
)
return None

def read_items(self) -> None:
while True:
if self._stop.is_set():
Expand Down Expand Up @@ -369,10 +293,6 @@ def read_items(self) -> None:
action.args
) or self.contains_unresolved_action_objects(action.kwargs):
continue
for arg in action.args:
self.preprocess_action_arg(arg)
for _, arg in action.kwargs.items():
self.preprocess_action_arg(arg)

msg_bytes = serialize(item, to_bytes=True)
worker_pool = item.worker_pool.resolve_with_context(
Expand Down
3 changes: 2 additions & 1 deletion tests/integration/local/job_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

# stdlib
from secrets import token_hex
import time

# third party
import pytest
Expand Down Expand Up @@ -87,6 +86,7 @@ def job(server):
@syft_function()
def process_batch():
# stdlib
import time

while time.sleep(1) is None:
...
Expand All @@ -96,6 +96,7 @@ def process_batch():
@syft_function_single_use()
def process_all(datasite):
# stdlib
import time

_ = datasite.launch_job(process_batch)
_ = datasite.launch_job(process_batch)
Expand Down

0 comments on commit 077fc56

Please sign in to comment.