Skip to content

Commit

Permalink
feat: update auth cookie logic for both backend and frontend
Browse files Browse the repository at this point in the history
  • Loading branch information
spwoodcock committed Dec 11, 2024
1 parent 4af490f commit 8e797f7
Show file tree
Hide file tree
Showing 19 changed files with 542 additions and 552 deletions.
243 changes: 243 additions & 0 deletions src/backend/app/auth/auth_deps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
# Copyright (c) Humanitarian OpenStreetMap Team
#
# This file is part of FMTM.
#
# FMTM is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# FMTM is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with FMTM. If not, see <https:#www.gnu.org/licenses/>.
#

"""Auth dependencies, for restricted routes and cookie handling."""

import time
from typing import Optional

import jwt
from fastapi import Header, HTTPException, Request, Response
from fastapi.responses import JSONResponse
from loguru import logger as log

from app.auth.auth_schemas import AuthUser
from app.config import settings
from app.db.enums import HTTPStatus, UserRole

### Cookie / Token Handling


def get_cookie_value(request: Request, *cookie_names: str) -> Optional[str]:
"""Get the first available value from a list of cookie names."""
for name in cookie_names:
value = request.cookies.get(name)
if value:
return value
return None


def set_cookie(
response: Response,
key: str,
value: str,
max_age: int,
secure: bool,
domain: str,
) -> None:
"""Helper function to set a cookie on a response.
For now, samesite is set lax, max_age equals expiry.
"""
response.set_cookie(
key=key,
value=value,
max_age=max_age,
expires=max_age,
path="/",
domain=domain,
secure=secure,
httponly=True,
samesite="lax",
)


def set_cookies(
access_token: str,
refresh_token: str,
cookie_name: str = settings.cookie_name,
refresh_cookie_name: str = f"{settings.cookie_name}_refresh",
) -> JSONResponse:
"""Set cookies for the access and refresh tokens.
Args:
access_token (str): The access token to be stored in the cookie.
refresh_token (str): The refresh token to be stored in the cookie.
cookie_name (str, optional): The name of the cookie to store the access token.
refresh_cookie_name (str, optional): The name of the cookie to store the
refresh token.
Returns:
JSONResponse: A response with attached cookies (set-cookie headers).
"""
# NOTE if needed we can return the token in the JSON response, but we don't for now
# response = JSONResponse(status_code=HTTPStatus.OK,
# content={"token": access_token})
response = JSONResponse(status_code=HTTPStatus.OK, content={})

secure = not settings.DEBUG
domain = settings.FMTM_DOMAIN

set_cookie(
response,
cookie_name,
access_token,
max_age=86400, # 1 day
secure=secure,
domain=domain,
)
set_cookie(
response,
refresh_cookie_name,
refresh_token,
max_age=86400 * 7, # 1 week
secure=secure,
domain=domain,
)

return response


def create_jwt_tokens(input_data: dict) -> tuple[str, str]:
"""Generate access and refresh tokens.
Args:
input_data (dict): user data for which the access token is being generated.
Returns:
tuple[str]: The generated access tokens.
"""
access_token_data = input_data.copy()
# Set refresh token expiry to 7 days
refresh_token_data = {**input_data, "exp": int(time.time()) + 86400 * 7}

encryption_key = settings.ENCRYPTION_KEY.get_secret_value()
algorithm = settings.JWT_ENCRYPTION_ALGORITHM

access_token = jwt.encode(access_token_data, encryption_key, algorithm=algorithm)
refresh_token = jwt.encode(refresh_token_data, encryption_key, algorithm=algorithm)

return access_token, refresh_token


def refresh_jwt_token(
payload: dict,
# Default expiry 1 day
expiry_seconds: int = 86400,
) -> str:
"""Generate a new JTW token with expiry."""
payload["exp"] = int(time.time()) + expiry_seconds
return jwt.encode(
payload,
settings.ENCRYPTION_KEY.get_secret_value(),
algorithm=settings.JWT_ENCRYPTION_ALGORITHM,
)


def verify_jwt_token(token: str, ignore_expiry: bool = False) -> dict:
"""Verify the access token and return its payload.
Args:
token (str): The access token to be verified.
ignore_expiry (bool): Do not throw an error if the token is expired
upon deserialisation.
Returns:
dict: The payload of the access token if verification is successful.
Raises:
HTTPException: If the token has expired or credentials could not be validated.
"""
try:
return jwt.decode(
token,
settings.ENCRYPTION_KEY.get_secret_value(),
algorithms=[settings.JWT_ENCRYPTION_ALGORITHM],
audience=settings.FMTM_DOMAIN,
options={"verify_exp": False if ignore_expiry else True},
)
except jwt.ExpiredSignatureError as e:
raise HTTPException(
status_code=HTTPStatus.UNAUTHORIZED,
detail="Refresh token has expired",
) from e
except jwt.PyJWTError as e:
raise HTTPException(
status_code=HTTPStatus.UNAUTHORIZED,
detail="Could not validate refresh token",
) from e
except Exception as e:
log.exception(f"Unknown cookie/jwt error: {e}")
raise HTTPException(
status_code=HTTPStatus.UNAUTHORIZED,
detail="Could not validate refresh token",
) from e


### Endpoint Dependencies ###


async def login_required(
request: Request, access_token: str = Header(None)
) -> AuthUser:
"""Dependency for endpoints requiring login."""
if settings.DEBUG:
return AuthUser(sub="fmtm|1", username="localadmin", role=UserRole.ADMIN)

# Extract access token only from the OSM cookie
extracted_token = access_token or get_cookie_value(
request,
settings.cookie_name, # OSM cookie
)
return await _authenticate_user(extracted_token)


async def mapper_login_required(
request: Request, access_token: str = Header(None)
) -> AuthUser:
"""Dependency for mapper frontend login."""
if settings.DEBUG:
return AuthUser(sub="fmtm|1", username="localadmin", role=UserRole.ADMIN)

# Extract access token from OSM cookie, fallback to temp auth cookie
extracted_token = access_token or get_cookie_value(
request,
settings.cookie_name, # OSM cookie
f"{settings.cookie_name}_temp", # Temp cookie
)
return await _authenticate_user(extracted_token)


async def _authenticate_user(access_token: Optional[str]) -> AuthUser:
"""Authenticate user by verifying the access token."""
if not access_token:
raise HTTPException(
status_code=HTTPStatus.UNAUTHORIZED,
detail="No access token provided",
)

try:
token_data = verify_jwt_token(access_token)
except ValueError as e:
log.exception(f"Failed to verify access token: {e}", stack_info=True)
raise HTTPException(
status_code=HTTPStatus.UNAUTHORIZED,
detail="Access token not valid",
) from e

return AuthUser(**token_data)
Loading

0 comments on commit 8e797f7

Please sign in to comment.