moved reverse_joins definition centrally to contact model

This commit is contained in:
David Kennedy 2024-01-12 18:02:25 -05:00
parent 579b890996
commit 73b0b33ee8
No known key found for this signature in database
GPG key ID: 6528A5386E66B96B
3 changed files with 21 additions and 50 deletions

View file

@ -100,7 +100,6 @@ class RegistrarFormSet(forms.BaseFormSet):
self, self,
obj: DomainApplication, obj: DomainApplication,
join: str, join: str,
reverse_joins: list,
should_delete: Callable, should_delete: Callable,
pre_update: Callable, pre_update: Callable,
pre_create: Callable, pre_create: Callable,
@ -137,14 +136,14 @@ class RegistrarFormSet(forms.BaseFormSet):
# matching database object exists, update it # matching database object exists, update it
if db_obj is not None and cleaned: if db_obj is not None and cleaned:
if should_delete(cleaned): if should_delete(cleaned):
if hasattr(db_obj, "has_more_than_one_join") and db_obj.has_more_than_one_join(reverse_joins, related_name): if hasattr(db_obj, "has_more_than_one_join") and db_obj.has_more_than_one_join(related_name):
# Remove the specific relationship without deleting the object # Remove the specific relationship without deleting the object
getattr(db_obj, related_name).remove(self.application) getattr(db_obj, related_name).remove(self.application)
else: else:
# If there are no other relationships, delete the object # If there are no other relationships, delete the object
db_obj.delete() db_obj.delete()
else: else:
if hasattr(db_obj, "has_more_than_one_join") and db_obj.has_more_than_one_join(reverse_joins, related_name): if hasattr(db_obj, "has_more_than_one_join") and db_obj.has_more_than_one_join(related_name):
# create a new db_obj and disconnect existing one # create a new db_obj and disconnect existing one
getattr(db_obj, related_name).remove(self.application) getattr(db_obj, related_name).remove(self.application)
kwargs = pre_create(db_obj, cleaned) kwargs = pre_create(db_obj, cleaned)
@ -330,21 +329,12 @@ class AboutYourOrganizationForm(RegistrarForm):
class AuthorizingOfficialForm(RegistrarForm): class AuthorizingOfficialForm(RegistrarForm):
JOIN = "authorizing_official" JOIN = "authorizing_official"
REVERSE_JOINS = [
"user",
"authorizing_official",
"submitted_applications",
"contact_applications",
"information_authorizing_official",
"submitted_applications_information",
"contact_applications_information",
]
def to_database(self, obj): def to_database(self, obj):
if not self.is_valid(): if not self.is_valid():
return return
contact = getattr(obj, "authorizing_official", None) contact = getattr(obj, "authorizing_official", None)
if contact is not None and not contact.has_more_than_one_join(self.REVERSE_JOINS, "authorizing_official"): if contact is not None and not contact.has_more_than_one_join("authorizing_official"):
# if contact exists in the database and is not joined to other entities # if contact exists in the database and is not joined to other entities
super().to_database(contact) super().to_database(contact)
else: else:
@ -403,7 +393,7 @@ class BaseCurrentSitesFormSet(RegistrarFormSet):
def to_database(self, obj: DomainApplication): def to_database(self, obj: DomainApplication):
# If we want to test against multiple joins for a website object, replace the empty array # If we want to test against multiple joins for a website object, replace the empty array
# and change the JOIN in the models to allow for reverse references # and change the JOIN in the models to allow for reverse references
self._to_database(obj, self.JOIN, [], self.should_delete, self.pre_update, self.pre_create) self._to_database(obj, self.JOIN, self.should_delete, self.pre_update, self.pre_create)
@classmethod @classmethod
def from_database(cls, obj): def from_database(cls, obj):
@ -462,7 +452,7 @@ class BaseAlternativeDomainFormSet(RegistrarFormSet):
def to_database(self, obj: DomainApplication): def to_database(self, obj: DomainApplication):
# If we want to test against multiple joins for a website object, replace the empty array and # If we want to test against multiple joins for a website object, replace the empty array and
# change the JOIN in the models to allow for reverse references # change the JOIN in the models to allow for reverse references
self._to_database(obj, self.JOIN, [], self.should_delete, self.pre_update, self.pre_create) self._to_database(obj, self.JOIN, self.should_delete, self.pre_update, self.pre_create)
@classmethod @classmethod
def on_fetch(cls, query): def on_fetch(cls, query):
@ -542,21 +532,12 @@ class PurposeForm(RegistrarForm):
class YourContactForm(RegistrarForm): class YourContactForm(RegistrarForm):
JOIN = "submitter" JOIN = "submitter"
REVERSE_JOINS = [
"user",
"authorizing_official",
"submitted_applications",
"contact_applications",
"information_authorizing_official",
"submitted_applications_information",
"contact_applications_information",
]
def to_database(self, obj): def to_database(self, obj):
if not self.is_valid(): if not self.is_valid():
return return
contact = getattr(obj, "submitter", None) contact = getattr(obj, "submitter", None)
if contact is not None and not contact.has_more_than_one_join(self.REVERSE_JOINS, "submitted_applications"): if contact is not None and not contact.has_more_than_one_join("submitted_applications"):
# if contact exists in the database and is not joined to other entities # if contact exists in the database and is not joined to other entities
super().to_database(contact) super().to_database(contact)
else: else:
@ -711,20 +692,10 @@ class BaseOtherContactsFormSet(RegistrarFormSet):
must co-exist. must co-exist.
Also, other_contacts have db relationships to multiple db objects. When attempting Also, other_contacts have db relationships to multiple db objects. When attempting
to delete an other_contact from an application, those db relationships must be to delete an other_contact from an application, those db relationships must be
tested and handled; this is configured with REVERSE_JOINS, which is an array of tested and handled.
strings representing the relationships between contact model and other models.
""" """
JOIN = "other_contacts" JOIN = "other_contacts"
REVERSE_JOINS = [
"user",
"authorizing_official",
"submitted_applications",
"contact_applications",
"information_authorizing_official",
"submitted_applications_information",
"contact_applications_information",
]
def get_deletion_widget(self): def get_deletion_widget(self):
return forms.HiddenInput(attrs={"class": "deletion"}) return forms.HiddenInput(attrs={"class": "deletion"})
@ -756,7 +727,7 @@ class BaseOtherContactsFormSet(RegistrarFormSet):
return cleaned return cleaned
def to_database(self, obj: DomainApplication): def to_database(self, obj: DomainApplication):
self._to_database(obj, self.JOIN, self.REVERSE_JOINS, self.should_delete, self.pre_update, self.pre_create) self._to_database(obj, self.JOIN, self.should_delete, self.pre_update, self.pre_create)
@classmethod @classmethod
def from_database(cls, obj): def from_database(cls, obj):

View file

@ -213,15 +213,6 @@ class ContactForm(forms.ModelForm):
class AuthorizingOfficialContactForm(ContactForm): class AuthorizingOfficialContactForm(ContactForm):
"""Form for updating authorizing official contacts.""" """Form for updating authorizing official contacts."""
JOIN = "authorizing_official" JOIN = "authorizing_official"
REVERSE_JOINS = [
"user",
"authorizing_official",
"submitted_applications",
"contact_applications",
"information_authorizing_official",
"submitted_applications_information",
"contact_applications_information",
]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -258,7 +249,7 @@ class AuthorizingOfficialContactForm(ContactForm):
# get db object # get db object
db_ao = Contact.objects.get(id=self.instance.id) db_ao = Contact.objects.get(id=self.instance.id)
logger.info(f"db_ao.information_authorizing_official {db_ao.information_authorizing_official}") logger.info(f"db_ao.information_authorizing_official {db_ao.information_authorizing_official}")
if self.domainInfo and db_ao.has_more_than_one_join(self.REVERSE_JOINS, "information_authorizing_official"): if self.domainInfo and db_ao.has_more_than_one_join("information_authorizing_official"):
logger.info(f"domain info => {self.domainInfo}") logger.info(f"domain info => {self.domainInfo}")
logger.info(f"authorizing official id => {self.domainInfo.authorizing_official.id}") logger.info(f"authorizing official id => {self.domainInfo.authorizing_official.id}")
contact = Contact() contact = Contact()

View file

@ -54,10 +54,19 @@ class Contact(TimeStampedModel):
db_index=True, db_index=True,
) )
def has_more_than_one_join(self, all_relations, expected_relation): def has_more_than_one_join(self, expected_relation):
"""Helper for finding whether an object is joined more than once. """Helper for finding whether an object is joined more than once.
all_relations is the list of all_relations to be checked for existing joins.
expected_relation is the one relation with one expected join""" expected_relation is the one relation with one expected join"""
# all_relations is the list of all_relations (from contact) to be checked for existing joins
all_relations = [
"user",
"authorizing_official",
"submitted_applications",
"contact_applications",
"information_authorizing_official",
"submitted_applications_information",
"contact_applications_information",
]
return any(self._has_more_than_one_join_per_relation(rel, expected_relation) for rel in all_relations) return any(self._has_more_than_one_join_per_relation(rel, expected_relation) for rel in all_relations)
def _has_more_than_one_join_per_relation(self, relation, expected_relation): def _has_more_than_one_join_per_relation(self, relation, expected_relation):
@ -70,7 +79,7 @@ class Contact(TimeStampedModel):
threshold = 1 if relation == expected_relation else 0 threshold = 1 if relation == expected_relation else 0
# Raise a KeyError if rel is not a defined field on the db_obj model # Raise a KeyError if rel is not a defined field on the db_obj model
# This will help catch any errors in reverse_join config on forms # This will help catch any errors in relation passed.
if relation not in [field.name for field in self._meta.get_fields()]: if relation not in [field.name for field in self._meta.get_fields()]:
raise KeyError(f"{relation} is not a defined field on the {self._meta.model_name} model.") raise KeyError(f"{relation} is not a defined field on the {self._meta.model_name} model.")