Refactor part 1

This commit is contained in:
zandercymatics 2024-11-13 13:10:11 -07:00
parent b0dc18cc3f
commit 2265b70b50
No known key found for this signature in database
GPG key ID: FF4636ABEC9682B7
3 changed files with 239 additions and 130 deletions

View file

@ -10,16 +10,20 @@ from registrar.models import (
DomainInformation,
PublicContact,
UserDomainRole,
PortfolioInvitation,
)
from django.db.models import Case, CharField, Count, DateField, F, ManyToManyField, Q, QuerySet, Value, When
from django.db.models import Case, CharField, Count, DateField, F, ManyToManyField, Q, QuerySet, Value, When, TextField, OuterRef, Subquery
from django.db.models.functions import Cast
from django.utils import timezone
from django.db.models.functions import Concat, Coalesce
from django.contrib.postgres.aggregates import StringAgg
from registrar.models.user_portfolio_permission import UserPortfolioPermission
from registrar.models.utility.generic_helper import convert_queryset_to_dict
from registrar.models.utility.orm_helper import ArrayRemove
from registrar.templatetags.custom_filters import get_region
from registrar.utility.constants import BranchChoices
from registrar.utility.enums import DefaultEmail
from django.contrib.postgres.aggregates import ArrayAgg
logger = logging.getLogger(__name__)
@ -167,14 +171,7 @@ class BaseModelDict(ABC):
return cls.update_queryset(queryset, **kwargs)
@classmethod
def update_queryset(cls, queryset, **kwargs):
"""
Returns an updated queryset. Override in subclass to update queryset.
"""
return queryset
@classmethod
def get_model_dict(cls):
def get_annotated_queryset(cls, request=None):
sort_fields = cls.get_sort_fields()
kwargs = cls.get_additional_args()
select_related = cls.get_select_related()
@ -196,12 +193,21 @@ class BaseModelDict(ABC):
.order_by(*sort_fields)
.distinct()
)
annotated_queryset = cls.annotate_and_retrieve_fields(
return cls.annotate_and_retrieve_fields(
model_queryset, annotated_fields, related_table_fields, **kwargs
)
models_dict = convert_queryset_to_dict(annotated_queryset, is_model=False)
return models_dict
@classmethod
def update_queryset(cls, queryset, **kwargs):
"""
Returns an updated queryset. Override in subclass to update queryset.
"""
return queryset
@classmethod
def get_models_dict(cls, request=None):
return convert_queryset_to_dict(cls.get_annotated_queryset(request), is_model=False)
class UserPortfolioPermissionModelDict(BaseModelDict):
@ -212,36 +218,170 @@ class UserPortfolioPermissionModelDict(BaseModelDict):
return UserPortfolioPermission
@classmethod
def get_filter_conditions(cls, **export_kwargs):
def get_select_related(cls):
"""
Get a list of tables to pass to select_related when building queryset.
"""
return ["user"]
@classmethod
def get_filter_conditions(cls, portfolio):
"""
Get a Q object of filter conditions to filter when building queryset.
"""
return Q()
if not portfolio:
# Return nothing
return Q(id__in=[])
# Get all members on this portfolio
return Q(portfolio=portfolio)
@classmethod
def get_exclusions(cls):
"""
Get a Q object of exclusion conditions to pass to .exclude() when building queryset.
"""
return Q()
@classmethod
def get_annotated_fields(cls):
def get_annotated_fields(cls, portfolio):
"""
Get a dict of computed fields. These are fields that do not exist on the model normally
and will be passed to .annotate() when building a queryset.
"""
return {}
if not portfolio:
# Return nothing
return {}
return {
"first_name": F("user__first_name"),
"last_name": F("user__last_name"),
"email_display": F("user__email"),
"last_active": Coalesce(
Cast(F("user__last_login"), output_field=TextField()),
Value("Invalid date"),
output_field=TextField(),
),
"additional_permissions_display": F("additional_permissions"),
"member_display": Case(
When(
Q(user__email__isnull=False) & ~Q(user__email=""),
then=F("user__email")
),
When(
Q(user__first_name__isnull=False) | Q(user__last_name__isnull=False),
then=Concat(
Coalesce(F("user__first_name"), Value("")),
Value(" "),
Coalesce(F("user__last_name"), Value("")),
),
),
default=Value(""),
output_field=CharField(),
),
"domain_info": ArrayAgg(
Concat(
F("user__permissions__domain_id"),
Value(":"),
F("user__permissions__domain__name"),
output_field=CharField(),
),
distinct=True,
filter=Q(user__permissions__domain__isnull=False)
& Q(user__permissions__domain__domain_info__portfolio=portfolio),
),
"source": Value("permission", output_field=CharField()),
}
@classmethod
def get_annotated_queryset(cls, portfolio):
"""Override of the base annotated queryset to pass in portfolio"""
model_queryset = (
cls.model()
.objects
.select_related(*cls.get_select_related())
.prefetch_related(*cls.get_prefetch_related())
.filter(cls.get_filter_conditions(portfolio))
.exclude(cls.get_exclusions())
.annotate(**cls.get_annotations_for_sort())
.order_by(*cls.get_sort_fields())
.distinct()
)
annotated_fields = cls.get_annotated_fields(portfolio)
related_table_fields = cls.get_related_table_fields()
return cls.annotate_and_retrieve_fields(
model_queryset, annotated_fields, related_table_fields
)
class PortfolioInvitationModelDict(BaseModelDict):
@classmethod
def parse_row(cls, columns, model):
"""
Given a set of columns and a model dictionary, generate a new row from cleaned column data.
"""
FIELDS = {"Not yet defined": "Not yet defined"}
def model(cls):
# Return the model class that this export handles
return PortfolioInvitation
row = [FIELDS.get(column, "") for column in columns]
return row
@classmethod
def get_filter_conditions(cls, portfolio):
"""
Get a Q object of filter conditions to filter when building queryset.
"""
if not portfolio:
# Return nothing
return Q(id__in=[])
# Get all members on this portfolio
return Q(
# Check if email matches the OuterRef("email")
email=OuterRef("email"),
# Check if the domain's portfolio matches the given portfolio)
domain__domain_info__portfolio=portfolio,
)
@classmethod
def get_annotated_fields(cls, portfolio):
"""
Get a dict of computed fields. These are fields that do not exist on the model normally
and will be passed to .annotate() when building a queryset.
"""
if not portfolio:
# Return nothing
return {}
domain_invitations = DomainInvitation.objects.filter(
email=OuterRef("email"), # Check if email matches the OuterRef("email")
domain__domain_info__portfolio=portfolio, # Check if the domain's portfolio matches the given portfolio
).annotate(domain_info=Concat(F("domain__id"), Value(":"), F("domain__name"), output_field=CharField()))
return {
"first_name": Value(None, output_field=CharField()),
"last_name": Value(None, output_field=CharField()),
"email_display": F("email"),
"last_active": Value("Invited", output_field=TextField()),
"additional_permissions_display": F("additional_permissions"),
"member_display": F("email"),
"domain_info": ArrayRemove(
ArrayAgg(
Subquery(domain_invitations.values("domain_info")),
distinct=True,
)
),
"source": Value("invitation", output_field=CharField()),
}
@classmethod
def get_annotated_queryset(cls, portfolio):
"""Override of the base annotated queryset to pass in portfolio"""
model_queryset = (
cls.model()
.objects
.select_related(*cls.get_select_related())
.prefetch_related(*cls.get_prefetch_related())
.filter(cls.get_filter_conditions(portfolio))
.exclude(cls.get_exclusions())
.annotate(**cls.get_annotations_for_sort())
.order_by(*cls.get_sort_fields())
.distinct()
)
annotated_fields = cls.get_annotated_fields(portfolio)
related_table_fields = cls.get_related_table_fields()
return cls.annotate_and_retrieve_fields(
model_queryset, annotated_fields, related_table_fields
)
class BaseExport(BaseModelDict):
@ -266,14 +406,14 @@ class BaseExport(BaseModelDict):
pass
@classmethod
def export_data_to_csv(cls, csv_file, **export_kwargs):
def export_data_to_csv(cls, csv_file, request=None, **export_kwargs):
"""
All domain metadata:
Exports domains of all statuses plus domain managers.
"""
writer = csv.writer(csv_file)
columns = cls.get_columns()
models_dict = cls.get_model_dict()
models_dict = cls.get_models_dict()
# Write to csv file before the write_csv
cls.write_csv_before(writer, **export_kwargs)
@ -321,6 +461,43 @@ class BaseExport(BaseModelDict):
"""
pass
class MemberExport(BaseExport):
@classmethod
def model(self):
"""
No model is defined for the member report as it is a combination of multiple fields.
This is a special edge case, but the base report requires this to be defined.
"""
return None
@classmethod
def get_models_dict(cls, request=None):
portfolio = request.session.get("portfolio")
if not portfolio:
return {}
# Union the two querysets to combine UserPortfolioPermission + invites
permissions = UserPortfolioPermissionModelDict.get_annotated_queryset(portfolio)
invitations = PortfolioInvitationModelDict.get_annotated_queryset(portfolio)
objects = permissions.union(invitations)
return convert_queryset_to_dict(objects, is_model=False)
@classmethod
def get_columns(cls):
"""
Returns the columns for CSV export. Override in subclasses as needed.
"""
return []
@classmethod
@abstractmethod
def parse_row(cls, columns, model):
"""
Given a set of columns and a model dictionary, generate a new row from cleaned column data.
Must be implemented by subclasses
"""
pass
class DomainExport(BaseExport):
"""
@ -1597,3 +1774,4 @@ class DomainRequestDataFull(DomainRequestExport):
distinct=True,
)
return query