mirror of
https://github.com/cisagov/manage.get.gov.git
synced 2025-07-03 09:43:33 +02:00
changed parameters for has_more_than_one_join to pass array of reverse_joins rather than individual join
This commit is contained in:
parent
ab05da9c2d
commit
0920246593
2 changed files with 17 additions and 11 deletions
|
@ -137,14 +137,14 @@ class RegistrarFormSet(forms.BaseFormSet):
|
||||||
# matching database object exists, update it
|
# matching database object exists, update it
|
||||||
if db_obj is not None and cleaned:
|
if db_obj is not None and cleaned:
|
||||||
if should_delete(cleaned):
|
if should_delete(cleaned):
|
||||||
if hasattr(db_obj, "has_more_than_one_join") and any(db_obj.has_more_than_one_join(rel, related_name) for rel in reverse_joins):
|
if hasattr(db_obj, "has_more_than_one_join") and db_obj.has_more_than_one_join(reverse_joins, related_name):
|
||||||
# Remove the specific relationship without deleting the object
|
# Remove the specific relationship without deleting the object
|
||||||
getattr(db_obj, related_name).remove(self.application)
|
getattr(db_obj, related_name).remove(self.application)
|
||||||
else:
|
else:
|
||||||
# If there are no other relationships, delete the object
|
# If there are no other relationships, delete the object
|
||||||
db_obj.delete()
|
db_obj.delete()
|
||||||
else:
|
else:
|
||||||
if hasattr(db_obj, "has_more_than_one_join") and any(db_obj.has_more_than_one_join(rel, related_name) for rel in reverse_joins):
|
if hasattr(db_obj, "has_more_than_one_join") and db_obj.has_more_than_one_join(reverse_joins, related_name):
|
||||||
# create a new db_obj and disconnect existing one
|
# create a new db_obj and disconnect existing one
|
||||||
getattr(db_obj, related_name).remove(self.application)
|
getattr(db_obj, related_name).remove(self.application)
|
||||||
kwargs = pre_create(db_obj, cleaned)
|
kwargs = pre_create(db_obj, cleaned)
|
||||||
|
@ -344,7 +344,7 @@ class AuthorizingOfficialForm(RegistrarForm):
|
||||||
if not self.is_valid():
|
if not self.is_valid():
|
||||||
return
|
return
|
||||||
contact = getattr(obj, "authorizing_official", None)
|
contact = getattr(obj, "authorizing_official", None)
|
||||||
if contact is not None and not any(contact.has_more_than_one_join(rel, "authorizing_official") for rel in self.REVERSE_JOINS):
|
if contact is not None and not contact.has_more_than_one_join(self.REVERSE_JOINS, "authorizing_official"):
|
||||||
# if contact exists in the database and is not joined to other entities
|
# if contact exists in the database and is not joined to other entities
|
||||||
super().to_database(contact)
|
super().to_database(contact)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -54,23 +54,29 @@ class Contact(TimeStampedModel):
|
||||||
db_index=True,
|
db_index=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def has_more_than_one_join(self, rel, related_name):
|
def has_more_than_one_join(self, all_relations, expected_relation):
|
||||||
|
"""Helper for finding whether an object is joined more than once.
|
||||||
|
all_relations is the list of all_relations to be checked for existing joins.
|
||||||
|
expected_relation is the one relation with one expected join"""
|
||||||
|
return any(self._has_more_than_one_join_per_relation(rel, expected_relation) for rel in all_relations)
|
||||||
|
|
||||||
|
def _has_more_than_one_join_per_relation(self, relation, expected_relation):
|
||||||
"""Helper for finding whether an object is joined more than once."""
|
"""Helper for finding whether an object is joined more than once."""
|
||||||
# threshold is the number of related objects that are acceptable
|
# threshold is the number of related objects that are acceptable
|
||||||
# when determining if related objects exist. threshold is 0 for most
|
# when determining if related objects exist. threshold is 0 for most
|
||||||
# relationships. if the relationship is related_name, we know that
|
# relationships. if the relationship is expected_relation, we know that
|
||||||
# there is already exactly 1 acceptable relationship (the one we are
|
# there is already exactly 1 acceptable relationship (the one we are
|
||||||
# attempting to delete), so the threshold is 1
|
# attempting to delete), so the threshold is 1
|
||||||
threshold = 1 if rel == related_name else 0
|
threshold = 1 if relation == expected_relation else 0
|
||||||
|
|
||||||
# Raise a KeyError if rel is not a defined field on the db_obj model
|
# Raise a KeyError if rel is not a defined field on the db_obj model
|
||||||
# This will help catch any errors in reverse_join config on forms
|
# This will help catch any errors in reverse_join config on forms
|
||||||
if rel not in [field.name for field in self._meta.get_fields()]:
|
if relation not in [field.name for field in self._meta.get_fields()]:
|
||||||
raise KeyError(f"{rel} is not a defined field on the {self._meta.model_name} model.")
|
raise KeyError(f"{relation} is not a defined field on the {self._meta.model_name} model.")
|
||||||
|
|
||||||
# if attr rel in db_obj is not None, then test if reference object(s) exist
|
# if attr rel in db_obj is not None, then test if reference object(s) exist
|
||||||
if getattr(self, rel) is not None:
|
if getattr(self, relation) is not None:
|
||||||
field = self._meta.get_field(rel)
|
field = self._meta.get_field(relation)
|
||||||
if isinstance(field, models.OneToOneField):
|
if isinstance(field, models.OneToOneField):
|
||||||
# if the rel field is a OneToOne field, then we have already
|
# if the rel field is a OneToOne field, then we have already
|
||||||
# determined that the object exists (is not None)
|
# determined that the object exists (is not None)
|
||||||
|
@ -79,7 +85,7 @@ class Contact(TimeStampedModel):
|
||||||
# if the rel field is a ManyToOne or ManyToMany, then we need
|
# if the rel field is a ManyToOne or ManyToMany, then we need
|
||||||
# to determine if the count of related objects is greater than
|
# to determine if the count of related objects is greater than
|
||||||
# the threshold
|
# the threshold
|
||||||
return getattr(self, rel).count() > threshold
|
return getattr(self, relation).count() > threshold
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_formatted_name(self):
|
def get_formatted_name(self):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue