Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support for configure magic on Spark Python Kubernetes Kernels (WIP) #1105

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
177 changes: 176 additions & 1 deletion enterprise_gateway/services/kernels/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,20 @@
"""Tornado handlers for kernel CRUD and communication."""
import json
import os
from datetime import datetime, timezone
from functools import partial

import jupyter_server.services.kernels.handlers as jupyter_server_handlers
import tornado
from jupyter_client.jsonutil import date_default
from jupyter_server.base.handlers import APIHandler
from tornado import web

try:
from jupyter_client.jsonutil import json_default
except ImportError:
from jupyter_client.jsonutil import date_default as json_default

from ...mixins import CORSMixin, JSONErrorsMixin, TokenAuthorizationMixin


Expand Down Expand Up @@ -146,11 +153,179 @@ def get(self, kernel_id):
self.finish(json.dumps(model, default=date_default))


default_handlers = []
class ConfigureMagicHandler(CORSMixin, JSONErrorsMixin, APIHandler):
@web.authenticated
async def post(self, kernel_id):
self.log.info(f"Update request received for kernel: {kernel_id}")
km = self.kernel_manager
km.check_kernel_id(kernel_id)
payload = self.get_json_body()
self.log.debug(f"Request payload: {payload}")
if payload is None:
self.finish(
json.dumps(
{
"message": f"Empty payload received. No operation performed on kernel: {kernel_id}"
},
default=date_default,
)
)
return
if type(payload) != dict:
raise web.HTTPError(400, f"Invalid JSON payload received for kernel: {kernel_id}.")
if payload.get("env", None) is None: # We only allow env field for now.
raise web.HTTPError(
400, "Missing required field `env` in payload for kernel: {kernel_id}."
)
kernel = km.get_kernel(kernel_id)
if kernel.restarting: # handle duplicate request.
self.log.info(
"An existing restart request is still in progress. Skipping this request."
)
raise web.HTTPError(
400, f"Duplicate configure kernel request received for kernel: {kernel_id}."
)
try:
# update Kernel metadata
kernel.set_user_extra_overrides(payload)
await km.restart_kernel(kernel_id)
kernel.fire_kernel_event_callbacks(
event="kernel_refresh", zmq_messages=payload.get("zmq_messages", {})
)
except web.HTTPError as he:
self.log.exception(
f"HTTPError exception occurred while re-configuring kernel: {kernel_id}: {he}"
)
await km.shutdown_kernel(kernel_id)
kernel.fire_kernel_event_callbacks(
event="kernel_refresh_failure", zmq_messages=payload.get("zmq_messages", {})
)
raise he
except Exception as e:
self.log.exception(
f"An exception occurred while re-configuring kernel: {kernel_id}: {e}"
)
await km.shutdown_kernel(kernel_id)
kernel.fire_kernel_event_callbacks(
event="kernel_refresh_failure", zmq_messages=payload.get("zmq_messages", {})
)
raise web.HTTPError(
500,
f"Error occurred while re-configuring kernel: {kernel_id}",
reason=f"{e}",
)
else:
response_body = {"message": f"Successfully re-configured kernel: {kernel_id}."}
self.finish(json.dumps(response_body, default=date_default))
return


class RemoteZMQChannelsHandler(
TokenAuthorizationMixin, CORSMixin, JSONErrorsMixin, jupyter_server_handlers.ZMQChannelsHandler
):
def open(self, kernel_id):
self.log.debug(f"Websocket open request received for kernel: {kernel_id}")
super().open(kernel_id)
km = self.kernel_manager
km.add_kernel_event_callbacks(kernel_id, self.on_kernel_refresh, "kernel_refresh")
km.add_kernel_event_callbacks(
kernel_id, self.on_kernel_refresh_failure, "kernel_refresh_failure"
)

def on_kernel_refresh(self, **kwargs):
self.log.info("Refreshing the client websocket to kernel connection.")
self.refresh_zmq_sockets()
zmq_messages = kwargs.get("zmq_messages", {})
if "stream_reply" in zmq_messages:
self.log.debug("Sending stream_reply success message.")
success_message = zmq_messages.get("stream_reply")
success_message["content"] = {
"name": "stdout",
"text": "The kernel is successfully refreshed.",
}
self._send_ws_message(success_message)
if "exec_reply" in zmq_messages:
self.log.debug("Sending exec_reply message.")
self._send_ws_message(zmq_messages.get("exec_reply"))
if "idle_reply" in zmq_messages:
self.log.debug("Sending idle_reply message.")
self._send_ws_message(zmq_messages.get("idle_reply"))
self._send_status_message(
"kernel_refreshed"
) # In the future, UI clients might start to consume this.

def on_kernel_refresh_failure(self, **kwargs):
self.log.error("kernel %s refresh failed!", self.kernel_id)
zmq_messages = kwargs.get("zmq_messages", {})
if "error_reply" in zmq_messages:
self.log.debug("Sending stream_reply error message.")
error_message = zmq_messages.get("error_reply")
error_message["content"] = {
"ename": "KernelRefreshFailed",
"evalue": "The kernel refresh operation failed.",
"traceback": ["The kernel refresh operation failed."],
}
self._send_ws_message(error_message)
if "exec_reply" in zmq_messages:
self.log.debug("Sending exec_reply message.")
exec_reply = zmq_messages.get("exec_reply").copy()
if "metadata" in exec_reply:
exec_reply["metadata"]["status"] = "error"
exec_reply["content"]["status"] = "error"
exec_reply["content"]["ename"] = "KernelRefreshFailed."
exec_reply["content"]["evalue"] = "The kernel refresh operation failed."
exec_reply["content"]["traceback"] = ["The kernel refresh operation failed."]
self._send_ws_message(exec_reply)
if "idle_reply" in zmq_messages:
self.log.info("Sending idle reply message.")
self._send_ws_message(zmq_messages.get("idle_reply"))
self.log.debug("sending kernel dead message.")
self._send_status_message("dead")

def refresh_zmq_sockets(self):
self.close_existing_streams()
kernel = self.kernel_manager.get_kernel(self.kernel_id)
self.session.key = kernel.session.key # refresh the session key
self.log.debug("Creating new ZMQ Socket streams.")
self.create_stream()
for channel, stream in self.channels.items():
self.log.debug(f"Updating channel: {channel}")
stream.on_recv_stream(self._on_zmq_reply)

def close_existing_streams(self):
self.log.debug(f"Closing existing channels for kernel: {self.kernel_id}")
for channel, stream in self.channels.items():
if stream is not None and not stream.closed():
self.log.debug(f"Close channel : {channel}")
stream.on_recv(None)
stream.close()
self.channels = {}

def _send_ws_message(self, kernel_msg):
self.log.debug(f"Sending websocket message: {kernel_msg}")
if "header" in kernel_msg and type(kernel_msg["header"] == dict):
kernel_msg["header"]["date"] = datetime.utcnow().replace(tzinfo=timezone.utc)
self.write_message(json.dumps(kernel_msg, default=json_default))

def on_close(self):
self.log.info(f"Websocket close request received for kernel: {self.kernel_id}")
super().on_close()
self.kernel_manager.remove_kernel_event_callbacks(
self.kernel_id, self.on_kernel_refresh, "kernel_refresh"
)
self.kernel_manager.remove_kernel_event_callbacks(
self.kernel_id, self.on_kernel_refresh_failure, "kernel_refresh_failure"
)


_kernel_id_regex = r"(?P<kernel_id>\w+-\w+-\w+-\w+-\w+)"
default_handlers = [(r"/api/kernels/configure/%s" % _kernel_id_regex, ConfigureMagicHandler)]
for path, cls in jupyter_server_handlers.default_handlers:
if cls.__name__ in globals():
# Use the same named class from here if it exists
default_handlers.append((path, globals()[cls.__name__]))
elif cls.__name__ == jupyter_server_handlers.ZMQChannelsHandler.__name__:
default_handlers.append((path, RemoteZMQChannelsHandler))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this meant to replace ZMQChannelsHandler? I guess I don't understand why ZMQChannelsHandler isn't satisfied by the first condition - but I'm not that familiar with globals() (sorry).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to discuss this further.
what I am trying do here is replace the ZMQChannelsHandler with RemoteZMQChannelsHandler for handling the channels requests.

I tried to re-use the same class name on EG but was facing some issue where websocket connection was failing.

else:
# Gen a new type with CORS and token auth
bases = (TokenAuthorizationMixin, CORSMixin, JSONErrorsMixin, cls)
Expand Down
Loading