Generalize code

This commit is contained in:
zandercymatics 2024-01-08 11:35:46 -07:00
parent e1384a7bb2
commit 69fc902fc4
No known key found for this signature in database
GPG key ID: FF4636ABEC9682B7
3 changed files with 144 additions and 87 deletions

View file

@ -2,9 +2,8 @@ import argparse
import logging import logging
from django.core.paginator import Paginator from django.core.paginator import Paginator
from typing import List from typing import List
from django.core.management import BaseCommand from django.core.management import BaseCommand
from registrar.management.commands.utility.terminal_helper import TerminalColors, TerminalHelper from registrar.management.commands.utility.terminal_helper import TerminalColors, TerminalHelper, ScriptDataHelper
from registrar.models import Domain from registrar.models import Domain
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -13,12 +12,6 @@ logger = logging.getLogger(__name__)
class Command(BaseCommand): class Command(BaseCommand):
help = "Loops through each valid Domain object and updates its first_created value" help = "Loops through each valid Domain object and updates its first_created value"
def __init__(self):
super().__init__()
self.to_update: List[Domain] = []
self.failed_to_update: List[Domain] = []
self.skipped: List[Domain] = []
def add_arguments(self, parser): def add_arguments(self, parser):
"""Adds command line arguments""" """Adds command line arguments"""
parser.add_argument("--debug", action=argparse.BooleanOptionalAction) parser.add_argument("--debug", action=argparse.BooleanOptionalAction)
@ -29,91 +22,36 @@ class Command(BaseCommand):
valid_states = [Domain.State.READY, Domain.State.ON_HOLD, Domain.State.DELETED] valid_states = [Domain.State.READY, Domain.State.ON_HOLD, Domain.State.DELETED]
domains = Domain.objects.filter(first_ready=None, state__in=valid_states) domains = Domain.objects.filter(first_ready=None, state__in=valid_states)
# Keep track of what we want to update, what failed, and what was skipped
to_update: List[Domain] = []
failed_to_update: List[Domain] = []
skipped: List[Domain] = []
# Code execution will stop here if the user prompts "N"
TerminalHelper.prompt_for_execution(
system_exit_on_terminate=True,
info_to_inspect=f"""
==Proposed Changes==
Number of Domain objects to change: {len(domains)}
""",
prompt_title="Do you wish to patch first_ready data?",
)
logger.info("Updating...")
for domain in domains: for domain in domains:
try: try:
self.update_first_ready_for_domain(domain, debug) update_first_ready_for_domain(domain, debug)
except Exception as err: except Exception as err:
self.failed_to_update.append(domain) failed_to_update.append(domain)
logger.error(err) logger.error(err)
logger.error( logger.error(
f"{TerminalColors.FAIL}" f"{TerminalColors.FAIL}"
f"Failed to update {domain}" f"Failed to update {domain}"
f"{TerminalColors.ENDC}" f"{TerminalColors.ENDC}"
) )
ScriptDataHelper.bulk_update_fields(Domain, to_update, ["first_ready"])
batch_size = 1000
# Create a Paginator object. Bulk_update on the full dataset
# is too memory intensive for our current app config, so we can chunk this data instead.
paginator = Paginator(self.to_update, batch_size)
for page_num in paginator.page_range:
page = paginator.page(page_num)
Domain.objects.bulk_update(page.object_list, ["first_ready"])
self.log_script_run_summary(debug) # Log what happened
TerminalHelper.log_script_run_summary(
def update_first_ready_for_domain(self, domain: Domain, debug: bool): to_update, failed_to_update, skipped, debug
"""Grabs the created_at field and associates it with the first_ready column.
Appends the result to the to_update list."""
created_at = domain.created_at
if created_at is not None:
domain.first_ready = domain.created_at
self.to_update.append(domain)
if debug:
logger.info(f"Updating {domain}")
else:
self.skipped.append(domain)
if debug:
logger.warning(f"Skipped updating {domain}")
def log_script_run_summary(self, debug: bool):
"""Prints success, failed, and skipped counts, as well as
all affected objects."""
update_success_count = len(self.to_update)
update_failed_count = len(self.failed_to_update)
update_skipped_count = len(self.skipped)
# Prepare debug messages
debug_messages = {
"success": (f"{TerminalColors.OKCYAN}Updated: {self.to_update}{TerminalColors.ENDC}\n"),
"skipped": (f"{TerminalColors.YELLOW}Skipped: {self.skipped}{TerminalColors.ENDC}\n"),
"failed": (f"{TerminalColors.FAIL}Failed: {self.failed_to_update}{TerminalColors.ENDC}\n"),
}
# Print out a list of everything that was changed, if we have any changes to log.
# Otherwise, don't print anything.
TerminalHelper.print_conditional(
debug,
f"{debug_messages.get('success') if update_success_count > 0 else ''}"
f"{debug_messages.get('skipped') if update_skipped_count > 0 else ''}"
f"{debug_messages.get('failed') if update_failed_count > 0 else ''}",
) )
if update_failed_count == 0 and update_skipped_count == 0:
logger.info(
f"""{TerminalColors.OKGREEN}
============= FINISHED ===============
Updated {update_success_count} Domain entries
{TerminalColors.ENDC}
"""
)
elif update_failed_count == 0:
logger.warning(
f"""{TerminalColors.YELLOW}
============= FINISHED ===============
Updated {update_success_count} Domain entries
----- SOME CREATED_AT DATA WAS NONE (NEEDS MANUAL PATCHING) -----
Skipped updating {update_skipped_count} Domain entries
{TerminalColors.ENDC}
"""
)
else:
logger.error(
f"""{TerminalColors.FAIL}
============= FINISHED ===============
Updated {update_success_count} Domain entries
----- UPDATE FAILED -----
Failed to update {update_failed_count} Domain entries,
Skipped updating {update_skipped_count} Domain entries
{TerminalColors.ENDC}
"""
)

View file

@ -1,8 +1,11 @@
from enum import Enum from enum import Enum
import logging import logging
import sys import sys
from django.core.paginator import Paginator
from typing import List from typing import List
from registrar import models
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -41,7 +44,92 @@ class TerminalColors:
BackgroundLightYellow = "\033[103m" BackgroundLightYellow = "\033[103m"
class ScriptDataHelper:
"""Helper method with utilities to speed up development of scripts that do DB operations"""
@staticmethod
def bulk_update_fields(model_class, update_list, batch_size=1000):
"""
This function performs a bulk update operation on a specified Django model class in batches.
It uses Django's Paginator to handle large datasets in a memory-efficient manner.
Parameters:
model_class: The Django model class that you want to perform the bulk update on.
This should be the actual class, not a string of the class name.
update_list: A list of model instances that you want to update. Each instance in the list
should already have the updated values set on the instance.
batch_size: The maximum number of model instances to update in a single database query.
Defaults to 1000. If you're dealing with models that have a large number of fields,
or large field values, you may need to decrease this value to prevent out-of-memory errors.
Usage:
bulk_update_fields(Domain, page.object_list, ["first_ready"])
"""
# Create a Paginator object. Bulk_update on the full dataset
# is too memory intensive for our current app config, so we can chunk this data instead.
paginator = Paginator(update_list, batch_size)
for page_num in paginator.page_range:
page = paginator.page(page_num)
model_class.objects.bulk_update(page.object_list, update_list)
class TerminalHelper: class TerminalHelper:
@staticmethod
def log_script_run_summary(to_update, failed_to_update, skipped, debug: bool):
"""Prints success, failed, and skipped counts, as well as
all affected objects."""
update_success_count = len(to_update)
update_failed_count = len(failed_to_update)
update_skipped_count = len(skipped)
# Prepare debug messages
debug_messages = {
"success": (f"{TerminalColors.OKCYAN}Updated: {to_update}{TerminalColors.ENDC}\n"),
"skipped": (f"{TerminalColors.YELLOW}Skipped: {skipped}{TerminalColors.ENDC}\n"),
"failed": (f"{TerminalColors.FAIL}Failed: {failed_to_update}{TerminalColors.ENDC}\n"),
}
# Print out a list of everything that was changed, if we have any changes to log.
# Otherwise, don't print anything.
TerminalHelper.print_conditional(
debug,
f"{debug_messages.get('success') if update_success_count > 0 else ''}"
f"{debug_messages.get('skipped') if update_skipped_count > 0 else ''}"
f"{debug_messages.get('failed') if update_failed_count > 0 else ''}",
)
if update_failed_count == 0 and update_skipped_count == 0:
logger.info(
f"""{TerminalColors.OKGREEN}
============= FINISHED ===============
Updated {update_success_count} entries
{TerminalColors.ENDC}
"""
)
elif update_failed_count == 0:
logger.warning(
f"""{TerminalColors.YELLOW}
============= FINISHED ===============
Updated {update_success_count} entries
----- SOME DATA WAS INVALID (NEEDS MANUAL PATCHING) -----
Skipped updating {update_skipped_count} entries
{TerminalColors.ENDC}
"""
)
else:
logger.error(
f"""{TerminalColors.FAIL}
============= FINISHED ===============
Updated {update_success_count} entries
----- UPDATE FAILED -----
Failed to update {update_failed_count} entries,
Skipped updating {update_skipped_count} entries
{TerminalColors.ENDC}
"""
)
@staticmethod @staticmethod
def query_yes_no(question: str, default="yes"): def query_yes_no(question: str, default="yes"):
"""Ask a yes/no question via raw_input() and return their answer. """Ask a yes/no question via raw_input() and return their answer.

View file

@ -22,6 +22,37 @@ from .common import MockEppLib, MockSESClient, less_console_noise
import boto3_mocking # type: ignore import boto3_mocking # type: ignore
class TestPopulateFirstReady(TestCase):
"""Tests for the populate_first_ready script"""
def setUp(self):
"""Creates a fake domain object"""
super().setUp()
Domain.objects.get_or_create(
name="fake.gov", state=Domain.State.READY, created_at=datetime.date(2024, 12, 31)
)
def tearDown(self):
"""Deletes all DB objects related to migrations"""
super().tearDown()
# Delete domains
Domain.objects.all().delete()
def run_populate_first_ready(self):
"""
This method executes the populate_first_ready command.
The 'call_command' function from Django's management framework is then used to
execute the populate_first_ready command with the specified arguments.
"""
with patch(
"registrar.management.commands.utility.terminal_helper.TerminalHelper.query_yes_no_exit", # noqa
return_value=True,
):
call_command("populate_first_ready")
class TestExtendExpirationDates(MockEppLib): class TestExtendExpirationDates(MockEppLib):
def setUp(self): def setUp(self):
"""Defines the file name of migration_json and the folder its contained in""" """Defines the file name of migration_json and the folder its contained in"""
@ -78,10 +109,10 @@ class TestExtendExpirationDates(MockEppLib):
def run_extend_expiration_dates(self): def run_extend_expiration_dates(self):
""" """
This method executes the transfer_transition_domains_to_domains command. This method executes the extend_expiration_dates command.
The 'call_command' function from Django's management framework is then used to The 'call_command' function from Django's management framework is then used to
execute the load_transition_domain command with the specified arguments. execute the extend_expiration_dates command with the specified arguments.
""" """
with patch( with patch(
"registrar.management.commands.utility.terminal_helper.TerminalHelper.query_yes_no_exit", # noqa "registrar.management.commands.utility.terminal_helper.TerminalHelper.query_yes_no_exit", # noqa