diff --git a/braces/views/_other.py b/braces/views/_other.py index 1d82988..77db488 100644 --- a/braces/views/_other.py +++ b/braces/views/_other.py @@ -1,5 +1,6 @@ from django.core.exceptions import ImproperlyConfigured from django.shortcuts import redirect +from django.template.response import TemplateResponse from django.views.decorators.cache import cache_control, never_cache from django.urls import resolve from django.utils.encoding import force_str @@ -133,6 +134,13 @@ def dispatch(self, request, *args, **kwargs): return handler(request, *args, **kwargs) +class CustomHeadersTemplateResponse(TemplateResponse): + def __init__(self, *args, **kwargs): + headers = kwargs.pop("headers", {}) + headers.update(kwargs.pop("extra_headers", {})) + super().__init__(*args, headers=headers, **kwargs) + + class HeaderMixin: """ Add extra HTTP headers to a response by specifying them in the @@ -140,10 +148,20 @@ class HeaderMixin: """ headers = {} + response_class = CustomHeadersTemplateResponse def get_headers(self, request): return self.headers + def render_to_response(self, context, **response_kwargs): + """ + Return a response, using the `response_class` for this view, with a + template rendered with the given context. + Pass response_kwargs to the constructor of the response class. + """ + response_kwargs.setdefault("extra_headers", self.get_headers(self.request)) + return super().render_to_response(context, **response_kwargs) + def dispatch(self, request, *args, **kwargs): """ Override this method to customize the way additional headers are @@ -151,9 +169,11 @@ def dispatch(self, request, *args, **kwargs): ``.items()`` method. """ response = super().dispatch(request, *args, **kwargs) - for key, value in self.get_headers(request).items(): - if key not in response: - response[key] = value + if not getattr(self, "template_name", None): + # No template so probably no `render_to_response` call + for key, value in self.get_headers(request).items(): + if key not in response: + response[key] = value return response diff --git a/tests/test_access_mixins.py b/tests/test_access_mixins.py index e8c477d..746efcf 100644 --- a/tests/test_access_mixins.py +++ b/tests/test_access_mixins.py @@ -744,7 +744,7 @@ def test_https_does_not_redirect(self): @pytest.mark.django_db class TestRecentLoginRequiredMixin(test.TestCase): - """ Scenarios requiring a recent login""" + """ Scenarios requiring a recent login """ view_class = RecentLoginRequiredView recent_view_url = "/recent_login/" diff --git a/tests/test_other_mixins.py b/tests/test_other_mixins.py index 79c5ca3..8f47ecf 100644 --- a/tests/test_other_mixins.py +++ b/tests/test_other_mixins.py @@ -588,6 +588,10 @@ def test_existing(self): response = self.client.get('/headers/existing/') self.assertEqual(response['X-DJANGO-BRACES-EXISTING'], 'value') + def test_template(self): + response = self.client.get('/headers/template/') + self.assertEqual(response['X-DJANGO-BRACES-1'], '1') + class TestCacheControlMixin(test.TestCase): """Scenarios around controlling cache""" def test_cachecontrol_public(self): diff --git a/tests/urls.py b/tests/urls.py index bbccef0..6bef3f3 100644 --- a/tests/urls.py +++ b/tests/urls.py @@ -118,6 +118,7 @@ re_path(r'^headers/attribute/$', views.AttributeHeaderView.as_view()), re_path(r'^headers/method/$', views.MethodHeaderView.as_view()), re_path(r'^headers/existing/$', views.ExistingHeaderView.as_view()), + re_path(r'^headers/template/$', views.HeadersWithTemplate.as_view()), # CacheControlMixin tests re_path(r'^cachecontrol/public/$', views.CacheControlPublicView.as_view()), # NeverCacheMixin tests diff --git a/tests/views.py b/tests/views.py index 8a8e054..11c3810 100644 --- a/tests/views.py +++ b/tests/views.py @@ -451,6 +451,18 @@ class ExistingHeaderView(views.HeaderMixin, AuxiliaryHeaderView): 'X-DJANGO-BRACES-EXISTING': 'other value' } +class HeadersWithTemplate(views.SetHeadlineMixin, views.HeaderMixin, TemplateView): + """ + View for testing HeaderMixin with a custom TemplateResponse. + """ + + template_name = "blank.html" + headline = "Test headline" + + headers = { + "X-DJANGO-BRACES-1": 1 + } + class CacheControlPublicView(views.CacheControlMixin, OkView): """A public-cached page with a 60 second timeout"""