diff --git a/src/registrar/forms/application_wizard.py b/src/registrar/forms/application_wizard.py index a1d454f74..3c1f0cc9c 100644 --- a/src/registrar/forms/application_wizard.py +++ b/src/registrar/forms/application_wizard.py @@ -137,14 +137,14 @@ class RegistrarFormSet(forms.BaseFormSet): # matching database object exists, update it if db_obj is not None and cleaned: if should_delete(cleaned): - if hasattr(db_obj, "has_more_than_one_join") and any(db_obj.has_more_than_one_join(rel, related_name) for rel in reverse_joins): + if hasattr(db_obj, "has_more_than_one_join") and db_obj.has_more_than_one_join(reverse_joins, related_name): # Remove the specific relationship without deleting the object getattr(db_obj, related_name).remove(self.application) else: # If there are no other relationships, delete the object db_obj.delete() else: - if hasattr(db_obj, "has_more_than_one_join") and any(db_obj.has_more_than_one_join(rel, related_name) for rel in reverse_joins): + if hasattr(db_obj, "has_more_than_one_join") and db_obj.has_more_than_one_join(reverse_joins, related_name): # create a new db_obj and disconnect existing one getattr(db_obj, related_name).remove(self.application) kwargs = pre_create(db_obj, cleaned) @@ -344,7 +344,7 @@ class AuthorizingOfficialForm(RegistrarForm): if not self.is_valid(): return contact = getattr(obj, "authorizing_official", None) - if contact is not None and not any(contact.has_more_than_one_join(rel, "authorizing_official") for rel in self.REVERSE_JOINS): + if contact is not None and not contact.has_more_than_one_join(self.REVERSE_JOINS, "authorizing_official"): # if contact exists in the database and is not joined to other entities super().to_database(contact) else: diff --git a/src/registrar/models/contact.py b/src/registrar/models/contact.py index 4352e0a16..483752c56 100644 --- a/src/registrar/models/contact.py +++ b/src/registrar/models/contact.py @@ -54,23 +54,29 @@ class Contact(TimeStampedModel): db_index=True, ) - def has_more_than_one_join(self, rel, related_name): + def has_more_than_one_join(self, all_relations, expected_relation): + """Helper for finding whether an object is joined more than once. + all_relations is the list of all_relations to be checked for existing joins. + expected_relation is the one relation with one expected join""" + return any(self._has_more_than_one_join_per_relation(rel, expected_relation) for rel in all_relations) + + def _has_more_than_one_join_per_relation(self, relation, expected_relation): """Helper for finding whether an object is joined more than once.""" # threshold is the number of related objects that are acceptable # when determining if related objects exist. threshold is 0 for most - # relationships. if the relationship is related_name, we know that + # relationships. if the relationship is expected_relation, we know that # there is already exactly 1 acceptable relationship (the one we are # attempting to delete), so the threshold is 1 - threshold = 1 if rel == related_name else 0 + threshold = 1 if relation == expected_relation else 0 # Raise a KeyError if rel is not a defined field on the db_obj model # This will help catch any errors in reverse_join config on forms - if rel not in [field.name for field in self._meta.get_fields()]: - raise KeyError(f"{rel} is not a defined field on the {self._meta.model_name} model.") + if relation not in [field.name for field in self._meta.get_fields()]: + raise KeyError(f"{relation} is not a defined field on the {self._meta.model_name} model.") # if attr rel in db_obj is not None, then test if reference object(s) exist - if getattr(self, rel) is not None: - field = self._meta.get_field(rel) + if getattr(self, relation) is not None: + field = self._meta.get_field(relation) if isinstance(field, models.OneToOneField): # if the rel field is a OneToOne field, then we have already # determined that the object exists (is not None) @@ -79,7 +85,7 @@ class Contact(TimeStampedModel): # if the rel field is a ManyToOne or ManyToMany, then we need # to determine if the count of related objects is greater than # the threshold - return getattr(self, rel).count() > threshold + return getattr(self, relation).count() > threshold return False def get_formatted_name(self):