mirror of
https://github.com/cisagov/manage.get.gov.git
synced 2025-07-23 03:06:01 +02:00
Merge branch 'main' of https://github.com/cisagov/manage.get.gov into rh/3363-uat-1-bug-fixes
This commit is contained in:
commit
854c5ef169
4 changed files with 108 additions and 36 deletions
|
@ -21,49 +21,66 @@ class OpenIdConnectBackend(ModelBackend):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def authenticate(self, request, **kwargs):
|
def authenticate(self, request, **kwargs):
|
||||||
logger.debug("kwargs %s" % kwargs)
|
logger.debug("kwargs %s", kwargs)
|
||||||
user = None
|
|
||||||
if not kwargs or "sub" not in kwargs.keys():
|
if not kwargs or "sub" not in kwargs:
|
||||||
return user
|
return None
|
||||||
|
|
||||||
UserModel = get_user_model()
|
UserModel = get_user_model()
|
||||||
username = self.clean_username(kwargs["sub"])
|
username = self.clean_username(kwargs["sub"])
|
||||||
|
openid_data = self.extract_openid_data(kwargs)
|
||||||
|
|
||||||
# Some OP may actually choose to withhold some information, so we must
|
|
||||||
# test if it is present
|
|
||||||
openid_data = {"last_login": timezone.now()}
|
|
||||||
openid_data["first_name"] = kwargs.get("given_name", "")
|
|
||||||
openid_data["last_name"] = kwargs.get("family_name", "")
|
|
||||||
openid_data["email"] = kwargs.get("email", "")
|
|
||||||
openid_data["phone"] = kwargs.get("phone", "")
|
|
||||||
|
|
||||||
# Note that this could be accomplished in one try-except clause, but
|
|
||||||
# instead we use get_or_create when creating unknown users since it has
|
|
||||||
# built-in safeguards for multiple threads.
|
|
||||||
if getattr(settings, "OIDC_CREATE_UNKNOWN_USER", True):
|
if getattr(settings, "OIDC_CREATE_UNKNOWN_USER", True):
|
||||||
args = {
|
user = self.get_or_create_user(UserModel, username, openid_data, kwargs)
|
||||||
UserModel.USERNAME_FIELD: username,
|
|
||||||
# defaults _will_ be updated, these are not fallbacks
|
|
||||||
"defaults": openid_data,
|
|
||||||
}
|
|
||||||
|
|
||||||
user, created = UserModel.objects.get_or_create(**args)
|
|
||||||
|
|
||||||
if not created:
|
|
||||||
# If user exists, update existing user
|
|
||||||
self.update_existing_user(user, args["defaults"])
|
|
||||||
else:
|
|
||||||
# If user is created, configure the user
|
|
||||||
user = self.configure_user(user, **kwargs)
|
|
||||||
else:
|
else:
|
||||||
try:
|
user = self.get_user_by_username(UserModel, username)
|
||||||
user = UserModel.objects.get_by_natural_key(username)
|
|
||||||
except UserModel.DoesNotExist:
|
if user:
|
||||||
return None
|
user.on_each_login()
|
||||||
# run this callback for a each login
|
|
||||||
user.on_each_login()
|
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
def extract_openid_data(self, kwargs):
|
||||||
|
"""Extract OpenID data from authentication kwargs."""
|
||||||
|
return {
|
||||||
|
"last_login": timezone.now(),
|
||||||
|
"first_name": kwargs.get("given_name", ""),
|
||||||
|
"last_name": kwargs.get("family_name", ""),
|
||||||
|
"email": kwargs.get("email", ""),
|
||||||
|
"phone": kwargs.get("phone", ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_or_create_user(self, UserModel, username, openid_data, kwargs):
|
||||||
|
"""Retrieve user by username or email, or create a new user."""
|
||||||
|
user = self.get_user_by_username(UserModel, username)
|
||||||
|
|
||||||
|
if not user and openid_data["email"]:
|
||||||
|
user = self.get_user_by_email(UserModel, openid_data["email"])
|
||||||
|
if user:
|
||||||
|
# if found by email, update the username
|
||||||
|
setattr(user, UserModel.USERNAME_FIELD, username)
|
||||||
|
|
||||||
|
if not user:
|
||||||
|
user = UserModel.objects.create(**{UserModel.USERNAME_FIELD: username}, **openid_data)
|
||||||
|
return self.configure_user(user, **kwargs)
|
||||||
|
|
||||||
|
self.update_existing_user(user, openid_data)
|
||||||
|
return user
|
||||||
|
|
||||||
|
def get_user_by_username(self, UserModel, username):
|
||||||
|
"""Retrieve user by username."""
|
||||||
|
try:
|
||||||
|
return UserModel.objects.get(**{UserModel.USERNAME_FIELD: username})
|
||||||
|
except UserModel.DoesNotExist:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_user_by_email(self, UserModel, email):
|
||||||
|
"""Retrieve user by email."""
|
||||||
|
try:
|
||||||
|
return UserModel.objects.get(email=email)
|
||||||
|
except UserModel.DoesNotExist:
|
||||||
|
return None
|
||||||
|
|
||||||
def update_existing_user(self, user, kwargs):
|
def update_existing_user(self, user, kwargs):
|
||||||
"""
|
"""
|
||||||
Update user fields without overwriting certain fields.
|
Update user fields without overwriting certain fields.
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
from registrar.models import User
|
from registrar.models import User
|
||||||
|
from api.tests.common import less_console_noise_decorator
|
||||||
from ..backends import OpenIdConnectBackend # Adjust the import path based on your project structure
|
from ..backends import OpenIdConnectBackend # Adjust the import path based on your project structure
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,6 +18,7 @@ class OpenIdConnectBackendTestCase(TestCase):
|
||||||
def tearDown(self) -> None:
|
def tearDown(self) -> None:
|
||||||
User.objects.all().delete()
|
User.objects.all().delete()
|
||||||
|
|
||||||
|
@less_console_noise_decorator
|
||||||
def test_authenticate_with_create_user(self):
|
def test_authenticate_with_create_user(self):
|
||||||
"""Test that authenticate creates a new user if it does not find
|
"""Test that authenticate creates a new user if it does not find
|
||||||
existing user"""
|
existing user"""
|
||||||
|
@ -32,6 +34,7 @@ class OpenIdConnectBackendTestCase(TestCase):
|
||||||
self.assertEqual(user.email, "john.doe@example.com")
|
self.assertEqual(user.email, "john.doe@example.com")
|
||||||
self.assertEqual(user.phone, "123456789")
|
self.assertEqual(user.phone, "123456789")
|
||||||
|
|
||||||
|
@less_console_noise_decorator
|
||||||
def test_authenticate_with_existing_user(self):
|
def test_authenticate_with_existing_user(self):
|
||||||
"""Test that authenticate updates an existing user if it finds one.
|
"""Test that authenticate updates an existing user if it finds one.
|
||||||
For this test, given_name and family_name are supplied"""
|
For this test, given_name and family_name are supplied"""
|
||||||
|
@ -50,6 +53,30 @@ class OpenIdConnectBackendTestCase(TestCase):
|
||||||
self.assertEqual(user.email, "john.doe@example.com")
|
self.assertEqual(user.email, "john.doe@example.com")
|
||||||
self.assertEqual(user.phone, "123456789")
|
self.assertEqual(user.phone, "123456789")
|
||||||
|
|
||||||
|
@less_console_noise_decorator
|
||||||
|
def test_authenticate_with_existing_user_same_email_different_username(self):
|
||||||
|
"""Test that authenticate updates an existing user if it finds one.
|
||||||
|
In this case, match is to an existing record with matching email but
|
||||||
|
a non-matching username. The existing record's username should be udpated.
|
||||||
|
For this test, given_name and family_name are supplied"""
|
||||||
|
# Create an existing user with the same username
|
||||||
|
User.objects.create_user(username="old_username", email="john.doe@example.com")
|
||||||
|
|
||||||
|
# Ensure that the authenticate method updates the existing user
|
||||||
|
user = self.backend.authenticate(request=None, **self.kwargs)
|
||||||
|
self.assertIsNotNone(user)
|
||||||
|
self.assertIsInstance(user, User)
|
||||||
|
|
||||||
|
# Verify that user fields are correctly updated
|
||||||
|
self.assertEqual(user.first_name, "John")
|
||||||
|
self.assertEqual(user.last_name, "Doe")
|
||||||
|
self.assertEqual(user.email, "john.doe@example.com")
|
||||||
|
self.assertEqual(user.phone, "123456789")
|
||||||
|
self.assertEqual(user.username, "test_user")
|
||||||
|
# Assert that a user no longer exists by the old username
|
||||||
|
self.assertFalse(User.objects.filter(username="old_username").exists())
|
||||||
|
|
||||||
|
@less_console_noise_decorator
|
||||||
def test_authenticate_with_existing_user_with_existing_first_last_phone(self):
|
def test_authenticate_with_existing_user_with_existing_first_last_phone(self):
|
||||||
"""Test that authenticate updates an existing user if it finds one.
|
"""Test that authenticate updates an existing user if it finds one.
|
||||||
For this test, given_name and family_name are not supplied.
|
For this test, given_name and family_name are not supplied.
|
||||||
|
@ -79,6 +106,7 @@ class OpenIdConnectBackendTestCase(TestCase):
|
||||||
self.assertEqual(user.email, "john.doe@example.com")
|
self.assertEqual(user.email, "john.doe@example.com")
|
||||||
self.assertEqual(user.phone, "9999999999")
|
self.assertEqual(user.phone, "9999999999")
|
||||||
|
|
||||||
|
@less_console_noise_decorator
|
||||||
def test_authenticate_with_existing_user_different_name_phone(self):
|
def test_authenticate_with_existing_user_different_name_phone(self):
|
||||||
"""Test that authenticate updates an existing user if it finds one.
|
"""Test that authenticate updates an existing user if it finds one.
|
||||||
For this test, given_name and family_name are supplied and overwrite"""
|
For this test, given_name and family_name are supplied and overwrite"""
|
||||||
|
@ -100,6 +128,7 @@ class OpenIdConnectBackendTestCase(TestCase):
|
||||||
self.assertEqual(user.email, "john.doe@example.com")
|
self.assertEqual(user.email, "john.doe@example.com")
|
||||||
self.assertEqual(user.phone, "123456789")
|
self.assertEqual(user.phone, "123456789")
|
||||||
|
|
||||||
|
@less_console_noise_decorator
|
||||||
def test_authenticate_with_unknown_user(self):
|
def test_authenticate_with_unknown_user(self):
|
||||||
"""Test that authenticate returns None when no kwargs are supplied"""
|
"""Test that authenticate returns None when no kwargs are supplied"""
|
||||||
# Ensure that the authenticate method handles the case when the user is not found
|
# Ensure that the authenticate method handles the case when the user is not found
|
||||||
|
|
|
@ -11,7 +11,8 @@ address,
|
||||||
}
|
}
|
||||||
|
|
||||||
h1:not(.usa-alert__heading),
|
h1:not(.usa-alert__heading),
|
||||||
h2:not(.usa-alert__heading),
|
// .module h2 excludes headers in DJA
|
||||||
|
h2:not(.usa-alert__heading, .module h2),
|
||||||
h3:not(.usa-alert__heading),
|
h3:not(.usa-alert__heading),
|
||||||
h4:not(.usa-alert__heading),
|
h4:not(.usa-alert__heading),
|
||||||
h5:not(.usa-alert__heading),
|
h5:not(.usa-alert__heading),
|
||||||
|
|
|
@ -352,12 +352,37 @@ class UserFixture:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_existing_users(users):
|
def _get_existing_users(users):
|
||||||
|
# if users match existing users in db by email address, update the users with the username
|
||||||
|
# from the db. this will prevent duplicate users (with same email) from being added to db.
|
||||||
|
# it is ok to keep the old username in the db because the username will be updated by oidc process during login
|
||||||
|
|
||||||
|
# Extract email addresses from users
|
||||||
|
emails = [user.get("email") for user in users]
|
||||||
|
|
||||||
|
# Fetch existing users by email
|
||||||
|
existing_users_by_email = User.objects.filter(email__in=emails).values_list("email", "username", "id")
|
||||||
|
|
||||||
|
# Create a dictionary to map emails to existing usernames
|
||||||
|
email_to_existing_user = {user[0]: user[1] for user in existing_users_by_email}
|
||||||
|
|
||||||
|
# Update the users list with the usernames from existing users by email
|
||||||
|
for user in users:
|
||||||
|
email = user.get("email")
|
||||||
|
if email and email in email_to_existing_user:
|
||||||
|
user["username"] = email_to_existing_user[email] # Update username with the existing one
|
||||||
|
|
||||||
|
# Get the user identifiers (username, id) for the existing users to query the database
|
||||||
user_identifiers = [(user.get("username"), user.get("id")) for user in users]
|
user_identifiers = [(user.get("username"), user.get("id")) for user in users]
|
||||||
|
|
||||||
|
# Fetch existing users by username or id
|
||||||
existing_users = User.objects.filter(
|
existing_users = User.objects.filter(
|
||||||
username__in=[user[0] for user in user_identifiers] + [user[1] for user in user_identifiers]
|
username__in=[user[0] for user in user_identifiers] + [user[1] for user in user_identifiers]
|
||||||
).values_list("username", "id")
|
).values_list("username", "id")
|
||||||
|
|
||||||
|
# Create sets for usernames and ids that exist
|
||||||
existing_usernames = set(user[0] for user in existing_users)
|
existing_usernames = set(user[0] for user in existing_users)
|
||||||
existing_user_ids = set(user[1] for user in existing_users)
|
existing_user_ids = set(user[1] for user in existing_users)
|
||||||
|
|
||||||
return existing_usernames, existing_user_ids
|
return existing_usernames, existing_user_ids
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue