Address PR feedback

This commit is contained in:
Seamus Johnston 2022-09-30 14:14:43 -05:00
parent b5ea6c8e39
commit 828051854e
No known key found for this signature in database
GPG key ID: 2F21225985069105
4 changed files with 26 additions and 28 deletions

View file

@ -1,6 +1,7 @@
import os
import logging
from contextlib import contextmanager
def get_handlers():
"""Obtain pointers to all StreamHandlers."""
@ -20,31 +21,29 @@ def get_handlers():
return handlers
def dont_print_garbage(f):
@contextmanager
def less_console_noise():
"""
Decorator to place on tests to silence console logging.
Context manager to use in tests to silence console logging.
This is helpful on tests which trigger console messages
(such as errors) which are normal and expected.
It can easily be removed to debug a failing test.
"""
restore = {}
handlers = get_handlers()
devnull = open(os.devnull, "w")
def wrapper(*args, **kwargs):
restore = {}
handlers = get_handlers()
devnull = open(os.devnull, "w")
# redirect all the streams
for handler in handlers.values():
prior = handler.setStream(devnull)
restore[handler.name] = prior
# redirect all the streams
for handler in handlers.values():
prior = handler.setStream(devnull)
restore[handler.name] = prior
try:
# run the test
result = f(*args, **kwargs)
yield
finally:
# restore the streams
for handler in handlers.values():
handler.setStream(restore[handler.name])
return result
return wrapper

View file

@ -4,7 +4,7 @@ from django.http import HttpResponse
from django.test import Client, TestCase
from django.urls import reverse
from .common import dont_print_garbage
from .common import less_console_noise
@patch("djangooidc.views.CLIENT", autospec=True)
@ -40,18 +40,17 @@ class ViewsTest(TestCase):
self.assertEqual(response.status_code, 200)
self.assertContains(response, "Hi")
@dont_print_garbage
def test_openid_raises(self, mock_client):
# mock
mock_client.create_authn_request.side_effect = Exception("Test")
# test
response = self.client.get(reverse("openid"))
with less_console_noise():
response = self.client.get(reverse("openid"))
# assert
self.assertEqual(response.status_code, 500)
self.assertTemplateUsed(response, "500.html")
self.assertIn("Server Error", response.content.decode("utf-8"))
@dont_print_garbage
def test_login_callback_reads_next(self, mock_client):
# setup
session = self.client.session
@ -60,25 +59,25 @@ class ViewsTest(TestCase):
# mock
mock_client.callback.side_effect = self.user_info
# test
response = self.client.get(reverse("openid_login_callback"))
with less_console_noise():
response = self.client.get(reverse("openid_login_callback"))
# assert
self.assertEqual(response.status_code, 302)
self.assertEqual(response.url, reverse("logout"))
@patch("djangooidc.views.authenticate")
@dont_print_garbage
def test_login_callback_raises(self, mock_auth, mock_client):
# mock
mock_client.callback.side_effect = self.user_info
mock_auth.return_value = None
# test
response = self.client.get(reverse("openid_login_callback"))
with less_console_noise():
response = self.client.get(reverse("openid_login_callback"))
# assert
self.assertEqual(response.status_code, 401)
self.assertTemplateUsed(response, "401.html")
self.assertIn("Unauthorized", response.content.decode("utf-8"))
@dont_print_garbage
def test_logout_redirect_url(self, mock_client):
# setup
session = self.client.session
@ -94,7 +93,8 @@ class ViewsTest(TestCase):
"end_session_endpoint": "http://example.com/log_me_out"
}
# test
response = self.client.get(reverse("logout"))
with less_console_noise():
response = self.client.get(reverse("logout"))
# assert
expected = (
"http://example.com/log_me_out?id_token_hint=TEST&state"
@ -105,11 +105,11 @@ class ViewsTest(TestCase):
self.assertEqual(actual, expected)
@patch("djangooidc.views.auth_logout")
@dont_print_garbage
def test_logout_always_logs_out(self, mock_logout, _):
# Without additional mocking, logout will always fail.
# Here we test that auth_logout is called regardless
self.client.get(reverse("logout"))
with less_console_noise():
self.client.get(reverse("logout"))
self.assertTrue(mock_logout.called)
def test_logout_callback_redirects(self, _):