diff --git a/jupyter_server/serverapp.py b/jupyter_server/serverapp.py index 90dd80eeb3..fa81b16555 100644 --- a/jupyter_server/serverapp.py +++ b/jupyter_server/serverapp.py @@ -110,7 +110,6 @@ from jupyter_server.services.kernels.connection.channels import ( ZMQChannelsWebsocketConnection, ) -from jupyter_server.services.kernels.kernel_broker import KernelWebsocketBroker from jupyter_server.services.kernels.kernelmanager import ( AsyncMappingKernelManager, MappingKernelManager, diff --git a/jupyter_server/services/kernels/connection/abc.py b/jupyter_server/services/kernels/connection/abc.py index 8fc4706b18..f918466c90 100644 --- a/jupyter_server/services/kernels/connection/abc.py +++ b/jupyter_server/services/kernels/connection/abc.py @@ -14,16 +14,20 @@ class KernelWebsocketConnectionABC(ABC): @abstractmethod async def connect(self): + """Connect the kernel websocket to the kernel ZMQ connections""" ... @abstractmethod async def disconnect(self): + """Disconnect the kernel websocket from the kernel ZMQ connections""" ... @abstractmethod def handle_incoming_message(self, incoming_msg: str) -> None: + """Broker the incoming websocket message to the appropriate ZMQ channel.""" ... @abstractmethod def handle_outgoing_message(self, stream: str, outgoing_msg: list) -> None: + """Broker outgoing ZMQ messages to the kernel websocket.""" ... diff --git a/jupyter_server/services/kernels/connection/base.py b/jupyter_server/services/kernels/connection/base.py index 84cbdb5f37..12b062e788 100644 --- a/jupyter_server/services/kernels/connection/base.py +++ b/jupyter_server/services/kernels/connection/base.py @@ -3,7 +3,7 @@ import sys from jupyter_client.session import Session -from traitlets import Float, Instance, default +from traitlets import Callable, Float, Instance, default from traitlets.config import LoggingConfigurable try: @@ -135,6 +135,8 @@ def _default_kernel_info_timeout(self): def _default_session(self): return Session(config=self.config) + write_message = Callable() + async def connect(self): raise NotImplementedError() diff --git a/jupyter_server/services/kernels/connection/channels.py b/jupyter_server/services/kernels/connection/channels.py index 0cedb1becf..a1233477de 100644 --- a/jupyter_server/services/kernels/connection/channels.py +++ b/jupyter_server/services/kernels/connection/channels.py @@ -9,7 +9,7 @@ from tornado import gen, web from tornado.ioloop import IOLoop from tornado.websocket import WebSocketClosedError -from traitlets import Bool, Dict, Float, Instance, Int, List, Unicode, default +from traitlets import Any, Bool, Dict, Float, Instance, Int, List, Unicode, default try: from jupyter_client.jsonutil import json_default @@ -89,7 +89,7 @@ class ZMQChannelsWebsocketConnection(BaseKernelWebsocketConnection): _close_future: Future channels = Dict({}) - kernel_info_channel = Unicode(allow_none=True) + kernel_info_channel = Any(allow_none=True) _kernel_info_future = Instance(klass=Future) @@ -132,7 +132,7 @@ def create_stream(self): identity = self.session.bsession for channel in ("iopub", "shell", "control", "stdin"): meth = getattr(self.kernel_manager, "connect_" + channel) - self.channels[channel] = stream = meth(self.kernel_id, identity=identity) + self.channels[channel] = stream = meth(identity=identity) stream.channel = channel def nudge(self): @@ -214,14 +214,6 @@ def on_iopub(msg): # Nudge the kernel with kernel info requests until we get an IOPub message def nudge(count): count += 1 - - # NOTE: this close check appears to never be True during on_open, - # even when the peer has closed the connection - if self.ws_connection is None or self.ws_connection.is_closing(): - self.log.debug("Nudge: cancelling on closed websocket: %s", self.kernel_id) - finish() - return - # check for stopped kernel if self.kernel_id not in self.multi_kernel_manager: self.log.debug("Nudge: cancelling on stopped kernel: %s", self.kernel_id) @@ -274,8 +266,6 @@ async def _register_session(self): self._open_sessions[self.session_key] = self async def prepare(self): - # authenticate first - super().pre_get() # check session collision: await self._register_session() # then request kernel info, waiting up to a certain time before giving up. @@ -525,7 +515,7 @@ def select_subprotocol(self, subprotocols): def _on_zmq_reply(self, stream, msg_list): # Sometimes this gets triggered when the on_close method is scheduled in the # eventloop but hasn't been called. - if self.ws_connection is None or stream.closed(): + if stream.closed(): self.log.warning("zmq message arrived on closed channel") self.close() return @@ -554,7 +544,7 @@ def request_kernel_info(self): # Create a kernel_info channel to query the kernel protocol version. # This channel will be closed after the kernel_info reply is received. if self.kernel_info_channel is None: - self.kernel_info_channel = self.kernel_manager.connect_shell(self.kernel_id) + self.kernel_info_channel = self.multi_kernel_manager.connect_shell(self.kernel_id) assert self.kernel_info_channel is not None self.kernel_info_channel.on_recv(self._handle_kernel_info_reply) self.session.send(self.kernel_info_channel, "kernel_info_request") diff --git a/jupyter_server/services/kernels/websocket.py b/jupyter_server/services/kernels/websocket.py index 98497e4918..70ac886ea8 100644 --- a/jupyter_server/services/kernels/websocket.py +++ b/jupyter_server/services/kernels/websocket.py @@ -187,18 +187,19 @@ async def pre_get(self): else: self.log.warning("No session ID specified") # For backwards compatibility with older versions - # of the message broker, call a prepare method if found. + # of the websocket connection, call a prepare method if found. if hasattr(self.connection, "prepare"): await self.connection.prepare() async def get(self, kernel_id): self.kernel_id = kernel_id + await self.pre_get() await super().get(kernel_id=kernel_id) async def open(self, kernel_id): # Wait for the kernel to emit an idle status. self.log.info(f"Connecting to kernel {self.kernel_id}.") - await self.connection.connect(session_id=self.session_id) + await self.connection.connect() def on_message(self, ws_message): """Get a kernel message from the websocket and turn it into a ZMQ message."""