Skip to content

Commit

Permalink
[sapphire] Remove launcher thread
Browse files Browse the repository at this point in the history
Working to reduce complexity and avoid potential deadlocks.
Also increase code coverage and adds docs.
  • Loading branch information
tysmith committed Sep 27, 2023
1 parent 7c9c768 commit 6e48e73
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 150 deletions.
222 changes: 124 additions & 98 deletions sapphire/connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from logging import getLogger
from sys import exc_info
from threading import Thread, ThreadError, active_count
from time import sleep, time
from time import time
from traceback import format_exception

from .worker import Worker
Expand All @@ -18,27 +16,67 @@
class ConnectionManager:
SHUTDOWN_DELAY = 0.5 # allow extra time before closing socket if needed

__slots__ = ("_job", "_listener", "_socket", "_workers")
__slots__ = (
"_deadline",
"_deadline_exceeded",
"_job",
"_limit",
"_next_poll",
"_poll",
"_socket",
)

def __init__(self, job, sock, max_workers=1):
assert max_workers > 0
def __init__(self, job, srv_socket, limit=1, poll=0.5):
assert limit > 0
assert poll > 0
self._deadline = None
self._deadline_exceeded = False
self._job = job
self._listener = None
self._socket = sock
self._workers = max_workers
self._limit = limit
self._next_poll = 0
self._poll = poll
self._socket = srv_socket

def __enter__(self):
self.start()
return self

def __exit__(self, *exc):
self.close()

def _can_continue(self, continue_cb):
"""Check timeout and callback status.
Args:
continue_cb: Indicates whether to continue.
Returns:
True if callback returns True and timeout has been be hit otherwise False.
"""
now = time()
if self._next_poll > now:
return True
self._next_poll = now + self._poll
# check if callback returns False
if continue_cb is not None and not continue_cb():
LOG.debug("continue_cb() returned False")
return False
# check for a timeout
if self._deadline and self._deadline <= now:
LOG.debug("exceeded serve deadline")
self._deadline_exceeded = True
return False
return True

def close(self):
"""Set job state to finished and raise any errors encountered by workers.
Args:
None
Returns:
None
"""
self._job.finish()
if self._listener is not None:
self._listener.join()
self._listener = None
if not self._job.exceptions.empty():
exc_type, exc_obj, exc_tb = self._job.exceptions.get()
LOG.error(
Expand All @@ -48,108 +86,96 @@ def close(self):
# re-raise exception from worker once all workers are closed
raise exc_obj

def start(self):
@staticmethod
def _join_workers(workers, timeout=0):
"""Attempt to join workers.
Args:
workers: Collection of workers.
timeout: Maximum time in seconds to wait.
Returns:
list: Workers that do not join before the timeout is reached.
"""
assert timeout >= 0
alive = []
deadline = time() + timeout
for worker in workers:
if not worker.join(timeout=max(deadline - time(), 0)):
alive.append(worker)
return alive

def serve(self, timeout, continue_cb=None, shutdown_delay=SHUTDOWN_DELAY):
"""Manage workers and serve job contents.
Args:
timeout: Maximum time to serve in seconds.
continue_cb: Indicates whether to continue.
shutdown_delay: Time in seconds to wait before calling shutdown on
sockets of active workers.
Returns:
bool: True unless the timeout is exceeded.
"""
assert self._job.pending
# create the listener thread to handle incoming requests
listener = Thread(
target=self.listener,
args=(self._socket, self._job, self._workers),
kwargs={"shutdown_delay": self.SHUTDOWN_DELAY},
)
# launch listener thread and handle thread errors
for retry in reversed(range(10)):
try:
listener.start()
except ThreadError:
# thread errors can be due to low system resources while fuzzing
LOG.warning("ThreadError (listener), threads: %d", active_count())
if retry < 1:
raise
sleep(1)
continue
self._listener = listener
break

def wait(self, timeout, continue_cb=None, poll=0.5):
assert self._listener is not None
if timeout > 0:
deadline = time() + timeout
else:
deadline = None
assert self._socket.gettimeout() is not None
assert shutdown_delay >= 0
if continue_cb is not None and not callable(continue_cb):
raise TypeError("continue_cb must be callable")
# it is important to keep this loop fast because it can limit
# the total iteration rate of Grizzly
while not self._job.is_complete(wait=poll):
# check if callback returns False
if continue_cb is not None and not continue_cb():
LOG.debug("continue_cb() returned False")
break
# check for a timeout
if deadline and deadline <= time():
return False
return True

@staticmethod
def listener(serv_sock, serv_job, max_workers, shutdown_delay=0):
assert max_workers > 0
assert shutdown_delay >= 0
self._deadline_exceeded = False
start_time = time()
if not timeout:
self._deadline = None
else:
assert timeout > 0
self._deadline = start_time + timeout

launches = 0
running = 0
workers = []
start_time = time()
LOG.debug("starting listener (max workers %d)", max_workers)
LOG.debug(
"accepting requests (worker limit: %d, timeout: %r)", self._limit, timeout
)
try:
while not serv_job.is_complete():
if not serv_job.accepting.wait(0.05):
continue
worker = Worker.launch(serv_sock, serv_job)
if worker is not None:
workers.append(worker)
launches += 1
while not self._job.is_complete() and self._can_continue(continue_cb):
# launch workers
if running < self._limit:
if not self._job.accepting.wait(0.05):
# wait for accepting flag to be set
continue
worker = Worker.launch(self._socket, self._job)
if worker is not None:
workers.append(worker)
running = len(workers)
launches += 1

# manage workers
if len(workers) >= max_workers:
LOG.debug("max worker limit (%d) hit, waiting...", len(workers))
assert serv_job.worker_complete.wait(300)
serv_job.worker_complete.clear()
LOG.debug("removing completed workers...")
# sometimes the thread that triggered the event doesn't quite
# cleanup in time, so retry (10x with 0.5 second sleep on failure)
for _ in range(10):
workers = list(w for w in workers if not w.join(timeout=0))
if len(workers) < max_workers:
break
sleep(0.5) # pragma: no cover
else: # pragma: no cover
# this should never happen
LOG.error("Failed to remove workers")
raise RuntimeError("Failed to remove workers")
LOG.debug("removed completed workers (%d active)", len(workers))

except Exception: # pylint: disable=broad-except
if serv_job.exceptions.empty():
serv_job.exceptions.put(exc_info())
serv_job.finish()
if running >= self._limit:
LOG.debug("worker limit (%d) hit, waiting...", len(workers))
if self._job.worker_complete.wait(1):
self._job.worker_complete.clear()
workers = self._join_workers(workers)
running = len(workers)
LOG.debug("removed completed workers (%d active)", running)

finally:
LOG.debug("%d requests in %0.3f seconds", launches, time() - start_time)
LOG.debug("shutting down, waiting for %d worker(s)...", len(workers))
if not self._job.is_complete():
LOG.debug("job was incomplete")
self._job.finish()
# use shutdown_delay to avoid cutting off connections
deadline = time() + shutdown_delay
# wait for all running workers to exit
while time() < deadline:
serv_job.worker_complete.clear()
workers = list(w for w in workers if not w.join(timeout=0))
if not workers:
break
serv_job.worker_complete.wait(max(deadline - time(), 0))
workers = self._join_workers(workers, timeout=shutdown_delay)
# close remaining active workers
if workers:
LOG.debug("closing remaining active workers: %d", len(workers))
for worker in workers:
worker.close()
# join remaining workers
deadline = time() + 30
for worker in workers:
worker.join(timeout=max(deadline - time(), 0))
if not all(w.join(timeout=0) for w in workers): # pragma: no cover
if self._join_workers(workers, timeout=30):
LOG.error("Failed to close workers")
raise RuntimeError("Failed to close workers")

# return False only if there was a timeout
return not self._deadline_exceeded
4 changes: 2 additions & 2 deletions sapphire/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,8 @@ def serve_path(
job.finish()
LOG.debug("nothing to serve")
return (Served.NONE, tuple())
with ConnectionManager(job, self._socket, self._max_workers) as loadmgr:
was_timeout = not loadmgr.wait(self.timeout, continue_cb=continue_cb)
with ConnectionManager(job, self._socket, limit=self._max_workers) as mgr:
was_timeout = not mgr.serve(self.timeout, continue_cb=continue_cb)
LOG.debug("%s, timeout: %r", job.status, was_timeout)
return (Served.TIMEOUT if was_timeout else job.status, tuple(job.served))

Expand Down
Loading

0 comments on commit 6e48e73

Please sign in to comment.