diff --git a/src/djangooidc/tests/__init__.py b/src/djangooidc/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/djangooidc/tests/common.py b/src/djangooidc/tests/common.py new file mode 100644 index 000000000..731a4efb9 --- /dev/null +++ b/src/djangooidc/tests/common.py @@ -0,0 +1,47 @@ +import os +import logging + +def get_handlers(): + """Obtain pointers to all StreamHandlers.""" + handlers = {} + + rootlogger = logging.getLogger() + for h in rootlogger.handlers: + if isinstance(h, logging.StreamHandler): + handlers[h.name] = h + + for logger in logging.Logger.manager.loggerDict.values(): + if not isinstance(logger, logging.PlaceHolder): + for h in logger.handlers: + if isinstance(h, logging.StreamHandler): + handlers[h.name] = h + + return handlers + +def dont_print_garbage(f): + """ + Decorator to place on tests to silence console logging. + + This is helpful on tests which trigger console messages + (such as errors) which are normal and expected. + + It can easily be removed to debug a failing test. + """ + def wrapper(*args, **kwargs): + restore = {} + handlers = get_handlers() + devnull = open(os.devnull, 'w') + + # redirect all the streams + for handler in handlers.values(): + prior = handler.setStream(devnull) + restore[handler.name] = prior + # run the test + result = f(*args, **kwargs) + # restore the streams + for handler in handlers.values(): + handler.setStream(restore[handler.name]) + + return result + + return wrapper diff --git a/src/djangooidc/tests/test_views.py b/src/djangooidc/tests/test_views.py new file mode 100644 index 000000000..9b9605ebb --- /dev/null +++ b/src/djangooidc/tests/test_views.py @@ -0,0 +1,121 @@ +from unittest.mock import patch, Mock + +from django.http import HttpResponse +from django.test import Client, TestCase +from django.urls import reverse + +from .common import dont_print_garbage + +@patch("djangooidc.views.CLIENT", autospec=True) +class ViewsTest(TestCase): + def setUp(self): + self.client = Client() + + def say_hi(*args): + return HttpResponse("Hi") + + def user_info(*args): + return { + "sub": "51234512345123", + "email": "test@example.com", + "first_name": "Testy", + "last_name": "Tester", + "phone": "814564000" + } + + def test_error_page(self, mock_client): + pass + + def test_openid_sets_next(self, mock_client): + # setup + callback_url = reverse("openid_login_callback") + # mock + mock_client.create_authn_request.side_effect = self.say_hi + # test + response = self.client.get(reverse("openid"), {"next": callback_url}) + # assert + session = mock_client.create_authn_request.call_args[0][0] + self.assertEqual(session["next"], callback_url) + self.assertEqual(response.status_code, 200) + self.assertContains(response, "Hi") + + @dont_print_garbage + def test_openid_raises(self, mock_client): + # mock + mock_client.create_authn_request.side_effect = Exception("Test") + # test + response = self.client.get(reverse("openid")) + # assert + self.assertEqual(response.status_code, 500) + self.assertTemplateUsed(response, "500.html") + self.assertIn("Server Error", response.content.decode("utf-8")) + + @dont_print_garbage + def test_login_callback_reads_next(self, mock_client): + # setup + session = self.client.session + session["next"] = reverse("logout") + session.save() + # mock + mock_client.callback.side_effect = self.user_info + # test + response = self.client.get(reverse("openid_login_callback")) + # assert + self.assertEqual(response.status_code, 302) + self.assertEqual(response.url, reverse("logout")) + + @patch("djangooidc.views.authenticate") + @dont_print_garbage + def test_login_callback_raises(self, mock_auth, mock_client): + # mock + mock_client.callback.side_effect = self.user_info + mock_auth.return_value = None + # test + response = self.client.get(reverse("openid_login_callback")) + # assert + self.assertEqual(response.status_code, 401) + self.assertTemplateUsed(response, "401.html") + self.assertIn("Unauthorized", response.content.decode("utf-8")) + + @dont_print_garbage + def test_logout_redirect_url(self, mock_client): + # setup + session = self.client.session + session["id_token_raw"] = "83450234852349" + session["state"] = "7534298229506" + session.save() + # mock + mock_client.callback.side_effect = self.user_info + mock_client.registration_response = { + "post_logout_redirect_uris": ["http://example.com/back"] + } + mock_client.provider_info = { + "end_session_endpoint": "http://example.com/log_me_out" + } + # test + response = self.client.get(reverse("logout")) + # assert + expected = "http://example.com/log_me_out?id_token_hint=83450234852349&state" \ + "=7534298229506&post_logout_redirect_uri=http%3A%2F%2Fexample.com%2Fback" + actual = response.url + self.assertEqual(response.status_code, 302) + self.assertEqual(actual, expected) + + @patch("djangooidc.views.auth_logout") + @dont_print_garbage + def test_logout_always_logs_out(self, mock_logout, _): + # Without additional mocking, logout will always fail. + # Here we test that auth_logout is called regardless + self.client.get(reverse("logout")) + self.assertTrue(mock_logout.called) + + def test_logout_callback_redirects(self, _): + # setup + session = self.client.session + session["next"] = reverse("logout") + session.save() + # test + response = self.client.get(reverse("openid_logout_callback")) + # assert + self.assertEqual(response.status_code, 302) + self.assertEqual(response.url, reverse("logout")) \ No newline at end of file diff --git a/src/djangooidc/views.py b/src/djangooidc/views.py index eeadd4f31..04a50916b 100644 --- a/src/djangooidc/views.py +++ b/src/djangooidc/views.py @@ -116,5 +116,5 @@ def logout(request, next_page=None): def logout_callback(request): """Simple redirection view: after logout, redirect to `next`.""" - next = request.session["next"] if "next" in request.session.keys() else "/" + next = request.session.get("next", "/") return redirect(next) diff --git a/src/registrar/config/settings.py b/src/registrar/config/settings.py index 93293073c..d0de27def 100644 --- a/src/registrar/config/settings.py +++ b/src/registrar/config/settings.py @@ -17,6 +17,7 @@ $ docker-compose exec app python manage.py shell """ import environs +from sys import argv as sys_argv from base64 import b64decode from cfenv import AppEnv # type: ignore from pathlib import Path @@ -49,6 +50,8 @@ env_base_url = env.str("DJANGO_BASE_URL") secret_login_key = b64decode(secret("DJANGO_SECRET_LOGIN_KEY", "")) secret_key = secret("DJANGO_SECRET_KEY") +cli_testing_mode = True if "test" in sys_argv else False + # region: Basic Django Config-----------------------------------------------### # Build paths inside the project like this: BASE_DIR / "subdir". @@ -350,6 +353,12 @@ LOGGING = { "level": "INFO", "propagate": False, }, + # Django's runserver requests + "django.request": { + "handlers": ["django.server"], + "level": "INFO", + "propagate": False, + }, # OpenID Connect logger "oic": { "handlers": ["console"], @@ -366,6 +375,12 @@ LOGGING = { "level": "DEBUG", }, }, + # root logger catches anything, unless + # defined by a more specific logger + "root": { + "handlers": ["console"], + "level": "INFO" + } } # endregion