diff --git a/src/api/tests/common.py b/src/api/tests/common.py index 122965ae8..1a8c32526 100644 --- a/src/api/tests/common.py +++ b/src/api/tests/common.py @@ -49,3 +49,17 @@ def less_console_noise(): handler.setStream(restore[handler.name]) # close the file we opened devnull.close() + + +def less_console_noise_decorator(func): + """ + Decorator to silence console logging using the less_console_noise() function. + """ + + # "Wrap" the original function in the less_console_noise with clause, + # then just return this wrapper. + def wrapper(*args, **kwargs): + with less_console_noise(): + return func(*args, **kwargs) + + return wrapper diff --git a/src/djangooidc/views.py b/src/djangooidc/views.py index 2d3c842d2..8e112769b 100644 --- a/src/djangooidc/views.py +++ b/src/djangooidc/views.py @@ -6,12 +6,13 @@ from django.conf import settings from django.contrib.auth import logout as auth_logout from django.contrib.auth import authenticate, login from django.http import HttpResponseRedirect -from django.shortcuts import redirect, render +from django.shortcuts import redirect from urllib.parse import parse_qs, urlencode from djangooidc.oidc import Client from djangooidc import exceptions as o_e from registrar.models import User +from registrar.views.utility.error_views import custom_500_error_view, custom_401_error_view logger = logging.getLogger(__name__) @@ -49,27 +50,19 @@ def error_page(request, error): """Display a sensible message and log the error.""" logger.error(error) if isinstance(error, o_e.AuthenticationFailed): - return render( - request, - "401.html", - context={ - "friendly_message": error.friendly_message, - "log_identifier": error.locator, - }, - status=401, - ) + context = { + "friendly_message": error.friendly_message, + "log_identifier": error.locator, + } + return custom_401_error_view(request, context) if isinstance(error, o_e.InternalError): - return render( - request, - "500.html", - context={ - "friendly_message": error.friendly_message, - "log_identifier": error.locator, - }, - status=500, - ) + context = { + "friendly_message": error.friendly_message, + "log_identifier": error.locator, + } + return custom_500_error_view(request, context) if isinstance(error, Exception): - return render(request, "500.html", status=500) + return custom_500_error_view(request) def openid(request): diff --git a/src/registrar/config/urls.py b/src/registrar/config/urls.py index 9049d718c..c743aed0c 100644 --- a/src/registrar/config/urls.py +++ b/src/registrar/config/urls.py @@ -149,6 +149,18 @@ urlpatterns = [ ), ] +# Djangooidc strips out context data from that context, so we define a custom error +# view through this method. +# If Djangooidc is left to its own devices and uses reverse directly, +# then both context and session information will be obliterated due to: + +# a) Djangooidc being out of scope for context_processors +# b) Potential cyclical import errors restricting what kind of data is passable. + +# Rather than dealing with that, we keep everything centralized in one location. +# This way, we can share a view for djangooidc, and other pages as we see fit. +handler500 = "registrar.views.utility.error_views.custom_500_error_view" + # we normally would guard these with `if settings.DEBUG` but tests run with # DEBUG = False even when these apps have been loaded because settings.DEBUG # was actually True. Instead, let's add these URLs any time we are able to diff --git a/src/registrar/tests/test_views.py b/src/registrar/tests/test_views.py index eec12e463..b8055f288 100644 --- a/src/registrar/tests/test_views.py +++ b/src/registrar/tests/test_views.py @@ -1,8 +1,14 @@ from django.test import Client, TestCase, override_settings from django.contrib.auth import get_user_model -from .common import MockEppLib # type: ignore +from api.tests.common import less_console_noise_decorator +from registrar.models.domain import Domain +from registrar.models.user_domain_role import UserDomainRole +from registrar.views.domain import DomainNameserversView +from .common import MockEppLib # type: ignore +from unittest.mock import patch +from django.urls import reverse from registrar.models import ( DomainRequest, @@ -66,6 +72,7 @@ class TestEnvironmentVariablesEffects(TestCase): def tearDown(self): super().tearDown() + Domain.objects.all().delete() self.user.delete() @override_settings(IS_PRODUCTION=True) @@ -79,3 +86,52 @@ class TestEnvironmentVariablesEffects(TestCase): """Banner on non-prod.""" home_page = self.client.get("/") self.assertContains(home_page, "You are on a test site.") + + def side_effect_raise_value_error(self): + """Side effect that raises a 500 error""" + raise ValueError("Some error") + + @less_console_noise_decorator + @override_settings(IS_PRODUCTION=False) + def test_non_production_environment_raises_500_and_shows_banner(self): + """Tests if the non-prod banner is still shown on a 500""" + fake_domain, _ = Domain.objects.get_or_create(name="igorville.gov") + + # Add a role + fake_role, _ = UserDomainRole.objects.get_or_create( + user=self.user, domain=fake_domain, role=UserDomainRole.Roles.MANAGER + ) + + with patch.object(DomainNameserversView, "get_initial", side_effect=self.side_effect_raise_value_error): + with self.assertRaises(ValueError): + contact_page_500 = self.client.get( + reverse("domain-dns-nameservers", kwargs={"pk": fake_domain.id}), + ) + + # Check that a 500 response is returned + self.assertEqual(contact_page_500.status_code, 500) + + self.assertContains(contact_page_500, "You are on a test site.") + + @less_console_noise_decorator + @override_settings(IS_PRODUCTION=True) + def test_production_environment_raises_500_and_doesnt_show_banner(self): + """Test if the non-prod banner is not shown on production when a 500 is raised""" + + fake_domain, _ = Domain.objects.get_or_create(name="igorville.gov") + + # Add a role + fake_role, _ = UserDomainRole.objects.get_or_create( + user=self.user, domain=fake_domain, role=UserDomainRole.Roles.MANAGER + ) + + with patch.object(DomainNameserversView, "get_initial", side_effect=self.side_effect_raise_value_error): + with self.assertRaises(ValueError): + contact_page_500 = self.client.get( + reverse("domain-dns-nameservers", kwargs={"pk": fake_domain.id}), + ) + + # Check that a 500 response is returned + self.assertEqual(contact_page_500.status_code, 500) + + self.assertNotContains(contact_page_500, "You are on a test site.") diff --git a/src/registrar/views/utility/error_views.py b/src/registrar/views/utility/error_views.py new file mode 100644 index 000000000..48ae628a4 --- /dev/null +++ b/src/registrar/views/utility/error_views.py @@ -0,0 +1,32 @@ +""" +Custom views that allow for error view customization. + +Used as a general handler for 500 errors both coming from the registrar app, but +also the djangooidc app. + +If Djangooidc is left to its own devices and uses reverse directly, +then both context and session information will be obliterated due to: + +a) Djangooidc being out of scope for context_processors +b) Potential cyclical import errors restricting what kind of data is passable. + +Rather than dealing with that, we keep everything centralized in one location. +""" + +from django.shortcuts import render + + +def custom_500_error_view(request, context=None): + """Used to redirect 500 errors to a custom view""" + if context is None: + return render(request, "500.html", status=500) + else: + return render(request, "500.html", context=context, status=500) + + +def custom_401_error_view(request, context=None): + """Used to redirect 401 errors to a custom view""" + if context is None: + return render(request, "401.html", status=401) + else: + return render(request, "401.html", context=context, status=401)