Merge pull request #1340 from cisagov/za/fix-transition-domain-on-login

Fix transition domain on login
This commit is contained in:
zandercymatics 2023-11-15 10:05:40 -07:00 committed by GitHub
commit 093c34a787
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 71 deletions

View file

@ -5,7 +5,6 @@ from django.db import models
from .domain_invitation import DomainInvitation
from .transition_domain import TransitionDomain
from .domain_information import DomainInformation
from .domain import Domain
from phonenumber_field.modelfields import PhoneNumberField # type: ignore
@ -97,51 +96,6 @@ class User(AbstractUser):
new_domain_invitation = DomainInvitation(email=transition_domain_email.lower(), domain=new_domain)
new_domain_invitation.save()
def check_transition_domains_on_login(self):
"""When a user first arrives on the site, we need to check
if they are logging in with the same e-mail as a
transition domain and update our database accordingly."""
for transition_domain in TransitionDomain.objects.filter(username=self.email):
# Looks like the user logged in with the same e-mail as
# one or more corresponding transition domains.
# Create corresponding DomainInformation objects.
# NOTE: adding an ADMIN user role for this user
# for each domain should already be done
# in the invitation.retrieve() method.
# However, if the migration scripts for transition
# domain objects were not executed correctly,
# there could be transition domains without
# any corresponding Domain & DomainInvitation objects,
# which means the invitation.retrieve() method might
# not execute.
# Check that there is a corresponding domain object
# for this transition domain. If not, we have an error
# with our data and migrations need to be run again.
# Get the domain that corresponds with this transition domain
domain_exists = Domain.objects.filter(name=transition_domain.domain_name).exists()
if not domain_exists:
logger.warn(
"""There are transition domains without
corresponding domain objects!
Please run migration scripts for transition domains
(See data_migration.md)"""
)
# No need to throw an exception...just create a domain
# and domain invite, then proceed as normal
self.create_domain_and_invite(transition_domain)
domain = Domain.objects.get(name=transition_domain.domain_name)
# Create a domain information object, if one doesn't
# already exist
domain_info_exists = DomainInformation.objects.filter(domain=domain).exists()
if not domain_info_exists:
new_domain_info = DomainInformation(creator=self, domain=domain)
new_domain_info.save()
def on_each_login(self):
"""Callback each time the user is authenticated.
@ -152,17 +106,6 @@ class User(AbstractUser):
as a transition domain and update our domainInfo objects accordingly.
"""
# PART 1: TRANSITION DOMAINS
#
# NOTE: THIS MUST RUN FIRST
# (If we have an issue where transition domains were
# not fully converted into Domain and DomainInvitation
# objects, this method will fill in the gaps.
# This will ensure the Domain Invitations method
# runs correctly (no missing invites))
self.check_transition_domains_on_login()
# PART 2: DOMAIN INVITATIONS
self.check_domain_invitations_on_login()
class Meta:

View file

@ -627,22 +627,10 @@ class TestUser(TestCase):
TransitionDomain.objects.all().delete()
User.objects.all().delete()
def test_check_transition_domains_on_login(self):
"""A user's on_each_login callback checks transition domains.
Makes DomainInformation object."""
self.domain, _ = Domain.objects.get_or_create(name=self.domain_name)
self.user.on_each_login()
self.assertTrue(DomainInformation.objects.get(domain=self.domain))
def test_check_transition_domains_without_domains_on_login(self):
"""A user's on_each_login callback checks transition domains.
"""A user's on_each_login callback does not check transition domains.
This test makes sure that in the event a domain does not exist
for a given transition domain, both a domain and domain invitation
are created."""
self.user.on_each_login()
self.assertTrue(Domain.objects.get(name=self.domain_name))
domain = Domain.objects.get(name=self.domain_name)
self.assertTrue(DomainInvitation.objects.get(email=self.email, domain=domain))
self.assertTrue(DomainInformation.objects.get(domain=domain))
self.assertFalse(Domain.objects.filter(name=self.domain_name).exists())