From 884b72e788ecfc89556787d87b71a42a9927622b Mon Sep 17 00:00:00 2001 From: Zachary Sailer Date: Sun, 20 Nov 2022 10:14:11 -0800 Subject: [PATCH] New configurable/overridable kernel ZMQ+Websocket connection API (#1047) * add new configurable websocket api * cleaning up unit tests * more updates for unit tests * all loop to ensure kernel is alive before connecting working unit tests * fix pre-commit errors * handle trait deprecation * cleanup from code review * ignore deprecation warning from zmqhandlers module * move base websocket mixin into its own module --- jupyter_server/base/websocket.py | 126 +++ jupyter_server/base/zmqhandlers.py | 367 +------- jupyter_server/serverapp.py | 151 ++-- .../services/kernels/connection/__init__.py | 0 .../services/kernels/connection/abc.py | 33 + .../services/kernels/connection/base.py | 156 ++++ .../services/kernels/connection/channels.py | 809 ++++++++++++++++++ jupyter_server/services/kernels/handlers.py | 716 +--------------- jupyter_server/services/kernels/websocket.py | 81 ++ pyproject.toml | 6 +- tests/services/sessions/test_api.py | 10 +- tests/test_serialize.py | 2 +- 12 files changed, 1324 insertions(+), 1133 deletions(-) create mode 100644 jupyter_server/base/websocket.py create mode 100644 jupyter_server/services/kernels/connection/__init__.py create mode 100644 jupyter_server/services/kernels/connection/abc.py create mode 100644 jupyter_server/services/kernels/connection/base.py create mode 100644 jupyter_server/services/kernels/connection/channels.py create mode 100644 jupyter_server/services/kernels/websocket.py diff --git a/jupyter_server/base/websocket.py b/jupyter_server/base/websocket.py new file mode 100644 index 0000000000..15cd6ea77e --- /dev/null +++ b/jupyter_server/base/websocket.py @@ -0,0 +1,126 @@ +import re +from typing import Optional, no_type_check +from urllib.parse import urlparse + +from tornado import ioloop +from tornado.iostream import IOStream + +# ping interval for keeping websockets alive (30 seconds) +WS_PING_INTERVAL = 30000 + + +class WebSocketMixin: + """Mixin for common websocket options""" + + ping_callback = None + last_ping = 0.0 + last_pong = 0.0 + stream = None # type: Optional[IOStream] + + @property + def ping_interval(self): + """The interval for websocket keep-alive pings. + + Set ws_ping_interval = 0 to disable pings. + """ + return self.settings.get("ws_ping_interval", WS_PING_INTERVAL) # type:ignore[attr-defined] + + @property + def ping_timeout(self): + """If no ping is received in this many milliseconds, + close the websocket connection (VPNs, etc. can fail to cleanly close ws connections). + Default is max of 3 pings or 30 seconds. + """ + return self.settings.get( # type:ignore[attr-defined] + "ws_ping_timeout", max(3 * self.ping_interval, WS_PING_INTERVAL) + ) + + @no_type_check + def check_origin(self, origin: Optional[str] = None) -> bool: + """Check Origin == Host or Access-Control-Allow-Origin. + + Tornado >= 4 calls this method automatically, raising 403 if it returns False. + """ + + if self.allow_origin == "*" or ( + hasattr(self, "skip_check_origin") and self.skip_check_origin() + ): + return True + + host = self.request.headers.get("Host") + if origin is None: + origin = self.get_origin() + + # If no origin or host header is provided, assume from script + if origin is None or host is None: + return True + + origin = origin.lower() + origin_host = urlparse(origin).netloc + + # OK if origin matches host + if origin_host == host: + return True + + # Check CORS headers + if self.allow_origin: + allow = self.allow_origin == origin + elif self.allow_origin_pat: + allow = bool(re.match(self.allow_origin_pat, origin)) + else: + # No CORS headers deny the request + allow = False + if not allow: + self.log.warning( + "Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s", + origin, + host, + ) + return allow + + def clear_cookie(self, *args, **kwargs): + """meaningless for websockets""" + pass + + @no_type_check + def open(self, *args, **kwargs): + self.log.debug("Opening websocket %s", self.request.path) + + # start the pinging + if self.ping_interval > 0: + loop = ioloop.IOLoop.current() + self.last_ping = loop.time() # Remember time of last ping + self.last_pong = self.last_ping + self.ping_callback = ioloop.PeriodicCallback( + self.send_ping, + self.ping_interval, + ) + self.ping_callback.start() + return super().open(*args, **kwargs) + + @no_type_check + def send_ping(self): + """send a ping to keep the websocket alive""" + if self.ws_connection is None and self.ping_callback is not None: + self.ping_callback.stop() + return + + if self.ws_connection.client_terminated: + self.close() + return + + # check for timeout on pong. Make sure that we really have sent a recent ping in + # case the machine with both server and client has been suspended since the last ping. + now = ioloop.IOLoop.current().time() + since_last_pong = 1e3 * (now - self.last_pong) + since_last_ping = 1e3 * (now - self.last_ping) + if since_last_ping < 2 * self.ping_interval and since_last_pong > self.ping_timeout: + self.log.warning("WebSocket ping timeout after %i ms.", since_last_pong) + self.close() + return + + self.ping(b"") + self.last_ping = now + + def on_pong(self, data): + self.last_pong = ioloop.IOLoop.current().time() diff --git a/jupyter_server/base/zmqhandlers.py b/jupyter_server/base/zmqhandlers.py index 91d0f12c8d..454848a9fe 100644 --- a/jupyter_server/base/zmqhandlers.py +++ b/jupyter_server/base/zmqhandlers.py @@ -1,350 +1,17 @@ -"""Tornado handlers for WebSocket <-> ZMQ sockets.""" -# Copyright (c) Jupyter Development Team. -# Distributed under the terms of the Modified BSD License. -import json -import re -import struct -from typing import Optional, no_type_check -from urllib.parse import urlparse - -import tornado - -try: - from jupyter_client.jsonutil import json_default -except ImportError: - from jupyter_client.jsonutil import date_default as json_default - -from jupyter_client.jsonutil import extract_dates -from jupyter_client.session import Session -from tornado import ioloop, web -from tornado.iostream import IOStream -from tornado.websocket import WebSocketClosedError, WebSocketHandler - -from .handlers import JupyterHandler - - -def serialize_binary_message(msg): - """serialize a message as a binary blob - - Header: - - 4 bytes: number of msg parts (nbufs) as 32b int - 4 * nbufs bytes: offset for each buffer as integer as 32b int - - Offsets are from the start of the buffer, including the header. - - Returns - ------- - The message serialized to bytes. - - """ - # don't modify msg or buffer list in-place - msg = msg.copy() - buffers = list(msg.pop("buffers")) - bmsg = json.dumps(msg, default=json_default).encode("utf8") - buffers.insert(0, bmsg) - nbufs = len(buffers) - offsets = [4 * (nbufs + 1)] - for buf in buffers[:-1]: - offsets.append(offsets[-1] + len(buf)) - offsets_buf = struct.pack("!" + "I" * (nbufs + 1), nbufs, *offsets) - buffers.insert(0, offsets_buf) - return b"".join(buffers) - - -def deserialize_binary_message(bmsg): - """deserialize a message from a binary blog - - Header: - - 4 bytes: number of msg parts (nbufs) as 32b int - 4 * nbufs bytes: offset for each buffer as integer as 32b int - - Offsets are from the start of the buffer, including the header. - - Returns - ------- - message dictionary - """ - nbufs = struct.unpack("!i", bmsg[:4])[0] - offsets = list(struct.unpack("!" + "I" * nbufs, bmsg[4 : 4 * (nbufs + 1)])) - offsets.append(None) - bufs = [] - for start, stop in zip(offsets[:-1], offsets[1:]): - bufs.append(bmsg[start:stop]) - msg = json.loads(bufs[0].decode("utf8")) - msg["header"] = extract_dates(msg["header"]) - msg["parent_header"] = extract_dates(msg["parent_header"]) - msg["buffers"] = bufs[1:] - return msg - - -def serialize_msg_to_ws_v1(msg_or_list, channel, pack=None): - if pack: - msg_list = [ - pack(msg_or_list["header"]), - pack(msg_or_list["parent_header"]), - pack(msg_or_list["metadata"]), - pack(msg_or_list["content"]), - ] - else: - msg_list = msg_or_list - channel = channel.encode("utf-8") - offsets: list = [] - offsets.append(8 * (1 + 1 + len(msg_list) + 1)) - offsets.append(len(channel) + offsets[-1]) - for msg in msg_list: - offsets.append(len(msg) + offsets[-1]) - offset_number = len(offsets).to_bytes(8, byteorder="little") - offsets = [offset.to_bytes(8, byteorder="little") for offset in offsets] - bin_msg = b"".join([offset_number] + offsets + [channel] + msg_list) - return bin_msg - - -def deserialize_msg_from_ws_v1(ws_msg): - offset_number = int.from_bytes(ws_msg[:8], "little") - offsets = [ - int.from_bytes(ws_msg[8 * (i + 1) : 8 * (i + 2)], "little") for i in range(offset_number) - ] - channel = ws_msg[offsets[0] : offsets[1]].decode("utf-8") - msg_list = [ws_msg[offsets[i] : offsets[i + 1]] for i in range(1, offset_number - 1)] - return channel, msg_list - - -# ping interval for keeping websockets alive (30 seconds) -WS_PING_INTERVAL = 30000 - - -class WebSocketMixin: - """Mixin for common websocket options""" - - ping_callback = None - last_ping = 0.0 - last_pong = 0.0 - stream = None # type: Optional[IOStream] - - @property - def ping_interval(self): - """The interval for websocket keep-alive pings. - - Set ws_ping_interval = 0 to disable pings. - """ - return self.settings.get("ws_ping_interval", WS_PING_INTERVAL) # type:ignore[attr-defined] - - @property - def ping_timeout(self): - """If no ping is received in this many milliseconds, - close the websocket connection (VPNs, etc. can fail to cleanly close ws connections). - Default is max of 3 pings or 30 seconds. - """ - return self.settings.get( # type:ignore[attr-defined] - "ws_ping_timeout", max(3 * self.ping_interval, WS_PING_INTERVAL) - ) - - @no_type_check - def check_origin(self, origin: Optional[str] = None) -> bool: - """Check Origin == Host or Access-Control-Allow-Origin. - - Tornado >= 4 calls this method automatically, raising 403 if it returns False. - """ - - if self.allow_origin == "*" or ( - hasattr(self, "skip_check_origin") and self.skip_check_origin() - ): - return True - - host = self.request.headers.get("Host") - if origin is None: - origin = self.get_origin() - - # If no origin or host header is provided, assume from script - if origin is None or host is None: - return True - - origin = origin.lower() - origin_host = urlparse(origin).netloc - - # OK if origin matches host - if origin_host == host: - return True - - # Check CORS headers - if self.allow_origin: - allow = self.allow_origin == origin - elif self.allow_origin_pat: - allow = bool(re.match(self.allow_origin_pat, origin)) - else: - # No CORS headers deny the request - allow = False - if not allow: - self.log.warning( - "Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s", - origin, - host, - ) - return allow - - def clear_cookie(self, *args, **kwargs): - """meaningless for websockets""" - pass - - @no_type_check - def open(self, *args, **kwargs): - self.log.debug("Opening websocket %s", self.request.path) - - # start the pinging - if self.ping_interval > 0: - loop = ioloop.IOLoop.current() - self.last_ping = loop.time() # Remember time of last ping - self.last_pong = self.last_ping - self.ping_callback = ioloop.PeriodicCallback( - self.send_ping, - self.ping_interval, - ) - self.ping_callback.start() - return super().open(*args, **kwargs) - - @no_type_check - def send_ping(self): - """send a ping to keep the websocket alive""" - if self.ws_connection is None and self.ping_callback is not None: - self.ping_callback.stop() - return - - if self.ws_connection.client_terminated: - self.close() - return - - # check for timeout on pong. Make sure that we really have sent a recent ping in - # case the machine with both server and client has been suspended since the last ping. - now = ioloop.IOLoop.current().time() - since_last_pong = 1e3 * (now - self.last_pong) - since_last_ping = 1e3 * (now - self.last_ping) - if since_last_ping < 2 * self.ping_interval and since_last_pong > self.ping_timeout: - self.log.warning("WebSocket ping timeout after %i ms.", since_last_pong) - self.close() - return - - self.ping(b"") - self.last_ping = now - - def on_pong(self, data): - self.last_pong = ioloop.IOLoop.current().time() - - -class ZMQStreamHandler(WebSocketMixin, WebSocketHandler): - - if tornado.version_info < (4, 1): - """Backport send_error from tornado 4.1 to 4.0""" - - def send_error(self, *args, **kwargs): - if self.stream is None: - super(WebSocketHandler, self).send_error(*args, **kwargs) - else: - # If we get an uncaught exception during the handshake, - # we have no choice but to abruptly close the connection. - # TODO: for uncaught exceptions after the handshake, - # we can close the connection more gracefully. - self.stream.close() - - def _reserialize_reply(self, msg_or_list, channel=None): - """Reserialize a reply message using JSON. - - msg_or_list can be an already-deserialized msg dict or the zmq buffer list. - If it is the zmq list, it will be deserialized with self.session. - - This takes the msg list from the ZMQ socket and serializes the result for the websocket. - This method should be used by self._on_zmq_reply to build messages that can - be sent back to the browser. - - """ - if isinstance(msg_or_list, dict): - # already unpacked - msg = msg_or_list - else: - idents, msg_list = self.session.feed_identities(msg_or_list) - msg = self.session.deserialize(msg_list) - if channel: - msg["channel"] = channel - if msg["buffers"]: - buf = serialize_binary_message(msg) - return buf - else: - return json.dumps(msg, default=json_default) - - def select_subprotocol(self, subprotocols): - preferred_protocol = self.settings.get("kernel_ws_protocol") - if preferred_protocol is None: - preferred_protocol = "v1.kernel.websocket.jupyter.org" - elif preferred_protocol == "": - preferred_protocol = None - selected_subprotocol = preferred_protocol if preferred_protocol in subprotocols else None - # None is the default, "legacy" protocol - return selected_subprotocol - - def _on_zmq_reply(self, stream, msg_list): - # Sometimes this gets triggered when the on_close method is scheduled in the - # eventloop but hasn't been called. - if self.ws_connection is None or stream.closed(): - self.log.warning("zmq message arrived on closed channel") - self.close() - return - channel = getattr(stream, "channel", None) - if self.selected_subprotocol == "v1.kernel.websocket.jupyter.org": - bin_msg = serialize_msg_to_ws_v1(msg_list, channel) - self.write_message(bin_msg, binary=True) - else: - try: - msg = self._reserialize_reply(msg_list, channel=channel) - except Exception: - self.log.critical("Malformed message: %r" % msg_list, exc_info=True) - else: - try: - self.write_message(msg, binary=isinstance(msg, bytes)) - except WebSocketClosedError as e: - self.log.warning(str(e)) - - -class AuthenticatedZMQStreamHandler(ZMQStreamHandler, JupyterHandler): - def set_default_headers(self): - """Undo the set_default_headers in JupyterHandler - - which doesn't make sense for websockets - """ - pass - - def pre_get(self): - """Run before finishing the GET request - - Extend this method to add logic that should fire before - the websocket finishes completing. - """ - # authenticate the request before opening the websocket - user = self.current_user - if user is None: - self.log.warning("Couldn't authenticate WebSocket connection") - raise web.HTTPError(403) - - # authorize the user. - if not self.authorizer.is_authorized(self, user, "execute", "kernels"): - raise web.HTTPError(403) - - if self.get_argument("session_id", None): - self.session.session = self.get_argument("session_id") - else: - self.log.warning("No session ID specified") - - async def get(self, *args, **kwargs): - # pre_get can be a coroutine in subclasses - # assign and yield in two step to avoid tornado 3 issues - res = self.pre_get() - await res - res = super().get(*args, **kwargs) - await res - - def initialize(self): - self.log.debug("Initializing websocket connection %s", self.request.path) - self.session = Session(config=self.config) - - def get_compression_options(self): - return self.settings.get("websocket_compression_options", None) +"""This module is deprecated in Jupyter Server 2.0""" +# Raise a warning that this module is deprecated. +import warnings + +from jupyter_server.base.websocket import WebSocketMixin +from jupyter_server.services.kernels.connection.base import ( + deserialize_binary_message, + deserialize_msg_from_ws_v1, + serialize_binary_message, + serialize_msg_to_ws_v1, +) + +warnings.warn( + "jupyter_server.base.zmqhandlers module is deprecated in Jupyter Server 2.0", + DeprecationWarning, + stacklevel=2, +) diff --git a/jupyter_server/serverapp.py b/jupyter_server/serverapp.py index e4e8d83880..7ce2e003d0 100644 --- a/jupyter_server/serverapp.py +++ b/jupyter_server/serverapp.py @@ -26,36 +26,6 @@ import webbrowser from base64 import encodebytes -try: - import resource -except ImportError: - # Windows - resource = None # type:ignore[assignment] - -from jinja2 import Environment, FileSystemLoader -from jupyter_core.paths import secure_write - -from jupyter_server.services.kernels.handlers import ZMQChannelsHandler -from jupyter_server.transutils import _i18n, trans -from jupyter_server.utils import ensure_async, pathname2url, urljoin - -# the minimum viable tornado version: needs to be kept in sync with setup.py -MIN_TORNADO = (6, 1, 0) - -try: - import tornado - - assert tornado.version_info >= MIN_TORNADO -except (ImportError, AttributeError, AssertionError) as e: # pragma: no cover - raise ImportError(_i18n("The Jupyter Server requires tornado >=%s.%s.%s") % MIN_TORNADO) from e - -from tornado import httpserver, ioloop, web -from tornado.httputil import url_concat -from tornado.log import LogFormatter, access_log, app_log, gen_log - -if not sys.platform.startswith("win"): - from tornado.netutil import bind_unix_socket - from jupyter_client.kernelspec import KernelSpecManager from jupyter_client.manager import KernelManager from jupyter_client.session import Session @@ -63,6 +33,13 @@ from jupyter_core.paths import jupyter_runtime_dir from jupyter_events.logger import EventLogger from nbformat.sign import NotebookNotary +from tornado import httpserver, ioloop, web +from tornado.httputil import url_concat +from tornado.log import LogFormatter, access_log, app_log, gen_log + +if not sys.platform.startswith("win"): + from tornado.netutil import bind_unix_socket + from traitlets import ( Any, Bool, @@ -127,6 +104,12 @@ AsyncContentsManager, ContentsManager, ) +from jupyter_server.services.kernels.connection.base import ( + BaseKernelWebsocketConnection, +) +from jupyter_server.services.kernels.connection.channels import ( + ZMQChannelsWebsocketConnection, +) from jupyter_server.services.kernels.kernelmanager import ( AsyncMappingKernelManager, MappingKernelManager, @@ -141,6 +124,34 @@ urlencode_unix_socket_path, ) +try: + import resource +except ImportError: + # Windows + resource = None # type:ignore[assignment] + +from jinja2 import Environment, FileSystemLoader +from jupyter_core.paths import secure_write + +from jupyter_server.transutils import _i18n, trans +from jupyter_server.utils import ensure_async, pathname2url, urljoin + +# the minimum viable tornado version: needs to be kept in sync with setup.py +MIN_TORNADO = (6, 1, 0) + +try: + import tornado + + assert tornado.version_info >= MIN_TORNADO +except (ImportError, AttributeError, AssertionError) as e: # pragma: no cover + raise ImportError(_i18n("The Jupyter Server requires tornado >=%s.%s.%s") % MIN_TORNADO) from e + +try: + import resource +except ImportError: + # Windows + resource = None # type:ignore[assignment] + # ----------------------------------------------------------------------------- # Module globals # ----------------------------------------------------------------------------- @@ -157,7 +168,10 @@ config=["jupyter_server.services.config.handlers"], contents=["jupyter_server.services.contents.handlers"], files=["jupyter_server.files.handlers"], - kernels=["jupyter_server.services.kernels.handlers"], + kernels=[ + "jupyter_server.services.kernels.handlers", + "jupyter_server.services.kernels.websocket", + ], kernelspecs=[ "jupyter_server.kernelspecs.handlers", "jupyter_server.services.kernelspecs.handlers", @@ -224,6 +238,7 @@ def __init__( *, authorizer=None, identity_provider=None, + kernel_websocket_connection_class=None, ): if identity_provider is None: warnings.warn( @@ -259,6 +274,7 @@ def __init__( jinja_env_options, authorizer=authorizer, identity_provider=identity_provider, + kernel_websocket_connection_class=kernel_websocket_connection_class, ) handlers = self.init_handlers(default_services, settings) @@ -282,6 +298,7 @@ def init_settings( *, authorizer=None, identity_provider=None, + kernel_websocket_connection_class=None, ): _template_path = settings_overrides.get( @@ -362,6 +379,7 @@ def init_settings( authorizer=authorizer, identity_provider=identity_provider, event_logger=event_logger, + kernel_websocket_connection_class=kernel_websocket_connection_class, # handlers extra_services=extra_services, # Jupyter stuff @@ -772,6 +790,7 @@ class ServerApp(JupyterApp): GatewayClient, Authorizer, EventLogger, + ZMQChannelsWebsocketConnection, ] subcommands = dict( @@ -1462,6 +1481,13 @@ def _default_session_manager_class(self): return "jupyter_server.gateway.managers.GatewaySessionManager" return SessionManager + kernel_websocket_connection_class = Type( + default_value=ZMQChannelsWebsocketConnection, + klass=BaseKernelWebsocketConnection, + config=True, + help=_i18n("The kernel websocket connection class to use."), + ) + config_manager_class = Type( default_value=ConfigManager, config=True, @@ -1707,57 +1733,55 @@ def _update_server_extensions(self, change): ) kernel_ws_protocol = Unicode( - None, allow_none=True, config=True, - help=_i18n( - "Preferred kernel message protocol over websocket to use (default: None). " - "If an empty string is passed, select the legacy protocol. If None, " - "the selected protocol will depend on what the front-end supports " - "(usually the most recent protocol supported by the back-end and the " - "front-end)." - ), + help=_i18n("DEPRECATED. Use ZMQChannelsWebsocketConnection.kernel_ws_protocol"), ) + @observe("kernel_ws_protocol") + def _deprecated_kernel_ws_protocol(self, change): + self._warn_deprecated_config(change, "ZMQChannelsWebsocketConnection") + limit_rate = Bool( - True, + allow_none=True, config=True, - help=_i18n( - "Whether to limit the rate of IOPub messages (default: True). " - "If True, use iopub_msg_rate_limit, iopub_data_rate_limit and/or rate_limit_window " - "to tune the rate." - ), + help=_i18n("DEPRECATED. Use ZMQChannelsWebsocketConnection.limit_rate"), ) + @observe("limit_rate") + def _deprecated_limit_rate(self, change): + self._warn_deprecated_config(change, "ZMQChannelsWebsocketConnection") + iopub_msg_rate_limit = Float( - 1000, + allow_none=True, config=True, - help=_i18n( - """(msgs/sec) - Maximum rate at which messages can be sent on iopub before they are - limited.""" - ), + help=_i18n("DEPRECATED. Use ZMQChannelsWebsocketConnection.iopub_msg_rate_limit"), ) + @observe("iopub_msg_rate_limit") + def _deprecated_iopub_msg_rate_limit(self, change): + self._warn_deprecated_config(change, "ZMQChannelsWebsocketConnection") + iopub_data_rate_limit = Float( - 1000000, + allow_none=True, config=True, - help=_i18n( - """(bytes/sec) - Maximum rate at which stream output can be sent on iopub before they are - limited.""" - ), + help=_i18n("DEPRECATED. Use ZMQChannelsWebsocketConnection.iopub_data_rate_limit"), ) + @observe("iopub_data_rate_limit") + def _deprecated_iopub_data_rate_limit(self, change): + self._warn_deprecated_config(change, "ZMQChannelsWebsocketConnection") + rate_limit_window = Float( - 3, + allow_none=True, config=True, - help=_i18n( - """(sec) Time window used to - check the message and data rate limits.""" - ), + help=_i18n("DEPRECATED. Use ZMQChannelsWebsocketConnection.rate_limit_window"), ) + @observe("rate_limit_window") + def _deprecated_rate_limit_window(self, change): + self._warn_deprecated_config(change, "ZMQChannelsWebsocketConnection") + shutdown_no_activity_timeout = Integer( 0, config=True, @@ -2025,6 +2049,7 @@ def init_webapp(self): self.jinja_environment_options, authorizer=self.authorizer, identity_provider=self.identity_provider, + kernel_websocket_connection_class=self.kernel_websocket_connection_class, ) if self.certfile: self.ssl_options["certfile"] = self.certfile @@ -2820,7 +2845,7 @@ async def _cleanup(self): self.remove_browser_open_files() await self.cleanup_extensions() await self.cleanup_kernels() - await ZMQChannelsHandler.close_all() + await self.kernel_websocket_connection_class.close_all() if getattr(self, "kernel_manager", None): self.kernel_manager.__del__() if getattr(self, "session_manager", None): diff --git a/jupyter_server/services/kernels/connection/__init__.py b/jupyter_server/services/kernels/connection/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/jupyter_server/services/kernels/connection/abc.py b/jupyter_server/services/kernels/connection/abc.py new file mode 100644 index 0000000000..4bdf6e3edc --- /dev/null +++ b/jupyter_server/services/kernels/connection/abc.py @@ -0,0 +1,33 @@ +from abc import ABC, abstractmethod +from typing import Any + + +class KernelWebsocketConnectionABC(ABC): + """ + This class defines a minimal interface that should + be used to bridge the connection between Jupyter + Server's websocket API and a kernel's ZMQ socket + interface. + """ + + websocket_handler: Any + + @abstractmethod + async def connect(self): + """Connect the kernel websocket to the kernel ZMQ connections""" + ... + + @abstractmethod + async def disconnect(self): + """Disconnect the kernel websocket from the kernel ZMQ connections""" + ... + + @abstractmethod + def handle_incoming_message(self, incoming_msg: str) -> None: + """Broker the incoming websocket message to the appropriate ZMQ channel.""" + ... + + @abstractmethod + def handle_outgoing_message(self, stream: str, outgoing_msg: list) -> None: + """Broker outgoing ZMQ messages to the kernel websocket.""" + ... diff --git a/jupyter_server/services/kernels/connection/base.py b/jupyter_server/services/kernels/connection/base.py new file mode 100644 index 0000000000..880c5e69d3 --- /dev/null +++ b/jupyter_server/services/kernels/connection/base.py @@ -0,0 +1,156 @@ +import json +import struct +import sys + +from jupyter_client.session import Session +from tornado.websocket import WebSocketHandler +from traitlets import Float, Instance, default +from traitlets.config import LoggingConfigurable + +try: + from jupyter_client.jsonutil import json_default +except ImportError: + from jupyter_client.jsonutil import date_default as json_default + +from jupyter_client.jsonutil import extract_dates + +from .abc import KernelWebsocketConnectionABC + + +def serialize_binary_message(msg): + """serialize a message as a binary blob + + Header: + + 4 bytes: number of msg parts (nbufs) as 32b int + 4 * nbufs bytes: offset for each buffer as integer as 32b int + + Offsets are from the start of the buffer, including the header. + + Returns + ------- + The message serialized to bytes. + + """ + # don't modify msg or buffer list in-place + msg = msg.copy() + buffers = list(msg.pop("buffers")) + if sys.version_info < (3, 4): + buffers = [x.tobytes() for x in buffers] + bmsg = json.dumps(msg, default=json_default).encode("utf8") + buffers.insert(0, bmsg) + nbufs = len(buffers) + offsets = [4 * (nbufs + 1)] + for buf in buffers[:-1]: + offsets.append(offsets[-1] + len(buf)) + offsets_buf = struct.pack("!" + "I" * (nbufs + 1), nbufs, *offsets) + buffers.insert(0, offsets_buf) + return b"".join(buffers) + + +def deserialize_binary_message(bmsg): + """deserialize a message from a binary blog + + Header: + + 4 bytes: number of msg parts (nbufs) as 32b int + 4 * nbufs bytes: offset for each buffer as integer as 32b int + + Offsets are from the start of the buffer, including the header. + + Returns + ------- + message dictionary + """ + nbufs = struct.unpack("!i", bmsg[:4])[0] + offsets = list(struct.unpack("!" + "I" * nbufs, bmsg[4 : 4 * (nbufs + 1)])) + offsets.append(None) + bufs = [] + for start, stop in zip(offsets[:-1], offsets[1:]): + bufs.append(bmsg[start:stop]) + msg = json.loads(bufs[0].decode("utf8")) + msg["header"] = extract_dates(msg["header"]) + msg["parent_header"] = extract_dates(msg["parent_header"]) + msg["buffers"] = bufs[1:] + return msg + + +def serialize_msg_to_ws_v1(msg_or_list, channel, pack=None): + if pack: + msg_list = [ + pack(msg_or_list["header"]), + pack(msg_or_list["parent_header"]), + pack(msg_or_list["metadata"]), + pack(msg_or_list["content"]), + ] + else: + msg_list = msg_or_list + channel = channel.encode("utf-8") + offsets: list = [] + offsets.append(8 * (1 + 1 + len(msg_list) + 1)) + offsets.append(len(channel) + offsets[-1]) + for msg in msg_list: + offsets.append(len(msg) + offsets[-1]) + offset_number = len(offsets).to_bytes(8, byteorder="little") + offsets = [offset.to_bytes(8, byteorder="little") for offset in offsets] + bin_msg = b"".join([offset_number] + offsets + [channel] + msg_list) + return bin_msg + + +def deserialize_msg_from_ws_v1(ws_msg): + offset_number = int.from_bytes(ws_msg[:8], "little") + offsets = [ + int.from_bytes(ws_msg[8 * (i + 1) : 8 * (i + 2)], "little") for i in range(offset_number) + ] + channel = ws_msg[offsets[0] : offsets[1]].decode("utf-8") + msg_list = [ws_msg[offsets[i] : offsets[i + 1]] for i in range(1, offset_number - 1)] + return channel, msg_list + + +class BaseKernelWebsocketConnection(LoggingConfigurable): + """A configurable base class for connecting Kernel WebSockets to ZMQ sockets.""" + + @property + def kernel_manager(self): + return self.parent + + @property + def multi_kernel_manager(self): + return self.kernel_manager.parent + + @property + def kernel_id(self): + return self.kernel_manager.kernel_id + + @property + def session_id(self): + return self.session.session + + kernel_info_timeout = Float() + + @default("kernel_info_timeout") + def _default_kernel_info_timeout(self): + return self.multi_kernel_manager.kernel_info_timeout + + session = Instance(klass=Session, config=True) + + @default("session") + def _default_session(self): + return Session(config=self.config) + + websocket_handler = Instance(WebSocketHandler) + + async def connect(self): + raise NotImplementedError() + + async def disconnect(self): + raise NotImplementedError() + + def handle_incoming_message(self, incoming_msg: str) -> None: + raise NotImplementedError() + + def handle_outgoing_message(self, stream: str, outgoing_msg: list) -> None: + raise NotImplementedError() + + +KernelWebsocketConnectionABC.register(BaseKernelWebsocketConnection) diff --git a/jupyter_server/services/kernels/connection/channels.py b/jupyter_server/services/kernels/connection/channels.py new file mode 100644 index 0000000000..0b8ced3980 --- /dev/null +++ b/jupyter_server/services/kernels/connection/channels.py @@ -0,0 +1,809 @@ +import asyncio +import json +import time +import weakref +from concurrent.futures import Future +from textwrap import dedent +from typing import MutableSet + +from jupyter_client import protocol_version as client_protocol_version +from tornado import gen, web +from tornado.ioloop import IOLoop +from tornado.websocket import WebSocketClosedError +from traitlets import Any, Bool, Dict, Float, Instance, Int, List, Unicode, default + +try: + from jupyter_client.jsonutil import json_default +except ImportError: + from jupyter_client.jsonutil import date_default as json_default + +from jupyter_client.utils import ensure_async + +from jupyter_server.transutils import _i18n + +from .abc import KernelWebsocketConnectionABC +from .base import ( + BaseKernelWebsocketConnection, + deserialize_binary_message, + deserialize_msg_from_ws_v1, + serialize_binary_message, + serialize_msg_to_ws_v1, +) + + +def _ensure_future(f): + """Wrap a concurrent future as an asyncio future if there is a running loop.""" + try: + asyncio.get_running_loop() + return asyncio.wrap_future(f) + except RuntimeError: + return f + + +class ZMQChannelsWebsocketConnection(BaseKernelWebsocketConnection): + """A Jupyter Server Websocket Connection""" + + limit_rate = Bool( + True, + config=True, + help=_i18n( + "Whether to limit the rate of IOPub messages (default: True). " + "If True, use iopub_msg_rate_limit, iopub_data_rate_limit and/or rate_limit_window " + "to tune the rate." + ), + ) + + iopub_msg_rate_limit = Float( + 1000, + config=True, + help=_i18n( + """(msgs/sec) + Maximum rate at which messages can be sent on iopub before they are + limited.""" + ), + ) + + iopub_data_rate_limit = Float( + 1000000, + config=True, + help=_i18n( + """(bytes/sec) + Maximum rate at which stream output can be sent on iopub before they are + limited.""" + ), + ) + + rate_limit_window = Float( + 3, + config=True, + help=_i18n( + """(sec) Time window used to + check the message and data rate limits.""" + ), + ) + + kernel_ws_protocol = Unicode( + None, + allow_none=True, + config=True, + help=_i18n( + "Preferred kernel message protocol over websocket to use (default: None). " + "If an empty string is passed, select the legacy protocol. If None, " + "the selected protocol will depend on what the front-end supports " + "(usually the most recent protocol supported by the back-end and the " + "front-end)." + ), + ) + + @property + def write_message(self): + """Alias to the websocket handler's write_message method.""" + return self.websocket_handler.write_message + + # class-level registry of open sessions + # allows checking for conflict on session-id, + # which is used as a zmq identity and must be unique. + _open_sessions: dict = {} + _open_sockets: MutableSet["ZMQChannelsWebsocketConnection"] = weakref.WeakSet() + + _kernel_info_future: Future + _close_future: Future + + channels = Dict({}) + kernel_info_channel = Any(allow_none=True) + + _kernel_info_future = Instance(klass=Future) + + @default("_kernel_info_future") + def _default_kernel_info_future(self): + return Future() + + _close_future = Instance(klass=Future) + + @default("_close_future") + def _default_close_future(self): + return Future() + + session_key = Unicode("") + + _iopub_window_msg_count = Int() + _iopub_window_byte_count = Int() + _iopub_msgs_exceeded = Bool(False) + _iopub_data_exceeded = Bool(False) + # Queue of (time stamp, byte count) + # Allows you to specify that the byte count should be lowered + # by a delta amount at some point in the future. + _iopub_window_byte_queue = List([]) + + @classmethod + async def close_all(cls): + """Tornado does not provide a way to close open sockets, so add one.""" + for connection in list(cls._open_sockets): + connection.disconnect() + await _ensure_future(connection._close_future) + + @property + def subprotocol(self): + try: + protocol = self.websocket_handler.selected_subprotocol + except Exception: + protocol = None + return protocol + + def create_stream(self): + identity = self.session.bsession + for channel in ("iopub", "shell", "control", "stdin"): + meth = getattr(self.kernel_manager, "connect_" + channel) + self.channels[channel] = stream = meth(identity=identity) + stream.channel = channel + + def nudge(self): + """Nudge the zmq connections with kernel_info_requests + Returns a Future that will resolve when we have received + a shell or control reply and at least one iopub message, + ensuring that zmq subscriptions are established, + sockets are fully connected, and kernel is responsive. + Keeps retrying kernel_info_request until these are both received. + """ + # Do not nudge busy kernels as kernel info requests sent to shell are + # queued behind execution requests. + # nudging in this case would cause a potentially very long wait + # before connections are opened, + # plus it is *very* unlikely that a busy kernel will not finish + # establishing its zmq subscriptions before processing the next request. + if getattr(self.kernel_manager, "execution_state", None) == "busy": + self.log.debug("Nudge: not nudging busy kernel %s", self.kernel_id) + f: Future = Future() + f.set_result(None) + return _ensure_future(f) + # Use a transient shell channel to prevent leaking + # shell responses to the front-end. + shell_channel = self.kernel_manager.connect_shell() + # Use a transient control channel to prevent leaking + # control responses to the front-end. + control_channel = self.kernel_manager.connect_control() + # The IOPub used by the client, whose subscriptions we are verifying. + iopub_channel = self.channels["iopub"] + + info_future: Future = Future() + iopub_future: Future = Future() + both_done = gen.multi([info_future, iopub_future]) + + def finish(_=None): + """Ensure all futures are resolved + which in turn triggers cleanup + """ + for f in (info_future, iopub_future): + if not f.done(): + f.set_result(None) + + def cleanup(_=None): + """Common cleanup""" + loop.remove_timeout(nudge_handle) + iopub_channel.stop_on_recv() + if not shell_channel.closed(): + shell_channel.close() + if not control_channel.closed(): + control_channel.close() + + # trigger cleanup when both message futures are resolved + both_done.add_done_callback(cleanup) + + def on_shell_reply(msg): + self.log.debug("Nudge: shell info reply received: %s", self.kernel_id) + if not info_future.done(): + self.log.debug("Nudge: resolving shell future: %s", self.kernel_id) + info_future.set_result(None) + + def on_control_reply(msg): + self.log.debug("Nudge: control info reply received: %s", self.kernel_id) + if not info_future.done(): + self.log.debug("Nudge: resolving control future: %s", self.kernel_id) + info_future.set_result(None) + + def on_iopub(msg): + self.log.debug("Nudge: IOPub received: %s", self.kernel_id) + if not iopub_future.done(): + iopub_channel.stop_on_recv() + self.log.debug("Nudge: resolving iopub future: %s", self.kernel_id) + iopub_future.set_result(None) + + iopub_channel.on_recv(on_iopub) + shell_channel.on_recv(on_shell_reply) + control_channel.on_recv(on_control_reply) + loop = IOLoop.current() + + # Nudge the kernel with kernel info requests until we get an IOPub message + def nudge(count): + count += 1 + # check for stopped kernel + if self.kernel_id not in self.multi_kernel_manager: + self.log.debug("Nudge: cancelling on stopped kernel: %s", self.kernel_id) + finish() + return + + # check for closed zmq socket + if shell_channel.closed(): + self.log.debug("Nudge: cancelling on closed zmq socket: %s", self.kernel_id) + finish() + return + + # check for closed zmq socket + if control_channel.closed(): + self.log.debug("Nudge: cancelling on closed zmq socket: %s", self.kernel_id) + finish() + return + + if not both_done.done(): + log = self.log.warning if count % 10 == 0 else self.log.debug + log(f"Nudge: attempt {count} on kernel {self.kernel_id}") + self.session.send(shell_channel, "kernel_info_request") + self.session.send(control_channel, "kernel_info_request") + nonlocal nudge_handle # type: ignore[misc] + nudge_handle = loop.call_later(0.5, nudge, count) + + nudge_handle = loop.call_later(0, nudge, count=0) + + # resolve with a timeout if we get no response + future = gen.with_timeout(loop.time() + self.kernel_info_timeout, both_done) + # ensure we have no dangling resources or unresolved Futures in case of timeout + future.add_done_callback(finish) + return _ensure_future(future) + + async def _register_session(self): + """Ensure we aren't creating a duplicate session. + + If a previous identical session is still open, close it to avoid collisions. + This is likely due to a client reconnecting from a lost network connection, + where the socket on our side has not been cleaned up yet. + """ + self.session_key = f"{self.kernel_id}:{self.session.session}" + stale_handler = self._open_sessions.get(self.session_key) + if stale_handler: + self.log.warning("Replacing stale connection: %s", self.session_key) + await stale_handler.close() + if ( + self.kernel_id in self.multi_kernel_manager + ): # only update open sessions if kernel is actively managed + self._open_sessions[self.session_key] = self.websocket_handler + + async def prepare(self): + # check session collision: + await self._register_session() + # then request kernel info, waiting up to a certain time before giving up. + # We don't want to wait forever, because browsers don't take it well when + # servers never respond to websocket connection requests. + + if hasattr(self.kernel_manager, "ready"): + ready = self.kernel_manager.ready + if not isinstance(ready, asyncio.Future): + ready = asyncio.wrap_future(ready) + try: + await ready + except Exception as e: + self.kernel_manager.execution_state = "dead" + self.kernel_manager.reason = str(e) + raise web.HTTPError(500, str(e)) from e + + t0 = time.time() + while not await ensure_async(self.kernel_manager.is_alive()): + await asyncio.sleep(0.1) + if (time.time() - t0) > self.multi_kernel_manager.kernel_info_timeout: + raise TimeoutError("Kernel never reached an 'alive' state.") + + self.session.key = self.kernel_manager.session.key + future = self.request_kernel_info() + + def give_up(): + """Don't wait forever for the kernel to reply""" + if future.done(): + return + self.log.warning("Timeout waiting for kernel_info reply from %s", self.kernel_id) + future.set_result({}) + + loop = IOLoop.current() + loop.add_timeout(loop.time() + self.kernel_info_timeout, give_up) + # actually wait for it + await asyncio.wrap_future(future) + + def connect(self): + self.multi_kernel_manager.notify_connect(self.kernel_id) + + # on new connections, flush the message buffer + buffer_info = self.multi_kernel_manager.get_buffer(self.kernel_id, self.session_key) + if buffer_info and buffer_info["session_key"] == self.session_key: + self.log.info("Restoring connection for %s", self.session_key) + if self.kernel_manager.ports_changed(self.kernel_id): + # If the kernel's ports have changed (some restarts trigger this) + # then reset the channels so nudge() is using the correct iopub channel + self.create_stream() + else: + # The kernel's ports have not changed; use the channels captured in the buffer + self.channels = buffer_info["channels"] + + connected = self.nudge() + + def replay(value): + replay_buffer = buffer_info["buffer"] + if replay_buffer: + self.log.info("Replaying %s buffered messages", len(replay_buffer)) + for channel, msg_list in replay_buffer: + stream = self.channels[channel] + self._on_zmq_reply(stream, msg_list) + + connected.add_done_callback(replay) + else: + try: + self.create_stream() + connected = self.nudge() + except web.HTTPError as e: + # Do not log error if the kernel is already shutdown, + # as it's normal that it's not responding + try: + self.multi_kernel_manager.get_kernel(self.kernel_id) + self.log.error("Error opening stream: %s", e) + except KeyError: + pass + # WebSockets don't respond to traditional error codes so we + # close the connection. + for _, stream in self.channels.items(): + if not stream.closed(): + stream.close() + self.disconnect() + return + + self.multi_kernel_manager.add_restart_callback(self.kernel_id, self.on_kernel_restarted) + self.multi_kernel_manager.add_restart_callback( + self.kernel_id, self.on_restart_failed, "dead" + ) + + def subscribe(value): + for _, stream in self.channels.items(): + stream.on_recv_stream(self._on_zmq_reply) + + connected.add_done_callback(subscribe) + ZMQChannelsWebsocketConnection._open_sockets.add(self) + return connected + + def close(self): + return self.disconnect() + + def disconnect(self): + self.log.debug("Websocket closed %s", self.session_key) + # unregister myself as an open session (only if it's really me) + if self._open_sessions.get(self.session_key) is self: + self._open_sessions.pop(self.session_key) + + if self.kernel_id in self.multi_kernel_manager: + self.multi_kernel_manager.notify_disconnect(self.kernel_id) + self.multi_kernel_manager.remove_restart_callback( + self.kernel_id, + self.on_kernel_restarted, + ) + self.multi_kernel_manager.remove_restart_callback( + self.kernel_id, + self.on_restart_failed, + "dead", + ) + + # start buffering instead of closing if this was the last connection + if self.multi_kernel_manager._kernel_connections[self.kernel_id] == 0: + self.multi_kernel_manager.start_buffering( + self.kernel_id, self.session_key, self.channels + ) + ZMQChannelsWebsocketConnection._open_sockets.remove(self) + self._close_future.set_result(None) + return + + # This method can be called twice, once by self.kernel_died and once + # from the WebSocket close event. If the WebSocket connection is + # closed before the ZMQ streams are setup, they could be None. + for _, stream in self.channels.items(): + if stream is not None and not stream.closed(): + stream.on_recv(None) + stream.close() + + self.channels = {} + try: + ZMQChannelsWebsocketConnection._open_sockets.remove(self) + self._close_future.set_result(None) + except Exception: + pass + + def handle_incoming_message(self, incoming_msg: str) -> None: + """Handle incoming messages from Websocket to ZMQ Sockets.""" + ws_msg = incoming_msg + if not self.channels: + # already closed, ignore the message + self.log.debug("Received message on closed websocket %r", ws_msg) + return + + if self.subprotocol == "v1.kernel.websocket.jupyter.org": + channel, msg_list = deserialize_msg_from_ws_v1(ws_msg) + msg = { + "header": None, + } + else: + if isinstance(ws_msg, bytes): + msg = deserialize_binary_message(ws_msg) + else: + msg = json.loads(ws_msg) + msg_list = [] + channel = msg.pop("channel", None) + + if channel is None: + self.log.warning("No channel specified, assuming shell: %s", msg) + channel = "shell" + if channel not in self.channels: + self.log.warning("No such channel: %r", channel) + return + am = self.multi_kernel_manager.allowed_message_types + ignore_msg = False + if am: + msg["header"] = self.get_part("header", msg["header"], msg_list) + assert msg["header"] is not None + if msg["header"]["msg_type"] not in am: + self.log.warning( + 'Received message of type "%s", which is not allowed. Ignoring.' + % msg["header"]["msg_type"] + ) + ignore_msg = True + if not ignore_msg: + stream = self.channels[channel] + if self.subprotocol == "v1.kernel.websocket.jupyter.org": + self.session.send_raw(stream, msg_list) + else: + self.session.send(stream, msg) + + def handle_outgoing_message(self, stream: str, outgoing_msg: list) -> None: + """Handle the outgoing messages from ZMQ sockets to Websocket.""" + msg_list = outgoing_msg + _, fed_msg_list = self.session.feed_identities(msg_list) + + if self.subprotocol == "v1.kernel.websocket.jupyter.org": + msg = {"header": None, "parent_header": None, "content": None} + else: + msg = self.session.deserialize(fed_msg_list) + + channel = getattr(stream, "channel", None) + parts = fed_msg_list[1:] + + self._on_error(channel, msg, parts) + + if self._limit_rate(channel, msg, parts): + return + + if self.subprotocol == "v1.kernel.websocket.jupyter.org": + self._on_zmq_reply(stream, parts) + else: + self._on_zmq_reply(stream, msg) + + def get_part(self, field, value, msg_list): + if value is None: + field2idx = { + "header": 0, + "parent_header": 1, + "content": 3, + } + value = self.session.unpack(msg_list[field2idx[field]]) + return value + + def _reserialize_reply(self, msg_or_list, channel=None): + """Reserialize a reply message using JSON. + + msg_or_list can be an already-deserialized msg dict or the zmq buffer list. + If it is the zmq list, it will be deserialized with self.session. + + This takes the msg list from the ZMQ socket and serializes the result for the websocket. + This method should be used by self._on_zmq_reply to build messages that can + be sent back to the browser. + + """ + if isinstance(msg_or_list, dict): + # already unpacked + msg = msg_or_list + else: + _, msg_list = self.session.feed_identities(msg_or_list) + msg = self.session.deserialize(msg_list) + if channel: + msg["channel"] = channel + if msg["buffers"]: + buf = serialize_binary_message(msg) + return buf + else: + return json.dumps(msg, default=json_default) + + def select_subprotocol(self, subprotocols): + preferred_protocol = self.settings.get("kernel_ws_protocol") + if preferred_protocol is None: + preferred_protocol = "v1.kernel.websocket.jupyter.org" + elif preferred_protocol == "": + preferred_protocol = None + selected_subprotocol = preferred_protocol if preferred_protocol in subprotocols else None + # None is the default, "legacy" protocol + return selected_subprotocol + + def _on_zmq_reply(self, stream, msg_list): + # Sometimes this gets triggered when the on_close method is scheduled in the + # eventloop but hasn't been called. + if stream.closed(): + self.log.warning("zmq message arrived on closed channel") + self.disconnect() + return + channel = getattr(stream, "channel", None) + if self.subprotocol == "v1.kernel.websocket.jupyter.org": + bin_msg = serialize_msg_to_ws_v1(msg_list, channel) + self.write_message(bin_msg, binary=True) + else: + try: + msg = self._reserialize_reply(msg_list, channel=channel) + except Exception: + self.log.critical("Malformed message: %r" % msg_list, exc_info=True) + else: + try: + self.write_message(msg, binary=isinstance(msg, bytes)) + except WebSocketClosedError as e: + self.log.warning(str(e)) + + def request_kernel_info(self): + """send a request for kernel_info""" + try: + # check for previous request + future = self.kernel_manager._kernel_info_future + except AttributeError: + self.log.debug("Requesting kernel info from %s", self.kernel_id) + # Create a kernel_info channel to query the kernel protocol version. + # This channel will be closed after the kernel_info reply is received. + if self.kernel_info_channel is None: + self.kernel_info_channel = self.multi_kernel_manager.connect_shell(self.kernel_id) + assert self.kernel_info_channel is not None + self.kernel_info_channel.on_recv(self._handle_kernel_info_reply) + self.session.send(self.kernel_info_channel, "kernel_info_request") + # store the future on the kernel, so only one request is sent + self.kernel_manager._kernel_info_future = self._kernel_info_future + else: + if not future.done(): + self.log.debug("Waiting for pending kernel_info request") + future.add_done_callback(lambda f: self._finish_kernel_info(f.result())) + return _ensure_future(self._kernel_info_future) + + def _handle_kernel_info_reply(self, msg): + """process the kernel_info_reply + + enabling msg spec adaptation, if necessary + """ + idents, msg = self.session.feed_identities(msg) + try: + msg = self.session.deserialize(msg) + except BaseException: + self.log.error("Bad kernel_info reply", exc_info=True) + self._kernel_info_future.set_result({}) + return + else: + info = msg["content"] + self.log.debug("Received kernel info: %s", info) + if msg["msg_type"] != "kernel_info_reply" or "protocol_version" not in info: + self.log.error("Kernel info request failed, assuming current %s", info) + info = {} + self._finish_kernel_info(info) + + # close the kernel_info channel, we don't need it anymore + if self.kernel_info_channel: + self.kernel_info_channel.close() + self.kernel_info_channel = None + + def _finish_kernel_info(self, info): + """Finish handling kernel_info reply + + Set up protocol adaptation, if needed, + and signal that connection can continue. + """ + protocol_version = info.get("protocol_version", client_protocol_version) + if protocol_version != client_protocol_version: + self.session.adapt_version = int(protocol_version.split(".")[0]) + self.log.info( + "Adapting from protocol version {protocol_version} (kernel {kernel_id}) to {client_protocol_version} (client).".format( + protocol_version=protocol_version, + kernel_id=self.kernel_id, + client_protocol_version=client_protocol_version, + ) + ) + if not self._kernel_info_future.done(): + self._kernel_info_future.set_result(info) + + def write_stderr(self, error_message, parent_header): + self.log.warning(error_message) + err_msg = self.session.msg( + "stream", + content={"text": error_message + "\n", "name": "stderr"}, + parent=parent_header, + ) + if self.subprotocol == "v1.kernel.websocket.jupyter.org": + bin_msg = serialize_msg_to_ws_v1(err_msg, "iopub", self.session.pack) + self.write_message(bin_msg, binary=True) + else: + err_msg["channel"] = "iopub" + self.write_message(json.dumps(err_msg, default=json_default)) + + def _limit_rate(self, channel, msg, msg_list): + if not (self.limit_rate and channel == "iopub"): + return False + + msg["header"] = self.get_part("header", msg["header"], msg_list) + + msg_type = msg["header"]["msg_type"] + if msg_type == "status": + msg["content"] = self.get_part("content", msg["content"], msg_list) + if msg["content"].get("execution_state") == "idle": + # reset rate limit counter on status=idle, + # to avoid 'Run All' hitting limits prematurely. + self._iopub_window_byte_queue = [] + self._iopub_window_msg_count = 0 + self._iopub_window_byte_count = 0 + self._iopub_msgs_exceeded = False + self._iopub_data_exceeded = False + + if msg_type not in {"status", "comm_open", "execute_input"}: + # Remove the counts queued for removal. + now = IOLoop.current().time() + while len(self._iopub_window_byte_queue) > 0: + queued = self._iopub_window_byte_queue[0] + if now >= queued[0]: + self._iopub_window_byte_count -= queued[1] + self._iopub_window_msg_count -= 1 + del self._iopub_window_byte_queue[0] + else: + # This part of the queue hasn't be reached yet, so we can + # abort the loop. + break + + # Increment the bytes and message count + self._iopub_window_msg_count += 1 + if msg_type == "stream": + byte_count = sum(len(x) for x in msg_list) + else: + byte_count = 0 + self._iopub_window_byte_count += byte_count + + # Queue a removal of the byte and message count for a time in the + # future, when we are no longer interested in it. + self._iopub_window_byte_queue.append((now + self.rate_limit_window, byte_count)) + + # Check the limits, set the limit flags, and reset the + # message and data counts. + msg_rate = float(self._iopub_window_msg_count) / self.rate_limit_window + data_rate = float(self._iopub_window_byte_count) / self.rate_limit_window + + # Check the msg rate + if self.iopub_msg_rate_limit > 0 and msg_rate > self.iopub_msg_rate_limit: + if not self._iopub_msgs_exceeded: + self._iopub_msgs_exceeded = True + msg["parent_header"] = self.get_part( + "parent_header", msg["parent_header"], msg_list + ) + self.write_stderr( + dedent( + """\ + IOPub message rate exceeded. + The Jupyter server will temporarily stop sending output + to the client in order to avoid crashing it. + To change this limit, set the config variable + `--ServerApp.iopub_msg_rate_limit`. + + Current values: + ServerApp.iopub_msg_rate_limit={} (msgs/sec) + ServerApp.rate_limit_window={} (secs) + """.format( + self.iopub_msg_rate_limit, self.rate_limit_window + ) + ), + msg["parent_header"], + ) + else: + # resume once we've got some headroom below the limit + if self._iopub_msgs_exceeded and msg_rate < (0.8 * self.iopub_msg_rate_limit): + self._iopub_msgs_exceeded = False + if not self._iopub_data_exceeded: + self.log.warning("iopub messages resumed") + + # Check the data rate + if self.iopub_data_rate_limit > 0 and data_rate > self.iopub_data_rate_limit: + if not self._iopub_data_exceeded: + self._iopub_data_exceeded = True + msg["parent_header"] = self.get_part( + "parent_header", msg["parent_header"], msg_list + ) + self.write_stderr( + dedent( + """\ + IOPub data rate exceeded. + The Jupyter server will temporarily stop sending output + to the client in order to avoid crashing it. + To change this limit, set the config variable + `--ServerApp.iopub_data_rate_limit`. + + Current values: + ServerApp.iopub_data_rate_limit={} (bytes/sec) + ServerApp.rate_limit_window={} (secs) + """.format( + self.iopub_data_rate_limit, self.rate_limit_window + ) + ), + msg["parent_header"], + ) + else: + # resume once we've got some headroom below the limit + if self._iopub_data_exceeded and data_rate < (0.8 * self.iopub_data_rate_limit): + self._iopub_data_exceeded = False + if not self._iopub_msgs_exceeded: + self.log.warning("iopub messages resumed") + + # If either of the limit flags are set, do not send the message. + if self._iopub_msgs_exceeded or self._iopub_data_exceeded: + # we didn't send it, remove the current message from the calculus + self._iopub_window_msg_count -= 1 + self._iopub_window_byte_count -= byte_count + self._iopub_window_byte_queue.pop(-1) + return True + + return False + + def _send_status_message(self, status): + iopub = self.channels.get("iopub", None) + if iopub and not iopub.closed(): + # flush IOPub before sending a restarting/dead status message + # ensures proper ordering on the IOPub channel + # that all messages from the stopped kernel have been delivered + iopub.flush() + msg = self.session.msg("status", {"execution_state": status}) + if self.subprotocol == "v1.kernel.websocket.jupyter.org": + bin_msg = serialize_msg_to_ws_v1(msg, "iopub", self.session.pack) + self.write_message(bin_msg, binary=True) + else: + msg["channel"] = "iopub" + self.write_message(json.dumps(msg, default=json_default)) + + def on_kernel_restarted(self): + self.log.warning("kernel %s restarted", self.kernel_id) + self._send_status_message("restarting") + + def on_restart_failed(self): + self.log.error("kernel %s restarted failed!", self.kernel_id) + self._send_status_message("dead") + + def _on_error(self, channel, msg, msg_list): + if self.kernel_manager.allow_tracebacks: + return + + if channel == "iopub": + msg["header"] = self.get_part("header", msg["header"], msg_list) + if msg["header"]["msg_type"] == "error": + msg["content"] = self.get_part("content", msg["content"], msg_list) + msg["content"]["ename"] = "ExecutionError" + msg["content"]["evalue"] = "Execution error" + msg["content"]["traceback"] = [self.kernel_manager.traceback_replacement_message] + if self.subprotocol == "v1.kernel.websocket.jupyter.org": + msg_list[3] = self.session.pack(msg["content"]) + + +KernelWebsocketConnectionABC.register(ZMQChannelsWebsocketConnection) diff --git a/jupyter_server/services/kernels/handlers.py b/jupyter_server/services/kernels/handlers.py index 32c2e80eee..ef2f2e0c74 100644 --- a/jupyter_server/services/kernels/handlers.py +++ b/jupyter_server/services/kernels/handlers.py @@ -4,35 +4,20 @@ """ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. -import asyncio import json -import weakref -from textwrap import dedent from traceback import format_tb -from typing import MutableSet - -from jupyter_client import protocol_version as client_protocol_version try: from jupyter_client.jsonutil import json_default except ImportError: from jupyter_client.jsonutil import date_default as json_default -from concurrent.futures import Future - -from tornado import gen, web -from tornado.ioloop import IOLoop +from tornado import web from jupyter_server.auth import authorized from jupyter_server.utils import ensure_async, url_escape, url_path_join from ...base.handlers import APIHandler -from ...base.zmqhandlers import ( - AuthenticatedZMQStreamHandler, - deserialize_binary_message, - deserialize_msg_from_ws_v1, - serialize_msg_to_ws_v1, -) AUTH_RESOURCE = "kernels" @@ -110,707 +95,9 @@ async def post(self, kernel_id, action): self.finish() -def _ensure_future(f): - """Wrap a concurrent future as an asyncio future if there is a running loop.""" - try: - asyncio.get_running_loop() - return asyncio.wrap_future(f) - except RuntimeError: - return f - - -class ZMQChannelsHandler(AuthenticatedZMQStreamHandler): - """There is one ZMQChannelsHandler per running kernel and it oversees all - the sessions. - """ - - auth_resource = AUTH_RESOURCE - - # class-level registry of open sessions - # allows checking for conflict on session-id, - # which is used as a zmq identity and must be unique. - _open_sessions: dict = {} - _open_sockets: MutableSet["ZMQChannelsHandler"] = weakref.WeakSet() - - _kernel_info_future: Future - _close_future: Future - - @classmethod - async def close_all(cls): - """Tornado does not provide a way to close open sockets, so add one.""" - for socket in list(cls._open_sockets): - await socket.close() - - @property - def kernel_info_timeout(self): - km_default = self.kernel_manager.kernel_info_timeout - return self.settings.get("kernel_info_timeout", km_default) - - @property - def limit_rate(self): - return self.settings.get("limit_rate", True) - - @property - def iopub_msg_rate_limit(self): - return self.settings.get("iopub_msg_rate_limit", 0) - - @property - def iopub_data_rate_limit(self): - return self.settings.get("iopub_data_rate_limit", 0) - - @property - def rate_limit_window(self): - return self.settings.get("rate_limit_window", 1.0) - - @property - def subprotocol(self): - try: - protocol = self.selected_subprotocol - except Exception: - protocol = None - return protocol - - def __repr__(self): - return "{}({})".format( - self.__class__.__name__, - getattr(self, "kernel_id", "uninitialized"), - ) - - def create_stream(self): - km = self.kernel_manager - identity = self.session.bsession - for channel in ("iopub", "shell", "control", "stdin"): - meth = getattr(km, "connect_" + channel) - self.channels[channel] = stream = meth(self.kernel_id, identity=identity) - stream.channel = channel - - def nudge(self): - """Nudge the zmq connections with kernel_info_requests - Returns a Future that will resolve when we have received - a shell or control reply and at least one iopub message, - ensuring that zmq subscriptions are established, - sockets are fully connected, and kernel is responsive. - Keeps retrying kernel_info_request until these are both received. - """ - kernel = self.kernel_manager.get_kernel(self.kernel_id) - - # Do not nudge busy kernels as kernel info requests sent to shell are - # queued behind execution requests. - # nudging in this case would cause a potentially very long wait - # before connections are opened, - # plus it is *very* unlikely that a busy kernel will not finish - # establishing its zmq subscriptions before processing the next request. - if getattr(kernel, "execution_state", None) == "busy": - self.log.debug("Nudge: not nudging busy kernel %s", self.kernel_id) - f: Future = Future() - f.set_result(None) - return _ensure_future(f) - # Use a transient shell channel to prevent leaking - # shell responses to the front-end. - shell_channel = kernel.connect_shell() - # Use a transient control channel to prevent leaking - # control responses to the front-end. - control_channel = kernel.connect_control() - # The IOPub used by the client, whose subscriptions we are verifying. - iopub_channel = self.channels["iopub"] - - info_future: Future = Future() - iopub_future: Future = Future() - both_done = gen.multi([info_future, iopub_future]) - - def finish(_=None): - """Ensure all futures are resolved - which in turn triggers cleanup - """ - for f in (info_future, iopub_future): - if not f.done(): - f.set_result(None) - - def cleanup(_=None): - """Common cleanup""" - loop.remove_timeout(nudge_handle) # type:ignore[has-type] - iopub_channel.stop_on_recv() - if not shell_channel.closed(): - shell_channel.close() - if not control_channel.closed(): - control_channel.close() - - # trigger cleanup when both message futures are resolved - both_done.add_done_callback(cleanup) - - def on_shell_reply(msg): - self.log.debug("Nudge: shell info reply received: %s", self.kernel_id) - if not info_future.done(): - self.log.debug("Nudge: resolving shell future: %s", self.kernel_id) - info_future.set_result(None) - - def on_control_reply(msg): - self.log.debug("Nudge: control info reply received: %s", self.kernel_id) - if not info_future.done(): - self.log.debug("Nudge: resolving control future: %s", self.kernel_id) - info_future.set_result(None) - - def on_iopub(msg): - self.log.debug("Nudge: IOPub received: %s", self.kernel_id) - if not iopub_future.done(): - iopub_channel.stop_on_recv() - self.log.debug("Nudge: resolving iopub future: %s", self.kernel_id) - iopub_future.set_result(None) - - iopub_channel.on_recv(on_iopub) - shell_channel.on_recv(on_shell_reply) - control_channel.on_recv(on_control_reply) - loop = IOLoop.current() - - # Nudge the kernel with kernel info requests until we get an IOPub message - def nudge(count): - count += 1 - - # NOTE: this close check appears to never be True during on_open, - # even when the peer has closed the connection - if self.ws_connection is None or self.ws_connection.is_closing(): - self.log.debug("Nudge: cancelling on closed websocket: %s", self.kernel_id) - finish() - return - - # check for stopped kernel - if self.kernel_id not in self.kernel_manager: - self.log.debug("Nudge: cancelling on stopped kernel: %s", self.kernel_id) - finish() - return - - # check for closed zmq socket - if shell_channel.closed(): - self.log.debug("Nudge: cancelling on closed zmq socket: %s", self.kernel_id) - finish() - return - - # check for closed zmq socket - if control_channel.closed(): - self.log.debug("Nudge: cancelling on closed zmq socket: %s", self.kernel_id) - finish() - return - - if not both_done.done(): - log = self.log.warning if count % 10 == 0 else self.log.debug - log(f"Nudge: attempt {count} on kernel {self.kernel_id}") - self.session.send(shell_channel, "kernel_info_request") - self.session.send(control_channel, "kernel_info_request") - nonlocal nudge_handle # type:ignore[misc] - nudge_handle = loop.call_later(0.5, nudge, count) - - nudge_handle = loop.call_later(0, nudge, count=0) - - # resolve with a timeout if we get no response - future = gen.with_timeout(loop.time() + self.kernel_info_timeout, both_done) - # ensure we have no dangling resources or unresolved Futures in case of timeout - future.add_done_callback(finish) - return _ensure_future(future) - - def request_kernel_info(self): - """send a request for kernel_info""" - km = self.kernel_manager - kernel = km.get_kernel(self.kernel_id) - try: - # check for previous request - future = kernel._kernel_info_future - except AttributeError: - self.log.debug("Requesting kernel info from %s", self.kernel_id) - # Create a kernel_info channel to query the kernel protocol version. - # This channel will be closed after the kernel_info reply is received. - if self.kernel_info_channel is None: # type:ignore[has-type] - self.kernel_info_channel = km.connect_shell(self.kernel_id) - assert self.kernel_info_channel is not None - self.kernel_info_channel.on_recv(self._handle_kernel_info_reply) - self.session.send(self.kernel_info_channel, "kernel_info_request") - # store the future on the kernel, so only one request is sent - kernel._kernel_info_future = self._kernel_info_future - else: - if not future.done(): - self.log.debug("Waiting for pending kernel_info request") - future.add_done_callback(lambda f: self._finish_kernel_info(f.result())) - return _ensure_future(self._kernel_info_future) - - def _handle_kernel_info_reply(self, msg): - """process the kernel_info_reply - - enabling msg spec adaptation, if necessary - """ - idents, msg = self.session.feed_identities(msg) - try: - msg = self.session.deserialize(msg) - except BaseException: - self.log.error("Bad kernel_info reply", exc_info=True) - self._kernel_info_future.set_result({}) - return - else: - info = msg["content"] - self.log.debug("Received kernel info: %s", info) - if msg["msg_type"] != "kernel_info_reply" or "protocol_version" not in info: - self.log.error("Kernel info request failed, assuming current %s", info) - info = {} - self._finish_kernel_info(info) - - # close the kernel_info channel, we don't need it anymore - if self.kernel_info_channel: - self.kernel_info_channel.close() - self.kernel_info_channel = None - - def _finish_kernel_info(self, info): - """Finish handling kernel_info reply - - Set up protocol adaptation, if needed, - and signal that connection can continue. - """ - protocol_version = info.get("protocol_version", client_protocol_version) - if protocol_version != client_protocol_version: - self.session.adapt_version = int(protocol_version.split(".")[0]) - self.log.info( - "Adapting from protocol version {protocol_version} (kernel {kernel_id}) to {client_protocol_version} (client).".format( - protocol_version=protocol_version, - kernel_id=self.kernel_id, - client_protocol_version=client_protocol_version, - ) - ) - if not self._kernel_info_future.done(): - self._kernel_info_future.set_result(info) - - def initialize(self): - super().initialize() - self.zmq_stream = None - self.channels = {} - self.kernel_id = None - self.kernel_info_channel = None - self._kernel_info_future = Future() - self._close_future = Future() - self.session_key = "" - - # Rate limiting code - self._iopub_window_msg_count = 0 - self._iopub_window_byte_count = 0 - self._iopub_msgs_exceeded = False - self._iopub_data_exceeded = False - # Queue of (time stamp, byte count) - # Allows you to specify that the byte count should be lowered - # by a delta amount at some point in the future. - self._iopub_window_byte_queue = [] - - async def pre_get(self): - # authenticate first - super().pre_get() - # check session collision: - await self._register_session() - # then request kernel info, waiting up to a certain time before giving up. - # We don't want to wait forever, because browsers don't take it well when - # servers never respond to websocket connection requests. - kernel = self.kernel_manager.get_kernel(self.kernel_id) - - if hasattr(kernel, "ready"): - ready = kernel.ready - if not isinstance(ready, asyncio.Future): - ready = asyncio.wrap_future(ready) - try: - await ready - except Exception as e: - kernel.execution_state = "dead" - kernel.reason = str(e) - raise web.HTTPError(500, str(e)) from e - - self.session.key = kernel.session.key - future = self.request_kernel_info() - - def give_up(): - """Don't wait forever for the kernel to reply""" - if future.done(): - return - self.log.warning("Timeout waiting for kernel_info reply from %s", self.kernel_id) - future.set_result({}) - - loop = IOLoop.current() - loop.add_timeout(loop.time() + self.kernel_info_timeout, give_up) - # actually wait for it - await asyncio.wrap_future(future) - - async def get(self, kernel_id): - self.kernel_id = kernel_id - await super().get(kernel_id=kernel_id) - - async def _register_session(self): - """Ensure we aren't creating a duplicate session. - - If a previous identical session is still open, close it to avoid collisions. - This is likely due to a client reconnecting from a lost network connection, - where the socket on our side has not been cleaned up yet. - """ - self.session_key = f"{self.kernel_id}:{self.session.session}" - stale_handler = self._open_sessions.get(self.session_key) - if stale_handler: - self.log.warning("Replacing stale connection: %s", self.session_key) - await stale_handler.close() - if ( - self.kernel_id in self.kernel_manager - ): # only update open sessions if kernel is actively managed - self._open_sessions[self.session_key] = self - - def open(self, kernel_id): - super().open() - km = self.kernel_manager - km.notify_connect(kernel_id) - - # on new connections, flush the message buffer - buffer_info = km.get_buffer(kernel_id, self.session_key) - if buffer_info and buffer_info["session_key"] == self.session_key: - self.log.info("Restoring connection for %s", self.session_key) - if km.ports_changed(kernel_id): - # If the kernel's ports have changed (some restarts trigger this) - # then reset the channels so nudge() is using the correct iopub channel - self.create_stream() - else: - # The kernel's ports have not changed; use the channels captured in the buffer - self.channels = buffer_info["channels"] - - connected = self.nudge() - - def replay(value): - replay_buffer = buffer_info["buffer"] - if replay_buffer: - self.log.info("Replaying %s buffered messages", len(replay_buffer)) - for channel, msg_list in replay_buffer: - stream = self.channels[channel] - self._on_zmq_reply(stream, msg_list) - - connected.add_done_callback(replay) - else: - try: - self.create_stream() - connected = self.nudge() - except web.HTTPError as e: - # Do not log error if the kernel is already shutdown, - # as it's normal that it's not responding - try: - self.kernel_manager.get_kernel(kernel_id) - - self.log.error("Error opening stream: %s", e) - except KeyError: - pass - # WebSockets don't respond to traditional error codes so we - # close the connection. - for _, stream in self.channels.items(): - if not stream.closed(): - stream.close() - self.close() - return - - km.add_restart_callback(self.kernel_id, self.on_kernel_restarted) - km.add_restart_callback(self.kernel_id, self.on_restart_failed, "dead") - - def subscribe(value): - for _, stream in self.channels.items(): - stream.on_recv_stream(self._on_zmq_reply) - - connected.add_done_callback(subscribe) - ZMQChannelsHandler._open_sockets.add(self) - return connected - - def on_message(self, ws_msg): - if not self.channels: - # already closed, ignore the message - self.log.debug("Received message on closed websocket %r", ws_msg) - return - - if self.subprotocol == "v1.kernel.websocket.jupyter.org": - channel, msg_list = deserialize_msg_from_ws_v1(ws_msg) - msg = { - "header": None, - } - else: - if isinstance(ws_msg, bytes): - msg = deserialize_binary_message(ws_msg) - else: - msg = json.loads(ws_msg) - msg_list = [] - channel = msg.pop("channel", None) - - if channel is None: - self.log.warning("No channel specified, assuming shell: %s", msg) - channel = "shell" - if channel not in self.channels: - self.log.warning("No such channel: %r", channel) - return - am = self.kernel_manager.allowed_message_types - ignore_msg = False - if am: - msg["header"] = self.get_part("header", msg["header"], msg_list) - assert msg["header"] is not None - if msg["header"]["msg_type"] not in am: - self.log.warning( - 'Received message of type "%s", which is not allowed. Ignoring.' - % msg["header"]["msg_type"] - ) - ignore_msg = True - if not ignore_msg: - stream = self.channels[channel] - if self.subprotocol == "v1.kernel.websocket.jupyter.org": - self.session.send_raw(stream, msg_list) - else: - self.session.send(stream, msg) - - def get_part(self, field, value, msg_list): - if value is None: - field2idx = { - "header": 0, - "parent_header": 1, - "content": 3, - } - value = self.session.unpack(msg_list[field2idx[field]]) - return value - - def _on_zmq_reply(self, stream, msg_list): - idents, fed_msg_list = self.session.feed_identities(msg_list) - - if self.subprotocol == "v1.kernel.websocket.jupyter.org": - msg = {"header": None, "parent_header": None, "content": None} - else: - msg = self.session.deserialize(fed_msg_list) - - channel = getattr(stream, "channel", None) - parts = fed_msg_list[1:] - - self._on_error(channel, msg, parts) - - if self._limit_rate(channel, msg, parts): - return - - if self.subprotocol == "v1.kernel.websocket.jupyter.org": - super()._on_zmq_reply(stream, parts) - else: - super()._on_zmq_reply(stream, msg) - - def write_stderr(self, error_message, parent_header): - self.log.warning(error_message) - err_msg = self.session.msg( - "stream", - content={"text": error_message + "\n", "name": "stderr"}, - parent=parent_header, - ) - if self.subprotocol == "v1.kernel.websocket.jupyter.org": - bin_msg = serialize_msg_to_ws_v1(err_msg, "iopub", self.session.pack) - self.write_message(bin_msg, binary=True) - else: - err_msg["channel"] = "iopub" - self.write_message(json.dumps(err_msg, default=json_default)) - - def _limit_rate(self, channel, msg, msg_list): - if not (self.limit_rate and channel == "iopub"): - return False - - msg["header"] = self.get_part("header", msg["header"], msg_list) - - msg_type = msg["header"]["msg_type"] - if msg_type == "status": - msg["content"] = self.get_part("content", msg["content"], msg_list) - if msg["content"].get("execution_state") == "idle": - # reset rate limit counter on status=idle, - # to avoid 'Run All' hitting limits prematurely. - self._iopub_window_byte_queue = [] - self._iopub_window_msg_count = 0 - self._iopub_window_byte_count = 0 - self._iopub_msgs_exceeded = False - self._iopub_data_exceeded = False - - if msg_type not in {"status", "comm_open", "execute_input"}: - # Remove the counts queued for removal. - now = IOLoop.current().time() - while len(self._iopub_window_byte_queue) > 0: - queued = self._iopub_window_byte_queue[0] - if now >= queued[0]: - self._iopub_window_byte_count -= queued[1] - self._iopub_window_msg_count -= 1 - del self._iopub_window_byte_queue[0] - else: - # This part of the queue hasn't be reached yet, so we can - # abort the loop. - break - - # Increment the bytes and message count - self._iopub_window_msg_count += 1 - if msg_type == "stream": - byte_count = sum(len(x) for x in msg_list) - else: - byte_count = 0 - self._iopub_window_byte_count += byte_count - - # Queue a removal of the byte and message count for a time in the - # future, when we are no longer interested in it. - self._iopub_window_byte_queue.append((now + self.rate_limit_window, byte_count)) - - # Check the limits, set the limit flags, and reset the - # message and data counts. - msg_rate = float(self._iopub_window_msg_count) / self.rate_limit_window - data_rate = float(self._iopub_window_byte_count) / self.rate_limit_window - - # Check the msg rate - if self.iopub_msg_rate_limit > 0 and msg_rate > self.iopub_msg_rate_limit: - if not self._iopub_msgs_exceeded: - self._iopub_msgs_exceeded = True - msg["parent_header"] = self.get_part( - "parent_header", msg["parent_header"], msg_list - ) - self.write_stderr( - dedent( - """\ - IOPub message rate exceeded. - The Jupyter server will temporarily stop sending output - to the client in order to avoid crashing it. - To change this limit, set the config variable - `--ServerApp.iopub_msg_rate_limit`. - - Current values: - ServerApp.iopub_msg_rate_limit={} (msgs/sec) - ServerApp.rate_limit_window={} (secs) - """.format( - self.iopub_msg_rate_limit, self.rate_limit_window - ) - ), - msg["parent_header"], - ) - else: - # resume once we've got some headroom below the limit - if self._iopub_msgs_exceeded and msg_rate < (0.8 * self.iopub_msg_rate_limit): - self._iopub_msgs_exceeded = False - if not self._iopub_data_exceeded: - self.log.warning("iopub messages resumed") - - # Check the data rate - if self.iopub_data_rate_limit > 0 and data_rate > self.iopub_data_rate_limit: - if not self._iopub_data_exceeded: - self._iopub_data_exceeded = True - msg["parent_header"] = self.get_part( - "parent_header", msg["parent_header"], msg_list - ) - self.write_stderr( - dedent( - """\ - IOPub data rate exceeded. - The Jupyter server will temporarily stop sending output - to the client in order to avoid crashing it. - To change this limit, set the config variable - `--ServerApp.iopub_data_rate_limit`. - - Current values: - ServerApp.iopub_data_rate_limit={} (bytes/sec) - ServerApp.rate_limit_window={} (secs) - """.format( - self.iopub_data_rate_limit, self.rate_limit_window - ) - ), - msg["parent_header"], - ) - else: - # resume once we've got some headroom below the limit - if self._iopub_data_exceeded and data_rate < (0.8 * self.iopub_data_rate_limit): - self._iopub_data_exceeded = False - if not self._iopub_msgs_exceeded: - self.log.warning("iopub messages resumed") - - # If either of the limit flags are set, do not send the message. - if self._iopub_msgs_exceeded or self._iopub_data_exceeded: - # we didn't send it, remove the current message from the calculus - self._iopub_window_msg_count -= 1 - self._iopub_window_byte_count -= byte_count - self._iopub_window_byte_queue.pop(-1) - return True - - return False - - def close(self): - super().close() - return _ensure_future(self._close_future) - - def on_close(self): - self.log.debug("Websocket closed %s", self.session_key) - # unregister myself as an open session (only if it's really me) - if self._open_sessions.get(self.session_key) is self: - self._open_sessions.pop(self.session_key) - - km = self.kernel_manager - if self.kernel_id in km: - km.notify_disconnect(self.kernel_id) - km.remove_restart_callback( - self.kernel_id, - self.on_kernel_restarted, - ) - km.remove_restart_callback( - self.kernel_id, - self.on_restart_failed, - "dead", - ) - - # start buffering instead of closing if this was the last connection - if km._kernel_connections[self.kernel_id] == 0: - km.start_buffering(self.kernel_id, self.session_key, self.channels) - ZMQChannelsHandler._open_sockets.remove(self) - self._close_future.set_result(None) - return - - # This method can be called twice, once by self.kernel_died and once - # from the WebSocket close event. If the WebSocket connection is - # closed before the ZMQ streams are setup, they could be None. - for _, stream in self.channels.items(): - if stream is not None and not stream.closed(): - stream.on_recv(None) - stream.close() - - self.channels = {} - try: - ZMQChannelsHandler._open_sockets.remove(self) - self._close_future.set_result(None) - except Exception: - pass - - def _send_status_message(self, status): - iopub = self.channels.get("iopub", None) - if iopub and not iopub.closed(): - # flush IOPub before sending a restarting/dead status message - # ensures proper ordering on the IOPub channel - # that all messages from the stopped kernel have been delivered - iopub.flush() - msg = self.session.msg("status", {"execution_state": status}) - if self.subprotocol == "v1.kernel.websocket.jupyter.org": - bin_msg = serialize_msg_to_ws_v1(msg, "iopub", self.session.pack) - self.write_message(bin_msg, binary=True) - else: - msg["channel"] = "iopub" - self.write_message(json.dumps(msg, default=json_default)) - - def on_kernel_restarted(self): - self.log.warning("kernel %s restarted", self.kernel_id) - self._send_status_message("restarting") - - def on_restart_failed(self): - self.log.error("kernel %s restarted failed!", self.kernel_id) - self._send_status_message("dead") - - def _on_error(self, channel, msg, msg_list): - if self.kernel_manager.allow_tracebacks: - return - - if channel == "iopub": - msg["header"] = self.get_part("header", msg["header"], msg_list) - if msg["header"]["msg_type"] == "error": - msg["content"] = self.get_part("content", msg["content"], msg_list) - msg["content"]["ename"] = "ExecutionError" - msg["content"]["evalue"] = "Execution error" - msg["content"]["traceback"] = [self.kernel_manager.traceback_replacement_message] - if self.subprotocol == "v1.kernel.websocket.jupyter.org": - msg_list[3] = self.session.pack(msg["content"]) - - # ----------------------------------------------------------------------------- # URL to handler mappings # ----------------------------------------------------------------------------- - - _kernel_id_regex = r"(?P\w+-\w+-\w+-\w+-\w+)" _kernel_action_regex = r"(?Prestart|interrupt)" @@ -821,5 +108,4 @@ def _on_error(self, channel, msg, msg_list): rf"/api/kernels/{_kernel_id_regex}/{_kernel_action_regex}", KernelActionHandler, ), - (r"/api/kernels/%s/channels" % _kernel_id_regex, ZMQChannelsHandler), ] diff --git a/jupyter_server/services/kernels/websocket.py b/jupyter_server/services/kernels/websocket.py new file mode 100644 index 0000000000..2806053a98 --- /dev/null +++ b/jupyter_server/services/kernels/websocket.py @@ -0,0 +1,81 @@ +"""Tornado handlers for WebSocket <-> ZMQ sockets.""" +# Copyright (c) Jupyter Development Team. +# Distributed under the terms of the Modified BSD License. + +from tornado import web +from tornado.websocket import WebSocketHandler + +from jupyter_server.base.handlers import JupyterHandler +from jupyter_server.base.websocket import WebSocketMixin + +from .handlers import _kernel_id_regex + +AUTH_RESOURCE = "kernels" + + +class KernelWebsocketHandler(WebSocketMixin, WebSocketHandler, JupyterHandler): + """The kernels websocket should connecte""" + + auth_resource = AUTH_RESOURCE + + @property + def kernel_websocket_connection_class(self): + return self.settings.get("kernel_websocket_connection_class") + + def set_default_headers(self): + """Undo the set_default_headers in JupyterHandler + + which doesn't make sense for websockets + """ + pass + + def get_compression_options(self): + return self.settings.get("websocket_compression_options", None) + + async def pre_get(self): + # authenticate first + user = self.current_user + if user is None: + self.log.warning("Couldn't authenticate WebSocket connection") + raise web.HTTPError(403) + + # authorize the user. + if not self.authorizer.is_authorized(self, user, "execute", "kernels"): + raise web.HTTPError(403) + + kernel = self.kernel_manager.get_kernel(self.kernel_id) + self.connection = self.kernel_websocket_connection_class( + parent=kernel, websocket_handler=self, config=self.config + ) + + if self.get_argument("session_id", None): + self.connection.session.session = self.get_argument("session_id") + else: + self.log.warning("No session ID specified") + # For backwards compatibility with older versions + # of the websocket connection, call a prepare method if found. + if hasattr(self.connection, "prepare"): + await self.connection.prepare() + + async def get(self, kernel_id): + self.kernel_id = kernel_id + await self.pre_get() + await super().get(kernel_id=kernel_id) + + async def open(self, kernel_id): + # Wait for the kernel to emit an idle status. + self.log.info(f"Connecting to kernel {self.kernel_id}.") + await self.connection.connect() + + def on_message(self, ws_message): + """Get a kernel message from the websocket and turn it into a ZMQ message.""" + self.connection.handle_incoming_message(ws_message) + + def on_close(self): + self.connection.disconnect() + self.connection = None + + +default_handlers = [ + (r"/api/kernels/%s/channels" % _kernel_id_regex, KernelWebsocketHandler), +] diff --git a/pyproject.toml b/pyproject.toml index 4949e8569a..6716b42894 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -143,6 +143,7 @@ filterwarnings = [ "ignore:run_pre_save_hook is deprecated:DeprecationWarning", "always:unclosed