From 332c27a69f391ce18743e4465a72540cb2a8ef2b Mon Sep 17 00:00:00 2001 From: Kenneth Love Date: Wed, 17 Nov 2021 12:32:11 -0800 Subject: [PATCH 1/2] Docstring updates to pass interrogate checks. Progress checkin --- braces/views/_other.py | 26 ++++++++-------- conftest.py | 1 + setup.py | 1 + tests/factories.py | 18 ++++++----- tests/forms.py | 2 ++ tests/models.py | 11 +++++-- tests/test_access_mixins.py | 59 +++++++++++++++++++++++++++++-------- tests/urls_namespaced.py | 4 +-- tests/views.py | 52 +++++++++++++++++++++++++------- 9 files changed, 129 insertions(+), 45 deletions(-) diff --git a/braces/views/_other.py b/braces/views/_other.py index 4bc35f6..caa99d1 100644 --- a/braces/views/_other.py +++ b/braces/views/_other.py @@ -68,13 +68,15 @@ class CanonicalSlugDetailMixin: """ def dispatch(self, request, *args, **kwargs): - # Set up since we need to super() later instead of earlier. + """ + Redirect to the appropriate URL if necessary. + Otherwise, trigger HTTP-method-appropriate handler. + """ self.request = request self.args = args self.kwargs = kwargs - # Get the current object, url slug, and - # urlpattern name (namespace aware). + # Get the current object, url slug, and url name. obj = self.get_object() slug = self.kwargs.get(self.slug_url_kwarg, None) match = resolve(request.path_info) @@ -82,14 +84,13 @@ def dispatch(self, request, *args, **kwargs): url_parts.append(match.url_name) current_urlpattern = ":".join(url_parts) - # Figure out what the slug is supposed to be. + # Find the canonical slug for the object if hasattr(obj, "get_canonical_slug"): canonical_slug = obj.get_canonical_slug() else: canonical_slug = self.get_canonical_slug() - # If there's a discrepancy between the slug in the url and the - # canonical slug, redirect to the canonical slug. + # Redirect if current slug is not the canonical one if canonical_slug != slug: params = { self.pk_url_kwarg: obj.pk, @@ -102,17 +103,14 @@ def dispatch(self, request, *args, **kwargs): def get_canonical_slug(self): """ - Override this method to customize what slug should be considered - canonical. - - Alternatively, define the get_canonical_slug method on this view's - object class. In that case, this method will never be called. + Provide a method to return the correct slug for this object. """ return self.get_object().slug class AllVerbsMixin: - """Call a single method for all HTTP verbs. + """ + Call a single method for all HTTP verbs. The name of the method should be specified using the class attribute `all_handler`. The default value of this attribute is 'all'. @@ -192,6 +190,7 @@ def get_cachecontrol_options(cls): @classmethod def as_view(cls, *args, **kwargs): + """Wrap the view with appropriate cache controls""" view_func = super().as_view(*args, **kwargs) options = cls.get_cachecontrol_options() return cache_control(**options)(view_func) @@ -204,5 +203,8 @@ class NeverCacheMixin: """ @classmethod def as_view(cls, *args, **kwargs): + """ + Wrap the view with the `never_cache` decorator. + """ view_func = super().as_view(*args, **kwargs) return never_cache(view_func) diff --git a/conftest.py b/conftest.py index a0f582b..a3ce09d 100644 --- a/conftest.py +++ b/conftest.py @@ -4,5 +4,6 @@ def pytest_configure(): + """Setup Django settings""" os.environ.setdefault("DJANGO_SETTINGS_MODULE", "tests.settings") settings.configure(default_settings=test_settings) diff --git a/setup.py b/setup.py index 93f4e3a..7b2901f 100644 --- a/setup.py +++ b/setup.py @@ -6,6 +6,7 @@ def _add_default(m): + """Add on a default""" attr_name, attr_value = m.groups() return ((attr_name, attr_value.strip("\"'")),) diff --git a/tests/factories.py b/tests/factories.py index cf6de5a..dd4a4ea 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -18,8 +18,9 @@ def _get_perm(perm_name): class ArticleFactory(factory.django.DjangoModelFactory): - title = factory.Sequence(lambda n: "Article number {0}".format(n)) - body = factory.Sequence(lambda n: "Body of article {0}".format(n)) + """Generates Articles""" + title = factory.Sequence(lambda n: f"Article number {n}") + body = factory.Sequence(lambda n: "Body of article {n}") class Meta: model = Article @@ -27,7 +28,8 @@ class Meta: class GroupFactory(factory.django.DjangoModelFactory): - name = factory.Sequence(lambda n: "group{0}".format(n)) + """Artificial divides as a service""" + name = factory.Sequence(lambda n: f"group{n}") class Meta: model = Group @@ -35,10 +37,11 @@ class Meta: class UserFactory(factory.django.DjangoModelFactory): - username = factory.Sequence(lambda n: "user{0}".format(n)) - first_name = factory.Sequence(lambda n: "John {0}".format(n)) - last_name = factory.Sequence(lambda n: "Doe {0}".format(n)) - email = factory.Sequence(lambda n: "user{0}@example.com".format(n)) + """The people who make it all possible""" + username = factory.Sequence(lambda n: f"user{n}") + first_name = factory.Sequence(lambda n: f"John {n}") + last_name = factory.Sequence(lambda n: f"Doe {n}") + email = factory.Sequence(lambda n: f"user{n}@example.com") password = factory.PostGenerationMethodCall("set_password", "asdf1234") class Meta: @@ -47,6 +50,7 @@ class Meta: @factory.post_generation def permissions(self, create, extracted, **kwargs): + """Give the user some permissions""" if create and extracted: # We have a saved object and a list of permission names self.user_permissions.add(*[_get_perm(pn) for pn in extracted]) diff --git a/tests/forms.py b/tests/forms.py index d57bd0e..c1928ba 100644 --- a/tests/forms.py +++ b/tests/forms.py @@ -6,10 +6,12 @@ class FormWithUserKwarg(UserKwargModelFormMixin, forms.Form): + """This form will get a `user` kwarg""" field1 = forms.CharField() class ArticleForm(forms.ModelForm): + """This form represents an Article""" class Meta: model = Article fields = ["author", "title", "body", "slug"] diff --git a/tests/models.py b/tests/models.py index c712ec4..4b51165 100644 --- a/tests/models.py +++ b/tests/models.py @@ -2,6 +2,9 @@ class Article(models.Model): + """ + A small but useful model for testing most features + """ author = models.ForeignKey( "auth.User", null=True, blank=True, on_delete=models.CASCADE ) @@ -11,6 +14,9 @@ class Article(models.Model): class CanonicalArticle(models.Model): + """ + Model specifically for testing the canonical slug mixins + """ author = models.ForeignKey( "auth.User", null=True, blank=True, on_delete=models.CASCADE ) @@ -19,6 +25,7 @@ class CanonicalArticle(models.Model): slug = models.SlugField(blank=True) def get_canonical_slug(self): + """Required by mixin to use the model as the source of truth""" if self.author: - return "{0.author.username}-{0.slug}".format(self) - return "unauthored-{0.slug}".format(self) + return f"{self.author.username}-{self.slug}" + return f"unauthored-{self.slug}" diff --git a/tests/test_access_mixins.py b/tests/test_access_mixins.py index fa55352..89f526a 100644 --- a/tests/test_access_mixins.py +++ b/tests/test_access_mixins.py @@ -235,24 +235,25 @@ def test_redirect_unauthenticated_false(self): @pytest.mark.django_db class TestLoginRequiredMixin(TestViewHelper, test.TestCase): - """ - Tests for LoginRequiredMixin. - """ + """Scenarios around requiring an authenticated session""" view_class = LoginRequiredView view_url = "/login_required/" def test_anonymous(self): + """Anonymous users should be redirected""" resp = self.client.get(self.view_url) self.assertRedirects(resp, "/accounts/login/?next=/login_required/") def test_anonymous_raises_exception(self): + """Anonymous users should raise an exception""" with self.assertRaises(PermissionDenied): self.dispatch_view( self.build_request(path=self.view_url), raise_exception=True ) def test_authenticated(self): + """Authenticated users should get 'OK'""" user = UserFactory() self.client.login(username=user.username, password="asdf1234") resp = self.client.get(self.view_url) @@ -260,6 +261,7 @@ def test_authenticated(self): assert force_str(resp.content) == "OK" def test_anonymous_redirects(self): + """Anonymous users are redirected with a 302""" resp = self.dispatch_view( self.build_request(path=self.view_url), raise_exception=True, @@ -361,7 +363,7 @@ def test_anonymous(self): def test_authenticated(self): """ Check that the authenticated user has been successfully directed - to the approparite view. + to the appropriate view. """ user = UserFactory() self.client.login(username=user.username, password="asdf1234") @@ -372,6 +374,7 @@ def test_authenticated(self): self.assertRedirects(resp, "/authenticated_view/") def test_no_url(self): + """View should raise an exception if no URL is provided""" self.view_class.authenticated_redirect_url = None user = UserFactory() self.client.login(username=user.username, password="asdf1234") @@ -379,6 +382,7 @@ def test_no_url(self): self.client.get(self.view_url) def test_bad_url(self): + """Redirection can be misconfigured""" self.view_class.authenticated_redirect_url = "/epicfailurl/" user = UserFactory() self.client.login(username=user.username, password="asdf1234") @@ -388,17 +392,17 @@ def test_bad_url(self): @pytest.mark.django_db class TestPermissionRequiredMixin(_TestAccessBasicsMixin, test.TestCase): - """ - Tests for PermissionRequiredMixin. - """ + """Scenarios around requiring a permission""" view_class = PermissionRequiredView view_url = "/permission_required/" def build_authorized_user(self): + """Create a user with permissions""" return UserFactory(permissions=["auth.add_user"]) def build_unauthorized_user(self): + """Create a user without permissions""" return UserFactory() def test_invalid_permission(self): @@ -414,10 +418,12 @@ def test_invalid_permission(self): class TestMultiplePermissionsRequiredMixin( _TestAccessBasicsMixin, test.TestCase ): + """Scenarios around requiring multiple permissions""" view_class = MultiplePermissionsRequiredView view_url = "/multiple_permissions_required/" def build_authorized_user(self): + """Get a user with permissions""" return UserFactory( permissions=[ "tests.add_article", @@ -427,6 +433,7 @@ def build_authorized_user(self): ) def build_unauthorized_user(self): + """Get a user without the important permissions""" return UserFactory(permissions=["tests.add_article"]) def test_redirects_to_login(self): @@ -530,49 +537,61 @@ def test_any_permissions_key(self): @pytest.mark.django_db class TestSuperuserRequiredMixin(_TestAccessBasicsMixin, test.TestCase): + """Scenarios requiring a superuser""" view_class = SuperuserRequiredView view_url = "/superuser_required/" def build_authorized_user(self): + """Make a superuser""" return UserFactory(is_superuser=True, is_staff=True) def build_unauthorized_user(self): + """Make a non-superusers""" return UserFactory() @pytest.mark.django_db class TestStaffuserRequiredMixin(_TestAccessBasicsMixin, test.TestCase): + """Scenarios requiring a staff user""" view_class = StaffuserRequiredView view_url = "/staffuser_required/" def build_authorized_user(self): + """Hire a user""" return UserFactory(is_staff=True) def build_unauthorized_user(self): + """Get a customer""" return UserFactory() @pytest.mark.django_db class TestGroupRequiredMixin(_TestAccessBasicsMixin, test.TestCase): + """Scenarios requiring membership in a certain group""" + view_class = GroupRequiredView view_url = "/group_required/" def build_authorized_user(self): + """Get a user with the right group""" user = UserFactory() group = GroupFactory(name="test_group") user.groups.add(group) return user def build_superuser(self): + """Get a superuser""" user = UserFactory() user.is_superuser = True user.save() return user def build_unauthorized_user(self): + """Just a normal users, not super and no groups""" return UserFactory() def test_with_string(self): + """A group name as a string should restrict access""" self.assertEqual("test_group", self.view_class.group_required) user = self.build_authorized_user() self.client.login(username=user.username, password="asdf1234") @@ -581,6 +600,7 @@ def test_with_string(self): self.assertEqual("OK", force_str(resp.content)) def test_with_group_list(self): + """A list of group names should restrict access""" group_list = ["test_group", "editors"] # the test client will instantiate a new view on request, so we have to # modify the class variable (and restore it when the test finished) @@ -595,6 +615,7 @@ def test_with_group_list(self): self.assertEqual("test_group", self.view_class.group_required) def test_superuser_allowed(self): + """Superusers should always be allowed, regardless of group rules""" user = self.build_superuser() self.client.login(username=user.username, password="asdf1234") resp = self.client.get(self.view_url) @@ -602,6 +623,7 @@ def test_superuser_allowed(self): self.assertEqual("OK", force_str(resp.content)) def test_improperly_configured(self): + """No group(s) specified should raise ImproperlyConfigured""" view = self.view_class() view.group_required = None with self.assertRaises(ImproperlyConfigured): @@ -612,6 +634,7 @@ def test_improperly_configured(self): view.get_group_required() def test_with_unicode(self): + """Unicode in group names should restrict access""" self.view_class.group_required = "niño" self.assertEqual("niño", self.view_class.group_required) @@ -631,6 +654,7 @@ def test_with_unicode(self): @pytest.mark.django_db class TestUserPassesTestMixin(_TestAccessBasicsMixin, test.TestCase): + """Scenarios requiring a user to pass a test""" view_class = UserPassesTestView view_url = "/user_passes_test/" view_not_implemented_class = UserPassesTestNotImplementedView @@ -638,14 +662,17 @@ class TestUserPassesTestMixin(_TestAccessBasicsMixin, test.TestCase): # for testing with passing and not passsing func_test def build_authorized_user(self, is_superuser=False): + """Get a test-passing user""" return UserFactory( is_superuser=is_superuser, is_staff=True, email="user@mydomain.com" ) def build_unauthorized_user(self): + """Get a blank user""" return UserFactory() def test_with_user_pass(self): + """Valid username and password should pass the test""" user = self.build_authorized_user() self.client.login(username=user.username, password="asdf1234") resp = self.client.get(self.view_url) @@ -654,6 +681,7 @@ def test_with_user_pass(self): self.assertEqual("OK", force_str(resp.content)) def test_with_user_not_pass(self): + """A failing user should be redirected""" user = self.build_authorized_user(is_superuser=True) self.client.login(username=user.username, password="asdf1234") resp = self.client.get(self.view_url) @@ -661,12 +689,14 @@ def test_with_user_not_pass(self): self.assertRedirects(resp, "/accounts/login/?next=/user_passes_test/") def test_with_user_raise_exception(self): + """PermissionDenied should be raised""" with self.assertRaises(PermissionDenied): self.dispatch_view( self.build_request(path=self.view_url), raise_exception=True ) def test_not_implemented(self): + """NotImplemented should be raised""" view = self.view_not_implemented_class() with self.assertRaises(NotImplementedError): view.dispatch( @@ -677,11 +707,13 @@ def test_not_implemented(self): @pytest.mark.django_db class TestSSLRequiredMixin(test.TestCase): + """Scenarios around requiring SSL""" view_class = SSLRequiredView view_url = "/sslrequired/" def test_ssl_redirection(self): - self.view_url = "https://testserver" + self.view_url + """Should redirect if not SSL""" + self.view_url = f"https://testserver{self.view_url}" self.view_class.raise_exception = False resp = self.client.get(self.view_url) self.assertRedirects(resp, self.view_url, status_code=301) @@ -690,17 +722,20 @@ def test_ssl_redirection(self): self.assertEqual("https", resp.request.get("wsgi.url_scheme")) def test_raises_exception(self): + """Should return 404""" self.view_class.raise_exception = True resp = self.client.get(self.view_url) self.assertEqual(404, resp.status_code) @override_settings(DEBUG=True) def test_debug_bypasses_redirect(self): + """Debug mode should not require SSL""" self.view_class.raise_exception = False resp = self.client.get(self.view_url) self.assertEqual(200, resp.status_code) def test_https_does_not_redirect(self): + """SSL requests should not redirect""" self.view_class.raise_exception = False resp = self.client.get(self.view_url, secure=True) self.assertEqual(200, resp.status_code) @@ -709,15 +744,14 @@ def test_https_does_not_redirect(self): @pytest.mark.django_db class TestRecentLoginRequiredMixin(test.TestCase): - """ - Tests for RecentLoginRequiredMixin. - """ + """ Scenarios requiring a recent login """ view_class = RecentLoginRequiredView recent_view_url = "/recent_login/" outdated_view_url = "/outdated_login/" def test_recent_login(self): + """A recent login should get a 200""" self.view_class.max_last_login_delta = 1800 last_login = datetime.datetime.now() last_login = make_aware(last_login, get_current_timezone()) @@ -728,6 +762,7 @@ def test_recent_login(self): assert force_str(resp.content) == "OK" def test_outdated_login(self): + """An outdated login should get a 302""" self.view_class.max_last_login_delta = 0 last_login = datetime.datetime.now() - datetime.timedelta(hours=2) last_login = make_aware(last_login, get_current_timezone()) @@ -737,8 +772,8 @@ def test_outdated_login(self): assert resp.status_code == 302 def test_not_logged_in(self): + """Anonymous requests should be handled appropriately""" last_login = datetime.datetime.now() last_login = make_aware(last_login, get_current_timezone()) - user = UserFactory(last_login=last_login) resp = self.client.get(self.recent_view_url) assert resp.status_code != 200 diff --git a/tests/urls_namespaced.py b/tests/urls_namespaced.py index 34b9d7a..d4f39ea 100644 --- a/tests/urls_namespaced.py +++ b/tests/urls_namespaced.py @@ -1,9 +1,9 @@ -from django.urls import include, re_path +from django.urls import re_path from . import views + urlpatterns = [ - # CanonicalSlugDetailMixin namespace tests re_path( r"^article/(?P\d+)-(?P[\w-]+)/$", views.CanonicalSlugDetailView.as_view(), diff --git a/tests/views.py b/tests/views.py index 46b1160..8a8e054 100644 --- a/tests/views.py +++ b/tests/views.py @@ -26,15 +26,19 @@ class OkView(View): """ def get(self, request): + """Everything is going to be OK""" return HttpResponse("OK") def post(self, request): + """Get it?""" return self.get(request) def put(self, request): + """Get it?""" return self.get(request) def delete(self, request): + """Get it?""" return self.get(request) @@ -67,15 +71,19 @@ class AjaxResponseView(views.AjaxResponseMixin, OkView): """ def get_ajax(self, request): + """Everything will eventually be OK""" return HttpResponse("AJAX_OK") def post_ajax(self, request): + """Get it?""" return self.get_ajax(request) def put_ajax(self, request): + """Get it?""" return self.get_ajax(request) def delete_ajax(self, request): + """Get it?""" return self.get_ajax(request) @@ -85,6 +93,7 @@ class SimpleJsonView(views.JSONResponseMixin, View): """ def get(self, request): + """Send back some JSON""" object = {"username": request.user.username} return self.render_json_response(object) @@ -98,6 +107,7 @@ class CustomJsonEncoderView(views.JSONResponseMixin, View): json_encoder_class = SetJSONEncoder def get(self, request): + """Send back some JSON""" object = {"numbers": set([1, 2, 3])} return self.render_json_response(object) @@ -109,6 +119,7 @@ class SimpleJsonBadRequestView(views.JSONResponseMixin, View): """ def get(self, request): + """Send back some JSON""" object = {"username": request.user.username} return self.render_json_response(object, status=400) @@ -120,6 +131,7 @@ class ArticleListJsonView(views.JSONResponseMixin, View): """ def get(self, request): + """Send back some JSON""" queryset = Article.objects.all() return self.render_json_object_response(queryset, fields=("title",)) @@ -130,6 +142,7 @@ class JsonRequestResponseView(views.JsonRequestResponseMixin, View): """ def post(self, request): + """Send back some JSON""" return self.render_json_response(self.request_json) @@ -142,6 +155,7 @@ class JsonBadRequestView(views.JsonRequestResponseMixin, View): require_json = True def post(self, request, *args, **kwargs): + """Send back some JSON""" return self.render_json_response(self.request_json) @@ -152,6 +166,7 @@ class JsonCustomBadRequestView(views.JsonRequestResponseMixin, View): """ def post(self, request, *args, **kwargs): + """Handle the POST request""" if not self.request_json: return self.render_bad_request_response({"error": "you messed up"}) return self.render_json_response(self.request_json) @@ -229,7 +244,8 @@ class FormWithUserKwargView(views.UserFormKwargsMixin, FormView): template_name = "form.html" def form_valid(self, form): - return HttpResponse("username: %s" % form.user.username) + """A simple response to watch for""" + return HttpResponse(f"username: {form.user.username}") class HeadlineView(views.SetHeadlineMixin, TemplateView): @@ -265,6 +281,7 @@ class DynamicHeadlineView(views.SetHeadlineMixin, TemplateView): template_name = "blank.html" def get_headline(self): + """Return the headline passed in via kwargs""" return self.kwargs["s"] @@ -286,24 +303,26 @@ class MultiplePermissionsRequiredView( class SuperuserRequiredView(views.SuperuserRequiredMixin, OkView): - pass + """Require a superuser""" class StaffuserRequiredView(views.StaffuserRequiredMixin, OkView): - pass + """Require a user marked as `is_staff`""" class CsrfExemptView(views.CsrfExemptMixin, OkView): - pass + """Ignore CSRF""" class AuthorDetailView(views.PrefetchRelatedMixin, ListView): + """A basic detail view to test prefetching""" model = User prefetch_related = ["article_set"] template_name = "blank.html" class OrderableListView(views.OrderableListMixin, ListView): + """A basic list view to test ordering the output""" model = Article orderable_columns = ( "id", @@ -313,23 +332,23 @@ class OrderableListView(views.OrderableListMixin, ListView): class CanonicalSlugDetailView(views.CanonicalSlugDetailMixin, DetailView): + """A basic detail view to test a canonical slug""" model = Article template_name = "blank.html" -class OverriddenCanonicalSlugDetailView( - views.CanonicalSlugDetailMixin, DetailView -): +class OverriddenCanonicalSlugDetailView(views.CanonicalSlugDetailMixin, DetailView): + """A basic detail view to test an overridden slug""" model = Article template_name = "blank.html" def get_canonical_slug(self): + """Give back a different, encoded slug. My slug secrets are safe""" return codecs.encode(self.get_object().slug, "rot_13") -class CanonicalSlugDetailCustomUrlKwargsView( - views.CanonicalSlugDetailMixin, DetailView -): +class CanonicalSlugDetailCustomUrlKwargsView(views.CanonicalSlugDetailMixin, DetailView): + """A basic detail view to test a slug with custom URL stuff""" model = Article template_name = "blank.html" pk_url_kwarg = "my_pk" @@ -337,11 +356,13 @@ class CanonicalSlugDetailCustomUrlKwargsView( class ModelCanonicalSlugDetailView(views.CanonicalSlugDetailMixin, DetailView): + """A basic detail view to test a model with a canonical slug""" model = CanonicalArticle template_name = "blank.html" class FormMessagesView(views.FormMessagesMixin, CreateView): + """A basic form view to test valid/invalid messages""" form_class = ArticleForm form_invalid_message = _("Invalid") form_valid_message = _("Valid") @@ -351,10 +372,12 @@ class FormMessagesView(views.FormMessagesMixin, CreateView): class GroupRequiredView(views.GroupRequiredMixin, OkView): + """Is everything OK in this group?""" group_required = "test_group" class UserPassesTestView(views.UserPassesTestMixin, OkView): + """Did I pass a test?""" def test_func(self, user): return ( user.is_staff @@ -366,6 +389,7 @@ def test_func(self, user): class UserPassesTestLoginRequiredView( views.LoginRequiredMixin, views.UserPassesTestMixin, OkView ): + """Am I logged in _and_ passing a test?""" def test_func(self, user): return ( user.is_staff @@ -375,15 +399,18 @@ def test_func(self, user): class UserPassesTestNotImplementedView(views.UserPassesTestMixin, OkView): + """The test went missing?""" pass class AllVerbsView(views.AllVerbsMixin, View): + """I know, like, all the verbs""" def all(self, request, *args, **kwargs): return HttpResponse("All verbs return this!") class SSLRequiredView(views.SSLRequiredMixin, OkView): + """Speak friend and enter""" pass @@ -394,6 +421,7 @@ class RecentLoginRequiredView(views.RecentLoginRequiredMixin, OkView): class AttributeHeaderView(views.HeaderMixin, OkView): + """Set headers in an attribute w/o a template render class""" headers = { "X-DJANGO-BRACES-1": 1, "X-DJANGO-BRACES-2": 2, @@ -401,6 +429,7 @@ class AttributeHeaderView(views.HeaderMixin, OkView): class MethodHeaderView(views.HeaderMixin, OkView): + """Set headers in a method w/o a template render class""" def get_headers(self, request): return { "X-DJANGO-BRACES-1": 1, @@ -409,6 +438,7 @@ def get_headers(self, request): class AuxiliaryHeaderView(View): + """A view with a header already set""" def dispatch(self, request, *args, **kwargs): response = HttpResponse("OK with headers") response["X-DJANGO-BRACES-EXISTING"] = "value" @@ -416,12 +446,14 @@ def dispatch(self, request, *args, **kwargs): class ExistingHeaderView(views.HeaderMixin, AuxiliaryHeaderView): + """A view trying to override a parent's header""" headers = { 'X-DJANGO-BRACES-EXISTING': 'other value' } class CacheControlPublicView(views.CacheControlMixin, OkView): + """A public-cached page with a 60 second timeout""" cachecontrol_public = True cachecontrol_max_age = 60 From 82c65d66ac3e64ae898e559a0e600780a098c262 Mon Sep 17 00:00:00 2001 From: Kenneth Love Date: Wed, 17 Nov 2021 12:39:09 -0800 Subject: [PATCH 2/2] An alternate way to handle custom headers This uses a new TemplateResponse which will let us more cleanly inject headers into the response. --- braces/views/_other.py | 26 +++++++++++++++++++++++--- tests/test_other_mixins.py | 4 ++++ tests/urls.py | 1 + tests/views.py | 12 ++++++++++++ 4 files changed, 40 insertions(+), 3 deletions(-) diff --git a/braces/views/_other.py b/braces/views/_other.py index caa99d1..abd9ad4 100644 --- a/braces/views/_other.py +++ b/braces/views/_other.py @@ -1,5 +1,6 @@ from django.core.exceptions import ImproperlyConfigured from django.shortcuts import redirect +from django.template.response import TemplateResponse from django.views.decorators.cache import cache_control, never_cache from django.urls import resolve from django.utils.encoding import force_str @@ -128,6 +129,13 @@ def dispatch(self, request, *args, **kwargs): return handler(request, *args, **kwargs) +class CustomHeadersTemplateResponse(TemplateResponse): + def __init__(self, *args, **kwargs): + headers = kwargs.pop("headers", {}) + headers.update(kwargs.pop("extra_headers", {})) + super().__init__(*args, headers=headers, **kwargs) + + class HeaderMixin: """ Add extra HTTP headers to a response by specifying them in the @@ -135,10 +143,20 @@ class HeaderMixin: """ headers = {} + response_class = CustomHeadersTemplateResponse def get_headers(self, request): return self.headers + def render_to_response(self, context, **response_kwargs): + """ + Return a response, using the `response_class` for this view, with a + template rendered with the given context. + Pass response_kwargs to the constructor of the response class. + """ + response_kwargs.setdefault("extra_headers", self.get_headers(self.request)) + return super().render_to_response(context, **response_kwargs) + def dispatch(self, request, *args, **kwargs): """ Override this method to customize the way additional headers are @@ -146,9 +164,11 @@ def dispatch(self, request, *args, **kwargs): ``.items()`` method. """ response = super().dispatch(request, *args, **kwargs) - for key, value in self.get_headers(request).items(): - if key not in response: - response[key] = value + if not getattr(self, "template_name", None): + # No template so probably no `render_to_response` call + for key, value in self.get_headers(request).items(): + if key not in response: + response[key] = value return response diff --git a/tests/test_other_mixins.py b/tests/test_other_mixins.py index df5ab81..67c380f 100644 --- a/tests/test_other_mixins.py +++ b/tests/test_other_mixins.py @@ -824,6 +824,10 @@ def test_existing(self): response = self.client.get('/headers/existing/') self.assertEqual(response['X-DJANGO-BRACES-EXISTING'], 'value') + def test_template(self): + response = self.client.get('/headers/template/') + self.assertEqual(response['X-DJANGO-BRACES-1'], '1') + class TestCacheControlMixin(test.TestCase): def test_cachecontrol_public(self): response = self.client.get('/cachecontrol/public/') diff --git a/tests/urls.py b/tests/urls.py index bbccef0..6bef3f3 100644 --- a/tests/urls.py +++ b/tests/urls.py @@ -118,6 +118,7 @@ re_path(r'^headers/attribute/$', views.AttributeHeaderView.as_view()), re_path(r'^headers/method/$', views.MethodHeaderView.as_view()), re_path(r'^headers/existing/$', views.ExistingHeaderView.as_view()), + re_path(r'^headers/template/$', views.HeadersWithTemplate.as_view()), # CacheControlMixin tests re_path(r'^cachecontrol/public/$', views.CacheControlPublicView.as_view()), # NeverCacheMixin tests diff --git a/tests/views.py b/tests/views.py index 8a8e054..11c3810 100644 --- a/tests/views.py +++ b/tests/views.py @@ -451,6 +451,18 @@ class ExistingHeaderView(views.HeaderMixin, AuxiliaryHeaderView): 'X-DJANGO-BRACES-EXISTING': 'other value' } +class HeadersWithTemplate(views.SetHeadlineMixin, views.HeaderMixin, TemplateView): + """ + View for testing HeaderMixin with a custom TemplateResponse. + """ + + template_name = "blank.html" + headline = "Test headline" + + headers = { + "X-DJANGO-BRACES-1": 1 + } + class CacheControlPublicView(views.CacheControlMixin, OkView): """A public-cached page with a 60 second timeout"""