diff --git a/src/registrar/utility/csv_export.py b/src/registrar/utility/csv_export.py index 64d960337..e90b27c29 100644 --- a/src/registrar/utility/csv_export.py +++ b/src/registrar/utility/csv_export.py @@ -15,6 +15,7 @@ from django.db.models import Case, CharField, Count, DateField, F, ManyToManyFie from django.utils import timezone from django.db.models.functions import Concat, Coalesce from django.contrib.postgres.aggregates import StringAgg +from registrar.models.user_portfolio_permission import UserPortfolioPermission from registrar.models.utility.generic_helper import convert_queryset_to_dict from registrar.templatetags.custom_filters import get_region from registrar.utility.constants import BranchChoices @@ -50,11 +51,7 @@ def format_end_date(end_date): return timezone.make_aware(datetime.strptime(end_date, "%Y-%m-%d")) if end_date else get_default_end_date() -class BaseExport(ABC): - """ - A generic class for exporting data which returns a csv file for the given model. - Base class in an inheritance tree of 3. - """ +class BaseModelDict(ABC): @classmethod @abstractmethod @@ -65,13 +62,6 @@ class BaseExport(ABC): """ pass - @classmethod - def get_columns(cls): - """ - Returns the columns for CSV export. Override in subclasses as needed. - """ - return [] - @classmethod def get_sort_fields(cls): """ @@ -116,7 +106,7 @@ class BaseExport(ABC): return Q() @classmethod - def get_computed_fields(cls): + def get_annotated_fields(cls): """ Get a dict of computed fields. These are fields that do not exist on the model normally and will be passed to .annotate() when building a queryset. @@ -136,25 +126,10 @@ class BaseExport(ABC): Get a list of fields from related tables. """ return [] - - @classmethod - def update_queryset(cls, queryset, **kwargs): - """ - Returns an updated queryset. Override in subclass to update queryset. - """ - return queryset - - @classmethod - def write_csv_before(cls, csv_writer, **export_kwargs): - """ - Write to csv file before the write_csv method. - Override in subclasses where needed. - """ - pass - + @classmethod def annotate_and_retrieve_fields( - cls, initial_queryset, computed_fields, related_table_fields=None, include_many_to_many=False, **kwargs + cls, initial_queryset, annotated_fields, related_table_fields=None, include_many_to_many=False, **kwargs ) -> QuerySet: """ Applies annotations to a queryset and retrieves specified fields, @@ -162,7 +137,7 @@ class BaseExport(ABC): Parameters: initial_queryset (QuerySet): Initial queryset. - computed_fields (dict, optional): Fields to compute {field_name: expression}. + annotated_fields (dict, optional): Fields to compute {field_name: expression}. related_table_fields (list, optional): Extra fields to retrieve; defaults to annotation keys if None. include_many_to_many (bool, optional): Determines if we should include many to many fields or not **kwargs: Additional keyword arguments for specific parameters (e.g., public_contacts, domain_invitations, @@ -176,8 +151,8 @@ class BaseExport(ABC): # We can infer that if we're passing in annotations, # we want to grab the result of said annotation. - if computed_fields: - related_table_fields.extend(computed_fields.keys()) + if annotated_fields: + related_table_fields.extend(annotated_fields.keys()) # Get prexisting fields on the model model_fields = set() @@ -187,10 +162,109 @@ class BaseExport(ABC): if many_to_many or not isinstance(field, ManyToManyField): model_fields.add(field.name) - queryset = initial_queryset.annotate(**computed_fields).values(*model_fields, *related_table_fields) + queryset = initial_queryset.annotate(**annotated_fields).values(*model_fields, *related_table_fields) return cls.update_queryset(queryset, **kwargs) + @classmethod + def update_queryset(cls, queryset, **kwargs): + """ + Returns an updated queryset. Override in subclass to update queryset. + """ + return queryset + + @classmethod + def get_model_dict(cls): + sort_fields = cls.get_sort_fields() + kwargs = cls.get_additional_args() + select_related = cls.get_select_related() + prefetch_related = cls.get_prefetch_related() + exclusions = cls.get_exclusions() + annotations_for_sort = cls.get_annotations_for_sort() + filter_conditions = cls.get_filter_conditions(**kwargs) + annotated_fields = cls.get_annotated_fields() + related_table_fields = cls.get_related_table_fields() + + model_queryset = ( + cls.model() + .objects + .select_related(*select_related) + .prefetch_related(*prefetch_related) + .filter(filter_conditions) + .exclude(exclusions) + .annotate(**annotations_for_sort) + .order_by(*sort_fields) + .distinct() + ) + annotated_queryset = cls.annotate_and_retrieve_fields( + model_queryset, annotated_fields, related_table_fields, **kwargs + ) + models_dict = convert_queryset_to_dict(annotated_queryset, is_model=False) + + return models_dict + + +class UserPortfolioPermissionModelDict(BaseModelDict): + + @classmethod + def model(cls): + # Return the model class that this export handles + return UserPortfolioPermission + + @classmethod + def get_filter_conditions(cls, **export_kwargs): + """ + Get a Q object of filter conditions to filter when building queryset. + """ + return Q() + + @classmethod + def get_exclusions(cls): + """ + Get a Q object of exclusion conditions to pass to .exclude() when building queryset. + """ + return Q() + + @classmethod + def get_annotated_fields(cls): + """ + Get a dict of computed fields. These are fields that do not exist on the model normally + and will be passed to .annotate() when building a queryset. + """ + return {} + + @classmethod + def parse_row(cls, columns, model): + """ + Given a set of columns and a model dictionary, generate a new row from cleaned column data. + """ + FIELDS = {"Not yet defined": "Not yet defined"} + + row = [FIELDS.get(column, "") for column in columns] + return row + + +class BaseExport(BaseModelDict): + """ + A generic class for exporting data which returns a csv file for the given model. + Base class in an inheritance tree of 3. + """ + + @classmethod + def get_columns(cls): + """ + Returns the columns for CSV export. Override in subclasses as needed. + """ + return [] + + @classmethod + def write_csv_before(cls, csv_writer, **export_kwargs): + """ + Write to csv file before the write_csv method. + Override in subclasses where needed. + """ + pass + @classmethod def export_data_to_csv(cls, csv_file, **export_kwargs): """ @@ -199,32 +273,7 @@ class BaseExport(ABC): """ writer = csv.writer(csv_file) columns = cls.get_columns() - sort_fields = cls.get_sort_fields() - kwargs = cls.get_additional_args() - select_related = cls.get_select_related() - prefetch_related = cls.get_prefetch_related() - exclusions = cls.get_exclusions() - annotations_for_sort = cls.get_annotations_for_sort() - filter_conditions = cls.get_filter_conditions(**export_kwargs) - computed_fields = cls.get_computed_fields() - related_table_fields = cls.get_related_table_fields() - - model_queryset = ( - cls.model() - .objects.select_related(*select_related) - .prefetch_related(*prefetch_related) - .filter(filter_conditions) - .exclude(exclusions) - .annotate(**annotations_for_sort) - .order_by(*sort_fields) - .distinct() - ) - - # Convert the queryset to a dictionary (including annotated fields) - annotated_queryset = cls.annotate_and_retrieve_fields( - model_queryset, computed_fields, related_table_fields, **kwargs - ) - models_dict = convert_queryset_to_dict(annotated_queryset, is_model=False) + models_dict = cls.get_model_dict() # Write to csv file before the write_csv cls.write_csv_before(writer, **export_kwargs) @@ -534,7 +583,7 @@ class DomainDataType(DomainExport): return ["permissions"] @classmethod - def get_computed_fields(cls, delimiter=", "): + def get_annotated_fields(cls, delimiter=", "): """ Get a dict of computed fields. """ @@ -751,7 +800,7 @@ class DomainDataFull(DomainExport): ) @classmethod - def get_computed_fields(cls, delimiter=", "): + def get_annotated_fields(cls, delimiter=", "): """ Get a dict of computed fields. """ @@ -846,7 +895,7 @@ class DomainDataFederal(DomainExport): ) @classmethod - def get_computed_fields(cls, delimiter=", "): + def get_annotated_fields(cls, delimiter=", "): """ Get a dict of computed fields. """ @@ -1465,7 +1514,7 @@ class DomainRequestDataFull(DomainRequestExport): ] @classmethod - def get_computed_fields(cls, delimiter=", "): + def get_annotated_fields(cls, delimiter=", "): """ Get a dict of computed fields. """