Skip to content

Commit

Permalink
Add an option to have authentication enabled for all endpoints by def…
Browse files Browse the repository at this point in the history
…ault (#1392)
  • Loading branch information
krassowski authored Mar 4, 2024
1 parent b3caa3c commit da50f2a
Show file tree
Hide file tree
Showing 17 changed files with 632 additions and 26 deletions.
55 changes: 55 additions & 0 deletions jupyter_server/auth/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,58 @@ async def inner(self, *args, **kwargs):
return cast(FuncT, wrapper(method))

return cast(FuncT, wrapper)


def allow_unauthenticated(method: FuncT) -> FuncT:
"""A decorator for tornado.web.RequestHandler methods
that allows any user to make the following request.
Selectively disables the 'authentication' layer of REST API which
is active when `ServerApp.allow_unauthenticated_access = False`.
To be used exclusively on endpoints which may be considered public,
for example the login page handler.
.. versionadded:: 2.13
Parameters
----------
method : bound callable
the endpoint method to remove authentication from.
"""

@wraps(method)
def wrapper(self, *args, **kwargs):
return method(self, *args, **kwargs)

setattr(wrapper, "__allow_unauthenticated", True)

return cast(FuncT, wrapper)


def ws_authenticated(method: FuncT) -> FuncT:
"""A decorator for websockets derived from `WebSocketHandler`
that authenticates user before allowing to proceed.
Differently from tornado.web.authenticated, does not redirect
to the login page, which would be meaningless for websockets.
.. versionadded:: 2.13
Parameters
----------
method : bound callable
the endpoint method to add authentication for.
"""

@wraps(method)
def wrapper(self, *args, **kwargs):
user = self.current_user
if user is None:
self.log.warning("Couldn't authenticate WebSocket connection")
raise HTTPError(403)
return method(self, *args, **kwargs)

setattr(wrapper, "__allow_unauthenticated", False)

return cast(FuncT, wrapper)
4 changes: 4 additions & 0 deletions jupyter_server/auth/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from tornado.escape import url_escape

from ..base.handlers import JupyterHandler
from .decorator import allow_unauthenticated
from .security import passwd_check, set_password


Expand Down Expand Up @@ -73,6 +74,7 @@ def _redirect_safe(self, url, default=None):
url = default
self.redirect(url)

@allow_unauthenticated
def get(self):
"""Get the login form."""
if self.current_user:
Expand All @@ -81,6 +83,7 @@ def get(self):
else:
self._render()

@allow_unauthenticated
def post(self):
"""Post a login."""
user = self.current_user = self.identity_provider.process_login_form(self)
Expand Down Expand Up @@ -110,6 +113,7 @@ def passwd_check(self, a, b):
"""Check a passwd."""
return passwd_check(a, b)

@allow_unauthenticated
def post(self):
"""Post a login form."""
typed_password = self.get_argument("password", default="")
Expand Down
2 changes: 2 additions & 0 deletions jupyter_server/auth/logout.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from ..base.handlers import JupyterHandler
from .decorator import allow_unauthenticated


class LogoutHandler(JupyterHandler):
"""An auth logout handler."""

@allow_unauthenticated
def get(self):
"""Handle a logout."""
self.identity_provider.clear_login_cookie(self)
Expand Down
64 changes: 57 additions & 7 deletions jupyter_server/base/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import warnings
from http.client import responses
from logging import Logger
from typing import TYPE_CHECKING, Any, Awaitable, Sequence, cast
from typing import TYPE_CHECKING, Any, Awaitable, Coroutine, Sequence, cast
from urllib.parse import urlparse

import prometheus_client
Expand All @@ -29,7 +29,7 @@
from jupyter_server import CallContext
from jupyter_server._sysinfo import get_sys_info
from jupyter_server._tz import utcnow
from jupyter_server.auth.decorator import authorized
from jupyter_server.auth.decorator import allow_unauthenticated, authorized
from jupyter_server.auth.identity import User
from jupyter_server.i18n import combine_translations
from jupyter_server.services.security import csp_report_uri
Expand Down Expand Up @@ -589,7 +589,7 @@ def check_host(self) -> bool:
)
return allow

async def prepare(self) -> Awaitable[None] | None: # type:ignore[override]
async def prepare(self, *, _redirect_to_login=True) -> Awaitable[None] | None: # type:ignore[override]
"""Prepare a response."""
# Set the current Jupyter Handler context variable.
CallContext.set(CallContext.JUPYTER_HANDLER, self)
Expand Down Expand Up @@ -630,6 +630,25 @@ async def prepare(self) -> Awaitable[None] | None: # type:ignore[override]
self.set_cors_headers()
if self.request.method not in {"GET", "HEAD", "OPTIONS"}:
self.check_xsrf_cookie()

if not self.settings.get("allow_unauthenticated_access", False):
if not self.request.method:
raise HTTPError(403)
method = getattr(self, self.request.method.lower())
if not getattr(method, "__allow_unauthenticated", False):
if _redirect_to_login:
# reuse `web.authenticated` logic, which redirects to the login
# page on GET and HEAD and otherwise raises 403
return web.authenticated(lambda _: super().prepare())(self)
else:
# raise 403 if user is not known without redirecting to login page
user = self.current_user
if user is None:
self.log.warning(
f"Couldn't authenticate {self.__class__.__name__} connection"
)
raise web.HTTPError(403)

return super().prepare()

# ---------------------------------------------------------------
Expand Down Expand Up @@ -726,7 +745,7 @@ def write_error(self, status_code: int, **kwargs: Any) -> None:
class APIHandler(JupyterHandler):
"""Base class for API handlers"""

async def prepare(self) -> None:
async def prepare(self) -> None: # type:ignore[override]
"""Prepare an API response."""
await super().prepare()
if not self.check_origin():
Expand Down Expand Up @@ -794,6 +813,7 @@ def finish(self, *args: Any, **kwargs: Any) -> Future[Any]:
self.set_header("Content-Type", set_content_type)
return super().finish(*args, **kwargs)

@allow_unauthenticated
def options(self, *args: Any, **kwargs: Any) -> None:
"""Get the options."""
if "Access-Control-Allow-Headers" in self.settings.get("headers", {}):
Expand Down Expand Up @@ -837,7 +857,7 @@ def options(self, *args: Any, **kwargs: Any) -> None:
class Template404(JupyterHandler):
"""Render our 404 template"""

async def prepare(self) -> None:
async def prepare(self) -> None: # type:ignore[override]
"""Prepare a 404 response."""
await super().prepare()
raise web.HTTPError(404)
Expand Down Expand Up @@ -1002,6 +1022,18 @@ def compute_etag(self) -> str | None:
"""Compute the etag."""
return None

# access is allowed as this class is used to serve static assets on login page
# TODO: create an allow-list of files used on login page and remove this decorator
@allow_unauthenticated
def get(self, path: str, include_body: bool = True) -> Coroutine[Any, Any, None]:
return super().get(path, include_body)

# access is allowed as this class is used to serve static assets on login page
# TODO: create an allow-list of files used on login page and remove this decorator
@allow_unauthenticated
def head(self, path: str) -> Awaitable[None]:
return super().head(path)

@classmethod
def get_absolute_path(cls, roots: Sequence[str], path: str) -> str:
"""locate a file to serve on our static file search path"""
Expand Down Expand Up @@ -1036,6 +1068,7 @@ class APIVersionHandler(APIHandler):

_track_activity = False

@allow_unauthenticated
def get(self) -> None:
"""Get the server version info."""
# not authenticated, so give as few info as possible
Expand All @@ -1048,6 +1081,7 @@ class TrailingSlashHandler(web.RequestHandler):
This should be the first, highest priority handler.
"""

@allow_unauthenticated
def get(self) -> None:
"""Handle trailing slashes in a get."""
assert self.request.uri is not None
Expand All @@ -1064,6 +1098,7 @@ def get(self) -> None:
class MainHandler(JupyterHandler):
"""Simple handler for base_url."""

@allow_unauthenticated
def get(self) -> None:
"""Get the main template."""
html = self.render_template("main.html")
Expand Down Expand Up @@ -1104,18 +1139,20 @@ async def redirect_to_files(self: Any, path: str) -> None:
self.log.debug("Redirecting %s to %s", self.request.path, url)
self.redirect(url)

@allow_unauthenticated
async def get(self, path: str = "") -> None:
return await self.redirect_to_files(self, path)


class RedirectWithParams(web.RequestHandler):
"""Sam as web.RedirectHandler, but preserves URL parameters"""
"""Same as web.RedirectHandler, but preserves URL parameters"""

def initialize(self, url: str, permanent: bool = True) -> None:
"""Initialize a redirect handler."""
self._url = url
self._permanent = permanent

@allow_unauthenticated
def get(self) -> None:
"""Get a redirect."""
sep = "&" if "?" in self._url else "?"
Expand All @@ -1128,6 +1165,7 @@ class PrometheusMetricsHandler(JupyterHandler):
Return prometheus metrics for this server
"""

@allow_unauthenticated
def get(self) -> None:
"""Get prometheus metrics."""
if self.settings["authenticate_prometheus"] and not self.logged_in:
Expand All @@ -1137,6 +1175,18 @@ def get(self) -> None:
self.write(prometheus_client.generate_latest(prometheus_client.REGISTRY))


class PublicStaticFileHandler(web.StaticFileHandler):
"""Same as web.StaticFileHandler, but decorated to acknowledge that auth is not required."""

@allow_unauthenticated
def head(self, path: str) -> Awaitable[None]:
return super().head(path)

@allow_unauthenticated
def get(self, path: str, include_body: bool = True) -> Coroutine[Any, Any, None]:
return super().get(path, include_body)


# -----------------------------------------------------------------------------
# URL pattern fragments for reuse
# -----------------------------------------------------------------------------
Expand All @@ -1152,6 +1202,6 @@ def get(self) -> None:
default_handlers = [
(r".*/", TrailingSlashHandler),
(r"api", APIVersionHandler),
(r"/(robots\.txt|favicon\.ico)", web.StaticFileHandler),
(r"/(robots\.txt|favicon\.ico)", PublicStaticFileHandler),
(r"/metrics", PrometheusMetricsHandler),
]
40 changes: 39 additions & 1 deletion jupyter_server/base/websocket.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
"""Base websocket classes."""
import re
import warnings
from typing import Optional, no_type_check
from urllib.parse import urlparse

from tornado import ioloop
from tornado import ioloop, web
from tornado.iostream import IOStream

from jupyter_server.base.handlers import JupyterHandler
from jupyter_server.utils import JupyterServerAuthWarning

# ping interval for keeping websockets alive (30 seconds)
WS_PING_INTERVAL = 30000

Expand Down Expand Up @@ -82,6 +86,40 @@ def check_origin(self, origin: Optional[str] = None) -> bool:
def clear_cookie(self, *args, **kwargs):
"""meaningless for websockets"""

@no_type_check
def _maybe_auth(self):
"""Verify authentication if required.
Only used when the websocket class does not inherit from JupyterHandler.
"""
if not self.settings.get("allow_unauthenticated_access", False):
if not self.request.method:
raise web.HTTPError(403)
method = getattr(self, self.request.method.lower())
if not getattr(method, "__allow_unauthenticated", False):
# rather than re-using `web.authenticated` which also redirects
# to login page on GET, just raise 403 if user is not known
user = self.current_user
if user is None:
self.log.warning("Couldn't authenticate WebSocket connection")
raise web.HTTPError(403)

@no_type_check
def prepare(self, *args, **kwargs):
"""Handle a get request."""
if not isinstance(self, JupyterHandler):
should_authenticate = not self.settings.get("allow_unauthenticated_access", False)
if "identity_provider" in self.settings and should_authenticate:
warnings.warn(
"WebSocketMixin sub-class does not inherit from JupyterHandler"
" preventing proper authentication using custom identity provider.",
JupyterServerAuthWarning,
stacklevel=2,
)
self._maybe_auth()
return super().prepare(*args, **kwargs)
return super().prepare(*args, **kwargs, _redirect_to_login=False)

@no_type_check
def open(self, *args, **kwargs):
"""Open the websocket."""
Expand Down
2 changes: 1 addition & 1 deletion jupyter_server/extension/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def _prepare_handlers(self):
)
new_handlers.append(handler)

webapp.add_handlers(".*$", new_handlers) # type:ignore[arg-type]
webapp.add_handlers(".*$", new_handlers)

def _prepare_templates(self):
"""Add templates to web app settings if extension has templates."""
Expand Down
Loading

0 comments on commit da50f2a

Please sign in to comment.