Refactor part 1

This commit is contained in:
zandercymatics 2024-11-13 13:10:11 -07:00
parent b0dc18cc3f
commit 2265b70b50
No known key found for this signature in database
GPG key ID: FF4636ABEC9682B7
3 changed files with 239 additions and 130 deletions

View file

@ -0,0 +1,6 @@
from django.db.models.expressions import Func
class ArrayRemove(Func):
"""Custom Func to use array_remove to remove null values"""
function = "array_remove"
template = "%(function)s(%(expressions)s, NULL)"

View file

@ -10,16 +10,20 @@ from registrar.models import (
DomainInformation, DomainInformation,
PublicContact, PublicContact,
UserDomainRole, UserDomainRole,
PortfolioInvitation,
) )
from django.db.models import Case, CharField, Count, DateField, F, ManyToManyField, Q, QuerySet, Value, When from django.db.models import Case, CharField, Count, DateField, F, ManyToManyField, Q, QuerySet, Value, When, TextField, OuterRef, Subquery
from django.db.models.functions import Cast
from django.utils import timezone from django.utils import timezone
from django.db.models.functions import Concat, Coalesce from django.db.models.functions import Concat, Coalesce
from django.contrib.postgres.aggregates import StringAgg from django.contrib.postgres.aggregates import StringAgg
from registrar.models.user_portfolio_permission import UserPortfolioPermission from registrar.models.user_portfolio_permission import UserPortfolioPermission
from registrar.models.utility.generic_helper import convert_queryset_to_dict from registrar.models.utility.generic_helper import convert_queryset_to_dict
from registrar.models.utility.orm_helper import ArrayRemove
from registrar.templatetags.custom_filters import get_region from registrar.templatetags.custom_filters import get_region
from registrar.utility.constants import BranchChoices from registrar.utility.constants import BranchChoices
from registrar.utility.enums import DefaultEmail from registrar.utility.enums import DefaultEmail
from django.contrib.postgres.aggregates import ArrayAgg
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -167,14 +171,7 @@ class BaseModelDict(ABC):
return cls.update_queryset(queryset, **kwargs) return cls.update_queryset(queryset, **kwargs)
@classmethod @classmethod
def update_queryset(cls, queryset, **kwargs): def get_annotated_queryset(cls, request=None):
"""
Returns an updated queryset. Override in subclass to update queryset.
"""
return queryset
@classmethod
def get_model_dict(cls):
sort_fields = cls.get_sort_fields() sort_fields = cls.get_sort_fields()
kwargs = cls.get_additional_args() kwargs = cls.get_additional_args()
select_related = cls.get_select_related() select_related = cls.get_select_related()
@ -196,12 +193,21 @@ class BaseModelDict(ABC):
.order_by(*sort_fields) .order_by(*sort_fields)
.distinct() .distinct()
) )
annotated_queryset = cls.annotate_and_retrieve_fields(
return cls.annotate_and_retrieve_fields(
model_queryset, annotated_fields, related_table_fields, **kwargs model_queryset, annotated_fields, related_table_fields, **kwargs
) )
models_dict = convert_queryset_to_dict(annotated_queryset, is_model=False)
return models_dict @classmethod
def update_queryset(cls, queryset, **kwargs):
"""
Returns an updated queryset. Override in subclass to update queryset.
"""
return queryset
@classmethod
def get_models_dict(cls, request=None):
return convert_queryset_to_dict(cls.get_annotated_queryset(request), is_model=False)
class UserPortfolioPermissionModelDict(BaseModelDict): class UserPortfolioPermissionModelDict(BaseModelDict):
@ -212,36 +218,170 @@ class UserPortfolioPermissionModelDict(BaseModelDict):
return UserPortfolioPermission return UserPortfolioPermission
@classmethod @classmethod
def get_filter_conditions(cls, **export_kwargs): def get_select_related(cls):
"""
Get a list of tables to pass to select_related when building queryset.
"""
return ["user"]
@classmethod
def get_filter_conditions(cls, portfolio):
""" """
Get a Q object of filter conditions to filter when building queryset. Get a Q object of filter conditions to filter when building queryset.
""" """
return Q() if not portfolio:
# Return nothing
return Q(id__in=[])
# Get all members on this portfolio
return Q(portfolio=portfolio)
@classmethod @classmethod
def get_exclusions(cls): def get_annotated_fields(cls, portfolio):
"""
Get a Q object of exclusion conditions to pass to .exclude() when building queryset.
"""
return Q()
@classmethod
def get_annotated_fields(cls):
""" """
Get a dict of computed fields. These are fields that do not exist on the model normally Get a dict of computed fields. These are fields that do not exist on the model normally
and will be passed to .annotate() when building a queryset. and will be passed to .annotate() when building a queryset.
""" """
return {} if not portfolio:
# Return nothing
return {}
return {
"first_name": F("user__first_name"),
"last_name": F("user__last_name"),
"email_display": F("user__email"),
"last_active": Coalesce(
Cast(F("user__last_login"), output_field=TextField()),
Value("Invalid date"),
output_field=TextField(),
),
"additional_permissions_display": F("additional_permissions"),
"member_display": Case(
When(
Q(user__email__isnull=False) & ~Q(user__email=""),
then=F("user__email")
),
When(
Q(user__first_name__isnull=False) | Q(user__last_name__isnull=False),
then=Concat(
Coalesce(F("user__first_name"), Value("")),
Value(" "),
Coalesce(F("user__last_name"), Value("")),
),
),
default=Value(""),
output_field=CharField(),
),
"domain_info": ArrayAgg(
Concat(
F("user__permissions__domain_id"),
Value(":"),
F("user__permissions__domain__name"),
output_field=CharField(),
),
distinct=True,
filter=Q(user__permissions__domain__isnull=False)
& Q(user__permissions__domain__domain_info__portfolio=portfolio),
),
"source": Value("permission", output_field=CharField()),
}
@classmethod @classmethod
def parse_row(cls, columns, model): def get_annotated_queryset(cls, portfolio):
""" """Override of the base annotated queryset to pass in portfolio"""
Given a set of columns and a model dictionary, generate a new row from cleaned column data. model_queryset = (
""" cls.model()
FIELDS = {"Not yet defined": "Not yet defined"} .objects
.select_related(*cls.get_select_related())
.prefetch_related(*cls.get_prefetch_related())
.filter(cls.get_filter_conditions(portfolio))
.exclude(cls.get_exclusions())
.annotate(**cls.get_annotations_for_sort())
.order_by(*cls.get_sort_fields())
.distinct()
)
row = [FIELDS.get(column, "") for column in columns] annotated_fields = cls.get_annotated_fields(portfolio)
return row related_table_fields = cls.get_related_table_fields()
return cls.annotate_and_retrieve_fields(
model_queryset, annotated_fields, related_table_fields
)
class PortfolioInvitationModelDict(BaseModelDict):
@classmethod
def model(cls):
# Return the model class that this export handles
return PortfolioInvitation
@classmethod
def get_filter_conditions(cls, portfolio):
"""
Get a Q object of filter conditions to filter when building queryset.
"""
if not portfolio:
# Return nothing
return Q(id__in=[])
# Get all members on this portfolio
return Q(
# Check if email matches the OuterRef("email")
email=OuterRef("email"),
# Check if the domain's portfolio matches the given portfolio)
domain__domain_info__portfolio=portfolio,
)
@classmethod
def get_annotated_fields(cls, portfolio):
"""
Get a dict of computed fields. These are fields that do not exist on the model normally
and will be passed to .annotate() when building a queryset.
"""
if not portfolio:
# Return nothing
return {}
domain_invitations = DomainInvitation.objects.filter(
email=OuterRef("email"), # Check if email matches the OuterRef("email")
domain__domain_info__portfolio=portfolio, # Check if the domain's portfolio matches the given portfolio
).annotate(domain_info=Concat(F("domain__id"), Value(":"), F("domain__name"), output_field=CharField()))
return {
"first_name": Value(None, output_field=CharField()),
"last_name": Value(None, output_field=CharField()),
"email_display": F("email"),
"last_active": Value("Invited", output_field=TextField()),
"additional_permissions_display": F("additional_permissions"),
"member_display": F("email"),
"domain_info": ArrayRemove(
ArrayAgg(
Subquery(domain_invitations.values("domain_info")),
distinct=True,
)
),
"source": Value("invitation", output_field=CharField()),
}
@classmethod
def get_annotated_queryset(cls, portfolio):
"""Override of the base annotated queryset to pass in portfolio"""
model_queryset = (
cls.model()
.objects
.select_related(*cls.get_select_related())
.prefetch_related(*cls.get_prefetch_related())
.filter(cls.get_filter_conditions(portfolio))
.exclude(cls.get_exclusions())
.annotate(**cls.get_annotations_for_sort())
.order_by(*cls.get_sort_fields())
.distinct()
)
annotated_fields = cls.get_annotated_fields(portfolio)
related_table_fields = cls.get_related_table_fields()
return cls.annotate_and_retrieve_fields(
model_queryset, annotated_fields, related_table_fields
)
class BaseExport(BaseModelDict): class BaseExport(BaseModelDict):
@ -266,14 +406,14 @@ class BaseExport(BaseModelDict):
pass pass
@classmethod @classmethod
def export_data_to_csv(cls, csv_file, **export_kwargs): def export_data_to_csv(cls, csv_file, request=None, **export_kwargs):
""" """
All domain metadata: All domain metadata:
Exports domains of all statuses plus domain managers. Exports domains of all statuses plus domain managers.
""" """
writer = csv.writer(csv_file) writer = csv.writer(csv_file)
columns = cls.get_columns() columns = cls.get_columns()
models_dict = cls.get_model_dict() models_dict = cls.get_models_dict()
# Write to csv file before the write_csv # Write to csv file before the write_csv
cls.write_csv_before(writer, **export_kwargs) cls.write_csv_before(writer, **export_kwargs)
@ -321,6 +461,43 @@ class BaseExport(BaseModelDict):
""" """
pass pass
class MemberExport(BaseExport):
@classmethod
def model(self):
"""
No model is defined for the member report as it is a combination of multiple fields.
This is a special edge case, but the base report requires this to be defined.
"""
return None
@classmethod
def get_models_dict(cls, request=None):
portfolio = request.session.get("portfolio")
if not portfolio:
return {}
# Union the two querysets to combine UserPortfolioPermission + invites
permissions = UserPortfolioPermissionModelDict.get_annotated_queryset(portfolio)
invitations = PortfolioInvitationModelDict.get_annotated_queryset(portfolio)
objects = permissions.union(invitations)
return convert_queryset_to_dict(objects, is_model=False)
@classmethod
def get_columns(cls):
"""
Returns the columns for CSV export. Override in subclasses as needed.
"""
return []
@classmethod
@abstractmethod
def parse_row(cls, columns, model):
"""
Given a set of columns and a model dictionary, generate a new row from cleaned column data.
Must be implemented by subclasses
"""
pass
class DomainExport(BaseExport): class DomainExport(BaseExport):
""" """
@ -1597,3 +1774,4 @@ class DomainRequestDataFull(DomainRequestExport):
distinct=True, distinct=True,
) )
return query return query

View file

@ -7,10 +7,9 @@ from django.contrib.postgres.aggregates import ArrayAgg
from django.urls import reverse from django.urls import reverse
from django.views import View from django.views import View
from registrar.models.domain_invitation import DomainInvitation from registrar.models import UserPortfolioPermission
from registrar.models.portfolio_invitation import PortfolioInvitation
from registrar.models.user_portfolio_permission import UserPortfolioPermission
from registrar.models.utility.portfolio_helper import UserPortfolioPermissionChoices, UserPortfolioRoleChoices from registrar.models.utility.portfolio_helper import UserPortfolioPermissionChoices, UserPortfolioRoleChoices
from registrar.utility.csv_export import UserPortfolioPermissionModelDict, PortfolioInvitationModelDict
from registrar.views.utility.mixins import PortfolioMembersPermission from registrar.views.utility.mixins import PortfolioMembersPermission
@ -39,7 +38,7 @@ class PortfolioMembersJson(PortfolioMembersPermission, View):
page_number = request.GET.get("page", 1) page_number = request.GET.get("page", 1)
page_obj = paginator.get_page(page_number) page_obj = paginator.get_page(page_number)
members = [self.serialize_members(request, portfolio, item, request.user) for item in page_obj.object_list] members = [self.serialize_members(portfolio, item, request.user) for item in page_obj.object_list]
return JsonResponse( return JsonResponse(
{ {
@ -56,92 +55,25 @@ class PortfolioMembersJson(PortfolioMembersPermission, View):
def initial_permissions_search(self, portfolio): def initial_permissions_search(self, portfolio):
"""Perform initial search for permissions before applying any filters.""" """Perform initial search for permissions before applying any filters."""
permissions = UserPortfolioPermission.objects.filter(portfolio=portfolio) queryset = UserPortfolioPermissionModelDict.get_annotated_queryset(portfolio)
permissions = ( return queryset.values(
permissions.select_related("user") "id",
.annotate( "first_name",
first_name=F("user__first_name"), "last_name",
last_name=F("user__last_name"), "email_display",
email_display=F("user__email"), "last_active",
last_active=Coalesce( "roles",
Cast(F("user__last_login"), output_field=TextField()), # Cast last_login to text "additional_permissions_display",
Value("Invalid date"), "member_display",
output_field=TextField(), "domain_info",
), "source",
additional_permissions_display=F("additional_permissions"), )
member_display=Case(
# If email is present and not blank, use email def initial_invitations_search(self, portfolio):
When(Q(user__email__isnull=False) & ~Q(user__email=""), then=F("user__email")), """Perform initial invitations search and get related DomainInvitation data based on the email."""
# If first name or last name is present, use concatenation of first_name + " " + last_name # Get DomainInvitation query for matching email and for the portfolio
When( queryset = PortfolioInvitationModelDict.get_annotated_queryset(portfolio)
Q(user__first_name__isnull=False) | Q(user__last_name__isnull=False), return queryset.values(
then=Concat(
Coalesce(F("user__first_name"), Value("")),
Value(" "),
Coalesce(F("user__last_name"), Value("")),
),
),
# If neither, use an empty string
default=Value(""),
output_field=CharField(),
),
domain_info=ArrayAgg(
# an array of domains, with id and name, colon separated
Concat(
F("user__permissions__domain_id"),
Value(":"),
F("user__permissions__domain__name"),
# specify the output_field to ensure union has same column types
output_field=CharField(),
),
distinct=True,
filter=Q(user__permissions__domain__isnull=False) # filter out null values
& Q(
user__permissions__domain__domain_info__portfolio=portfolio
), # only include domains in portfolio
),
source=Value("permission", output_field=CharField()),
)
.values(
"id",
"first_name",
"last_name",
"email_display",
"last_active",
"roles",
"additional_permissions_display",
"member_display",
"domain_info",
"source",
)
)
return permissions
def initial_invitations_search(self, portfolio):
"""Perform initial invitations search and get related DomainInvitation data based on the email."""
# Get DomainInvitation query for matching email and for the portfolio
domain_invitations = DomainInvitation.objects.filter(
email=OuterRef("email"), # Check if email matches the OuterRef("email")
domain__domain_info__portfolio=portfolio, # Check if the domain's portfolio matches the given portfolio
).annotate(domain_info=Concat(F("domain__id"), Value(":"), F("domain__name"), output_field=CharField()))
# PortfolioInvitation query
invitations = PortfolioInvitation.objects.filter(portfolio=portfolio)
invitations = invitations.annotate(
first_name=Value(None, output_field=CharField()),
last_name=Value(None, output_field=CharField()),
email_display=F("email"),
last_active=Value("Invited", output_field=TextField()),
additional_permissions_display=F("additional_permissions"),
member_display=F("email"),
# Use ArrayRemove to return an empty list when no domain invitations are found
domain_info=ArrayRemove(
ArrayAgg(
Subquery(domain_invitations.values("domain_info")),
distinct=True,
)
),
source=Value("invitation", output_field=CharField()),
).values(
"id", "id",
"first_name", "first_name",
"last_name", "last_name",
@ -153,7 +85,6 @@ class PortfolioMembersJson(PortfolioMembersPermission, View):
"domain_info", "domain_info",
"source", "source",
) )
return invitations
def apply_search_term(self, queryset, request): def apply_search_term(self, queryset, request):
"""Apply search term to the queryset.""" """Apply search term to the queryset."""
@ -179,7 +110,7 @@ class PortfolioMembersJson(PortfolioMembersPermission, View):
queryset = queryset.order_by(sort_by) queryset = queryset.order_by(sort_by)
return queryset return queryset
def serialize_members(self, request, portfolio, item, user): def serialize_members(self, portfolio, item, user):
# Check if the user can edit other users # Check if the user can edit other users
user_can_edit_other_users = any( user_can_edit_other_users = any(
user.has_perm(perm) for perm in ["registrar.full_access_permission", "registrar.change_user"] user.has_perm(perm) for perm in ["registrar.full_access_permission", "registrar.change_user"]
@ -213,9 +144,3 @@ class PortfolioMembersJson(PortfolioMembersPermission, View):
"svg_icon": ("visibility" if view_only else "settings"), "svg_icon": ("visibility" if view_only else "settings"),
} }
return member_json return member_json
# Custom Func to use array_remove to remove null values
class ArrayRemove(Func):
function = "array_remove"
template = "%(function)s(%(expressions)s, NULL)"