diff --git a/src/registrar/models/domain_invitation.py b/src/registrar/models/domain_invitation.py index 1ad150e4a..59157349f 100644 --- a/src/registrar/models/domain_invitation.py +++ b/src/registrar/models/domain_invitation.py @@ -1,5 +1,7 @@ """People are invited by email to administer domains.""" +import logging + from django.contrib.auth import get_user_model from django.db import models, IntegrityError @@ -9,6 +11,9 @@ from .utility.time_stamped_model import TimeStampedModel from .user_domain_role import UserDomainRole +logger = logging.getLogger(__name__) + + class DomainInvitation(TimeStampedModel): INVITED = "invited" RETRIEVED = "retrieved" @@ -39,7 +44,11 @@ class DomainInvitation(TimeStampedModel): @transition(field="status", source=INVITED, target=RETRIEVED) def retrieve(self): - """When an invitation is retrieved, create the corresponding permission.""" + """When an invitation is retrieved, create the corresponding permission. + + Raises: + RuntimeError if no matching user can be found. + """ # get a user with this email address User = get_user_model() @@ -54,12 +63,12 @@ class DomainInvitation(TimeStampedModel): # and create a role for that user on this domain try: - UserDomainRole.objects.create( + role = UserDomainRole.objects.get( user=user, domain=self.domain, role=UserDomainRole.Roles.ADMIN ) - except IntegrityError: - # should not happen because this user shouldn't retrieve this invitation - # more than once. - raise RuntimeError( - "Invitation would create a role that already exists for this user." - ) + except UserDomainRole.DoesNotExist: + UserDomainRole.objects.create(user=user, domain=self.domain, role=UserDomainRole.Roles.ADMIN) + else: + # something strange happened and this role already existed when + # the invitation was retrieved. Log that this occurred. + logger.warn("Invitation %s was retrieved for a role that already exists.", self) diff --git a/src/registrar/models/user.py b/src/registrar/models/user.py index fb51af30c..3448d712d 100644 --- a/src/registrar/models/user.py +++ b/src/registrar/models/user.py @@ -1,3 +1,5 @@ +import logging + from django.contrib.auth.models import AbstractUser from django.db import models @@ -6,6 +8,9 @@ from .domain_invitation import DomainInvitation from phonenumber_field.modelfields import PhoneNumberField # type: ignore +logger = logging.getLogger(__name__) + + class User(AbstractUser): """ A custom user model that performs identically to the default user model @@ -43,5 +48,11 @@ class User(AbstractUser): for invitation in DomainInvitation.objects.filter( email=self.email, status=DomainInvitation.INVITED ): - invitation.retrieve() - invitation.save() + try: + invitation.retrieve() + invitation.save() + except RuntimeError: + # retrieving should not fail because of a missing user, but + # if it does fail, log the error so a new user can continue + # logging in + logger.warn("Failed to retrieve invitation %s", invitation, exc_info=True) diff --git a/src/registrar/tests/test_models.py b/src/registrar/tests/test_models.py index b60898e50..5649172b2 100644 --- a/src/registrar/tests/test_models.py +++ b/src/registrar/tests/test_models.py @@ -177,6 +177,9 @@ class TestInvitations(TestCase): ) self.user, _ = User.objects.get_or_create(email=self.email) + # clean out the roles each time + UserDomainRole.objects.all().delete() + def test_retrieval_creates_role(self): self.invitation.retrieve() self.assertTrue(UserDomainRole.objects.get(user=self.user, domain=self.domain)) @@ -187,11 +190,13 @@ class TestInvitations(TestCase): with self.assertRaises(RuntimeError): self.invitation.retrieve() - def test_retrieve_existing_role_error(self): + def test_retrieve_existing_role_no_error(self): # make the overlapping role - UserDomainRole.objects.get_or_create(user=self.user, domain=self.domain) - with self.assertRaises(RuntimeError): + UserDomainRole.objects.get_or_create(user=self.user, domain=self.domain, role=UserDomainRole.Roles.ADMIN) + # this is not an error but does produce a console warning + with less_console_noise(): self.invitation.retrieve() + self.assertEqual(self.invitation.status, DomainInvitation.RETRIEVED) def test_retrieve_on_first_login(self): """A new user's first_login callback retrieves their invitations."""