diff --git a/src/registrar/management/commands/create_federal_portfolio.py b/src/registrar/management/commands/create_federal_portfolio.py index 2b4f64d37..9fb8fee02 100644 --- a/src/registrar/management/commands/create_federal_portfolio.py +++ b/src/registrar/management/commands/create_federal_portfolio.py @@ -37,12 +37,14 @@ class Command(BaseCommand): logger.info(f"{TerminalColors.BOLD}{no_changes_message}{TerminalColors.ENDC}") def has_changes(self) -> bool: - num_changes = [len(self.create), len(self.update), len(self.skip), len(self.fail)] - return any([num_change > 0 for num_change in num_changes]) + changes = [self.create, self.update, self.skip, self.fail] + return any([change for change in changes if change]) def bulk_create(self): try: - ScriptDataHelper.bulk_create_fields(self.model_class, self.create, quiet=True) + res = ScriptDataHelper.bulk_create_fields(self.model_class, self.create, return_created=True, quiet=True) + self.create = res + return res except Exception as err: # In this case, just swap the fail and add lists self.fail = self.create.copy() @@ -51,7 +53,9 @@ class Command(BaseCommand): def bulk_update(self, fields_to_update): try: - ScriptDataHelper.bulk_update_fields(self.model_class, self.update, fields_to_update, quiet=True) + res = ScriptDataHelper.bulk_update_fields(self.model_class, self.update, fields_to_update, quiet=True) + self.update = res + return res except Exception as err: # In this case, just swap the fail and update lists self.fail = self.update.copy() @@ -167,12 +171,11 @@ class Command(BaseCommand): organization_name__in=agencies.values_list("agency", flat=True), organization_name__isnull=False ) existing_portfolios_set = {normalize_string(p.organization_name): p for p in existing_portfolios} - agencies_set = {normalize_string(agency.agency): agency for agency in agencies} - for federal_agency in agencies_set.values(): + agencies_dict = {normalize_string(agency.agency): agency for agency in agencies} + for federal_agency in agencies_dict.values(): portfolio_name = normalize_string(federal_agency.agency, lowercase=False) portfolio = existing_portfolios_set.get(portfolio_name, None) - new_portfolio = portfolio is None - if new_portfolio: + if portfolio is None: portfolio = Portfolio( organization_name=portfolio_name, federal_agency=federal_agency, @@ -183,36 +186,30 @@ class Command(BaseCommand): ) self.portfolio_changes.create.append(portfolio) logger.info(f"{TerminalColors.OKGREEN}Created portfolio '{portfolio}'.{TerminalColors.ENDC}") - - if skip_existing_portfolios and not new_portfolio: + elif skip_existing_portfolios: message = f"Portfolio '{portfolio}' already exists. Skipped." logger.info(f"{TerminalColors.YELLOW}{message}{TerminalColors.ENDC}") - if portfolio: - self.portfolio_changes.skip.append(portfolio) + self.portfolio_changes.skip.append(portfolio) # Create portfolios - self.portfolio_changes.bulk_create() + portfolios_to_use = self.portfolio_changes.bulk_create() # After create, get the list of all portfolios to use portfolios_to_use = set(self.portfolio_changes.create) if not skip_existing_portfolios: portfolios_to_use.update(set(existing_portfolios)) + + portfolios_to_use_dict = {normalize_string(p.organization_name): p for p in portfolios_to_use} # == Handle suborganizations == # - for portfolio in portfolios_to_use: - created_suborgs = [] - org_name = normalize_string(portfolio.organization_name) - federal_agency = agencies_set.get(org_name) - if portfolio: - created_suborgs = self.create_suborganizations(portfolio, federal_agency) - Suborganization.objects.bulk_create(created_suborgs) - self.suborganization_changes.create.extend(created_suborgs) + created_suborgs = self.create_suborganizations(portfolios_to_use_dict, agencies_dict) + if created_suborgs: + self.suborganization_changes.create.extend(created_suborgs.values()) + self.suborganization_changes.bulk_create() # == Handle domains, requests, and managers == # - for portfolio in portfolios_to_use: - org_name = normalize_string(portfolio.organization_name) - federal_agency = agencies_set.get(org_name) - + for portfolio_org_name, portfolio in portfolios_to_use_dict.items(): + federal_agency = agencies_dict.get(portfolio_org_name) if parse_domains: self.handle_portfolio_domains(portfolio, federal_agency) @@ -318,85 +315,66 @@ class Command(BaseCommand): display_as_str=True, ) - def create_suborganizations(self, portfolio, federal_agency): + def create_suborganizations(self, portfolio_dict, agency_dict): """Create Suborganizations tied to the given portfolio based on DomainInformation objects""" - base_filter = Q( - organization_name__isnull=False, - ) & ~Q(organization_name__iexact=F("portfolio__organization_name")) + created_suborgs = {} - domains = federal_agency.domaininformation_set.filter(base_filter) - requests = federal_agency.domainrequest_set.filter(base_filter) - existing_orgs = Suborganization.objects.all() + portfolios = portfolio_dict.values() + agencies = agency_dict.values() + existing_suborgs = Suborganization.objects.filter(portfolio__in=portfolios) + suborg_dict = {normalize_string(org.name): org for org in existing_suborgs} + + domains = DomainInformation.objects.filter( + # Org name must not be null, and must not be the portfolio name + Q( + organization_name__isnull=False, + ) & ~Q(organization_name__iexact=F("portfolio__organization_name")), + # Only get relevant data to the agency/portfolio we are targeting + Q(federal_agency__in=agencies) | Q(portfolio__in=portfolios), + ) + requests = DomainRequest.objects.filter( + # Org name must not be null, and must not be the portfolio name + Q( + organization_name__isnull=False, + ) & ~Q(organization_name__iexact=F("portfolio__organization_name")), + # Only get relevant data to the agency/portfolio we are targeting + Q(federal_agency__in=agencies) | Q(portfolio__in=portfolios), + ) # Normalize all suborg names so we don't add duplicate data unintentionally. - # Get all suborg names that we COULD add - org_names_normalized = {} - for domain in domains: - org_name = normalize_string(domain.organization_name) - if org_name not in org_names_normalized: - org_names_normalized[org_name] = domain.organization_name + for portfolio_name, portfolio in portfolio_dict.items(): + for domain in domains: + if normalize_string(domain.federal_agency.agency) != portfolio_name: + continue - # Get all suborg names that presently exist - existing_org_names_normalized = {} - for org in existing_orgs: - org_name = normalize_string(org.name) - if org_name not in existing_org_names_normalized: - existing_org_names_normalized[org_name] = org.name - - # Subtract existing names from ones we COULD add. - # We don't want to add existing names. - new_org_names = {} - for norm_name, name in org_names_normalized.items(): - if norm_name not in existing_org_names_normalized: - new_org_names[norm_name] = name - - # Add new suborgs assuming they aren't duplicates and don't already exist in the db. - created_suborgs = [] - for norm_name, name in new_org_names.items(): - norm_portfolio_name = normalize_string(portfolio.organization_name) - if norm_name != norm_portfolio_name: - suborg = Suborganization(name=name, portfolio=portfolio) - created_suborgs.append(suborg) + org_name = domain.organization_name + norm_org_name = normalize_string(domain.organization_name) + # If the suborg already exists or if we've already added it, don't add it again. + if norm_org_name not in suborg_dict and norm_org_name not in created_suborgs: + suborg = Suborganization(name=org_name, portfolio=portfolio) + created_suborgs[norm_org_name] = suborg # Add location information to suborgs. # This can vary per domain and request, so this is a seperate step. - # First: Filter domains and requests by those that have data - valid_domains = domains.filter( - city__isnull=False, - state_territory__isnull=False, - portfolio__isnull=False, - sub_organization__isnull=False, - ) - valid_requests = requests.filter( - ( - Q(city__isnull=False, state_territory__isnull=False) - | Q(suborganization_city__isnull=False, suborganization_state_territory__isnull=False) - ), - portfolio__isnull=False, - sub_organization__isnull=False, - ) - # Second: Group domains and requests by normalized organization name. - # This means that later down the line we can account for "duplicate" org names. + # First: Group domains and requests by normalized organization name. domains_dict = {} requests_dict = {} - for domain in valid_domains: - print(f"what is the org name? {domain.organization_name}") + for domain in domains: normalized_name = normalize_string(domain.organization_name) domains_dict.setdefault(normalized_name, []).append(domain) - for request in valid_requests: - print(f"what is the org name for requests? {request.organization_name}") + for request in requests: normalized_name = normalize_string(request.organization_name) requests_dict.setdefault(normalized_name, []).append(request) - # Fourth: Process each suborg to add city / state territory info - for suborg in created_suborgs: - self.set_suborganization_location(suborg, domains_dict, requests_dict) - + # Second: Process each suborg to add city / state territory info + for norm_name, suborg in created_suborgs.items(): + self.set_suborganization_location(norm_name, suborg, domains_dict, requests_dict) + return created_suborgs - def set_suborganization_location(self, suborg, domains_dict, requests_dict): + def set_suborganization_location(self, normalized_suborg_name, suborg, domains_dict, requests_dict): """Updates a single suborganization's location data if valid. Args: @@ -407,18 +385,19 @@ class Command(BaseCommand): Priority matches parent method. Updates are skipped if location data conflicts between multiple records of the same type. """ - normalized_suborg_name = normalize_string(suborg.name) domains = domains_dict.get(normalized_suborg_name, []) requests = requests_dict.get(normalized_suborg_name, []) - print(f"domains: {domains}") - print(f"requests: {requests}") # Try to get matching domain info domain = None if domains: reference = domains[0] use_location_for_domain = all( - d.city == reference.city and d.state_territory == reference.state_territory for d in domains + d.city + and d.state_territory + and d.city == reference.city + and d.state_territory == reference.state_territory + for d in domains ) if use_location_for_domain: domain = reference diff --git a/src/registrar/management/commands/utility/terminal_helper.py b/src/registrar/management/commands/utility/terminal_helper.py index 85af1f9e5..e8ec5e41f 100644 --- a/src/registrar/management/commands/utility/terminal_helper.py +++ b/src/registrar/management/commands/utility/terminal_helper.py @@ -52,6 +52,8 @@ class ScriptDataHelper: Usage: ScriptDataHelper.bulk_update_fields(Domain, page.object_list, ["first_ready"]) + + Returns: A queryset of the updated objets """ if not quiet: logger.info(f"{TerminalColors.YELLOW} Bulk updating fields... {TerminalColors.ENDC}") @@ -63,7 +65,7 @@ class ScriptDataHelper: model_class.objects.bulk_update(page.object_list, fields_to_update) @staticmethod - def bulk_create_fields(model_class, update_list, batch_size=1000, quiet=False): + def bulk_create_fields(model_class, update_list, batch_size=1000, return_created=False, quiet=False): """ This function performs a bulk create operation on a specified Django model class in batches. It uses Django's Paginator to handle large datasets in a memory-efficient manner. @@ -80,13 +82,22 @@ class ScriptDataHelper: or large field values, you may need to decrease this value to prevent out-of-memory errors. Usage: ScriptDataHelper.bulk_add_fields(Domain, page.object_list) + + Returns: A queryset of the added objects """ if not quiet: logger.info(f"{TerminalColors.YELLOW} Bulk adding fields... {TerminalColors.ENDC}") + + created_objs = [] paginator = Paginator(update_list, batch_size) for page_num in paginator.page_range: page = paginator.page(page_num) - model_class.objects.bulk_create(page.object_list) + all_created = model_class.objects.bulk_create(page.object_list) + if return_created: + created_objs.extend([created.id for created in all_created]) + if return_created: + return model_class.objects.filter(id__in=created_objs) + return None class PopulateScriptTemplate(ABC):