diff --git a/src/registrar/management/commands/load_domain_invitations.py b/src/registrar/management/commands/load_domain_invitations.py index 28eb09def..32a63d860 100644 --- a/src/registrar/management/commands/load_domain_invitations.py +++ b/src/registrar/management/commands/load_domain_invitations.py @@ -62,7 +62,7 @@ class Command(BaseCommand): DomainInvitation( email=email_address.lower(), domain=domain, - status=DomainInvitation.INVITED, + status=DomainInvitation.DomainInvitationStatus.INVITED, ) ) logger.info("Creating %d invitations", len(to_create)) diff --git a/src/registrar/models/domain_invitation.py b/src/registrar/models/domain_invitation.py index 1e0b7fec8..395244df5 100644 --- a/src/registrar/models/domain_invitation.py +++ b/src/registrar/models/domain_invitation.py @@ -15,8 +15,11 @@ logger = logging.getLogger(__name__) class DomainInvitation(TimeStampedModel): - INVITED = "invited" - RETRIEVED = "retrieved" + + # Constants for status field + class DomainInvitationStatus(models.TextChoices): + INVITED = "invited", "Invited" + RETRIEVED = "retrieved", "Retrieved" email = models.EmailField( null=False, @@ -31,18 +34,15 @@ class DomainInvitation(TimeStampedModel): ) status = FSMField( - choices=[ - (INVITED, INVITED), - (RETRIEVED, RETRIEVED), - ], - default=INVITED, + choices=DomainInvitationStatus.choices, + default=DomainInvitationStatus.INVITED, protected=True, # can't alter state except through transition methods! ) def __str__(self): return f"Invitation for {self.email} on {self.domain} is {self.status}" - @transition(field="status", source=INVITED, target=RETRIEVED) + @transition(field="status", source=DomainInvitationStatus.INVITED, target=DomainInvitationStatus.RETRIEVED) def retrieve(self): """When an invitation is retrieved, create the corresponding permission. diff --git a/src/registrar/models/user.py b/src/registrar/models/user.py index 2daa3c253..346a97aa6 100644 --- a/src/registrar/models/user.py +++ b/src/registrar/models/user.py @@ -67,7 +67,7 @@ class User(AbstractUser): def check_domain_invitations_on_login(self): """When a user first arrives on the site, we need to retrieve any domain invitations that match their email address.""" - for invitation in DomainInvitation.objects.filter(email=self.email, status=DomainInvitation.INVITED): + for invitation in DomainInvitation.objects.filter(email=self.email, status=DomainInvitation.DomainInvitationStatus.INVITED): try: invitation.retrieve() invitation.save() diff --git a/src/registrar/tests/common.py b/src/registrar/tests/common.py index 89967ee90..f459cab60 100644 --- a/src/registrar/tests/common.py +++ b/src/registrar/tests/common.py @@ -344,7 +344,7 @@ class AuditedAdminMockData: full_arg_dict = dict( email="test_mail@mail.com", domain=self.dummy_domain(item_name, True), - status=DomainInvitation.INVITED, + status=DomainInvitation.DomainInvitationStatus.INVITED, ) return full_arg_dict diff --git a/src/registrar/tests/test_models.py b/src/registrar/tests/test_models.py index c2820696d..0a02ccb90 100644 --- a/src/registrar/tests/test_models.py +++ b/src/registrar/tests/test_models.py @@ -597,7 +597,7 @@ class TestInvitations(TestCase): # this is not an error but does produce a console warning with less_console_noise(): self.invitation.retrieve() - self.assertEqual(self.invitation.status, DomainInvitation.RETRIEVED) + self.assertEqual(self.invitation.status, DomainInvitation.DomainInvitationStatus.RETRIEVED) def test_retrieve_on_each_login(self): """A user's authenticate on_each_login callback retrieves their invitations."""