This commit is contained in:
zandercymatics 2024-03-18 12:28:21 -06:00
parent d046ee8315
commit b4829d650a
No known key found for this signature in database
GPG key ID: FF4636ABEC9682B7
6 changed files with 35 additions and 64 deletions

View file

@ -6,12 +6,13 @@ from django.conf import settings
from django.contrib.auth import logout as auth_logout from django.contrib.auth import logout as auth_logout
from django.contrib.auth import authenticate, login from django.contrib.auth import authenticate, login
from django.http import HttpResponseRedirect from django.http import HttpResponseRedirect
from django.shortcuts import redirect, render from django.shortcuts import redirect
from urllib.parse import parse_qs, urlencode from urllib.parse import parse_qs, urlencode
from djangooidc.oidc import Client from djangooidc.oidc import Client
from djangooidc import exceptions as o_e from djangooidc import exceptions as o_e
from registrar.models import User from registrar.models import User
from registrar.views.utility.error_views import custom_500_error_view, custom_401_error_view
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -49,27 +50,19 @@ def error_page(request, error):
"""Display a sensible message and log the error.""" """Display a sensible message and log the error."""
logger.error(error) logger.error(error)
if isinstance(error, o_e.AuthenticationFailed): if isinstance(error, o_e.AuthenticationFailed):
return render(
request,
"401.html",
context={ context={
"friendly_message": error.friendly_message, "friendly_message": error.friendly_message,
"log_identifier": error.locator, "log_identifier": error.locator,
}, }
status=401, return custom_401_error_view(request, context)
)
if isinstance(error, o_e.InternalError): if isinstance(error, o_e.InternalError):
return render(
request,
"500.html",
context={ context={
"friendly_message": error.friendly_message, "friendly_message": error.friendly_message,
"log_identifier": error.locator, "log_identifier": error.locator,
}, }
status=500, return custom_500_error_view(request, context)
)
if isinstance(error, Exception): if isinstance(error, Exception):
return render(request, "500.html", status=500) return custom_500_error_view(request)
def openid(request): def openid(request):

View file

@ -3,7 +3,7 @@
For more information see: For more information see:
https://docs.djangoproject.com/en/4.0/topics/http/urls/ https://docs.djangoproject.com/en/4.0/topics/http/urls/
""" """
from django.conf.urls import handler500
from django.contrib import admin from django.contrib import admin
from django.urls import include, path from django.urls import include, path
from django.views.generic import RedirectView from django.views.generic import RedirectView
@ -149,6 +149,10 @@ urlpatterns = [
), ),
] ]
# Djangooidc strips out context data from that context, so we define a custom error
# view through this method.
handler500 = "registrar.views.utility.error_views.custom_500_error_view"
# we normally would guard these with `if settings.DEBUG` but tests run with # we normally would guard these with `if settings.DEBUG` but tests run with
# DEBUG = False even when these apps have been loaded because settings.DEBUG # DEBUG = False even when these apps have been loaded because settings.DEBUG
# was actually True. Instead, let's add these URLs any time we are able to # was actually True. Instead, let's add these URLs any time we are able to

View file

@ -325,6 +325,7 @@ class Domain(TimeStampedModel, DomainHelper):
Subordinate hosts (something.your-domain.gov) MUST have IP addresses, Subordinate hosts (something.your-domain.gov) MUST have IP addresses,
while non-subordinate hosts MUST NOT. while non-subordinate hosts MUST NOT.
""" """
raise ValueError("test")
try: try:
# attempt to retrieve hosts from registry and store in cache and db # attempt to retrieve hosts from registry and store in cache and db
hosts = self._get_property("hosts") hosts = self._get_property("hosts")

View file

@ -136,36 +136,3 @@ class TestEnvironmentVariablesEffects(TestCase):
self.assertNotContains(contact_page_500, "You are on a test site.") self.assertNotContains(contact_page_500, "You are on a test site.")
@less_console_noise_decorator
@override_settings(IS_PRODUCTION=False)
def test_non_production_environment_raises_403_and_shows_banner(self):
"""Test if the non-prod banner is shown when a 403 is raised"""
fake_domain, _ = Domain.objects.get_or_create(name="igorville.gov")
# Test navigating to the contact page. Should return a 403,
# but the banner should still appear.
contact_page_403 = self.client.get(
reverse("domain-dns-nameservers", kwargs={"pk": fake_domain.id}),
)
self.assertEqual(contact_page_403.status_code, 403)
self.assertContains(contact_page_403, "You are on a test site.", status_code=403)
@less_console_noise_decorator
@override_settings(IS_PRODUCTION=True)
def test_production_environment_raises_403_and_doesnt_show_banner(self):
"""Test if the non-prod banner is not shown on production when a 403 is raised"""
fake_domain, _ = Domain.objects.get_or_create(name="igorville.gov")
# Test navigating to the contact page. Should return a 403,
# but the banner should still appear.
contact_page_403 = self.client.get(
reverse("domain-dns-nameservers", kwargs={"pk": fake_domain.id}),
)
self.assertEqual(contact_page_403.status_code, 403)
self.assertNotContains(contact_page_403, "You are on a test site.", status_code=403)

View file

@ -0,0 +1,16 @@
"""Custom views that allow for error view customization"""
from django.shortcuts import render
def custom_500_error_view(request, context=None):
"""Used to redirect 500 errors to a custom view"""
if context is None:
return render(request, "500.html", status=500)
else:
return render(request, "500.html", context=context, status=500)
def custom_401_error_view(request, context=None):
"""Used to redirect 401 errors to a custom view"""
if context is None:
return render(request, "401.html", status=401)
else:
return render(request, "401.html", context=context, status=401)

View file

@ -32,16 +32,6 @@ class DomainPermissionView(DomainPermission, DetailView, abc.ABC):
# variable name in template context for the model object # variable name in template context for the model object
context_object_name = "domain" context_object_name = "domain"
def dispatch(self, request, *args, **kwargs):
"""
Custom implementation of dispatch to ensure that 500 error pages (and others)
have access to the IS_PRODUCTION flag
"""
if "IS_PRODUCTION" not in request.session:
# Pass the production flag to the context
request.session["IS_PRODUCTION"] = settings.IS_PRODUCTION
return super().dispatch(request, *args, **kwargs)
# Adds context information for user permissions # Adds context information for user permissions
def get_context_data(self, **kwargs): def get_context_data(self, **kwargs):
context = super().get_context_data(**kwargs) context = super().get_context_data(**kwargs)