Skip to content

Commit

Permalink
[sapphire] Add socket timeout for worker
Browse files Browse the repository at this point in the history
Also improve the worker thread cleanup code.
  • Loading branch information
tysmith committed Sep 26, 2023
1 parent 6f89c25 commit f6d45bd
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 59 deletions.
23 changes: 16 additions & 7 deletions sapphire/connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,16 @@ def listener(serv_sock, serv_job, max_workers, shutdown_delay=0):
# sometimes the thread that triggered the event doesn't quite
# cleanup in time, so retry (10x with 0.5 second sleep on failure)
for _ in range(10):
workers = list(w for w in workers if not w.done)
workers = list(w for w in workers if not w.join(timeout=0))
if len(workers) < max_workers:
break
sleep(0.5) # pragma: no cover
else: # pragma: no cover
# this should never happen
raise RuntimeError("Failed remove workers!")
LOG.error("Failed to remove workers")
raise RuntimeError("Failed to remove workers")
LOG.debug("removed completed workers (%d active)", len(workers))

except Exception: # pylint: disable=broad-except
if serv_job.exceptions.empty():
serv_job.exceptions.put(exc_info())
Expand All @@ -135,12 +137,19 @@ def listener(serv_sock, serv_job, max_workers, shutdown_delay=0):
# wait for all running workers to exit
while time() < deadline:
serv_job.worker_complete.clear()
if all(w.done for w in workers):
workers = list(w for w in workers if not w.join(timeout=0))
if not workers:
break
serv_job.worker_complete.wait(max(deadline - time(), 0))
else: # pragma: no cover
# reached deadline force close workers
workers = list(w for w in workers if not w.done)
LOG.debug("closing remaining %d worker(s)", len(workers))
# close remaining active workers
if workers:
LOG.debug("closing remaining active workers: %d", len(workers))
for worker in workers:
worker.close()
# join remaining workers
deadline = time() + 30
for worker in workers:
worker.join(timeout=max(deadline - time(), 0))
if not all(w.join(timeout=0) for w in workers): # pragma: no cover
LOG.error("Failed to close workers")
raise RuntimeError("Failed to close workers")
64 changes: 29 additions & 35 deletions sapphire/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,43 +6,36 @@
import socket
import threading

from pytest import mark, raises
from pytest import mark

from .job import Job
from .worker import Request, Worker, WorkerError
from .worker import Request, Worker


def test_worker_01(mocker):
"""test simple Worker in running state"""
"""test a Worker"""
wthread = mocker.Mock(spec_set=threading.Thread)
wthread.is_alive.return_value = True
worker = Worker(mocker.Mock(spec_set=socket.socket), wthread)
assert worker._conn is not None
assert worker._thread is not None
wsocket = mocker.Mock(spec_set=socket.socket)
worker = Worker(wsocket, wthread)
# it is assumed that launch() has already been called at this point
assert not worker.done
assert wthread.join.call_count == 0
assert worker.is_alive()
assert wthread.is_alive.call_count == 1
worker.join(timeout=0)
assert not worker.join(timeout=0)
assert wthread.join.call_count == 1
assert wthread.is_alive.call_count == 2
assert worker._conn.close.call_count == 0
# simulator closing a worker that is alive
# have shutdown raise OSError for coverage
wsocket.shutdown.side_effect = (OSError("test"),)
worker.close()
assert wsocket.shutdown.call_count == 1
assert wsocket.close.call_count == 1
# at this point the worker should be complete
wthread.is_alive.return_value = False
assert not worker.is_alive()
assert worker.join(timeout=0)
# calling a close when the worker is not alive should do nothing
worker.close()
assert worker._conn.close.call_count == 1
assert worker._thread is None
assert worker.done


def test_worker_02(mocker):
"""test simple Worker fails to close"""
worker = Worker(
mocker.Mock(spec_set=socket.socket), mocker.Mock(spec_set=threading.Thread)
)
# it is assumed that launch() has already been called at this point
worker._thread.is_alive.return_value = True
with raises(WorkerError, match="Worker thread failed to join!"):
worker.close()
assert wsocket.shutdown.call_count == 1
assert wsocket.close.call_count == 1


@mark.parametrize(
Expand All @@ -52,7 +45,7 @@ def test_worker_02(mocker):
OSError("test"),
],
)
def test_worker_03(mocker, exc):
def test_worker_02(mocker, exc):
"""test Worker.launch() socket exception cases"""
mocker.patch("sapphire.worker.Thread", autospec=True)
serv_con = mocker.Mock(spec_set=socket.socket)
Expand All @@ -63,7 +56,7 @@ def test_worker_03(mocker, exc):
assert serv_job.accepting.set.call_count == 0


def test_worker_04(mocker):
def test_worker_03(mocker):
"""test Worker.launch() thread exception case"""
mocker.patch("sapphire.worker.sleep", autospec=True)
mocker.patch("sapphire.worker.Thread", side_effect=threading.ThreadError("test"))
Expand All @@ -87,7 +80,7 @@ def test_worker_04(mocker):
"http://sub.host:1234/testfile",
],
)
def test_worker_05(mocker, tmp_path, url):
def test_worker_04(mocker, tmp_path, url):
"""test Worker.launch()"""
(tmp_path / "testfile").touch()
job = Job(tmp_path)
Expand All @@ -103,10 +96,10 @@ def test_worker_05(mocker, tmp_path, url):
worker.close()
if not job.exceptions.empty():
raise job.exceptions.get()[1]
assert worker.done
assert worker.join(timeout=0)
assert clnt_sock.sendall.called
assert serv_sock.accept.call_count == 1
assert clnt_sock.close.call_count == 2
assert clnt_sock.close.call_count == 1


@mark.parametrize(
Expand All @@ -116,7 +109,7 @@ def test_worker_05(mocker, tmp_path, url):
(b"BAD / HTTP/1.1", b"405 Method Not Allowed"),
],
)
def test_worker_06(mocker, tmp_path, req, response):
def test_worker_05(mocker, tmp_path, req, response):
"""test Worker.launch() with invalid/unsupported requests"""
(tmp_path / "testfile").touch()
job = Job(tmp_path)
Expand All @@ -126,24 +119,25 @@ def test_worker_06(mocker, tmp_path, req, response):
serv_sock.accept.return_value = (clnt_sock, None)
worker = Worker.launch(serv_sock, job)
assert worker is not None
worker.join(timeout=1)
assert worker.join(timeout=0)
worker.close()
if not job.exceptions.empty():
raise job.exceptions.get()[1]
assert serv_sock.accept.call_count == 1
assert clnt_sock.close.call_count == 2
assert clnt_sock.close.call_count == 1
assert clnt_sock.sendall.called
assert response in clnt_sock.sendall.call_args[0][0]


def test_worker_07(mocker):
def test_worker_06(mocker):
"""test Worker.handle_request() socket errors"""
serv_con = mocker.Mock(spec_set=socket.socket)
serv_con.recv.side_effect = OSError
serv_job = mocker.Mock(spec_set=Job)
Worker.handle_request(serv_con, serv_job)
assert serv_job.accepting.set.call_count == 1
assert serv_con.sendall.call_count == 0
assert serv_con.shutdown.call_count == 0
assert serv_con.close.call_count == 1


Expand Down
36 changes: 19 additions & 17 deletions sapphire/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,27 +112,25 @@ def _4xx_page(code, hdr_msg, close=-1, encoding="ascii"):
return data.encode(encoding)

def close(self):
if not self.done:
LOG.debug("closing socket while thread is running!")
# workers that are no longer running will have had close() called
if self.is_alive():
# shutdown socket to avoid hang
self._conn.shutdown(SHUT_RDWR)
self._conn.close()
self.join(timeout=60)
if self._thread is not None and self._thread.is_alive():
# this is here to catch unexpected hangs
raise WorkerError("Worker thread failed to join!")
LOG.debug("closing socket while thread is running!")
try:
self._conn.shutdown(SHUT_RDWR)
except OSError as exc:
LOG.debug("close - shutdown(): %s", exc)
self._conn.close()

@property
def done(self):
if self._thread is not None and not self._thread.is_alive():
self.join()
self._thread = None
return self._thread is None
def is_alive(self):
return self._thread is not None and self._thread.is_alive()

@classmethod
def handle_request(cls, conn, serv_job):
finish_job = False # call finish() on return
try:
# socket operations should not block forever
assert conn.gettimeout() is not None
# receive incoming request data
raw_request = conn.recv(cls.DEFAULT_REQUEST_LIMIT)
if not raw_request:
Expand Down Expand Up @@ -244,6 +242,7 @@ def handle_request(cls, conn, serv_job):
serv_job.accepting.set()

except Exception: # pylint: disable=broad-except
LOG.debug("worker thread exception")
# set finish_job to abort immediately
finish_job = True
if serv_job.exceptions.empty():
Expand All @@ -255,19 +254,22 @@ def handle_request(cls, conn, serv_job):
serv_job.finish()
serv_job.worker_complete.set()

def join(self, timeout=None):
def join(self, timeout=30):
assert timeout >= 0
if self._thread is not None:
self._thread.join(timeout=timeout)
if not self._thread.is_alive():
self._thread = None
return self._thread is None

@classmethod
def launch(cls, listen_sock, job):
def launch(cls, listen_sock, job, timeout=30):
assert timeout >= 0
assert job.accepting.is_set()
conn = None
try:
conn, _ = listen_sock.accept()
conn.settimeout(None)
conn.settimeout(timeout)
# create a worker thread to handle client request
w_thread = Thread(target=cls.handle_request, args=(conn, job))
job.accepting.clear()
Expand Down

0 comments on commit f6d45bd

Please sign in to comment.