Merge pull request #1759 from cisagov/dk/1751-oidc-outages

Issue #1751: Handling of OIDC outages (during startup and during user logins)
This commit is contained in:
dave-kennedy-ecs 2024-02-13 10:42:50 -05:00 committed by GitHub
commit 7ec4b32f88
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 236 additions and 61 deletions

View file

@ -4,13 +4,13 @@ from django.http import HttpResponse
from django.test import Client, TestCase, RequestFactory from django.test import Client, TestCase, RequestFactory
from django.urls import reverse from django.urls import reverse
from djangooidc.exceptions import NoStateDefined from djangooidc.exceptions import NoStateDefined, InternalError
from ..views import login_callback from ..views import login_callback
from .common import less_console_noise from .common import less_console_noise
@patch("djangooidc.views.CLIENT", autospec=True) @patch("djangooidc.views.CLIENT", new_callable=MagicMock)
class ViewsTest(TestCase): class ViewsTest(TestCase):
def setUp(self): def setUp(self):
self.client = Client() self.client = Client()
@ -35,113 +35,252 @@ class ViewsTest(TestCase):
pass pass
def test_openid_sets_next(self, mock_client): def test_openid_sets_next(self, mock_client):
"""Test that the openid method properly sets next in the session."""
with less_console_noise(): with less_console_noise():
# setup # SETUP
# set up the callback url that will be tested in assertions against
# session[next]
callback_url = reverse("openid_login_callback") callback_url = reverse("openid_login_callback")
# mock # MOCK
# when login is called, response from create_authn_request should
# be returned to user, so let's mock it and test it
mock_client.create_authn_request.side_effect = self.say_hi mock_client.create_authn_request.side_effect = self.say_hi
# in this case, we need to mock the get_default_acr_value so that
# openid method will execute properly, but the acr_value itself
# is not important for this test
mock_client.get_default_acr_value.side_effect = self.create_acr mock_client.get_default_acr_value.side_effect = self.create_acr
# test # TEST
# test the login url, passing a callback url
response = self.client.get(reverse("login"), {"next": callback_url}) response = self.client.get(reverse("login"), {"next": callback_url})
# assert # ASSERTIONS
session = mock_client.create_authn_request.call_args[0][0] session = mock_client.create_authn_request.call_args[0][0]
# assert the session[next] is set to the callback_url
self.assertEqual(session["next"], callback_url) self.assertEqual(session["next"], callback_url)
# assert that openid returned properly the response from
# create_authn_request
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertContains(response, "Hi") self.assertContains(response, "Hi")
def test_openid_raises(self, mock_client): def test_openid_raises(self, mock_client):
"""Test that errors in openid raise 500 error for the user.
This test specifically tests for any exceptions that might be raised from
create_authn_request. This includes scenarios where CLIENT exists, but
is no longer functioning properly."""
with less_console_noise(): with less_console_noise():
# mock # MOCK
# when login is called, exception thrown from create_authn_request
# should present 500 error page to user
mock_client.create_authn_request.side_effect = Exception("Test") mock_client.create_authn_request.side_effect = Exception("Test")
# test # TEST
# test when login url is called
response = self.client.get(reverse("login")) response = self.client.get(reverse("login"))
# assert # ASSERTIONS
# assert that the 500 error page is raised
self.assertEqual(response.status_code, 500) self.assertEqual(response.status_code, 500)
self.assertTemplateUsed(response, "500.html") self.assertTemplateUsed(response, "500.html")
self.assertIn("Server error", response.content.decode("utf-8")) self.assertIn("Server error", response.content.decode("utf-8"))
def test_callback_with_no_session_state(self, mock_client): def test_openid_raises_when_client_is_none_and_cant_init(self, mock_client):
"""Test that errors in openid raise 500 error for the user.
This test specifically tests for the condition where the CLIENT
is None and the client initialization attempt raises an exception."""
with less_console_noise():
# MOCK
# mock that CLIENT is None
# mock that Client() raises an exception (by mocking _initialize_client)
# Patch CLIENT to None for this specific test
with patch("djangooidc.views.CLIENT", None):
# Patch _initialize_client() to raise an exception
with patch("djangooidc.views._initialize_client") as mock_init:
mock_init.side_effect = InternalError
# TEST
# test when login url is called
response = self.client.get(reverse("login"))
# ASSERTIONS
# assert that the 500 error page is raised
self.assertEqual(response.status_code, 500)
self.assertTemplateUsed(response, "500.html")
self.assertIn("Server error", response.content.decode("utf-8"))
def test_openid_initializes_client_and_calls_create_authn_request(self, mock_client):
"""Test that openid re-initializes the client when the client had not
been previously initiated."""
with less_console_noise():
# MOCK
# response from create_authn_request should
# be returned to user, so let's mock it and test it
mock_client.create_authn_request.side_effect = self.say_hi
# in this case, we need to mock the get_default_acr_value so that
# openid method will execute properly, but the acr_value itself
# is not important for this test
mock_client.get_default_acr_value.side_effect = self.create_acr
with patch("djangooidc.views._initialize_client") as mock_init_client:
with patch("djangooidc.views._client_is_none") as mock_client_is_none:
# mock the client to initially be None
mock_client_is_none.return_value = True
# TEST
# test when login url is called
response = self.client.get(reverse("login"))
# ASSERTIONS
# assert that _initialize_client was called
mock_init_client.assert_called_once()
# assert that the response is the mocked response from create_authn_request
self.assertEqual(response.status_code, 200)
self.assertContains(response, "Hi")
def test_login_callback_with_no_session_state(self, mock_client):
"""If the local session is None (ie the server restarted while user was logged out), """If the local session is None (ie the server restarted while user was logged out),
we do not throw an exception. Rather, we attempt to login again.""" we do not throw an exception. Rather, we attempt to login again."""
with less_console_noise(): with less_console_noise():
# mock # MOCK
# mock the acr_value to some string
# mock the callback function to raise the NoStateDefined Exception
mock_client.get_default_acr_value.side_effect = self.create_acr mock_client.get_default_acr_value.side_effect = self.create_acr
mock_client.callback.side_effect = NoStateDefined() mock_client.callback.side_effect = NoStateDefined()
# test # TEST
# test the login callback
response = self.client.get(reverse("openid_login_callback")) response = self.client.get(reverse("openid_login_callback"))
# assert # ASSERTIONS
# assert that the user is redirected to the start of the login process
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
self.assertEqual(response.url, "/") self.assertEqual(response.url, "/")
def test_login_callback_reads_next(self, mock_client): def test_login_callback_reads_next(self, mock_client):
"""If the next value is set in the session, test that login_callback returns
a redirect to the 'next' url."""
with less_console_noise(): with less_console_noise():
# setup # SETUP
session = self.client.session session = self.client.session
# set 'next' to the logout url
session["next"] = reverse("logout") session["next"] = reverse("logout")
session.save() session.save()
# mock # MOCK
# mock that callback returns user_info; this is the expected behavior
mock_client.callback.side_effect = self.user_info mock_client.callback.side_effect = self.user_info
# test # patch that the request does not require step up auth
with patch("djangooidc.views.requires_step_up_auth", return_value=False), less_console_noise(): # TEST
# test the login callback url
with patch("djangooidc.views._requires_step_up_auth", return_value=False):
response = self.client.get(reverse("openid_login_callback")) response = self.client.get(reverse("openid_login_callback"))
# assert # ASSERTIONS
# assert the redirect url is the same as the 'next' value set in session
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
self.assertEqual(response.url, reverse("logout")) self.assertEqual(response.url, reverse("logout"))
def test_login_callback_no_step_up_auth(self, mock_client): def test_login_callback_raises_when_client_is_none_and_cant_init(self, mock_client):
"""Walk through login_callback when requires_step_up_auth returns False """Test that errors in login_callback raise 500 error for the user.
and assert that we have a redirect to /""" This test specifically tests for the condition where the CLIENT
is None and the client initialization attempt raises an exception."""
with less_console_noise(): with less_console_noise():
# setup # MOCK
# mock that CLIENT is None
# mock that Client() raises an exception (by mocking _initialize_client)
# Patch CLIENT to None for this specific test
with patch("djangooidc.views.CLIENT", None):
# Patch _initialize_client() to raise an exception
with patch("djangooidc.views._initialize_client") as mock_init:
mock_init.side_effect = InternalError
# TEST
# test the login callback url
response = self.client.get(reverse("openid_login_callback"))
# ASSERTIONS
# assert that the 500 error page is raised
self.assertEqual(response.status_code, 500)
self.assertTemplateUsed(response, "500.html")
self.assertIn("Server error", response.content.decode("utf-8"))
def test_login_callback_initializes_client_and_succeeds(self, mock_client):
"""Test that openid re-initializes the client when the client had not
been previously initiated."""
with less_console_noise():
# SETUP
session = self.client.session session = self.client.session
session.save() session.save()
# mock # MOCK
# mock that callback returns user_info; this is the expected behavior
mock_client.callback.side_effect = self.user_info mock_client.callback.side_effect = self.user_info
# test # patch that the request does not require step up auth
with patch("djangooidc.views.requires_step_up_auth", return_value=False), less_console_noise(): with patch("djangooidc.views._requires_step_up_auth", return_value=False):
with patch("djangooidc.views._initialize_client") as mock_init_client:
with patch("djangooidc.views._client_is_none") as mock_client_is_none:
# mock the client to initially be None
mock_client_is_none.return_value = True
# TEST
# test the login callback url
response = self.client.get(reverse("openid_login_callback"))
# ASSERTIONS
# assert that _initialize_client was called
mock_init_client.assert_called_once()
# assert that redirect is to / when no 'next' is set
self.assertEqual(response.status_code, 302)
self.assertEqual(response.url, "/")
def test_login_callback_no_step_up_auth(self, mock_client):
"""Walk through login_callback when _requires_step_up_auth returns False
and assert that we have a redirect to /"""
with less_console_noise():
# SETUP
session = self.client.session
session.save()
# MOCK
# mock that callback returns user_info; this is the expected behavior
mock_client.callback.side_effect = self.user_info
# patch that the request does not require step up auth
# TEST
# test the login callback url
with patch("djangooidc.views._requires_step_up_auth", return_value=False):
response = self.client.get(reverse("openid_login_callback")) response = self.client.get(reverse("openid_login_callback"))
# assert # ASSERTIONS
# assert that redirect is to / when no 'next' is set
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
self.assertEqual(response.url, "/") self.assertEqual(response.url, "/")
def test_requires_step_up_auth(self, mock_client): def test_login_callback_requires_step_up_auth(self, mock_client):
"""Invoke login_callback passing it a request when requires_step_up_auth returns True """Invoke login_callback passing it a request when _requires_step_up_auth returns True
and assert that session is updated and create_authn_request (mock) is called.""" and assert that session is updated and create_authn_request (mock) is called."""
with less_console_noise(): with less_console_noise():
# MOCK
# Configure the mock to return an expected value for get_step_up_acr_value # Configure the mock to return an expected value for get_step_up_acr_value
mock_client.return_value.get_step_up_acr_value.return_value = "step_up_acr_value" mock_client.return_value.get_step_up_acr_value.return_value = "step_up_acr_value"
# Create a mock request # Create a mock request
request = self.factory.get("/some-url") request = self.factory.get("/some-url")
request.session = {"acr_value": ""} request.session = {"acr_value": ""}
# Ensure that the CLIENT instance used in login_callback is the mock # Ensure that the CLIENT instance used in login_callback is the mock
# patch requires_step_up_auth to return True # patch _requires_step_up_auth to return True
with patch("djangooidc.views.requires_step_up_auth", return_value=True), patch( with patch("djangooidc.views._requires_step_up_auth", return_value=True), patch(
"djangooidc.views.CLIENT.create_authn_request", return_value=MagicMock() "djangooidc.views.CLIENT.create_authn_request", return_value=MagicMock()
) as mock_create_authn_request: ) as mock_create_authn_request:
# TEST
# test the login callback
login_callback(request) login_callback(request)
# create_authn_request only gets called when requires_step_up_auth is True # ASSERTIONS
# create_authn_request only gets called when _requires_step_up_auth is True
# and it changes this acr_value in request.session # and it changes this acr_value in request.session
# Assert that acr_value is no longer empty string # Assert that acr_value is no longer empty string
self.assertNotEqual(request.session["acr_value"], "") self.assertNotEqual(request.session["acr_value"], "")
# And create_authn_request was called again # And create_authn_request was called again
mock_create_authn_request.assert_called_once() mock_create_authn_request.assert_called_once()
def test_does_not_requires_step_up_auth(self, mock_client): def test_login_callback_does_not_requires_step_up_auth(self, mock_client):
"""Invoke login_callback passing it a request when requires_step_up_auth returns False """Invoke login_callback passing it a request when _requires_step_up_auth returns False
and assert that session is not updated and create_authn_request (mock) is not called. and assert that session is not updated and create_authn_request (mock) is not called.
Possibly redundant with test_login_callback_requires_step_up_auth""" Possibly redundant with test_login_callback_requires_step_up_auth"""
with less_console_noise(): with less_console_noise():
# MOCK
# Create a mock request # Create a mock request
request = self.factory.get("/some-url") request = self.factory.get("/some-url")
request.session = {"acr_value": ""} request.session = {"acr_value": ""}
# Ensure that the CLIENT instance used in login_callback is the mock # Ensure that the CLIENT instance used in login_callback is the mock
# patch requires_step_up_auth to return False # patch _requires_step_up_auth to return False
with patch("djangooidc.views.requires_step_up_auth", return_value=False), patch( with patch("djangooidc.views._requires_step_up_auth", return_value=False), patch(
"djangooidc.views.CLIENT.create_authn_request", return_value=MagicMock() "djangooidc.views.CLIENT.create_authn_request", return_value=MagicMock()
) as mock_create_authn_request: ) as mock_create_authn_request:
# TEST
# test the login callback
login_callback(request) login_callback(request)
# create_authn_request only gets called when requires_step_up_auth is True # ASSERTIONS
# create_authn_request only gets called when _requires_step_up_auth is True
# and it changes this acr_value in request.session # and it changes this acr_value in request.session
# Assert that acr_value is NOT updated by testing that it is still an empty string # Assert that acr_value is NOT updated by testing that it is still an empty string
self.assertEqual(request.session["acr_value"], "") self.assertEqual(request.session["acr_value"], "")
@ -150,33 +289,36 @@ class ViewsTest(TestCase):
@patch("djangooidc.views.authenticate") @patch("djangooidc.views.authenticate")
def test_login_callback_raises(self, mock_auth, mock_client): def test_login_callback_raises(self, mock_auth, mock_client):
"""Test that login callback raises a 401 when user is unauthorized"""
with less_console_noise(): with less_console_noise():
# mock # MOCK
# mock that callback returns user_info; this is the expected behavior
mock_client.callback.side_effect = self.user_info mock_client.callback.side_effect = self.user_info
mock_auth.return_value = None mock_auth.return_value = None
# test # TEST
with patch("djangooidc.views.requires_step_up_auth", return_value=False), less_console_noise(): with patch("djangooidc.views._requires_step_up_auth", return_value=False):
response = self.client.get(reverse("openid_login_callback")) response = self.client.get(reverse("openid_login_callback"))
# assert # ASSERTIONS
self.assertEqual(response.status_code, 401) self.assertEqual(response.status_code, 401)
self.assertTemplateUsed(response, "401.html") self.assertTemplateUsed(response, "401.html")
self.assertIn("Unauthorized", response.content.decode("utf-8")) self.assertIn("Unauthorized", response.content.decode("utf-8"))
def test_logout_redirect_url(self, mock_client): def test_logout_redirect_url(self, mock_client):
"""Test that logout redirects to the configured post_logout_redirect_uris."""
with less_console_noise(): with less_console_noise():
# setup # SETUP
session = self.client.session session = self.client.session
session["state"] = "TEST" # nosec B105 session["state"] = "TEST" # nosec B105
session.save() session.save()
# mock # MOCK
mock_client.callback.side_effect = self.user_info mock_client.callback.side_effect = self.user_info
mock_client.registration_response = {"post_logout_redirect_uris": ["http://example.com/back"]} mock_client.registration_response = {"post_logout_redirect_uris": ["http://example.com/back"]}
mock_client.provider_info = {"end_session_endpoint": "http://example.com/log_me_out"} mock_client.provider_info = {"end_session_endpoint": "http://example.com/log_me_out"}
mock_client.client_id = "TEST" mock_client.client_id = "TEST"
# test # TEST
with less_console_noise(): with less_console_noise():
response = self.client.get(reverse("logout")) response = self.client.get(reverse("logout"))
# assert # ASSERTIONS
expected = ( expected = (
"http://example.com/log_me_out?client_id=TEST&state" "http://example.com/log_me_out?client_id=TEST&state"
"=TEST&post_logout_redirect_uri=http%3A%2F%2Fexample.com%2Fback" "=TEST&post_logout_redirect_uri=http%3A%2F%2Fexample.com%2Fback"
@ -187,20 +329,23 @@ class ViewsTest(TestCase):
@patch("djangooidc.views.auth_logout") @patch("djangooidc.views.auth_logout")
def test_logout_always_logs_out(self, mock_logout, _): def test_logout_always_logs_out(self, mock_logout, _):
# Without additional mocking, logout will always fail. """Without additional mocking, logout will always fail.
# Here we test that auth_logout is called regardless Here we test that auth_logout is called regardless"""
# TEST
with less_console_noise(): with less_console_noise():
self.client.get(reverse("logout")) self.client.get(reverse("logout"))
# ASSERTIONS
self.assertTrue(mock_logout.called) self.assertTrue(mock_logout.called)
def test_logout_callback_redirects(self, _): def test_logout_callback_redirects(self, _):
"""Test that the logout_callback redirects properly"""
with less_console_noise(): with less_console_noise():
# setup # SETUP
session = self.client.session session = self.client.session
session["next"] = reverse("logout") session["next"] = reverse("logout")
session.save() session.save()
# test # TEST
response = self.client.get(reverse("openid_logout_callback")) response = self.client.get(reverse("openid_logout_callback"))
# assert # ASSERTIONS
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
self.assertEqual(response.url, reverse("logout")) self.assertEqual(response.url, reverse("logout"))

View file

@ -15,15 +15,34 @@ from registrar.models import User
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try: CLIENT = None
def _initialize_client():
"""Initialize the OIDC client. Exceptions are allowed to raise
and will need to be caught."""
global CLIENT
# Initialize provider using pyOICD # Initialize provider using pyOICD
OP = getattr(settings, "OIDC_ACTIVE_PROVIDER") OP = getattr(settings, "OIDC_ACTIVE_PROVIDER")
CLIENT = Client(OP) CLIENT = Client(OP)
logger.debug("client initialized %s" % CLIENT) logger.debug("Client initialized: %s" % CLIENT)
def _client_is_none():
"""Return if the CLIENT is currently None."""
global CLIENT
return CLIENT is None
# Initialize CLIENT
try:
_initialize_client()
except Exception as err: except Exception as err:
CLIENT = None # type: ignore # In the event of an exception, log the error and allow the app load to continue
logger.warning(err) # without the OIDC Client. Subsequent login attempts will attempt to initialize
logger.warning("Unable to configure OpenID Connect provider. Users cannot log in.") # again if Client is None
logger.error(err)
logger.error("Unable to configure OpenID Connect provider. Users cannot log in.")
def error_page(request, error): def error_page(request, error):
@ -55,12 +74,15 @@ def error_page(request, error):
def openid(request): def openid(request):
"""Redirect the user to an authentication provider (OP).""" """Redirect the user to an authentication provider (OP)."""
# If the session reset because of a server restart, attempt to login again global CLIENT
request.session["acr_value"] = CLIENT.get_default_acr_value()
request.session["next"] = request.GET.get("next", "/")
try: try:
# If the CLIENT is none, attempt to reinitialize before handling the request
if _client_is_none():
logger.debug("OIDC client is None, attempting to initialize")
_initialize_client()
request.session["acr_value"] = CLIENT.get_default_acr_value()
request.session["next"] = request.GET.get("next", "/")
# Create the authentication request
return CLIENT.create_authn_request(request.session) return CLIENT.create_authn_request(request.session)
except Exception as err: except Exception as err:
return error_page(request, err) return error_page(request, err)
@ -68,12 +90,17 @@ def openid(request):
def login_callback(request): def login_callback(request):
"""Analyze the token returned by the authentication provider (OP).""" """Analyze the token returned by the authentication provider (OP)."""
global CLIENT
try: try:
# If the CLIENT is none, attempt to reinitialize before handling the request
if _client_is_none():
logger.debug("OIDC client is None, attempting to initialize")
_initialize_client()
query = parse_qs(request.GET.urlencode()) query = parse_qs(request.GET.urlencode())
userinfo = CLIENT.callback(query, request.session) userinfo = CLIENT.callback(query, request.session)
# test for need for identity verification and if it is satisfied # test for need for identity verification and if it is satisfied
# if not satisfied, redirect user to login with stepped up acr_value # if not satisfied, redirect user to login with stepped up acr_value
if requires_step_up_auth(userinfo): if _requires_step_up_auth(userinfo):
# add acr_value to request.session # add acr_value to request.session
request.session["acr_value"] = CLIENT.get_step_up_acr_value() request.session["acr_value"] = CLIENT.get_step_up_acr_value()
return CLIENT.create_authn_request(request.session) return CLIENT.create_authn_request(request.session)
@ -86,13 +113,16 @@ def login_callback(request):
else: else:
raise o_e.BannedUser() raise o_e.BannedUser()
except o_e.NoStateDefined as nsd_err: except o_e.NoStateDefined as nsd_err:
# In the event that a user is in the middle of a login when the app is restarted,
# their session state will no longer be available, so redirect the user to the
# beginning of login process without raising an error to the user.
logger.warning(f"No State Defined: {nsd_err}") logger.warning(f"No State Defined: {nsd_err}")
return redirect(request.session.get("next", "/")) return redirect(request.session.get("next", "/"))
except Exception as err: except Exception as err:
return error_page(request, err) return error_page(request, err)
def requires_step_up_auth(userinfo): def _requires_step_up_auth(userinfo):
"""if User.needs_identity_verification and step_up_acr_value not in """if User.needs_identity_verification and step_up_acr_value not in
ial returned from callback, return True""" ial returned from callback, return True"""
step_up_acr_value = CLIENT.get_step_up_acr_value() step_up_acr_value = CLIENT.get_step_up_acr_value()