This commit is contained in:
zandercymatics 2024-11-12 14:32:07 -07:00
parent 67f87e1a02
commit b0dc18cc3f
No known key found for this signature in database
GPG key ID: FF4636ABEC9682B7

View file

@ -15,6 +15,7 @@ from django.db.models import Case, CharField, Count, DateField, F, ManyToManyFie
from django.utils import timezone from django.utils import timezone
from django.db.models.functions import Concat, Coalesce from django.db.models.functions import Concat, Coalesce
from django.contrib.postgres.aggregates import StringAgg from django.contrib.postgres.aggregates import StringAgg
from registrar.models.user_portfolio_permission import UserPortfolioPermission
from registrar.models.utility.generic_helper import convert_queryset_to_dict from registrar.models.utility.generic_helper import convert_queryset_to_dict
from registrar.templatetags.custom_filters import get_region from registrar.templatetags.custom_filters import get_region
from registrar.utility.constants import BranchChoices from registrar.utility.constants import BranchChoices
@ -50,11 +51,7 @@ def format_end_date(end_date):
return timezone.make_aware(datetime.strptime(end_date, "%Y-%m-%d")) if end_date else get_default_end_date() return timezone.make_aware(datetime.strptime(end_date, "%Y-%m-%d")) if end_date else get_default_end_date()
class BaseExport(ABC): class BaseModelDict(ABC):
"""
A generic class for exporting data which returns a csv file for the given model.
Base class in an inheritance tree of 3.
"""
@classmethod @classmethod
@abstractmethod @abstractmethod
@ -65,13 +62,6 @@ class BaseExport(ABC):
""" """
pass pass
@classmethod
def get_columns(cls):
"""
Returns the columns for CSV export. Override in subclasses as needed.
"""
return []
@classmethod @classmethod
def get_sort_fields(cls): def get_sort_fields(cls):
""" """
@ -116,7 +106,7 @@ class BaseExport(ABC):
return Q() return Q()
@classmethod @classmethod
def get_computed_fields(cls): def get_annotated_fields(cls):
""" """
Get a dict of computed fields. These are fields that do not exist on the model normally Get a dict of computed fields. These are fields that do not exist on the model normally
and will be passed to .annotate() when building a queryset. and will be passed to .annotate() when building a queryset.
@ -136,25 +126,10 @@ class BaseExport(ABC):
Get a list of fields from related tables. Get a list of fields from related tables.
""" """
return [] return []
@classmethod
def update_queryset(cls, queryset, **kwargs):
"""
Returns an updated queryset. Override in subclass to update queryset.
"""
return queryset
@classmethod
def write_csv_before(cls, csv_writer, **export_kwargs):
"""
Write to csv file before the write_csv method.
Override in subclasses where needed.
"""
pass
@classmethod @classmethod
def annotate_and_retrieve_fields( def annotate_and_retrieve_fields(
cls, initial_queryset, computed_fields, related_table_fields=None, include_many_to_many=False, **kwargs cls, initial_queryset, annotated_fields, related_table_fields=None, include_many_to_many=False, **kwargs
) -> QuerySet: ) -> QuerySet:
""" """
Applies annotations to a queryset and retrieves specified fields, Applies annotations to a queryset and retrieves specified fields,
@ -162,7 +137,7 @@ class BaseExport(ABC):
Parameters: Parameters:
initial_queryset (QuerySet): Initial queryset. initial_queryset (QuerySet): Initial queryset.
computed_fields (dict, optional): Fields to compute {field_name: expression}. annotated_fields (dict, optional): Fields to compute {field_name: expression}.
related_table_fields (list, optional): Extra fields to retrieve; defaults to annotation keys if None. related_table_fields (list, optional): Extra fields to retrieve; defaults to annotation keys if None.
include_many_to_many (bool, optional): Determines if we should include many to many fields or not include_many_to_many (bool, optional): Determines if we should include many to many fields or not
**kwargs: Additional keyword arguments for specific parameters (e.g., public_contacts, domain_invitations, **kwargs: Additional keyword arguments for specific parameters (e.g., public_contacts, domain_invitations,
@ -176,8 +151,8 @@ class BaseExport(ABC):
# We can infer that if we're passing in annotations, # We can infer that if we're passing in annotations,
# we want to grab the result of said annotation. # we want to grab the result of said annotation.
if computed_fields: if annotated_fields:
related_table_fields.extend(computed_fields.keys()) related_table_fields.extend(annotated_fields.keys())
# Get prexisting fields on the model # Get prexisting fields on the model
model_fields = set() model_fields = set()
@ -187,10 +162,109 @@ class BaseExport(ABC):
if many_to_many or not isinstance(field, ManyToManyField): if many_to_many or not isinstance(field, ManyToManyField):
model_fields.add(field.name) model_fields.add(field.name)
queryset = initial_queryset.annotate(**computed_fields).values(*model_fields, *related_table_fields) queryset = initial_queryset.annotate(**annotated_fields).values(*model_fields, *related_table_fields)
return cls.update_queryset(queryset, **kwargs) return cls.update_queryset(queryset, **kwargs)
@classmethod
def update_queryset(cls, queryset, **kwargs):
"""
Returns an updated queryset. Override in subclass to update queryset.
"""
return queryset
@classmethod
def get_model_dict(cls):
sort_fields = cls.get_sort_fields()
kwargs = cls.get_additional_args()
select_related = cls.get_select_related()
prefetch_related = cls.get_prefetch_related()
exclusions = cls.get_exclusions()
annotations_for_sort = cls.get_annotations_for_sort()
filter_conditions = cls.get_filter_conditions(**kwargs)
annotated_fields = cls.get_annotated_fields()
related_table_fields = cls.get_related_table_fields()
model_queryset = (
cls.model()
.objects
.select_related(*select_related)
.prefetch_related(*prefetch_related)
.filter(filter_conditions)
.exclude(exclusions)
.annotate(**annotations_for_sort)
.order_by(*sort_fields)
.distinct()
)
annotated_queryset = cls.annotate_and_retrieve_fields(
model_queryset, annotated_fields, related_table_fields, **kwargs
)
models_dict = convert_queryset_to_dict(annotated_queryset, is_model=False)
return models_dict
class UserPortfolioPermissionModelDict(BaseModelDict):
@classmethod
def model(cls):
# Return the model class that this export handles
return UserPortfolioPermission
@classmethod
def get_filter_conditions(cls, **export_kwargs):
"""
Get a Q object of filter conditions to filter when building queryset.
"""
return Q()
@classmethod
def get_exclusions(cls):
"""
Get a Q object of exclusion conditions to pass to .exclude() when building queryset.
"""
return Q()
@classmethod
def get_annotated_fields(cls):
"""
Get a dict of computed fields. These are fields that do not exist on the model normally
and will be passed to .annotate() when building a queryset.
"""
return {}
@classmethod
def parse_row(cls, columns, model):
"""
Given a set of columns and a model dictionary, generate a new row from cleaned column data.
"""
FIELDS = {"Not yet defined": "Not yet defined"}
row = [FIELDS.get(column, "") for column in columns]
return row
class BaseExport(BaseModelDict):
"""
A generic class for exporting data which returns a csv file for the given model.
Base class in an inheritance tree of 3.
"""
@classmethod
def get_columns(cls):
"""
Returns the columns for CSV export. Override in subclasses as needed.
"""
return []
@classmethod
def write_csv_before(cls, csv_writer, **export_kwargs):
"""
Write to csv file before the write_csv method.
Override in subclasses where needed.
"""
pass
@classmethod @classmethod
def export_data_to_csv(cls, csv_file, **export_kwargs): def export_data_to_csv(cls, csv_file, **export_kwargs):
""" """
@ -199,32 +273,7 @@ class BaseExport(ABC):
""" """
writer = csv.writer(csv_file) writer = csv.writer(csv_file)
columns = cls.get_columns() columns = cls.get_columns()
sort_fields = cls.get_sort_fields() models_dict = cls.get_model_dict()
kwargs = cls.get_additional_args()
select_related = cls.get_select_related()
prefetch_related = cls.get_prefetch_related()
exclusions = cls.get_exclusions()
annotations_for_sort = cls.get_annotations_for_sort()
filter_conditions = cls.get_filter_conditions(**export_kwargs)
computed_fields = cls.get_computed_fields()
related_table_fields = cls.get_related_table_fields()
model_queryset = (
cls.model()
.objects.select_related(*select_related)
.prefetch_related(*prefetch_related)
.filter(filter_conditions)
.exclude(exclusions)
.annotate(**annotations_for_sort)
.order_by(*sort_fields)
.distinct()
)
# Convert the queryset to a dictionary (including annotated fields)
annotated_queryset = cls.annotate_and_retrieve_fields(
model_queryset, computed_fields, related_table_fields, **kwargs
)
models_dict = convert_queryset_to_dict(annotated_queryset, is_model=False)
# Write to csv file before the write_csv # Write to csv file before the write_csv
cls.write_csv_before(writer, **export_kwargs) cls.write_csv_before(writer, **export_kwargs)
@ -534,7 +583,7 @@ class DomainDataType(DomainExport):
return ["permissions"] return ["permissions"]
@classmethod @classmethod
def get_computed_fields(cls, delimiter=", "): def get_annotated_fields(cls, delimiter=", "):
""" """
Get a dict of computed fields. Get a dict of computed fields.
""" """
@ -751,7 +800,7 @@ class DomainDataFull(DomainExport):
) )
@classmethod @classmethod
def get_computed_fields(cls, delimiter=", "): def get_annotated_fields(cls, delimiter=", "):
""" """
Get a dict of computed fields. Get a dict of computed fields.
""" """
@ -846,7 +895,7 @@ class DomainDataFederal(DomainExport):
) )
@classmethod @classmethod
def get_computed_fields(cls, delimiter=", "): def get_annotated_fields(cls, delimiter=", "):
""" """
Get a dict of computed fields. Get a dict of computed fields.
""" """
@ -1465,7 +1514,7 @@ class DomainRequestDataFull(DomainRequestExport):
] ]
@classmethod @classmethod
def get_computed_fields(cls, delimiter=", "): def get_annotated_fields(cls, delimiter=", "):
""" """
Get a dict of computed fields. Get a dict of computed fields.
""" """