Skip to content

Commit

Permalink
Merge the gateway handlers into the standard handlers. (#1261)
Browse files Browse the repository at this point in the history
Co-authored-by: Steven Silvester <[email protected]>
  • Loading branch information
ojarjur and blink1073 authored May 10, 2023
1 parent 54d7292 commit 35ffe5d
Show file tree
Hide file tree
Showing 7 changed files with 293 additions and 14 deletions.
6 changes: 6 additions & 0 deletions docs/source/api/jupyter_server.gateway.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ Submodules
----------


.. automodule:: jupyter_server.gateway.connections
:members:
:undoc-members:
:show-inheritance:


.. automodule:: jupyter_server.gateway.gateway_client
:members:
:undoc-members:
Expand Down
175 changes: 175 additions & 0 deletions jupyter_server/gateway/connections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
"""Gateway connection classes."""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.

import asyncio
import logging
import random
from typing import Any, cast

import tornado.websocket as tornado_websocket
from tornado.concurrent import Future
from tornado.escape import json_decode, url_escape, utf8
from tornado.httpclient import HTTPRequest
from tornado.ioloop import IOLoop
from traitlets import Bool, Instance, Int

from ..services.kernels.connection.base import BaseKernelWebsocketConnection
from ..utils import url_path_join
from .managers import GatewayClient


class GatewayWebSocketConnection(BaseKernelWebsocketConnection):
"""Web socket connection that proxies to a kernel/enterprise gateway."""

ws = Instance(klass=tornado_websocket.WebSocketClientConnection, allow_none=True)

ws_future = Instance(default_value=Future(), klass=Future)

disconnected = Bool(False)

retry = Int(0)

async def connect(self):
"""Connect to the socket."""
# websocket is initialized before connection
self.ws = None
ws_url = url_path_join(
GatewayClient.instance().ws_url,
GatewayClient.instance().kernels_endpoint,
url_escape(self.kernel_id),
"channels",
)
self.log.info(f"Connecting to {ws_url}")
kwargs: dict = {}
kwargs = GatewayClient.instance().load_connection_args(**kwargs)

request = HTTPRequest(ws_url, **kwargs)
self.ws_future = cast(Future, tornado_websocket.websocket_connect(request))
self.ws_future.add_done_callback(self._connection_done)

loop = IOLoop.current()
loop.add_future(self.ws_future, lambda future: self._read_messages())

def _connection_done(self, fut):
"""Handle a finished connection."""
if (
not self.disconnected and fut.exception() is None
): # prevent concurrent.futures._base.CancelledError
self.ws = fut.result()
self.retry = 0
self.log.debug(f"Connection is ready: ws: {self.ws}")
else:
self.log.warning(
"Websocket connection has been closed via client disconnect or due to error. "
"Kernel with ID '{}' may not be terminated on GatewayClient: {}".format(
self.kernel_id, GatewayClient.instance().url
)
)

def disconnect(self):
"""Handle a disconnect."""
self.disconnected = True
if self.ws is not None:
# Close connection
self.ws.close()
elif not self.ws_future.done():
# Cancel pending connection. Since future.cancel() is a noop on tornado, we'll track cancellation locally
self.ws_future.cancel()
self.log.debug(f"_disconnect: future cancelled, disconnected: {self.disconnected}")

async def _read_messages(self):
"""Read messages from gateway server."""
while self.ws is not None:
message = None
if not self.disconnected:
try:
message = await self.ws.read_message()
except Exception as e:
self.log.error(
f"Exception reading message from websocket: {e}"
) # , exc_info=True)
if message is None:
if not self.disconnected:
self.log.warning(f"Lost connection to Gateway: {self.kernel_id}")
break
self.handle_outgoing_message(
message
) # pass back to notebook client (see self.on_open and WebSocketChannelsHandler.open)
else: # ws cancelled - stop reading
break

# NOTE(esevan): if websocket is not disconnected by client, try to reconnect.
if not self.disconnected and self.retry < GatewayClient.instance().gateway_retry_max:
jitter = random.randint(10, 100) * 0.01 # noqa
retry_interval = (
min(
GatewayClient.instance().gateway_retry_interval * (2**self.retry),
GatewayClient.instance().gateway_retry_interval_max,
)
+ jitter
)
self.retry += 1
self.log.info(
"Attempting to re-establish the connection to Gateway in %s secs (%s/%s): %s",
retry_interval,
self.retry,
GatewayClient.instance().gateway_retry_max,
self.kernel_id,
)
await asyncio.sleep(retry_interval)
loop = IOLoop.current()
loop.spawn_callback(self.connect)

def handle_outgoing_message(self, incoming_msg: str, *args: Any) -> None:
"""Send message to the notebook client."""
try:
self.websocket_handler.write_message(incoming_msg)
except tornado_websocket.WebSocketClosedError:
if self.log.isEnabledFor(logging.DEBUG):
msg_summary = GatewayWebSocketConnection._get_message_summary(
json_decode(utf8(incoming_msg))
)
self.log.debug(
"Notebook client closed websocket connection - message dropped: {}".format(
msg_summary
)
)

def handle_incoming_message(self, message: str) -> None:
"""Send message to gateway server."""
if self.ws is None:
loop = IOLoop.current()
loop.add_future(self.ws_future, lambda future: self.handle_incoming_message(message))
else:
self._write_message(message)

def _write_message(self, message):
"""Send message to gateway server."""
try:
if not self.disconnected and self.ws is not None:
self.ws.write_message(message)
except Exception as e:
self.log.error(f"Exception writing message to websocket: {e}") # , exc_info=True)

@staticmethod
def _get_message_summary(message):
"""Get a summary of a message."""
summary = []
message_type = message["msg_type"]
summary.append(f"type: {message_type}")

if message_type == "status":
summary.append(", state: {}".format(message["content"]["execution_state"]))
elif message_type == "error":
summary.append(
", {}:{}:{}".format(
message["content"]["ename"],
message["content"]["evalue"],
message["content"]["traceback"],
)
)
else:
summary.append(", ...") # don't display potentially sensitive data

return "".join(summary)
8 changes: 8 additions & 0 deletions jupyter_server/gateway/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import mimetypes
import os
import random
import warnings
from typing import Optional, cast

from jupyter_client.session import Session
Expand All @@ -21,6 +22,13 @@
from ..utils import url_path_join
from .managers import GatewayClient

warnings.warn(
"The jupyter_server.gateway.handlers module is deprecated and will not be supported in Jupyter Server 3.0",
DeprecationWarning,
stacklevel=2,
)


# Keepalive ping interval (default: 30 seconds)
GATEWAY_WS_PING_INTERVAL_SECS = int(os.getenv("GATEWAY_WS_PING_INTERVAL_SECS", "30"))

Expand Down
22 changes: 22 additions & 0 deletions jupyter_server/kernelspecs/handlers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Kernelspecs API Handlers."""
import mimetypes

from jupyter_core.utils import ensure_async
from tornado import web

Expand Down Expand Up @@ -27,6 +29,26 @@ async def get(self, kernel_name, path, include_body=True):
ksm = self.kernel_spec_manager
if path.lower().endswith(".png"):
self.set_header("Cache-Control", f"max-age={60*60*24*30}")
ksm = self.kernel_spec_manager
if hasattr(ksm, "get_kernel_spec_resource"):
# If the kernel spec manager defines a method to get kernelspec resources,
# then use that instead of trying to read from disk.
kernel_spec_res = await ksm.get_kernel_spec_resource(kernel_name, path)
if kernel_spec_res is not None:
# We have to explicitly specify the `absolute_path` attribute so that
# the underlying StaticFileHandler methods can calculate an etag.
self.absolute_path = path
mimetype: str = mimetypes.guess_type(path)[0] or "text/plain"
self.set_header("Content-Type", mimetype)
self.finish(kernel_spec_res)
return
else:
self.log.warning(
"Kernelspec resource '{}' for '{}' not found. Kernel spec manager may"
" not support resource serving. Falling back to reading from disk".format(
path, kernel_name
)
)
try:
kspec = await ensure_async(ksm.get_kernel_spec(kernel_name))
self.root = kspec.resource_dir
Expand Down
34 changes: 21 additions & 13 deletions jupyter_server/serverapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
from jupyter_server.extension.config import ExtensionConfigManager
from jupyter_server.extension.manager import ExtensionManager
from jupyter_server.extension.serverextension import ServerExtensionApp
from jupyter_server.gateway.connections import GatewayWebSocketConnection
from jupyter_server.gateway.managers import (
GatewayClient,
GatewayKernelSpecManager,
Expand Down Expand Up @@ -433,17 +434,6 @@ def init_handlers(self, default_services, settings):
# And from identity provider
handlers.extend(settings["identity_provider"].get_handlers())

# If gateway mode is enabled, replace appropriate handlers to perform redirection
if GatewayClient.instance().gateway_enabled:
# for each handler required for gateway, locate its pattern
# in the current list and replace that entry...
gateway_handlers = load_handlers("jupyter_server.gateway.handlers")
for _, gwh in enumerate(gateway_handlers):
for j, h in enumerate(handlers):
if gwh[0] == h[0]:
handlers[j] = (gwh[0], gwh[1])
break

# register base handlers last
handlers.extend(load_handlers("jupyter_server.base.handlers"))

Expand Down Expand Up @@ -796,6 +786,7 @@ class ServerApp(JupyterApp):
GatewayMappingKernelManager,
GatewayKernelSpecManager,
GatewaySessionManager,
GatewayWebSocketConnection,
GatewayClient,
Authorizer,
EventLogger,
Expand Down Expand Up @@ -1505,12 +1496,17 @@ def _default_session_manager_class(self):
return SessionManager

kernel_websocket_connection_class = Type(
default_value=ZMQChannelsWebsocketConnection,
klass=BaseKernelWebsocketConnection,
config=True,
help=_i18n("The kernel websocket connection class to use."),
)

@default("kernel_websocket_connection_class")
def _default_kernel_websocket_connection_class(self):
if self.gateway_config.gateway_enabled:
return "jupyter_server.gateway.connections.GatewayWebSocketConnection"
return ZMQChannelsWebsocketConnection

config_manager_class = Type(
default_value=ConfigManager,
config=True,
Expand Down Expand Up @@ -2876,7 +2872,19 @@ async def _cleanup(self):
self.remove_browser_open_files()
await self.cleanup_extensions()
await self.cleanup_kernels()
await self.kernel_websocket_connection_class.close_all()
try:
await self.kernel_websocket_connection_class.close_all()
except AttributeError:
# This can happen in two different scenarios:
#
# 1. During tests, where the _cleanup method is invoked without
# the corresponding initialize method having been invoked.
# 2. If the provided `kernel_websocket_connection_class` does not
# implement the `close_all` class method.
#
# In either case, we don't need to do anything and just want to treat
# the raised error as a no-op.
pass
if getattr(self, "kernel_manager", None):
self.kernel_manager.__del__()
if getattr(self, "session_manager", None):
Expand Down
60 changes: 60 additions & 0 deletions tests/test_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,18 @@
import pytest
import tornado
from jupyter_core.utils import ensure_async
from tornado.concurrent import Future
from tornado.httpclient import HTTPRequest, HTTPResponse
from tornado.httputil import HTTPServerRequest
from tornado.queues import Queue
from tornado.web import HTTPError
from traitlets import Int, Unicode
from traitlets.config import Config

from jupyter_server.gateway.connections import GatewayWebSocketConnection
from jupyter_server.gateway.gateway_client import GatewayTokenRenewerBase, NoOpTokenRenewer
from jupyter_server.gateway.managers import ChannelQueue, GatewayClient, GatewayKernelManager
from jupyter_server.services.kernels.websocket import KernelWebsocketHandler

from .utils import expected_http_error

Expand Down Expand Up @@ -659,6 +664,61 @@ async def test_channel_queue_get_msg_when_response_router_had_finished():
await queue.get_msg()


class MockWebSocketClientConnection(tornado.websocket.WebSocketClientConnection):
def __init__(self, *args, **kwargs):
self._msgs: Queue = Queue(2)
self._msgs.put_nowait('{"msg_type": "status", "content": {"execution_state": "starting"}}')

def write_message(self, message, *args, **kwargs):
return self._msgs.put(message)

def read_message(self, *args, **kwargs):
return self._msgs.get()


def mock_websocket_connect():
def helper(request):
fut: Future = Future()
mock_client = MockWebSocketClientConnection()
fut.set_result(mock_client)
return fut

return helper


@patch("tornado.websocket.websocket_connect", mock_websocket_connect())
async def test_websocket_connection_closed(init_gateway, jp_serverapp, jp_fetch, caplog):
# Create the kernel and get the kernel manager...
kernel_id = await create_kernel(jp_fetch, "kspec_foo")
km: GatewayKernelManager = jp_serverapp.kernel_manager.get_kernel(kernel_id)

# Create the KernelWebsocketHandler...
request = HTTPServerRequest("foo", "GET")
request.connection = MagicMock()
handler = KernelWebsocketHandler(jp_serverapp.web_app, request)

# Force the websocket handler to raise a closed error if we try to write a message
# to the client.
handler.ws_connection = MagicMock()
handler.ws_connection.is_closing = lambda: True

# Create the GatewayWebSocketConnection and attach it to the handler...
conn = GatewayWebSocketConnection(parent=km, websocket_handler=handler)
handler.connection = conn
await conn.connect()

# Processing websocket messages happens in separate coroutines and any
# errors in that process will show up in logs, but not bubble up to the
# caller.
#
# To check for these, we wait for the server to stop and then check the
# logs for errors.
await jp_serverapp._cleanup()
for _, level, message in caplog.record_tuples:
if level >= logging.ERROR:
pytest.fail(f"Logs contain an error: {message}")


#
# Test methods below...
#
Expand Down
Loading

0 comments on commit 35ffe5d

Please sign in to comment.