Unit tests

This commit is contained in:
Rachid Mrad 2024-06-28 16:34:39 -04:00
parent 59e912b69a
commit b600a26eb8
No known key found for this signature in database
5 changed files with 534 additions and 704 deletions

View file

@ -15,12 +15,10 @@ from django.db.models import QuerySet, Value, CharField, Count, Q, F
from django.db.models import Case, When, DateField
from django.db.models import ManyToManyField
from django.utils import timezone
from django.core.paginator import Paginator
from django.db.models.functions import Concat, Coalesce
from django.contrib.postgres.aggregates import StringAgg
from registrar.models.utility.generic_helper import convert_queryset_to_dict
from registrar.templatetags.custom_filters import get_region
from registrar.utility.enums import DefaultEmail
from registrar.utility.constants import BranchChoices
@ -34,14 +32,17 @@ def write_header(writer, columns):
"""
writer.writerow(columns)
def get_default_start_date():
# Default to a date that's prior to our first deployment
"""Default to a date that's prior to our first deployment"""
return timezone.make_aware(datetime(2023, 11, 1))
def get_default_end_date():
# Default to now()
"""Default to now()"""
return timezone.now()
def format_start_date(start_date):
return timezone.make_aware(datetime.strptime(start_date, "%Y-%m-%d")) if start_date else get_default_start_date()
@ -49,9 +50,11 @@ def format_start_date(start_date):
def format_end_date(end_date):
return timezone.make_aware(datetime.strptime(end_date, "%Y-%m-%d")) if end_date else get_default_end_date()
class BaseExport(ABC):
"""
A generic class for exporting data which returns a csv file for the given model.
Base class in an inheritance tree of 3.
"""
@classmethod
@ -69,14 +72,14 @@ class BaseExport(ABC):
Returns the columns for CSV export. Override in subclasses as needed.
"""
return []
@classmethod
def get_sort_fields(cls):
"""
Returns the sort fields for the CSV export. Override in subclasses as needed.
"""
return []
@classmethod
def get_additional_args(cls):
"""
@ -84,63 +87,63 @@ class BaseExport(ABC):
Override in subclasses to provide specific arguments.
"""
return {}
@classmethod
def get_select_related(cls):
"""
Get a list of tables to pass to select_related when building queryset.
"""
return []
@classmethod
def get_prefetch_related(cls):
"""
Get a list of tables to pass to prefetch_related when building queryset.
"""
return []
@classmethod
def get_exclusions(cls):
"""
Get a Q object of exclusion conditions to use when building queryset.
"""
return Q()
@classmethod
def get_filter_conditions(cls, start_date=None, end_date=None):
"""
Get a Q object of filter conditions to filter when building queryset.
"""
return Q()
@classmethod
def get_computed_fields(cls):
"""
Get a dict of computed fields.
"""
return {}
@classmethod
def get_annotations_for_sort(cls):
"""
Get a dict of annotations to make available for order_by clause.
"""
return {}
@classmethod
def get_related_table_fields(cls):
"""
Get a list of fields from related tables.
"""
return []
@classmethod
def update_queryset(cls, queryset, **kwargs):
"""
Returns an updated queryset. Override in subclass to update queryset.
"""
return queryset
@classmethod
def write_csv_before(cls, csv_writer, start_date=None, end_date=None):
"""
@ -187,7 +190,7 @@ class BaseExport(ABC):
queryset = initial_queryset.annotate(**computed_fields).values(*model_fields, *related_table_fields)
return cls.update_queryset(queryset, **kwargs)
@classmethod
def export_data_to_csv(cls, csv_file, start_date=None, end_date=None):
"""
@ -207,7 +210,8 @@ class BaseExport(ABC):
related_table_fields = cls.get_related_table_fields()
model_queryset = (
cls.model().objects.select_related(*select_related)
cls.model()
.objects.select_related(*select_related)
.prefetch_related(*prefetch_related)
.filter(filter_conditions)
.exclude(exclusions)
@ -217,7 +221,9 @@ class BaseExport(ABC):
)
# Convert the queryset to a dictionary (including annotated fields)
annotated_queryset = cls.annotate_and_retrieve_fields(model_queryset, computed_fields, related_table_fields, **kwargs)
annotated_queryset = cls.annotate_and_retrieve_fields(
model_queryset, computed_fields, related_table_fields, **kwargs
)
models_dict = convert_queryset_to_dict(annotated_queryset, is_model=False)
# Write to csv file before the write_csv
@ -259,10 +265,11 @@ class BaseExport(ABC):
"""
pass
class DomainExport(BaseExport):
"""
A collection of functions which return csv files regarding the Domain model.
Second class in an inheritance tree of 3.
"""
@classmethod
@ -279,9 +286,9 @@ class DomainExport(BaseExport):
based on public_contacts, domain_invitations and user_domain_roles
passed through kwargs.
"""
public_contacts = kwargs.get('public_contacts', {})
domain_invitations = kwargs.get('domain_invitations', {})
user_domain_roles = kwargs.get('user_domain_roles', {})
public_contacts = kwargs.get("public_contacts", {})
domain_invitations = kwargs.get("domain_invitations", {})
user_domain_roles = kwargs.get("user_domain_roles", {})
annotated_domain_infos = []
@ -296,16 +303,18 @@ class DomainExport(BaseExport):
# Annotate with security_contact from public_contacts
for domain_info in queryset:
domain_info['security_contact_email'] = public_contacts.get(domain_info.get('domain__security_contact_registry_id'))
domain_info['invited_users'] = ', '.join(invited_users_dict.get(domain_info.get('domain__name'), []))
domain_info['managers'] = ', '.join(managers_dict.get(domain_info.get('domain__name'), []))
domain_info["security_contact_email"] = public_contacts.get(
domain_info.get("domain__security_contact_registry_id")
)
domain_info["invited_users"] = ", ".join(invited_users_dict.get(domain_info.get("domain__name"), []))
domain_info["managers"] = ", ".join(managers_dict.get(domain_info.get("domain__name"), []))
annotated_domain_infos.append(domain_info)
if annotated_domain_infos:
return annotated_domain_infos
return queryset
# ============================================================= #
# Helper functions for django ORM queries. #
# We are using these rather than pure python for speed reasons. #
@ -316,15 +325,15 @@ class DomainExport(BaseExport):
"""
Fetch all PublicContact entries and return a mapping of registry_id to email.
"""
public_contacts = PublicContact.objects.values_list('registry_id', 'email')
public_contacts = PublicContact.objects.values_list("registry_id", "email")
return {registry_id: email for registry_id, email in public_contacts}
@classmethod
def get_all_domain_invitations(cls):
"""
Fetch all DomainInvitation entries and return a mapping of domain to email.
"""
domain_invitations = DomainInvitation.objects.filter(status="invited").values_list('domain__name', 'email')
domain_invitations = DomainInvitation.objects.filter(status="invited").values_list("domain__name", "email")
return list(domain_invitations)
@classmethod
@ -332,7 +341,7 @@ class DomainExport(BaseExport):
"""
Fetch all UserDomainRole entries and return a mapping of domain to user__email.
"""
user_domain_roles = UserDomainRole.objects.select_related('user').values_list('domain__name', 'user__email')
user_domain_roles = UserDomainRole.objects.select_related("user").values_list("domain__name", "user__email")
return list(user_domain_roles)
@classmethod
@ -360,19 +369,9 @@ class DomainExport(BaseExport):
if domain_federal_type and domain_org_type == DomainRequest.OrgChoicesElectionOffice.FEDERAL:
domain_type = f"{human_readable_domain_org_type} - {human_readable_domain_federal_type}"
if model.get("domain__name") == "18f.gov":
print(f'domain_type {domain_type}')
print(f'federal_agency {model.get("federal_agency")}')
print(f'city {model.get("city")}')
print(f'agency {model.get("agency")}')
print(f'federal_agency__agency {model.get("federal_agency__agency")}')
# create a dictionary of fields which can be included in output.
# "extra_fields" are precomputed fields (generated in the DB or parsed).
FIELDS = {
"Domain name": model.get("domain__name"),
"Status": human_readable_status,
"First ready on": first_ready_on,
@ -434,6 +433,10 @@ class DomainExport(BaseExport):
class DomainDataType(DomainExport):
"""
Shows security contacts, domain managers, ao
Inherits from BaseExport -> DomainExport
"""
@classmethod
def get_columns(cls):
@ -456,7 +459,7 @@ class DomainDataType(DomainExport):
"Domain managers",
"Invited domain managers",
]
@classmethod
def get_sort_fields(cls):
"""
@ -488,29 +491,24 @@ class DomainDataType(DomainExport):
user_domain_roles = cls.get_all_user_domain_roles()
return {
'public_contacts': public_contacts,
'domain_invitations': domain_invitations,
'user_domain_roles': user_domain_roles,
"public_contacts": public_contacts,
"domain_invitations": domain_invitations,
"user_domain_roles": user_domain_roles,
}
@classmethod
def get_select_related(cls):
"""
Get a list of tables to pass to select_related when building queryset.
"""
return [
"domain",
"authorizing_official"
]
return ["domain", "authorizing_official"]
@classmethod
def get_prefetch_related(cls):
"""
Get a list of tables to pass to prefetch_related when building queryset.
"""
return [
"permissions"
]
return ["permissions"]
@classmethod
def get_computed_fields(cls, delimiter=", "):
@ -525,7 +523,7 @@ class DomainDataType(DomainExport):
output_field=CharField(),
),
}
@classmethod
def get_related_table_fields(cls):
"""
@ -542,9 +540,13 @@ class DomainDataType(DomainExport):
"authorizing_official__email",
"federal_agency__agency",
]
class DomainDataFull(DomainExport):
"""
Shows security contacts, filtered by state
Inherits from BaseExport -> DomainExport
"""
@classmethod
def get_columns(cls):
@ -560,7 +562,7 @@ class DomainDataFull(DomainExport):
"State",
"Security contact email",
]
@classmethod
def get_sort_fields(cls):
"""
@ -586,17 +588,15 @@ class DomainDataFull(DomainExport):
public_contacts = cls.get_all_security_emails()
return {
'public_contacts': public_contacts,
"public_contacts": public_contacts,
}
@classmethod
def get_select_related(cls):
"""
Get a list of tables to pass to select_related when building queryset.
"""
return [
"domain"
]
return ["domain"]
@classmethod
def get_filter_conditions(cls, start_date=None, end_date=None):
@ -604,13 +604,13 @@ class DomainDataFull(DomainExport):
Get a Q object of filter conditions to filter when building queryset.
"""
return Q(
domain__state__in = [
domain__state__in=[
Domain.State.READY,
Domain.State.DNS_NEEDED,
Domain.State.ON_HOLD,
],
)
@classmethod
def get_computed_fields(cls, delimiter=", "):
"""
@ -624,7 +624,7 @@ class DomainDataFull(DomainExport):
output_field=CharField(),
),
}
@classmethod
def get_related_table_fields(cls):
"""
@ -638,6 +638,10 @@ class DomainDataFull(DomainExport):
class DomainDataFederal(DomainExport):
"""
Shows security contacts, filtered by state and org type
Inherits from BaseExport -> DomainExport
"""
@classmethod
def get_columns(cls):
@ -653,7 +657,7 @@ class DomainDataFederal(DomainExport):
"State",
"Security contact email",
]
@classmethod
def get_sort_fields(cls):
"""
@ -679,17 +683,15 @@ class DomainDataFederal(DomainExport):
public_contacts = cls.get_all_security_emails()
return {
'public_contacts': public_contacts,
"public_contacts": public_contacts,
}
@classmethod
def get_select_related(cls):
"""
Get a list of tables to pass to select_related when building queryset.
"""
return [
"domain"
]
return ["domain"]
@classmethod
def get_filter_conditions(cls, start_date=None, end_date=None):
@ -702,9 +704,9 @@ class DomainDataFederal(DomainExport):
Domain.State.READY,
Domain.State.DNS_NEEDED,
Domain.State.ON_HOLD,
]
],
)
@classmethod
def get_computed_fields(cls, delimiter=", "):
"""
@ -718,7 +720,7 @@ class DomainDataFederal(DomainExport):
output_field=CharField(),
),
}
@classmethod
def get_related_table_fields(cls):
"""
@ -732,6 +734,10 @@ class DomainDataFederal(DomainExport):
class DomainGrowth(DomainExport):
"""
Shows ready and deleted domains within a date range, sorted
Inherits from BaseExport -> DomainExport
"""
@classmethod
def get_columns(cls):
@ -751,7 +757,7 @@ class DomainGrowth(DomainExport):
"First ready",
"Deleted",
]
@classmethod
def get_annotations_for_sort(cls, delimiter=", "):
"""
@ -760,10 +766,10 @@ class DomainGrowth(DomainExport):
today = timezone.now().date()
return {
"custom_sort": Case(
When(domain__state=Domain.State.READY, then='domain__first_ready'),
When(domain__state=Domain.State.DELETED, then='domain__deleted'),
When(domain__state=Domain.State.READY, then="domain__first_ready"),
When(domain__state=Domain.State.DELETED, then="domain__deleted"),
default=Value(today), # Default value if no conditions match
output_field=DateField()
output_field=DateField(),
)
}
@ -773,19 +779,17 @@ class DomainGrowth(DomainExport):
Returns the sort fields.
"""
return [
'-domain__state',
'custom_sort',
'domain__name',
"-domain__state",
"custom_sort",
"domain__name",
]
@classmethod
def get_select_related(cls):
"""
Get a list of tables to pass to select_related when building queryset.
"""
return [
"domain"
]
return ["domain"]
@classmethod
def get_filter_conditions(cls, start_date=None, end_date=None):
@ -795,15 +799,13 @@ class DomainGrowth(DomainExport):
filter_ready = Q(
domain__state__in=[Domain.State.READY],
domain__first_ready__gte=start_date,
domain__first_ready__lte=end_date
domain__first_ready__lte=end_date,
)
filter_deleted = Q(
domain__state__in=[Domain.State.DELETED],
domain__deleted__gte=start_date,
domain__deleted__lte=end_date
domain__state__in=[Domain.State.DELETED], domain__deleted__gte=start_date, domain__deleted__lte=end_date
)
return filter_ready | filter_deleted
@classmethod
def get_related_table_fields(cls):
"""
@ -821,6 +823,10 @@ class DomainGrowth(DomainExport):
class DomainManaged(DomainExport):
"""
Shows managed domains by an end date, sorted
Inherits from BaseExport -> DomainExport
"""
@classmethod
def get_columns(cls):
@ -833,34 +839,30 @@ class DomainManaged(DomainExport):
"Domain managers",
"Invited domain managers",
]
@classmethod
def get_sort_fields(cls):
"""
Returns the sort fields.
"""
return [
'domain__name',
"domain__name",
]
@classmethod
def get_select_related(cls):
"""
Get a list of tables to pass to select_related when building queryset.
"""
return [
"domain"
]
return ["domain"]
@classmethod
def get_prefetch_related(cls):
"""
Get a list of tables to pass to prefetch_related when building queryset.
"""
return [
"permissions"
]
return ["permissions"]
@classmethod
def get_filter_conditions(cls, start_date=None, end_date=None):
"""
@ -871,7 +873,6 @@ class DomainManaged(DomainExport):
domain__permissions__isnull=False,
domain__first_ready__lte=end_date_formatted,
)
@classmethod
def get_additional_args(cls):
@ -889,10 +890,10 @@ class DomainManaged(DomainExport):
user_domain_roles = cls.get_all_user_domain_roles()
return {
'domain_invitations': domain_invitations,
'user_domain_roles': user_domain_roles,
"domain_invitations": domain_invitations,
"user_domain_roles": user_domain_roles,
}
@classmethod
def get_related_table_fields(cls):
"""
@ -901,7 +902,7 @@ class DomainManaged(DomainExport):
return [
"domain__name",
]
@classmethod
def write_csv_before(cls, csv_writer, start_date=None, end_date=None):
"""
@ -959,6 +960,10 @@ class DomainManaged(DomainExport):
class DomainUnmanaged(DomainExport):
"""
Shows unmanaged domains by an end date, sorted
Inherits from BaseExport -> DomainExport
"""
@classmethod
def get_columns(cls):
@ -969,34 +974,30 @@ class DomainUnmanaged(DomainExport):
"Domain name",
"Domain type",
]
@classmethod
def get_sort_fields(cls):
"""
Returns the sort fields.
"""
return [
'domain__name',
"domain__name",
]
@classmethod
def get_select_related(cls):
"""
Get a list of tables to pass to select_related when building queryset.
"""
return [
"domain"
]
return ["domain"]
@classmethod
def get_prefetch_related(cls):
"""
Get a list of tables to pass to prefetch_related when building queryset.
"""
return [
"permissions"
]
return ["permissions"]
@classmethod
def get_filter_conditions(cls, start_date=None, end_date=None):
"""
@ -1007,7 +1008,7 @@ class DomainUnmanaged(DomainExport):
domain__permissions__isnull=True,
domain__first_ready__lte=end_date_formatted,
)
@classmethod
def get_related_table_fields(cls):
"""
@ -1016,12 +1017,12 @@ class DomainUnmanaged(DomainExport):
return [
"domain__name",
]
@classmethod
def write_csv_before(cls, csv_writer, start_date=None, end_date=None):
"""
Write to csv file before the write_csv method.
"""
start_date_formatted = format_start_date(start_date)
end_date_formatted = format_end_date(end_date)
@ -1075,6 +1076,10 @@ class DomainUnmanaged(DomainExport):
class DomainRequestExport(BaseExport):
"""
A collection of functions which return csv files regarding the DomainRequest model.
Second class in an inheritance tree of 3.
"""
@classmethod
def model(cls):
@ -1197,6 +1202,10 @@ class DomainRequestExport(BaseExport):
class DomainRequestGrowth(DomainRequestExport):
"""
Shows submitted requests within a date range, sorted
Inherits from BaseExport -> DomainRequestExport
"""
@classmethod
def get_columns(cls):
@ -1218,7 +1227,7 @@ class DomainRequestGrowth(DomainRequestExport):
return [
"requested_domain__name",
]
@classmethod
def get_filter_conditions(cls, start_date=None, end_date=None):
"""
@ -1238,12 +1247,14 @@ class DomainRequestGrowth(DomainRequestExport):
"""
Get a list of fields from related tables.
"""
return [
"requested_domain__name"
]
return ["requested_domain__name"]
class DomainRequestDataFull(DomainRequestExport):
"""
Shows all but STARTED requests
Inherits from BaseExport -> DomainRequestExport
"""
@classmethod
def get_columns(cls):
@ -1285,34 +1296,22 @@ class DomainRequestDataFull(DomainRequestExport):
"""
Get a list of tables to pass to select_related when building queryset.
"""
return [
"creator",
"authorizing_official",
"federal_agency",
"investigator",
"requested_domain"
]
return ["creator", "authorizing_official", "federal_agency", "investigator", "requested_domain"]
@classmethod
def get_prefetch_related(cls):
"""
Get a list of tables to pass to prefetch_related when building queryset.
"""
return [
"current_websites",
"other_contacts",
"alternative_domains"
]
return ["current_websites", "other_contacts", "alternative_domains"]
@classmethod
def get_exclusions(cls):
"""
Get a Q object of exclusion conditions to use when building queryset.
"""
return Q(
status__in=[DomainRequest.DomainRequestStatus.STARTED]
)
return Q(status__in=[DomainRequest.DomainRequestStatus.STARTED])
@classmethod
def get_sort_fields(cls):
"""
@ -1322,7 +1321,7 @@ class DomainRequestDataFull(DomainRequestExport):
"status",
"requested_domain__name",
]
@classmethod
def get_computed_fields(cls, delimiter=", "):
"""
@ -1346,7 +1345,7 @@ class DomainRequestDataFull(DomainRequestExport):
distinct=True,
),
}
@classmethod
def get_related_table_fields(cls):
"""
@ -1364,7 +1363,6 @@ class DomainRequestDataFull(DomainRequestExport):
"creator__email",
"investigator__email",
]
# ============================================================= #
# Helper functions for django ORM queries. #