From aa0a398c06509b21e7d0cdfda2f8afeadfc4f0c9 Mon Sep 17 00:00:00 2001 From: David Kennedy Date: Fri, 14 Jun 2024 12:17:08 -0400 Subject: [PATCH] updated unit tests for the associated changes in this PR --- src/registrar/admin.py | 13 +- .../management/commands/export_tables.py | 31 ++-- .../management/commands/import_tables.py | 11 +- .../tests/test_management_scripts.py | 148 +++++++++--------- 4 files changed, 101 insertions(+), 102 deletions(-) diff --git a/src/registrar/admin.py b/src/registrar/admin.py index 16aba7ec0..46b468eb2 100644 --- a/src/registrar/admin.py +++ b/src/registrar/admin.py @@ -2370,14 +2370,15 @@ class PublicContactResource(resources.ModelResource): class Meta: model = models.PublicContact - use_bulk = True - batch_size = 1000 - force_init_instance = True + # may want to consider these bulk options in future, so left in as comments + # use_bulk = True + # batch_size = 1000 + # force_init_instance = True def __init__(self): """Sets global variables for code tidyness""" super().__init__() - self.skip_epp_save=False + self.skip_epp_save = False def import_data( self, @@ -2387,10 +2388,10 @@ class PublicContactResource(resources.ModelResource): use_transactions=None, collect_failed_rows=False, rollback_on_validation_errors=False, - **kwargs + **kwargs, ): """Override import_data to set self.skip_epp_save if in kwargs""" - self.skip_epp_save = kwargs.get('skip_epp_save', False) + self.skip_epp_save = kwargs.get("skip_epp_save", False) return super().import_data( dataset, dry_run, diff --git a/src/registrar/management/commands/export_tables.py b/src/registrar/management/commands/export_tables.py index 8c2cf7dc4..e5c940a40 100644 --- a/src/registrar/management/commands/export_tables.py +++ b/src/registrar/management/commands/export_tables.py @@ -41,20 +41,21 @@ class Command(BaseCommand): with pyzipper.AESZipFile(zip_filename, "w", compression=pyzipper.ZIP_DEFLATED) as zipf: for table_name in table_names: - # Define the directory and the pattern - tmp_dir = 'tmp' - pattern = os.path.join(tmp_dir, f'{table_name}_*.csv') - zip_file_path = os.path.join(tmp_dir, 'exported_files.zip') + # Define the tmp directory and the file pattern + tmp_dir = "tmp" + pattern = f"{table_name}_" + zip_file_path = os.path.join(tmp_dir, "exported_files.zip") # Find all files that match the pattern - for file_path in glob.glob(pattern): + matching_files = [file for file in os.listdir(tmp_dir) if file.startswith(pattern)] + for file_path in matching_files: # Add each file to the zip archive zipf.write(file_path, os.path.basename(file_path)) - logger.info(f'Added {file_path} to {zip_file_path}') - + logger.info(f"Added {file_path} to {zip_file_path}") + # Remove the file after adding to zip os.remove(file_path) - logger.info(f'Removed {file_path}') + logger.info(f"Removed {file_path}") def export_table(self, table_name): """Export a given table to a csv file in the tmp directory""" @@ -71,7 +72,8 @@ class Command(BaseCommand): # Calculate the number of files needed num_files = math.ceil(total_rows / rows_per_file) - logger.info(f'splitting {table_name} into {num_files} files') + + logger.info(f"splitting {table_name} into {num_files} files") # Split the dataset and export each chunk to a separate file for i in range(num_files): @@ -82,16 +84,15 @@ class Command(BaseCommand): chunk = tablib.Dataset(headers=dataset.headers) for row in dataset[start_row:end_row]: chunk.append(row) - #chunk = dataset[start_row:end_row] # Export the chunk to a new file - filename = f'tmp/{table_name}_{i + 1}.csv' - with open(filename, 'w') as f: - f.write(chunk.export('csv')) + filename = f"tmp/{table_name}_{i + 1}.csv" + with open(filename, "w") as f: + f.write(chunk.export("csv")) - logger.info(f'Successfully exported {table_name} into {num_files} files.') + logger.info(f"Successfully exported {table_name} into {num_files} files.") except AttributeError as ae: - logger.error(f"Resource class {resourcename} not found in registrar.admin: {ae}") + logger.error(f"Resource class {resourcename} not found in registrar.admin") except Exception as e: logger.error(f"Failed to export {table_name}: {e}") diff --git a/src/registrar/management/commands/import_tables.py b/src/registrar/management/commands/import_tables.py index d04d0dbb2..abe26830f 100644 --- a/src/registrar/management/commands/import_tables.py +++ b/src/registrar/management/commands/import_tables.py @@ -18,7 +18,7 @@ class Command(BaseCommand): def add_arguments(self, parser): """Add command line arguments.""" - parser.add_argument('--skipEppSave', default=True, action=argparse.BooleanOptionalAction) + parser.add_argument("--skipEppSave", default=True, action=argparse.BooleanOptionalAction) def handle(self, **options): """Extracts CSV files from a zip archive and imports them into the respective tables""" @@ -74,19 +74,20 @@ class Command(BaseCommand): self.clean_table(table_name) # Define the directory and the pattern for csv filenames - tmp_dir = 'tmp' - pattern = os.path.join(tmp_dir, f'{table_name}_*.csv') + tmp_dir = "tmp" + #pattern = os.path.join(tmp_dir, f"{table_name}_*.csv") + pattern = f"{table_name}_" resourceclass = getattr(registrar.admin, resourcename) resource_instance = resourceclass() # Find all files that match the pattern - for csv_filename in glob.glob(pattern): + matching_files = [file for file in os.listdir(tmp_dir) if file.startswith(pattern)] + for csv_filename in matching_files: try: with open(csv_filename, "r") as csvfile: dataset = tablib.Dataset().load(csvfile.read(), format="csv") result = resource_instance.import_data(dataset, dry_run=False, skip_epp_save=self.skip_epp_save) - if result.has_errors(): logger.error(f"Errors occurred while importing {csv_filename}:") for row_error in result.row_errors(): diff --git a/src/registrar/tests/test_management_scripts.py b/src/registrar/tests/test_management_scripts.py index 500953f02..006a39231 100644 --- a/src/registrar/tests/test_management_scripts.py +++ b/src/registrar/tests/test_management_scripts.py @@ -7,6 +7,7 @@ from django.utils.module_loading import import_string import logging import pyzipper from registrar.management.commands.clean_tables import Command as CleanTablesCommand +from registrar.management.commands.export_tables import Command as ExportTablesCommand from registrar.models import ( User, Domain, @@ -869,88 +870,77 @@ class TestCleanTables(TestCase): self.logger_mock.error.assert_any_call("Error cleaning table DomainInformation: Some error") + + class TestExportTables(MockEppLib): """Test the export_tables script""" def setUp(self): + self.command = ExportTablesCommand() self.logger_patcher = patch("registrar.management.commands.export_tables.logger") self.logger_mock = self.logger_patcher.start() def tearDown(self): self.logger_patcher.stop() - @patch("registrar.management.commands.export_tables.os.makedirs") - @patch("registrar.management.commands.export_tables.os.path.exists") - @patch("registrar.management.commands.export_tables.os.remove") - @patch("registrar.management.commands.export_tables.pyzipper.AESZipFile") + @patch("os.makedirs") + @patch("os.path.exists") + @patch("os.remove") + @patch("pyzipper.AESZipFile") @patch("registrar.management.commands.export_tables.getattr") - @patch("builtins.open", new_callable=mock_open, read_data=b"mock_csv_data") - @patch("django.utils.translation.trans_real._translations", {}) - @patch("django.utils.translation.trans_real.translation") + @patch("builtins.open", new_callable=mock_open) + @patch("os.listdir") def test_handle( - self, mock_translation, mock_file, mock_getattr, mock_zipfile, mock_remove, mock_path_exists, mock_makedirs + self, mock_listdir, mock_open, mock_getattr, mock_zipfile, mock_remove, mock_path_exists, mock_makedirs ): """test that the handle method properly exports tables""" - with less_console_noise(): - # Mock os.makedirs to do nothing - mock_makedirs.return_value = None + # Mock os.makedirs to do nothing + mock_makedirs.return_value = None - # Mock os.path.exists to always return True - mock_path_exists.return_value = True + # Mock os.path.exists to always return True + mock_path_exists.return_value = True - # Mock the resource class and its export method - mock_resource_class = MagicMock() - mock_dataset = MagicMock() - mock_dataset.csv = b"mock_csv_data" - mock_resource_class().export.return_value = mock_dataset - mock_getattr.return_value = mock_resource_class + # Check that the export_table function was called for each table + table_names = [ + "User", "Contact", "Domain", "DomainRequest", "DomainInformation", "FederalAgency", + "UserDomainRole", "DraftDomain", "Website", "HostIp", "Host", "PublicContact", + ] - # Mock translation function to return a dummy translation object - mock_translation.return_value = MagicMock() + # Mock directory listing + mock_listdir.side_effect = lambda path: [f"{table}_1.csv" for table in table_names] - call_command("export_tables") + # Mock the resource class and its export method + mock_dataset = tablib.Dataset() + mock_dataset.headers = ["header1", "header2"] + mock_dataset.append(["row1_col1", "row1_col2"]) + mock_resource_class = MagicMock() + mock_resource_class().export.return_value = mock_dataset + mock_getattr.return_value = mock_resource_class - # Check that os.makedirs was called once to create the tmp directory - mock_makedirs.assert_called_once_with("tmp", exist_ok=True) + command_instance = ExportTablesCommand() + command_instance.handle() - # Check that the export_table function was called for each table - table_names = [ - "User", - "Contact", - "Domain", - "DomainRequest", - "DomainInformation", - "UserDomainRole", - "DraftDomain", - "Website", - "HostIp", - "Host", - "PublicContact", - ] + # Check that os.makedirs was called once to create the tmp directory + mock_makedirs.assert_called_once_with("tmp", exist_ok=True) - # Check that the CSV file was written - for table_name in table_names: - mock_file().write.assert_any_call(b"mock_csv_data") - # Check that os.path.exists was called - mock_path_exists.assert_any_call(f"tmp/{table_name}.csv") - # Check that os.remove was called - mock_remove.assert_any_call(f"tmp/{table_name}.csv") + # Check that the CSV file was written + for table_name in table_names: + # Check that os.remove was called + mock_remove.assert_any_call(f"{table_name}_1.csv") - # Check that the zipfile was created and files were added - mock_zipfile.assert_called_once_with("tmp/exported_tables.zip", "w", compression=pyzipper.ZIP_DEFLATED) - zipfile_instance = mock_zipfile.return_value.__enter__.return_value - for table_name in table_names: - zipfile_instance.write.assert_any_call(f"tmp/{table_name}.csv", f"{table_name}.csv") + # Check that the zipfile was created and files were added + mock_zipfile.assert_called_once_with("tmp/exported_tables.zip", "w", compression=pyzipper.ZIP_DEFLATED) + zipfile_instance = mock_zipfile.return_value.__enter__.return_value + for table_name in table_names: + zipfile_instance.write.assert_any_call(f"{table_name}_1.csv", f"{table_name}_1.csv") - # Verify logging for added files - for table_name in table_names: - self.logger_mock.info.assert_any_call( - f"Added tmp/{table_name}.csv to zip archive tmp/exported_tables.zip" - ) + # Verify logging for added files + for table_name in table_names: + self.logger_mock.info.assert_any_call(f"Added {table_name}_1.csv to tmp/exported_files.zip") - # Verify logging for removed files - for table_name in table_names: - self.logger_mock.info.assert_any_call(f"Removed temporary file tmp/{table_name}.csv") + # Verify logging for removed files + for table_name in table_names: + self.logger_mock.info.assert_any_call(f"Removed {table_name}_1.csv") @patch("registrar.management.commands.export_tables.getattr") def test_export_table_handles_missing_resource_class(self, mock_getattr): @@ -995,8 +985,10 @@ class TestImportTables(TestCase): @patch("registrar.management.commands.import_tables.logger") @patch("registrar.management.commands.import_tables.getattr") @patch("django.apps.apps.get_model") + @patch("os.listdir") def test_handle( self, + mock_listdir, mock_get_model, mock_getattr, mock_logger, @@ -1019,6 +1011,24 @@ class TestImportTables(TestCase): mock_zipfile_instance = mock_zipfile.return_value.__enter__.return_value mock_zipfile_instance.extractall.return_value = None + # Check that the import_table function was called for each table + table_names = [ + "User", + "Contact", + "Domain", + "DomainRequest", + "DomainInformation", + "UserDomainRole", + "DraftDomain", + "Website", + "HostIp", + "Host", + "PublicContact", + ] + + # Mock directory listing + mock_listdir.side_effect = lambda path: [f"{table}_1.csv" for table in table_names] + # Mock the CSV file content csv_content = b"mock_csv_data" @@ -1054,23 +1064,9 @@ class TestImportTables(TestCase): # Check that extractall was called once to extract the zip file contents mock_zipfile_instance.extractall.assert_called_once_with("tmp") - # Check that the import_table function was called for each table - table_names = [ - "User", - "Contact", - "Domain", - "DomainRequest", - "DomainInformation", - "UserDomainRole", - "DraftDomain", - "Website", - "HostIp", - "Host", - "PublicContact", - ] # Check that os.path.exists was called for each table for table_name in table_names: - mock_path_exists.assert_any_call(f"tmp/{table_name}.csv") + mock_path_exists.assert_any_call(f"{table_name}_1.csv") # Check that clean_tables is called for Contact mock_get_model.assert_any_call("registrar", "Contact") @@ -1079,18 +1075,18 @@ class TestImportTables(TestCase): # Check that logger.info was called for each successful import for table_name in table_names: - mock_logger.info.assert_any_call(f"Successfully imported tmp/{table_name}.csv into {table_name}") + mock_logger.info.assert_any_call(f"Successfully imported {table_name}_1.csv into {table_name}") # Check that logger.error was not called for resource class not found mock_logger.error.assert_not_called() # Check that os.remove was called for each CSV file for table_name in table_names: - mock_remove.assert_any_call(f"tmp/{table_name}.csv") + mock_remove.assert_any_call(f"{table_name}_1.csv") # Check that logger.info was called for each CSV file removal for table_name in table_names: - mock_logger.info.assert_any_call(f"Removed temporary file tmp/{table_name}.csv") + mock_logger.info.assert_any_call(f"Removed temporary file {table_name}_1.csv") @patch("registrar.management.commands.import_tables.logger") @patch("registrar.management.commands.import_tables.os.makedirs")