diff --git a/src/registrar/management/commands/load_organization_data.py b/src/registrar/management/commands/load_organization_data.py index 937286a07..2c07bbb0a 100644 --- a/src/registrar/management/commands/load_organization_data.py +++ b/src/registrar/management/commands/load_organization_data.py @@ -2,8 +2,6 @@ import argparse import logging -import copy -import time from django.core.management import BaseCommand from registrar.management.commands.utility.extra_transition_domain_helper import OrganizationDataLoader @@ -12,7 +10,7 @@ from registrar.management.commands.utility.transition_domain_arguments import Tr from registrar.models import TransitionDomain from registrar.models.domain import Domain from registrar.models.domain_information import DomainInformation -from ...utility.email import send_templated_email, EmailSendingError +from django.core.paginator import Paginator from typing import List logger = logging.getLogger(__name__) @@ -120,6 +118,9 @@ class Command(BaseCommand): def update_domain_information(self, desired_objects: List[TransitionDomain], debug): di_to_update = [] di_failed_to_update = [] + # These are fields that we COULD update, but fields we choose not to update. + # For instance, if the user already entered data - lets not corrupt that. + di_skipped = [] # Grab each TransitionDomain we want to change. Store it. # Fetches all TransitionDomains in one query. @@ -137,9 +138,27 @@ class Command(BaseCommand): name__in=[td.domain_name for td in transition_domains] ) + + # Start with all DomainInformation objects + filtered_domain_informations = DomainInformation.objects.all() + + changed_fields = [ + "address_line1", + "city", + "state_territory", + "zipcode", + ] + + # Chain filter calls for each field. This checks to see if the end user + # made a change to ANY field in changed_fields. If they did, don't update their information. + # We assume that if they made a change, we don't want to interfere with that. + for field in changed_fields: + # For each changed_field, check if no data exists + filtered_domain_informations = filtered_domain_informations.filter(**{f'{field}__isnull': True}) + # Then, use each domain object to map domain <--> DomainInformation # Fetches all DomainInformations in one query. - domain_informations = DomainInformation.objects.filter( + domain_informations = filtered_domain_informations.filter( domain__in=domains ) @@ -149,32 +168,52 @@ class Command(BaseCommand): for item in transition_domains: try: + should_update = True # Grab the current Domain. This ensures we are pointing towards the right place. current_domain = domains_dict[item.domain_name] # Based on the current domain, grab the right DomainInformation object. - current_domain_information = domain_informations_dict[current_domain.name] + if current_domain.name in domain_informations_dict: + current_domain_information = domain_informations_dict[current_domain.name] + current_domain_information.address_line1 = item.address_line + current_domain_information.city = item.city + current_domain_information.state_territory = item.state_territory + current_domain_information.zipcode = item.zipcode + + if debug: + logger.info(f"Updating {current_domain.name}...") - current_domain_information.address_line1 = item.address_line - current_domain_information.city = item.city - current_domain_information.state_territory = item.state_territory - current_domain_information.zipcode = item.zipcode - - if debug: - logger.info(f"Updating {current_domain.name}...") + else: + logger.info( + f"{TerminalColors.YELLOW}" + f"Domain {current_domain.name} was updated by a user. Cannot update." + f"{TerminalColors.ENDC}" + ) + should_update = False except Exception as err: logger.error(err) - di_failed_to_update.append(current_domain_information) + di_failed_to_update.append(item) else: - di_to_update.append(current_domain_information) + if should_update: + di_to_update.append(current_domain_information) + else: + # TODO either update to name for all, + # or have this filter to the right field + di_skipped.append(item) if len(di_failed_to_update) > 0: logger.error( + f"{TerminalColors.FAIL}" "Failed to update. An exception was encountered " f"on the following TransitionDomains: {[item for item in di_failed_to_update]}" + f"{TerminalColors.ENDC}" ) raise Exception("Failed to update DomainInformations") + + skipped_count = len(di_skipped) + if skipped_count > 0: + logger.info(f"Skipped updating {skipped_count} fields. User-supplied data exists") if not debug: logger.info( @@ -198,7 +237,13 @@ class Command(BaseCommand): "zipcode", ] - DomainInformation.objects.bulk_update(di_to_update, changed_fields) + batch_size = 1000 + # Create a Paginator object. Bulk_update on the full dataset + # is too memory intensive for our current app config, so we can chunk this data instead. + paginator = Paginator(di_to_update, batch_size) + for page_num in paginator.page_range: + page = paginator.page(page_num) + DomainInformation.objects.bulk_update(page.object_list, changed_fields) if not debug: logger.info( diff --git a/src/registrar/management/commands/utility/extra_transition_domain_helper.py b/src/registrar/management/commands/utility/extra_transition_domain_helper.py index 96cc550b3..be84e7681 100644 --- a/src/registrar/management/commands/utility/extra_transition_domain_helper.py +++ b/src/registrar/management/commands/utility/extra_transition_domain_helper.py @@ -10,7 +10,7 @@ import logging import os import sys from typing import Dict - +from django.core.paginator import Paginator from registrar.models.transition_domain import TransitionDomain from .epp_data_containers import ( @@ -850,7 +850,13 @@ class OrganizationDataLoader: "zipcode", ] - TransitionDomain.objects.bulk_update(update_list, changed_fields) + batch_size = 1000 + # Create a Paginator object. Bulk_update on the full dataset + # is too memory intensive for our current app config, so we can chunk this data instead. + paginator = Paginator(update_list, batch_size) + for page_num in paginator.page_range: + page = paginator.page(page_num) + TransitionDomain.objects.bulk_update(page.object_list, changed_fields) if not self.debug: logger.info(