Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

proposal: add subprotocol for token-authenticated websockets #1407

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions jupyter_server/auth/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
import uuid
from dataclasses import asdict, dataclass
from http.cookies import Morsel
from urllib.parse import unquote

from tornado import escape, httputil, web
from tornado.websocket import WebSocketHandler
from traitlets import Bool, Dict, Type, Unicode, default
from traitlets.config import LoggingConfigurable

Expand Down Expand Up @@ -106,6 +108,9 @@ def _backward_compat_user(got_user: t.Any) -> User:
raise ValueError(msg)


_TOKEN_SUBPROTOCOL = "v1.token.websocket.jupyter.org"


class IdentityProvider(LoggingConfigurable):
"""
Interface for providing identity management and authentication.
Expand Down Expand Up @@ -424,6 +429,21 @@ def get_token(self, handler: web.RequestHandler) -> str | None:
m = self.auth_header_pat.match(handler.request.headers.get("Authorization", ""))
if m:
user_token = m.group(2)
if not user_token and isinstance(handler, WebSocketHandler):
subprotocol_header = handler.request.headers.get("Sec-WebSocket-Protocol")
if subprotocol_header:
subprotocols = [s.strip() for s in subprotocol_header.split(",")]
for subprotocol in subprotocols:
if subprotocol.startswith(_TOKEN_SUBPROTOCOL + "."):
user_token = subprotocol[len(_TOKEN_SUBPROTOCOL) + 1 :]
try:
user_token = unquote(user_token)
except ValueError:
# leave tokens that fail to decode
# these won't be accepted, but proceed with validation
pass
break

return user_token

async def get_user_token(self, handler: web.RequestHandler) -> User | None:
Expand Down
14 changes: 14 additions & 0 deletions jupyter_server/base/websocket.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
"""Base websocket classes."""

from __future__ import annotations

import re
import warnings
from typing import Optional, no_type_check
Expand Down Expand Up @@ -164,3 +167,14 @@ def send_ping(self):
def on_pong(self, data):
"""Handle a pong message."""
self.last_pong = ioloop.IOLoop.current().time()

def select_subprotocol(self, subprotocols: list[str]) -> str | None:
# default subprotocol
# some clients (Chrome)
# require selected subprotocol to match one of the requested subprotocols
# otherwise connection is rejected
token_subprotocol = "v1.token.websocket.jupyter.org"
if token_subprotocol in subprotocols:
return token_subprotocol
else:
return None
2 changes: 2 additions & 0 deletions jupyter_server/services/events/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@

from jupyter_server.auth.decorator import authorized, ws_authenticated
from jupyter_server.base.handlers import JupyterHandler
from jupyter_server.base.websocket import WebSocketMixin

from ...base.handlers import APIHandler

AUTH_RESOURCE = "events"


class SubscribeWebsocket(
WebSocketMixin,
JupyterHandler,
websocket.WebSocketHandler,
):
Expand Down
8 changes: 7 additions & 1 deletion jupyter_server/services/kernels/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ def select_subprotocol(self, subprotocols):
preferred_protocol = "v1.kernel.websocket.jupyter.org"
elif preferred_protocol == "":
preferred_protocol = None
selected_subprotocol = preferred_protocol if preferred_protocol in subprotocols else None

# super() subprotocol enables token authentication via subprotocol
selected_subprotocol = (
preferred_protocol
if preferred_protocol in subprotocols
else super().select_subprotocol(subprotocols)
)
# None is the default, "legacy" protocol
return selected_subprotocol
33 changes: 32 additions & 1 deletion tests/base/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tornado.websocket import WebSocketClosedError, WebSocketHandler

from jupyter_server.auth import IdentityProvider, User
from jupyter_server.auth.decorator import allow_unauthenticated
from jupyter_server.auth.decorator import allow_unauthenticated, ws_authenticated
from jupyter_server.base.handlers import JupyterHandler
from jupyter_server.base.websocket import WebSocketMixin
from jupyter_server.serverapp import ServerApp
Expand Down Expand Up @@ -75,6 +75,12 @@ class NoAuthRulesWebsocketHandler(MockJupyterHandler):
pass


class AuthenticatedWebsocketHandler(MockJupyterHandler):
@ws_authenticated
def get(self, *args, **kwargs) -> None:
return super().get(*args, **kwargs)


class PermissiveWebsocketHandler(MockJupyterHandler):
@allow_unauthenticated
def get(self, *args, **kwargs) -> None:
Expand Down Expand Up @@ -126,6 +132,31 @@ async def test_websocket_auth_required(jp_serverapp, jp_ws_fetch):
assert exception.value.code == 403


async def test_websocket_token_subprotocol_auth(jp_serverapp, jp_ws_fetch):
app: ServerApp = jp_serverapp
app.web_app.add_handlers(
".*$",
[
(url_path_join(app.base_url, "ws"), AuthenticatedWebsocketHandler),
],
)

with pytest.raises(HTTPClientError) as exception:
ws = await jp_ws_fetch("ws", headers={"Authorization": ""})
assert exception.value.code == 403
token = jp_serverapp.identity_provider.token
ws = await jp_ws_fetch(
"ws",
headers={
"Authorization": "",
"Sec-WebSocket-Protocol": "v1.kernel.websocket.jupyter.org, v1.token.websocket.jupyter.org, v1.token.websocket.jupyter.org."
+ token,
},
)
assert ws.protocol.selected_subprotocol == "v1.token.websocket.jupyter.org"
ws.close()


class IndiscriminateIdentityProvider(IdentityProvider):
async def get_user(self, handler):
return User(username="test")
Expand Down