changed parameters for has_more_than_one_join to pass array of reverse_joins rather than individual join

This commit is contained in:
David Kennedy 2024-01-12 17:51:02 -05:00
parent ab05da9c2d
commit 0920246593
No known key found for this signature in database
GPG key ID: 6528A5386E66B96B
2 changed files with 17 additions and 11 deletions

View file

@ -137,14 +137,14 @@ class RegistrarFormSet(forms.BaseFormSet):
# matching database object exists, update it
if db_obj is not None and cleaned:
if should_delete(cleaned):
if hasattr(db_obj, "has_more_than_one_join") and any(db_obj.has_more_than_one_join(rel, related_name) for rel in reverse_joins):
if hasattr(db_obj, "has_more_than_one_join") and db_obj.has_more_than_one_join(reverse_joins, related_name):
# Remove the specific relationship without deleting the object
getattr(db_obj, related_name).remove(self.application)
else:
# If there are no other relationships, delete the object
db_obj.delete()
else:
if hasattr(db_obj, "has_more_than_one_join") and any(db_obj.has_more_than_one_join(rel, related_name) for rel in reverse_joins):
if hasattr(db_obj, "has_more_than_one_join") and db_obj.has_more_than_one_join(reverse_joins, related_name):
# create a new db_obj and disconnect existing one
getattr(db_obj, related_name).remove(self.application)
kwargs = pre_create(db_obj, cleaned)
@ -344,7 +344,7 @@ class AuthorizingOfficialForm(RegistrarForm):
if not self.is_valid():
return
contact = getattr(obj, "authorizing_official", None)
if contact is not None and not any(contact.has_more_than_one_join(rel, "authorizing_official") for rel in self.REVERSE_JOINS):
if contact is not None and not contact.has_more_than_one_join(self.REVERSE_JOINS, "authorizing_official"):
# if contact exists in the database and is not joined to other entities
super().to_database(contact)
else:

View file

@ -54,23 +54,29 @@ class Contact(TimeStampedModel):
db_index=True,
)
def has_more_than_one_join(self, rel, related_name):
def has_more_than_one_join(self, all_relations, expected_relation):
"""Helper for finding whether an object is joined more than once.
all_relations is the list of all_relations to be checked for existing joins.
expected_relation is the one relation with one expected join"""
return any(self._has_more_than_one_join_per_relation(rel, expected_relation) for rel in all_relations)
def _has_more_than_one_join_per_relation(self, relation, expected_relation):
"""Helper for finding whether an object is joined more than once."""
# threshold is the number of related objects that are acceptable
# when determining if related objects exist. threshold is 0 for most
# relationships. if the relationship is related_name, we know that
# relationships. if the relationship is expected_relation, we know that
# there is already exactly 1 acceptable relationship (the one we are
# attempting to delete), so the threshold is 1
threshold = 1 if rel == related_name else 0
threshold = 1 if relation == expected_relation else 0
# Raise a KeyError if rel is not a defined field on the db_obj model
# This will help catch any errors in reverse_join config on forms
if rel not in [field.name for field in self._meta.get_fields()]:
raise KeyError(f"{rel} is not a defined field on the {self._meta.model_name} model.")
if relation not in [field.name for field in self._meta.get_fields()]:
raise KeyError(f"{relation} is not a defined field on the {self._meta.model_name} model.")
# if attr rel in db_obj is not None, then test if reference object(s) exist
if getattr(self, rel) is not None:
field = self._meta.get_field(rel)
if getattr(self, relation) is not None:
field = self._meta.get_field(relation)
if isinstance(field, models.OneToOneField):
# if the rel field is a OneToOne field, then we have already
# determined that the object exists (is not None)
@ -79,7 +85,7 @@ class Contact(TimeStampedModel):
# if the rel field is a ManyToOne or ManyToMany, then we need
# to determine if the count of related objects is greater than
# the threshold
return getattr(self, rel).count() > threshold
return getattr(self, relation).count() > threshold
return False
def get_formatted_name(self):