diff --git a/ninja/openapi/docs.py b/ninja/openapi/docs.py index c70fadd13..6b32c9bc9 100644 --- a/ninja/openapi/docs.py +++ b/ninja/openapi/docs.py @@ -1,14 +1,14 @@ import json from abc import ABC, abstractmethod from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Callable, Iterator, Optional, Union from django.conf import settings from django.http import HttpRequest, HttpResponse from django.shortcuts import render from django.urls import reverse -from ninja.constants import NOT_SET +from ninja.router import Router from ninja.types import DictStrAny if TYPE_CHECKING: @@ -102,10 +102,25 @@ def _render_cdn_template( return HttpResponse(html) -def _csrf_needed(api: "NinjaAPI") -> bool: - if api.csrf: - return True - if not api.auth or api.auth == NOT_SET: - return False +def _iter_auth_callbacks( + api_or_router: Union["NinjaAPI", Router], +) -> Iterator[Callable[..., Any]]: + """this is helper to iterate over all operations in api or router""" + if isinstance(api_or_router, Router): + for _, path_view in api_or_router.path_operations.items(): + for operation in path_view.operations: + yield from operation.auth_callbacks + for _, router in api_or_router._routers: # noqa + yield from _iter_auth_callbacks(router) + - return any(getattr(a, "csrf", False) for a in api.auth) # type: ignore +def _csrf_needed(api: "NinjaAPI") -> bool: + add_csrf: Optional[bool] = getattr(api, "_add_csrf", None) + if add_csrf is not None: + return add_csrf + for auth_callback in _iter_auth_callbacks(api): + if getattr(auth_callback, "csrf", False): + api._add_csrf = True # type: ignore[attr-defined] + return True + api._add_csrf = False # type: ignore[attr-defined] + return False diff --git a/tests/test_csrf.py b/tests/test_csrf.py index eeab4884f..17d5198a6 100644 --- a/tests/test_csrf.py +++ b/tests/test_csrf.py @@ -3,7 +3,7 @@ from django.conf import settings from django.views.decorators.csrf import csrf_exempt -from ninja import NinjaAPI +from ninja import NinjaAPI, Router from ninja.security import APIKeyCookie, APIKeyHeader from ninja.testing import TestClient as BaseTestClient @@ -98,40 +98,97 @@ def test_view(request): assert response.status_code == 200, response.content -def test_docs(): +def test_docs_add_csrf(): "Testing that docs are initializing csrf headers correctly" - api = NinjaAPI(csrf=True) + class CookieAuth(APIKeyCookie): + def authenticate(self, request, key): + return key == "test" + + api = NinjaAPI(csrf=False, auth=CookieAuth()) # `csrf=False` should be ignored + + @api.get("/test") + def test_view(request): + return {"success": True} client = TestClient(api) + resp = client.get("/docs") assert resp.status_code == 200 csrf_token = re.findall(r'data-csrf-token="(.*?)"', resp.content.decode("utf8"))[0] assert len(csrf_token) > 0 - api.csrf = False + assert hasattr(api, "_add_csrf") # `api._add_csrf` should be set as cache + resp = client.get("/docs") assert resp.status_code == 200 csrf_token = re.findall(r'data-csrf-token="(.*?)"', resp.content.decode("utf8"))[0] - assert len(csrf_token) == 0 + assert len(csrf_token) > 0 -def test_docs_cookie_auth(): +def test_docs_add_csrf_by_operation(): + "Testing that docs are initializing csrf headers correctly" + class CookieAuth(APIKeyCookie): def authenticate(self, request, key): return key == "test" - class HeaderAuth(APIKeyHeader): + api = NinjaAPI(csrf=False) # `csrf=False` should be ignored + + @api.get("/test1", auth=CookieAuth()) + def test_view1(request): + return {"success": True} + + @api.get("/test2") + def test_view2(request): + return {"success": True} + + client = TestClient(api) + resp = client.get("/docs") + assert resp.status_code == 200 + csrf_token = re.findall(r'data-csrf-token="(.*?)"', resp.content.decode("utf8"))[0] + assert len(csrf_token) > 0 + + +def test_docs_add_csrf_by_sub_router(): + "Testing that docs are initializing csrf headers correctly" + + class CookieAuth(APIKeyCookie): def authenticate(self, request, key): return key == "test" - api = NinjaAPI(csrf=False, auth=CookieAuth()) + api = NinjaAPI(csrf=False) # `csrf=False` should be ignored + + @api.get("/test1", auth=CookieAuth()) + def test_view1(request): + return {"success": True} + + router = Router() + + @router.get("/test2") + def test_view2(request): + return {"success": True} + + api.add_router("/router", router) + client = TestClient(api) resp = client.get("/docs") + assert resp.status_code == 200 csrf_token = re.findall(r'data-csrf-token="(.*?)"', resp.content.decode("utf8"))[0] assert len(csrf_token) > 0 - api = NinjaAPI(csrf=False, auth=HeaderAuth()) + +def test_docs_do_not_add_csrf(): + class HeaderAuth(APIKeyHeader): + def authenticate(self, request, key): + return key == "test" + + api = NinjaAPI(csrf=True, auth=HeaderAuth()) # `csrf=True` should be ignored + + @api.get("/test") + def test_view(request): + return {"success": True} + client = TestClient(api) resp = client.get("/docs") csrf_token = re.findall(r'data-csrf-token="(.*?)"', resp.content.decode("utf8"))[0]