Skip to content

Commit

Permalink
[sapphire] Add socket timeout for worker
Browse files Browse the repository at this point in the history
Also improve the worker thread cleanup code.
  • Loading branch information
tysmith committed Sep 26, 2023
1 parent 6f89c25 commit 7c9c768
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 62 deletions.
23 changes: 16 additions & 7 deletions sapphire/connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,16 @@ def listener(serv_sock, serv_job, max_workers, shutdown_delay=0):
# sometimes the thread that triggered the event doesn't quite
# cleanup in time, so retry (10x with 0.5 second sleep on failure)
for _ in range(10):
workers = list(w for w in workers if not w.done)
workers = list(w for w in workers if not w.join(timeout=0))
if len(workers) < max_workers:
break
sleep(0.5) # pragma: no cover
else: # pragma: no cover
# this should never happen
raise RuntimeError("Failed remove workers!")
LOG.error("Failed to remove workers")
raise RuntimeError("Failed to remove workers")
LOG.debug("removed completed workers (%d active)", len(workers))

except Exception: # pylint: disable=broad-except
if serv_job.exceptions.empty():
serv_job.exceptions.put(exc_info())
Expand All @@ -135,12 +137,19 @@ def listener(serv_sock, serv_job, max_workers, shutdown_delay=0):
# wait for all running workers to exit
while time() < deadline:
serv_job.worker_complete.clear()
if all(w.done for w in workers):
workers = list(w for w in workers if not w.join(timeout=0))
if not workers:
break
serv_job.worker_complete.wait(max(deadline - time(), 0))
else: # pragma: no cover
# reached deadline force close workers
workers = list(w for w in workers if not w.done)
LOG.debug("closing remaining %d worker(s)", len(workers))
# close remaining active workers
if workers:
LOG.debug("closing remaining active workers: %d", len(workers))
for worker in workers:
worker.close()
# join remaining workers
deadline = time() + 30
for worker in workers:
worker.join(timeout=max(deadline - time(), 0))
if not all(w.join(timeout=0) for w in workers): # pragma: no cover
LOG.error("Failed to close workers")
raise RuntimeError("Failed to close workers")
70 changes: 32 additions & 38 deletions sapphire/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,45 +4,38 @@
# pylint: disable=protected-access

import socket
import threading
from threading import Thread, ThreadError

from pytest import mark, raises
from pytest import mark

from .job import Job
from .worker import Request, Worker, WorkerError
from .worker import Request, Worker


def test_worker_01(mocker):
"""test simple Worker in running state"""
wthread = mocker.Mock(spec_set=threading.Thread)
wthread.is_alive.return_value = True
worker = Worker(mocker.Mock(spec_set=socket.socket), wthread)
assert worker._conn is not None
assert worker._thread is not None
"""test a Worker"""
wthread = mocker.Mock(spec_set=Thread)
wsocket = mocker.Mock(spec_set=socket.socket)
worker = Worker(wsocket, wthread)
# it is assumed that launch() has already been called at this point
assert not worker.done
assert wthread.join.call_count == 0
assert worker.is_alive()
assert wthread.is_alive.call_count == 1
worker.join(timeout=0)
assert not worker.join(timeout=0)
assert wthread.join.call_count == 1
assert wthread.is_alive.call_count == 2
assert worker._conn.close.call_count == 0
# simulator closing a worker that is alive
# have shutdown raise OSError for coverage
wsocket.shutdown.side_effect = (OSError("test"),)
worker.close()
assert wsocket.shutdown.call_count == 1
assert wsocket.close.call_count == 1
# at this point the worker should be complete
wthread.is_alive.return_value = False
assert not worker.is_alive()
assert worker.join(timeout=0)
# calling a close when the worker is not alive should do nothing
worker.close()
assert worker._conn.close.call_count == 1
assert worker._thread is None
assert worker.done


def test_worker_02(mocker):
"""test simple Worker fails to close"""
worker = Worker(
mocker.Mock(spec_set=socket.socket), mocker.Mock(spec_set=threading.Thread)
)
# it is assumed that launch() has already been called at this point
worker._thread.is_alive.return_value = True
with raises(WorkerError, match="Worker thread failed to join!"):
worker.close()
assert wsocket.shutdown.call_count == 1
assert wsocket.close.call_count == 1


@mark.parametrize(
Expand All @@ -52,7 +45,7 @@ def test_worker_02(mocker):
OSError("test"),
],
)
def test_worker_03(mocker, exc):
def test_worker_02(mocker, exc):
"""test Worker.launch() socket exception cases"""
mocker.patch("sapphire.worker.Thread", autospec=True)
serv_con = mocker.Mock(spec_set=socket.socket)
Expand All @@ -63,10 +56,10 @@ def test_worker_03(mocker, exc):
assert serv_job.accepting.set.call_count == 0


def test_worker_04(mocker):
def test_worker_03(mocker):
"""test Worker.launch() thread exception case"""
mocker.patch("sapphire.worker.sleep", autospec=True)
mocker.patch("sapphire.worker.Thread", side_effect=threading.ThreadError("test"))
mocker.patch("sapphire.worker.Thread", side_effect=ThreadError("test"))
serv_con = mocker.Mock(spec_set=socket.socket)
serv_job = mocker.Mock(spec_set=Job)
conn = mocker.Mock(spec_set=socket.socket)
Expand All @@ -87,7 +80,7 @@ def test_worker_04(mocker):
"http://sub.host:1234/testfile",
],
)
def test_worker_05(mocker, tmp_path, url):
def test_worker_04(mocker, tmp_path, url):
"""test Worker.launch()"""
(tmp_path / "testfile").touch()
job = Job(tmp_path)
Expand All @@ -103,10 +96,10 @@ def test_worker_05(mocker, tmp_path, url):
worker.close()
if not job.exceptions.empty():
raise job.exceptions.get()[1]
assert worker.done
assert worker.join(timeout=10)
assert clnt_sock.sendall.called
assert serv_sock.accept.call_count == 1
assert clnt_sock.close.call_count == 2
assert clnt_sock.close.call_count == 1


@mark.parametrize(
Expand All @@ -116,7 +109,7 @@ def test_worker_05(mocker, tmp_path, url):
(b"BAD / HTTP/1.1", b"405 Method Not Allowed"),
],
)
def test_worker_06(mocker, tmp_path, req, response):
def test_worker_05(mocker, tmp_path, req, response):
"""test Worker.launch() with invalid/unsupported requests"""
(tmp_path / "testfile").touch()
job = Job(tmp_path)
Expand All @@ -126,24 +119,25 @@ def test_worker_06(mocker, tmp_path, req, response):
serv_sock.accept.return_value = (clnt_sock, None)
worker = Worker.launch(serv_sock, job)
assert worker is not None
worker.join(timeout=1)
assert worker.join(timeout=10)
worker.close()
if not job.exceptions.empty():
raise job.exceptions.get()[1]
assert serv_sock.accept.call_count == 1
assert clnt_sock.close.call_count == 2
assert clnt_sock.close.call_count == 1
assert clnt_sock.sendall.called
assert response in clnt_sock.sendall.call_args[0][0]


def test_worker_07(mocker):
def test_worker_06(mocker):
"""test Worker.handle_request() socket errors"""
serv_con = mocker.Mock(spec_set=socket.socket)
serv_con.recv.side_effect = OSError
serv_job = mocker.Mock(spec_set=Job)
Worker.handle_request(serv_con, serv_job)
assert serv_job.accepting.set.call_count == 1
assert serv_con.sendall.call_count == 0
assert serv_con.shutdown.call_count == 0
assert serv_con.close.call_count == 1


Expand Down
36 changes: 19 additions & 17 deletions sapphire/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,27 +112,25 @@ def _4xx_page(code, hdr_msg, close=-1, encoding="ascii"):
return data.encode(encoding)

def close(self):
if not self.done:
LOG.debug("closing socket while thread is running!")
# workers that are no longer running will have had close() called
if self.is_alive():
# shutdown socket to avoid hang
self._conn.shutdown(SHUT_RDWR)
self._conn.close()
self.join(timeout=60)
if self._thread is not None and self._thread.is_alive():
# this is here to catch unexpected hangs
raise WorkerError("Worker thread failed to join!")
LOG.debug("closing socket while thread is running!")
try:
self._conn.shutdown(SHUT_RDWR)
except OSError as exc:
LOG.debug("close - shutdown(): %s", exc)
self._conn.close()

@property
def done(self):
if self._thread is not None and not self._thread.is_alive():
self.join()
self._thread = None
return self._thread is None
def is_alive(self):
return self._thread is not None and self._thread.is_alive()

@classmethod
def handle_request(cls, conn, serv_job):
finish_job = False # call finish() on return
try:
# socket operations should not block forever
assert conn.gettimeout() is not None
# receive incoming request data
raw_request = conn.recv(cls.DEFAULT_REQUEST_LIMIT)
if not raw_request:
Expand Down Expand Up @@ -244,6 +242,7 @@ def handle_request(cls, conn, serv_job):
serv_job.accepting.set()

except Exception: # pylint: disable=broad-except
LOG.debug("worker thread exception")
# set finish_job to abort immediately
finish_job = True
if serv_job.exceptions.empty():
Expand All @@ -255,19 +254,22 @@ def handle_request(cls, conn, serv_job):
serv_job.finish()
serv_job.worker_complete.set()

def join(self, timeout=None):
def join(self, timeout=30):
assert timeout >= 0
if self._thread is not None:
self._thread.join(timeout=timeout)
if not self._thread.is_alive():
self._thread = None
return self._thread is None

@classmethod
def launch(cls, listen_sock, job):
def launch(cls, listen_sock, job, timeout=30):
assert timeout >= 0
assert job.accepting.is_set()
conn = None
try:
conn, _ = listen_sock.accept()
conn.settimeout(None)
conn.settimeout(timeout)
# create a worker thread to handle client request
w_thread = Thread(target=cls.handle_request, args=(conn, job))
job.accepting.clear()
Expand Down

0 comments on commit 7c9c768

Please sign in to comment.