Skip to content

Commit

Permalink
Merge pull request #8906 from kiendang/gracefully-delete-workers
Browse files Browse the repository at this point in the history
Gracefully delete workers
  • Loading branch information
shubham3121 authored Jul 4, 2024
2 parents 37c13d5 + 922787f commit 5b72e9b
Show file tree
Hide file tree
Showing 6 changed files with 301 additions and 39 deletions.
2 changes: 1 addition & 1 deletion notebooks/api/0.8/10-container-images.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,7 @@
"outputs": [],
"source": [
"worker_delete_res = domain_client.api.services.worker.delete(\n",
" uid=second_worker.id,\n",
" uid=second_worker.id, force=True\n",
")"
]
},
Expand Down
7 changes: 7 additions & 0 deletions packages/syft/src/syft/protocol/protocol_version.json
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,13 @@
"hash": "f475543ed5e0066ca09c0dfd8c903e276d4974519e9958473d8141f8d446c881",
"action": "add"
}
},
"SyftWorker": {
"3": {
"version": 3,
"hash": "e124f56ddf4565df2be056553eecd15de7c80bd5f5fd0d06e8ff7815bb05563a",
"action": "add"
}
}
}
}
Expand Down
77 changes: 57 additions & 20 deletions packages/syft/src/syft/service/queue/zmq_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,22 @@
import socketserver
import sys
import threading
from threading import Event
import time
from time import sleep
from typing import Any
from typing import cast

# third party
from pydantic import field_validator
from result import Result
import zmq
from zmq import Frame
from zmq import LINGER
from zmq.error import ContextTerminated

# relative
from ...node.credentials import SyftVerifyKey
from ...serde.deserialize import _deserialize
from ...serde.serializable import serializable
from ...serde.serialize import _serialize as serialize
Expand All @@ -32,6 +36,7 @@
from ..response import SyftSuccess
from ..service import AbstractService
from ..worker.worker_pool import ConsumerState
from ..worker.worker_pool import SyftWorker
from ..worker.worker_stash import WorkerStash
from .base_queue import AbstractMessageHandler
from .base_queue import QueueClient
Expand All @@ -47,7 +52,7 @@
HEARTBEAT_INTERVAL_SEC = 2

# Thread join timeout (in seconds)
THREAD_TIMEOUT_SEC = 5
THREAD_TIMEOUT_SEC = 30

# Max duration (in ms) to wait for ZMQ poller to return
ZMQ_POLLER_TIMEOUT_MSEC = 1000
Expand Down Expand Up @@ -112,8 +117,6 @@ class Worker(SyftBaseModel):
syft_worker_id: UID | None = None
expiry_t: Timeout = Timeout(WORKER_TIMEOUT_SEC)

# TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information.
@field_validator("syft_worker_id", mode="before")
@classmethod
def set_syft_worker_id(cls, v: Any) -> Any:
Expand All @@ -130,6 +133,11 @@ def get_expiry(self) -> float:
def reset_expiry(self) -> None:
self.expiry_t.reset()

def _syft_worker(
self, stash: WorkerStash, credentials: SyftVerifyKey
) -> Result[SyftWorker | None, str]:
return stash.get_by_uid(credentials=credentials, uid=self.syft_worker_id)

def __str__(self) -> str:
svc = self.service.name if self.service else None
return (
Expand All @@ -156,7 +164,7 @@ def __init__(
self.worker_stash = worker_stash
self.queue_name = queue_name
self.auth_context = context
self._stop = threading.Event()
self._stop = Event()
self.post_init()

@property
Expand All @@ -182,24 +190,33 @@ def post_init(self) -> None:

def close(self) -> None:
self._stop.set()

try:
self.poll_workers.unregister(self.socket)
except Exception as e:
logger.exception("Failed to unregister poller.", exc_info=e)
finally:
if self.thread:
self.thread.join(THREAD_TIMEOUT_SEC)
if self.thread.is_alive():
logger.error(
f"ZMQProducer message sending thread join timed out during closing. "
f"Queue name {self.queue_name}, "
)
self.thread = None

if self.producer_thread:
self.producer_thread.join(THREAD_TIMEOUT_SEC)
if self.producer_thread.is_alive():
logger.error(
f"ZMQProducer queue thread join timed out during closing. "
f"Queue name {self.queue_name}, "
)
self.producer_thread = None

self.poll_workers.unregister(self.socket)
except Exception as e:
logger.exception("Failed to unregister poller.", exc_info=e)
finally:
self.socket.close()
self.context.destroy()

self._stop.clear()
# self._stop.clear()

@property
def action_service(self) -> AbstractService:
Expand Down Expand Up @@ -423,10 +440,23 @@ def purge_workers(self) -> None:
Workers are oldest to most recent, so we stop at the first alive worker.
"""
# work on a copy of the iterator
for worker in list(self.waiting):
if worker.has_expired():
for worker in self.waiting:
res = worker._syft_worker(self.worker_stash, self.auth_context.credentials)
if res.is_err() or (syft_worker := res.ok()) is None:
logger.info(f"Failed to retrieve SyftWorker {worker.syft_worker_id}")
continue

if worker.has_expired() or syft_worker.to_be_deleted:
logger.info(f"Deleting expired worker id={worker}")
self.delete_worker(worker, False)
self.delete_worker(worker, syft_worker.to_be_deleted)

# relative
from ...service.worker.worker_service import WorkerService

worker_service = cast(
WorkerService, self.auth_context.node.get_service(WorkerService)
)
worker_service._delete(self.auth_context, syft_worker)

def update_consumer_state_for_worker(
self, syft_worker_id: UID, consumer_state: ConsumerState
Expand Down Expand Up @@ -655,7 +685,7 @@ def __init__(
self.socket = None
self.verbose = verbose
self.id = UID().short()
self._stop = threading.Event()
self._stop = Event()
self.syft_worker_id = syft_worker_id
self.worker_stash = worker_stash
self.post_init()
Expand Down Expand Up @@ -692,16 +722,22 @@ def close(self) -> None:
self.disconnect_from_producer()
self._stop.set()
try:
self.poller.unregister(self.socket)
except Exception as e:
logger.exception("Failed to unregister worker.", exc_info=e)
finally:
if self.thread is not None:
self.thread.join(timeout=THREAD_TIMEOUT_SEC)
if self.thread.is_alive():
logger.error(
f"ZMQConsumer thread join timed out during closing. "
f"SyftWorker id {self.syft_worker_id}, "
f"service name {self.service_name}."
)
self.thread = None
self.poller.unregister(self.socket)
except Exception as e:
logger.error("Failed to unregister worker.", exc_info=e)
finally:
self.socket.close()
self.context.destroy()
self._stop.clear()
# self._stop.clear()

def send_to_producer(
self,
Expand Down Expand Up @@ -794,7 +830,8 @@ def _run(self) -> None:
self.reconnect_to_producer()
self.set_producer_alive()

self.send_heartbeat()
if not self._stop.is_set():
self.send_heartbeat()

except zmq.ZMQError as e:
if e.errno == zmq.ETERM:
Expand Down
51 changes: 50 additions & 1 deletion packages/syft/src/syft/service/worker/worker_pool.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# stdlib
from collections.abc import Callable
from enum import Enum
from typing import Any
from typing import cast
Expand All @@ -13,9 +14,13 @@
from ...store.linked_obj import LinkedObject
from ...types.base import SyftBaseModel
from ...types.datetime import DateTime
from ...types.syft_migration import migrate
from ...types.syft_object import SYFT_OBJECT_VERSION_2
from ...types.syft_object import SYFT_OBJECT_VERSION_3
from ...types.syft_object import SyftObject
from ...types.syft_object import short_uid
from ...types.transforms import drop
from ...types.transforms import make_set_default
from ...types.uid import UID
from ...util import options
from ...util.colors import SURFACE
Expand Down Expand Up @@ -47,7 +52,7 @@ class WorkerHealth(Enum):


@serializable()
class SyftWorker(SyftObject):
class SyftWorkerV2(SyftObject):
__canonical_name__ = "SyftWorker"
__version__ = SYFT_OBJECT_VERSION_2

Expand All @@ -74,6 +79,36 @@ class SyftWorker(SyftObject):
consumer_state: ConsumerState = ConsumerState.DETACHED
job_id: UID | None = None


@serializable()
class SyftWorker(SyftObject):
__canonical_name__ = "SyftWorker"
__version__ = SYFT_OBJECT_VERSION_3

__attr_unique__ = ["name"]
__attr_searchable__ = ["name", "container_id", "to_be_deleted"]
__repr_attrs__ = [
"name",
"container_id",
"image",
"status",
"healthcheck",
"worker_pool_name",
"created_at",
]

id: UID
name: str
container_id: str | None = None
created_at: DateTime = DateTime.now()
healthcheck: WorkerHealth | None = None
status: WorkerStatus
image: SyftWorkerImage | None = None
worker_pool_name: str
consumer_state: ConsumerState = ConsumerState.DETACHED
job_id: UID | None = None
to_be_deleted: bool = False

@property
def logs(self) -> str | SyftError:
api = APIRegistry.api_for(
Expand Down Expand Up @@ -313,3 +348,17 @@ def _get_worker_container_status(
container_status,
SyftError(message=f"Unknown container status: {container_status}"),
)


@migrate(SyftWorkerV2, SyftWorker)
def upgrade_syft_worker() -> list[Callable]:
return [
make_set_default("to_be_deleted", False),
]


@migrate(SyftWorker, SyftWorkerV2)
def downgrade_syft_worker() -> list[Callable]:
return [
drop(["to_be_deleted"]),
]
56 changes: 39 additions & 17 deletions packages/syft/src/syft/service/worker/worker_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,28 +146,26 @@ def logs(

return logs if raw else logs.decode(errors="ignore")

@service_method(
path="worker.delete",
name="delete",
roles=DATA_OWNER_ROLE_LEVEL,
)
def delete(
self,
context: AuthedServiceContext,
uid: UID,
force: bool = False,
def _delete(
self, context: AuthedServiceContext, worker: SyftWorker, force: bool = False
) -> SyftSuccess | SyftError:
worker = self._get_worker(context=context, uid=uid)
if isinstance(worker, SyftError):
return worker

uid = worker.id
worker_pool_name = worker.worker_pool_name

# relative
from ...service.job.job_service import JobService
from .worker_pool_service import SyftWorkerPoolService

worker_pool_service: AbstractService = context.node.get_service(
SyftWorkerPoolService
if force and worker.job_id is not None:
job_service = cast(JobService, context.node.get_service(JobService))
res = job_service.kill(context=context, id=worker.job_id)
if isinstance(res, SyftError):
return SyftError(
message=f"Failed to terminate the job associated with worker {uid}: {res.message}"
)

worker_pool_service = cast(
SyftWorkerPoolService, context.node.get_service(SyftWorkerPoolService)
)
worker_pool_stash = worker_pool_service.stash
result = worker_pool_stash.get_by_name(
Expand Down Expand Up @@ -205,7 +203,7 @@ def delete(
if isinstance(docker_container, SyftError):
return docker_container

stopped = _stop_worker_container(worker, docker_container, force)
stopped = _stop_worker_container(worker, docker_container, force=force)
if stopped is not None:
return stopped
else:
Expand Down Expand Up @@ -235,6 +233,30 @@ def delete(
message=f"Worker with id: {uid} deleted successfully from pool: {worker_pool.name}"
)

@service_method(
path="worker.delete",
name="delete",
roles=DATA_OWNER_ROLE_LEVEL,
)
def delete(
self,
context: AuthedServiceContext,
uid: UID,
force: bool = False,
) -> SyftSuccess | SyftError:
worker = self._get_worker(context=context, uid=uid)
worker.to_be_deleted = True

res = self.stash.update(context.credentials, worker)
if isinstance(res, SyftError):
return res

if not force:
# relative
return SyftSuccess(message=f"Worker {uid} has been marked for deletion.")

return self._delete(context, worker, force=True)

def _get_worker(
self, context: AuthedServiceContext, uid: UID
) -> SyftWorker | SyftError:
Expand Down
Loading

0 comments on commit 5b72e9b

Please sign in to comment.