diff --git a/sapphire/connection_manager.py b/sapphire/connection_manager.py index 92b1de96..3f053a4b 100644 --- a/sapphire/connection_manager.py +++ b/sapphire/connection_manager.py @@ -2,9 +2,7 @@ # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. from logging import getLogger -from sys import exc_info -from threading import Thread, ThreadError, active_count -from time import sleep, time +from time import time from traceback import format_exception from .worker import Worker @@ -18,27 +16,67 @@ class ConnectionManager: SHUTDOWN_DELAY = 0.5 # allow extra time before closing socket if needed - __slots__ = ("_job", "_listener", "_socket", "_workers") + __slots__ = ( + "_deadline", + "_deadline_exceeded", + "_job", + "_limit", + "_next_poll", + "_poll", + "_socket", + ) - def __init__(self, job, sock, max_workers=1): - assert max_workers > 0 + def __init__(self, job, srv_socket, limit=1, poll=0.5): + assert limit > 0 + assert poll > 0 + self._deadline = None + self._deadline_exceeded = False self._job = job - self._listener = None - self._socket = sock - self._workers = max_workers + self._limit = limit + self._next_poll = 0 + self._poll = poll + self._socket = srv_socket def __enter__(self): - self.start() return self def __exit__(self, *exc): self.close() + def _can_continue(self, continue_cb): + """Check timeout and callback status. + + Args: + continue_cb: Indicates whether to continue. + + Returns: + True if callback returns True and timeout has been be hit otherwise False. + """ + now = time() + if self._next_poll > now: + return True + self._next_poll = now + self._poll + # check if callback returns False + if continue_cb is not None and not continue_cb(): + LOG.debug("continue_cb() returned False") + return False + # check for a timeout + if self._deadline and self._deadline <= now: + LOG.debug("exceeded serve deadline") + self._deadline_exceeded = True + return False + return True + def close(self): + """Set job state to finished and raise any errors encountered by workers. + + Args: + None + + Returns: + None + """ self._job.finish() - if self._listener is not None: - self._listener.join() - self._listener = None if not self._job.exceptions.empty(): exc_type, exc_obj, exc_tb = self._job.exceptions.get() LOG.error( @@ -48,108 +86,96 @@ def close(self): # re-raise exception from worker once all workers are closed raise exc_obj - def start(self): + @staticmethod + def _join_workers(workers, timeout=0): + """Attempt to join workers. + + Args: + workers: Collection of workers. + timeout: Maximum time in seconds to wait. + + Returns: + list: Workers that do not join before the timeout is reached. + """ + assert timeout >= 0 + alive = [] + deadline = time() + timeout + for worker in workers: + if not worker.join(timeout=max(deadline - time(), 0)): + alive.append(worker) + return alive + + def serve(self, timeout, continue_cb=None, shutdown_delay=SHUTDOWN_DELAY): + """Manage workers and serve job contents. + + Args: + timeout: Maximum time to serve in seconds. + continue_cb: Indicates whether to continue. + shutdown_delay: Time in seconds to wait before calling shutdown on + sockets of active workers. + + Returns: + bool: True unless the timeout is exceeded. + """ assert self._job.pending - # create the listener thread to handle incoming requests - listener = Thread( - target=self.listener, - args=(self._socket, self._job, self._workers), - kwargs={"shutdown_delay": self.SHUTDOWN_DELAY}, - ) - # launch listener thread and handle thread errors - for retry in reversed(range(10)): - try: - listener.start() - except ThreadError: - # thread errors can be due to low system resources while fuzzing - LOG.warning("ThreadError (listener), threads: %d", active_count()) - if retry < 1: - raise - sleep(1) - continue - self._listener = listener - break - - def wait(self, timeout, continue_cb=None, poll=0.5): - assert self._listener is not None - if timeout > 0: - deadline = time() + timeout - else: - deadline = None + assert self._socket.gettimeout() is not None + assert shutdown_delay >= 0 if continue_cb is not None and not callable(continue_cb): raise TypeError("continue_cb must be callable") - # it is important to keep this loop fast because it can limit - # the total iteration rate of Grizzly - while not self._job.is_complete(wait=poll): - # check if callback returns False - if continue_cb is not None and not continue_cb(): - LOG.debug("continue_cb() returned False") - break - # check for a timeout - if deadline and deadline <= time(): - return False - return True - @staticmethod - def listener(serv_sock, serv_job, max_workers, shutdown_delay=0): - assert max_workers > 0 - assert shutdown_delay >= 0 + self._deadline_exceeded = False + start_time = time() + if not timeout: + self._deadline = None + else: + assert timeout > 0 + self._deadline = start_time + timeout + launches = 0 + running = 0 workers = [] - start_time = time() - LOG.debug("starting listener (max workers %d)", max_workers) + LOG.debug( + "accepting requests (worker limit: %d, timeout: %r)", self._limit, timeout + ) try: - while not serv_job.is_complete(): - if not serv_job.accepting.wait(0.05): - continue - worker = Worker.launch(serv_sock, serv_job) - if worker is not None: - workers.append(worker) - launches += 1 + while not self._job.is_complete() and self._can_continue(continue_cb): + # launch workers + if running < self._limit: + if not self._job.accepting.wait(0.05): + # wait for accepting flag to be set + continue + worker = Worker.launch(self._socket, self._job) + if worker is not None: + workers.append(worker) + running = len(workers) + launches += 1 + # manage workers - if len(workers) >= max_workers: - LOG.debug("max worker limit (%d) hit, waiting...", len(workers)) - assert serv_job.worker_complete.wait(300) - serv_job.worker_complete.clear() - LOG.debug("removing completed workers...") - # sometimes the thread that triggered the event doesn't quite - # cleanup in time, so retry (10x with 0.5 second sleep on failure) - for _ in range(10): - workers = list(w for w in workers if not w.join(timeout=0)) - if len(workers) < max_workers: - break - sleep(0.5) # pragma: no cover - else: # pragma: no cover - # this should never happen - LOG.error("Failed to remove workers") - raise RuntimeError("Failed to remove workers") - LOG.debug("removed completed workers (%d active)", len(workers)) - - except Exception: # pylint: disable=broad-except - if serv_job.exceptions.empty(): - serv_job.exceptions.put(exc_info()) - serv_job.finish() + if running >= self._limit: + LOG.debug("worker limit (%d) hit, waiting...", len(workers)) + if self._job.worker_complete.wait(1): + self._job.worker_complete.clear() + workers = self._join_workers(workers) + running = len(workers) + LOG.debug("removed completed workers (%d active)", running) + finally: LOG.debug("%d requests in %0.3f seconds", launches, time() - start_time) LOG.debug("shutting down, waiting for %d worker(s)...", len(workers)) + if not self._job.is_complete(): + LOG.debug("job was incomplete") + self._job.finish() # use shutdown_delay to avoid cutting off connections - deadline = time() + shutdown_delay - # wait for all running workers to exit - while time() < deadline: - serv_job.worker_complete.clear() - workers = list(w for w in workers if not w.join(timeout=0)) - if not workers: - break - serv_job.worker_complete.wait(max(deadline - time(), 0)) + workers = self._join_workers(workers, timeout=shutdown_delay) # close remaining active workers if workers: LOG.debug("closing remaining active workers: %d", len(workers)) for worker in workers: worker.close() # join remaining workers - deadline = time() + 30 - for worker in workers: - worker.join(timeout=max(deadline - time(), 0)) - if not all(w.join(timeout=0) for w in workers): # pragma: no cover + if self._join_workers(workers, timeout=30): LOG.error("Failed to close workers") raise RuntimeError("Failed to close workers") + + # return False only if there was a timeout + return not self._deadline_exceeded diff --git a/sapphire/core.py b/sapphire/core.py index 675f657f..872c0a48 100644 --- a/sapphire/core.py +++ b/sapphire/core.py @@ -222,8 +222,8 @@ def serve_path( job.finish() LOG.debug("nothing to serve") return (Served.NONE, tuple()) - with ConnectionManager(job, self._socket, self._max_workers) as loadmgr: - was_timeout = not loadmgr.wait(self.timeout, continue_cb=continue_cb) + with ConnectionManager(job, self._socket, limit=self._max_workers) as mgr: + was_timeout = not mgr.serve(self.timeout, continue_cb=continue_cb) LOG.debug("%s, timeout: %r", job.status, was_timeout) return (Served.TIMEOUT if was_timeout else job.status, tuple(job.served)) diff --git a/sapphire/test_connection_manager.py b/sapphire/test_connection_manager.py index b3080b72..82c4aff7 100644 --- a/sapphire/test_connection_manager.py +++ b/sapphire/test_connection_manager.py @@ -5,15 +5,16 @@ from itertools import count from socket import socket -from threading import ThreadError -from pytest import raises +from pytest import mark, raises from .connection_manager import ConnectionManager from .job import Job +from .worker import Worker -def test_connection_manager_01(mocker, tmp_path): +@mark.parametrize("timeout", [10, 0]) +def test_connection_manager_01(mocker, tmp_path, timeout): """test basic ConnectionManager""" (tmp_path / "testfile").write_bytes(b"test") job = Job(tmp_path) @@ -22,29 +23,16 @@ def test_connection_manager_01(mocker, tmp_path): serv_sock = mocker.Mock(spec_set=socket) serv_sock.accept.return_value = (clnt_sock, None) assert not job.is_complete() - with ConnectionManager(job, serv_sock) as loadmgr: - assert loadmgr.wait(1) + with ConnectionManager(job, serv_sock) as mgr: + assert mgr.serve(timeout) assert clnt_sock.close.call_count == 1 assert job.is_complete() assert not job.accepting.is_set() assert job.exceptions.empty() -def test_connection_manager_02(mocker): - """test ConnectionManager.start() failure""" - mocker.patch("sapphire.connection_manager.sleep", autospec=True) - fake_thread = mocker.patch("sapphire.connection_manager.Thread", autospec=True) - fake_thread.return_value.start.side_effect = ThreadError - job = mocker.Mock(spec_set=Job) - job.pending = True - loadmgr = ConnectionManager(job, None) - with raises(ThreadError): - loadmgr.start() - loadmgr.close() - assert job.is_complete() - - -def test_connection_manager_03(mocker, tmp_path): +@mark.parametrize("worker_limit", [1, 2, 10]) +def test_connection_manager_02(mocker, tmp_path, worker_limit): """test ConnectionManager multiple files and requests""" (tmp_path / "test1").touch() (tmp_path / "test2").touch() @@ -64,58 +52,90 @@ def test_connection_manager_03(mocker, tmp_path): serv_sock = mocker.Mock(spec_set=socket) serv_sock.accept.return_value = (clnt_sock, None) assert not job.is_complete() - with ConnectionManager(job, serv_sock, max_workers=2) as loadmgr: - assert loadmgr.wait(1) + with ConnectionManager(job, serv_sock, limit=worker_limit) as mgr: + assert mgr.serve(10) assert clnt_sock.close.call_count == 8 assert job.is_complete() -def test_connection_manager_04(mocker, tmp_path): - """test ConnectionManager.wait()""" +def test_connection_manager_03(mocker, tmp_path): + """test ConnectionManager re-raise worker exceptions""" (tmp_path / "test1").touch() job = Job(tmp_path) clnt_sock = mocker.Mock(spec_set=socket) - clnt_sock.recv.return_value = b"" + clnt_sock.recv.side_effect = Exception("worker exception") serv_sock = mocker.Mock(spec_set=socket) serv_sock.accept.return_value = (clnt_sock, None) - with ConnectionManager(job, serv_sock, max_workers=10) as loadmgr: + with raises(Exception, match="worker exception"): + with ConnectionManager(job, serv_sock) as mgr: + mgr.serve(10) + assert clnt_sock.close.call_count == 1 + assert job.is_complete() + assert job.exceptions.empty() + + +def test_connection_manager_04(mocker, tmp_path): + """test ConnectionManager.serve() with callback""" + (tmp_path / "test1").touch() + job = Job(tmp_path) + with ConnectionManager(job, mocker.Mock(spec_set=socket), poll=0.01) as mgr: # invalid callback with raises(TypeError, match="continue_cb must be callable"): - loadmgr.wait(0, continue_cb="test") + mgr.serve(10, continue_cb="test") + # job did not start + assert not job.is_complete() # callback abort - assert loadmgr.wait(1, continue_cb=lambda: False, poll=0.01) - # timeout - job = Job(tmp_path) - fake_time = mocker.patch("sapphire.connection_manager.time", autospec=True) - fake_time.side_effect = count() - with ConnectionManager(job, serv_sock, max_workers=10) as loadmgr: - assert not loadmgr.wait(1, continue_cb=lambda: True, poll=0.01) + assert mgr.serve(10, continue_cb=lambda: False) + assert job.is_complete() def test_connection_manager_05(mocker, tmp_path): - """test ConnectionManager re-raise worker exceptions""" + """test ConnectionManager.serve() with timeout""" + mocker.patch("sapphire.connection_manager.time", autospec=True, side_effect=count()) (tmp_path / "test1").touch() job = Job(tmp_path) clnt_sock = mocker.Mock(spec_set=socket) - clnt_sock.recv.side_effect = Exception("worker exception") + clnt_sock.recv.return_value = b"" serv_sock = mocker.Mock(spec_set=socket) serv_sock.accept.return_value = (clnt_sock, None) - with raises(Exception, match="worker exception"): - with ConnectionManager(job, serv_sock) as loadmgr: - loadmgr.wait(1) - assert clnt_sock.close.call_count == 1 - assert job.is_complete() - assert job.exceptions.empty() + job = Job(tmp_path) + with ConnectionManager(job, serv_sock, poll=0.01) as mgr: + assert not mgr.serve(10) + assert job.is_complete() def test_connection_manager_06(mocker, tmp_path): - """test ConnectionManager re-raise launcher exceptions""" + """test ConnectionManager.serve() worker fails to exit""" + mocker.patch("sapphire.worker.Thread", autospec=True) + mocker.patch("sapphire.connection_manager.time", autospec=True, side_effect=count()) (tmp_path / "test1").touch() - job = Job(tmp_path) + clnt_sock = mocker.Mock(spec_set=socket) serv_sock = mocker.Mock(spec_set=socket) - serv_sock.accept.side_effect = Exception("launcher exception") - with raises(Exception, match="launcher exception"): - with ConnectionManager(job, serv_sock) as loadmgr: - loadmgr.wait(1) - assert job.is_complete() - assert job.exceptions.empty() + serv_sock.accept.return_value = (clnt_sock, None) + job = Job(tmp_path) + mocker.patch.object(job, "worker_complete") + with ConnectionManager(job, serv_sock) as mgr: + with raises(RuntimeError, match="Failed to close workers"): + mgr.serve(10) + assert job.is_complete() + assert clnt_sock.close.call_count == 1 + + +def test_connection_manager_07(mocker): + """test ConnectionManager._join_workers()""" + # no workers + assert not ConnectionManager._join_workers([]) + # worker fails to join, without timeout + fake_worker = mocker.Mock(spec_set=Worker) + fake_worker.join.return_value = False + assert ConnectionManager._join_workers([fake_worker], timeout=0) + assert fake_worker.join.call_count == 1 + fake_worker.reset_mock() + # worker fails to join, with timeout + assert ConnectionManager._join_workers([fake_worker], timeout=1) + assert fake_worker.join.call_count == 1 + fake_worker.reset_mock() + # worker joins + fake_worker.join.return_value = True + assert not ConnectionManager._join_workers([fake_worker], timeout=0) + assert fake_worker.join.call_count == 1