updated unit tests for the associated changes in this PR

This commit is contained in:
David Kennedy 2024-06-14 12:17:08 -04:00
parent 552e434096
commit aa0a398c06
No known key found for this signature in database
GPG key ID: 6528A5386E66B96B
4 changed files with 101 additions and 102 deletions

View file

@ -2370,14 +2370,15 @@ class PublicContactResource(resources.ModelResource):
class Meta: class Meta:
model = models.PublicContact model = models.PublicContact
use_bulk = True # may want to consider these bulk options in future, so left in as comments
batch_size = 1000 # use_bulk = True
force_init_instance = True # batch_size = 1000
# force_init_instance = True
def __init__(self): def __init__(self):
"""Sets global variables for code tidyness""" """Sets global variables for code tidyness"""
super().__init__() super().__init__()
self.skip_epp_save=False self.skip_epp_save = False
def import_data( def import_data(
self, self,
@ -2387,10 +2388,10 @@ class PublicContactResource(resources.ModelResource):
use_transactions=None, use_transactions=None,
collect_failed_rows=False, collect_failed_rows=False,
rollback_on_validation_errors=False, rollback_on_validation_errors=False,
**kwargs **kwargs,
): ):
"""Override import_data to set self.skip_epp_save if in kwargs""" """Override import_data to set self.skip_epp_save if in kwargs"""
self.skip_epp_save = kwargs.get('skip_epp_save', False) self.skip_epp_save = kwargs.get("skip_epp_save", False)
return super().import_data( return super().import_data(
dataset, dataset,
dry_run, dry_run,

View file

@ -41,20 +41,21 @@ class Command(BaseCommand):
with pyzipper.AESZipFile(zip_filename, "w", compression=pyzipper.ZIP_DEFLATED) as zipf: with pyzipper.AESZipFile(zip_filename, "w", compression=pyzipper.ZIP_DEFLATED) as zipf:
for table_name in table_names: for table_name in table_names:
# Define the directory and the pattern # Define the tmp directory and the file pattern
tmp_dir = 'tmp' tmp_dir = "tmp"
pattern = os.path.join(tmp_dir, f'{table_name}_*.csv') pattern = f"{table_name}_"
zip_file_path = os.path.join(tmp_dir, 'exported_files.zip') zip_file_path = os.path.join(tmp_dir, "exported_files.zip")
# Find all files that match the pattern # Find all files that match the pattern
for file_path in glob.glob(pattern): matching_files = [file for file in os.listdir(tmp_dir) if file.startswith(pattern)]
for file_path in matching_files:
# Add each file to the zip archive # Add each file to the zip archive
zipf.write(file_path, os.path.basename(file_path)) zipf.write(file_path, os.path.basename(file_path))
logger.info(f'Added {file_path} to {zip_file_path}') logger.info(f"Added {file_path} to {zip_file_path}")
# Remove the file after adding to zip # Remove the file after adding to zip
os.remove(file_path) os.remove(file_path)
logger.info(f'Removed {file_path}') logger.info(f"Removed {file_path}")
def export_table(self, table_name): def export_table(self, table_name):
"""Export a given table to a csv file in the tmp directory""" """Export a given table to a csv file in the tmp directory"""
@ -71,7 +72,8 @@ class Command(BaseCommand):
# Calculate the number of files needed # Calculate the number of files needed
num_files = math.ceil(total_rows / rows_per_file) num_files = math.ceil(total_rows / rows_per_file)
logger.info(f'splitting {table_name} into {num_files} files')
logger.info(f"splitting {table_name} into {num_files} files")
# Split the dataset and export each chunk to a separate file # Split the dataset and export each chunk to a separate file
for i in range(num_files): for i in range(num_files):
@ -82,16 +84,15 @@ class Command(BaseCommand):
chunk = tablib.Dataset(headers=dataset.headers) chunk = tablib.Dataset(headers=dataset.headers)
for row in dataset[start_row:end_row]: for row in dataset[start_row:end_row]:
chunk.append(row) chunk.append(row)
#chunk = dataset[start_row:end_row]
# Export the chunk to a new file # Export the chunk to a new file
filename = f'tmp/{table_name}_{i + 1}.csv' filename = f"tmp/{table_name}_{i + 1}.csv"
with open(filename, 'w') as f: with open(filename, "w") as f:
f.write(chunk.export('csv')) f.write(chunk.export("csv"))
logger.info(f'Successfully exported {table_name} into {num_files} files.') logger.info(f"Successfully exported {table_name} into {num_files} files.")
except AttributeError as ae: except AttributeError as ae:
logger.error(f"Resource class {resourcename} not found in registrar.admin: {ae}") logger.error(f"Resource class {resourcename} not found in registrar.admin")
except Exception as e: except Exception as e:
logger.error(f"Failed to export {table_name}: {e}") logger.error(f"Failed to export {table_name}: {e}")

View file

@ -18,7 +18,7 @@ class Command(BaseCommand):
def add_arguments(self, parser): def add_arguments(self, parser):
"""Add command line arguments.""" """Add command line arguments."""
parser.add_argument('--skipEppSave', default=True, action=argparse.BooleanOptionalAction) parser.add_argument("--skipEppSave", default=True, action=argparse.BooleanOptionalAction)
def handle(self, **options): def handle(self, **options):
"""Extracts CSV files from a zip archive and imports them into the respective tables""" """Extracts CSV files from a zip archive and imports them into the respective tables"""
@ -74,19 +74,20 @@ class Command(BaseCommand):
self.clean_table(table_name) self.clean_table(table_name)
# Define the directory and the pattern for csv filenames # Define the directory and the pattern for csv filenames
tmp_dir = 'tmp' tmp_dir = "tmp"
pattern = os.path.join(tmp_dir, f'{table_name}_*.csv') #pattern = os.path.join(tmp_dir, f"{table_name}_*.csv")
pattern = f"{table_name}_"
resourceclass = getattr(registrar.admin, resourcename) resourceclass = getattr(registrar.admin, resourcename)
resource_instance = resourceclass() resource_instance = resourceclass()
# Find all files that match the pattern # Find all files that match the pattern
for csv_filename in glob.glob(pattern): matching_files = [file for file in os.listdir(tmp_dir) if file.startswith(pattern)]
for csv_filename in matching_files:
try: try:
with open(csv_filename, "r") as csvfile: with open(csv_filename, "r") as csvfile:
dataset = tablib.Dataset().load(csvfile.read(), format="csv") dataset = tablib.Dataset().load(csvfile.read(), format="csv")
result = resource_instance.import_data(dataset, dry_run=False, skip_epp_save=self.skip_epp_save) result = resource_instance.import_data(dataset, dry_run=False, skip_epp_save=self.skip_epp_save)
if result.has_errors(): if result.has_errors():
logger.error(f"Errors occurred while importing {csv_filename}:") logger.error(f"Errors occurred while importing {csv_filename}:")
for row_error in result.row_errors(): for row_error in result.row_errors():

View file

@ -7,6 +7,7 @@ from django.utils.module_loading import import_string
import logging import logging
import pyzipper import pyzipper
from registrar.management.commands.clean_tables import Command as CleanTablesCommand from registrar.management.commands.clean_tables import Command as CleanTablesCommand
from registrar.management.commands.export_tables import Command as ExportTablesCommand
from registrar.models import ( from registrar.models import (
User, User,
Domain, Domain,
@ -869,88 +870,77 @@ class TestCleanTables(TestCase):
self.logger_mock.error.assert_any_call("Error cleaning table DomainInformation: Some error") self.logger_mock.error.assert_any_call("Error cleaning table DomainInformation: Some error")
class TestExportTables(MockEppLib): class TestExportTables(MockEppLib):
"""Test the export_tables script""" """Test the export_tables script"""
def setUp(self): def setUp(self):
self.command = ExportTablesCommand()
self.logger_patcher = patch("registrar.management.commands.export_tables.logger") self.logger_patcher = patch("registrar.management.commands.export_tables.logger")
self.logger_mock = self.logger_patcher.start() self.logger_mock = self.logger_patcher.start()
def tearDown(self): def tearDown(self):
self.logger_patcher.stop() self.logger_patcher.stop()
@patch("registrar.management.commands.export_tables.os.makedirs") @patch("os.makedirs")
@patch("registrar.management.commands.export_tables.os.path.exists") @patch("os.path.exists")
@patch("registrar.management.commands.export_tables.os.remove") @patch("os.remove")
@patch("registrar.management.commands.export_tables.pyzipper.AESZipFile") @patch("pyzipper.AESZipFile")
@patch("registrar.management.commands.export_tables.getattr") @patch("registrar.management.commands.export_tables.getattr")
@patch("builtins.open", new_callable=mock_open, read_data=b"mock_csv_data") @patch("builtins.open", new_callable=mock_open)
@patch("django.utils.translation.trans_real._translations", {}) @patch("os.listdir")
@patch("django.utils.translation.trans_real.translation")
def test_handle( def test_handle(
self, mock_translation, mock_file, mock_getattr, mock_zipfile, mock_remove, mock_path_exists, mock_makedirs self, mock_listdir, mock_open, mock_getattr, mock_zipfile, mock_remove, mock_path_exists, mock_makedirs
): ):
"""test that the handle method properly exports tables""" """test that the handle method properly exports tables"""
with less_console_noise(): # Mock os.makedirs to do nothing
# Mock os.makedirs to do nothing mock_makedirs.return_value = None
mock_makedirs.return_value = None
# Mock os.path.exists to always return True # Mock os.path.exists to always return True
mock_path_exists.return_value = True mock_path_exists.return_value = True
# Mock the resource class and its export method # Check that the export_table function was called for each table
mock_resource_class = MagicMock() table_names = [
mock_dataset = MagicMock() "User", "Contact", "Domain", "DomainRequest", "DomainInformation", "FederalAgency",
mock_dataset.csv = b"mock_csv_data" "UserDomainRole", "DraftDomain", "Website", "HostIp", "Host", "PublicContact",
mock_resource_class().export.return_value = mock_dataset ]
mock_getattr.return_value = mock_resource_class
# Mock translation function to return a dummy translation object # Mock directory listing
mock_translation.return_value = MagicMock() mock_listdir.side_effect = lambda path: [f"{table}_1.csv" for table in table_names]
call_command("export_tables") # Mock the resource class and its export method
mock_dataset = tablib.Dataset()
mock_dataset.headers = ["header1", "header2"]
mock_dataset.append(["row1_col1", "row1_col2"])
mock_resource_class = MagicMock()
mock_resource_class().export.return_value = mock_dataset
mock_getattr.return_value = mock_resource_class
# Check that os.makedirs was called once to create the tmp directory command_instance = ExportTablesCommand()
mock_makedirs.assert_called_once_with("tmp", exist_ok=True) command_instance.handle()
# Check that the export_table function was called for each table # Check that os.makedirs was called once to create the tmp directory
table_names = [ mock_makedirs.assert_called_once_with("tmp", exist_ok=True)
"User",
"Contact",
"Domain",
"DomainRequest",
"DomainInformation",
"UserDomainRole",
"DraftDomain",
"Website",
"HostIp",
"Host",
"PublicContact",
]
# Check that the CSV file was written # Check that the CSV file was written
for table_name in table_names: for table_name in table_names:
mock_file().write.assert_any_call(b"mock_csv_data") # Check that os.remove was called
# Check that os.path.exists was called mock_remove.assert_any_call(f"{table_name}_1.csv")
mock_path_exists.assert_any_call(f"tmp/{table_name}.csv")
# Check that os.remove was called
mock_remove.assert_any_call(f"tmp/{table_name}.csv")
# Check that the zipfile was created and files were added # Check that the zipfile was created and files were added
mock_zipfile.assert_called_once_with("tmp/exported_tables.zip", "w", compression=pyzipper.ZIP_DEFLATED) mock_zipfile.assert_called_once_with("tmp/exported_tables.zip", "w", compression=pyzipper.ZIP_DEFLATED)
zipfile_instance = mock_zipfile.return_value.__enter__.return_value zipfile_instance = mock_zipfile.return_value.__enter__.return_value
for table_name in table_names: for table_name in table_names:
zipfile_instance.write.assert_any_call(f"tmp/{table_name}.csv", f"{table_name}.csv") zipfile_instance.write.assert_any_call(f"{table_name}_1.csv", f"{table_name}_1.csv")
# Verify logging for added files # Verify logging for added files
for table_name in table_names: for table_name in table_names:
self.logger_mock.info.assert_any_call( self.logger_mock.info.assert_any_call(f"Added {table_name}_1.csv to tmp/exported_files.zip")
f"Added tmp/{table_name}.csv to zip archive tmp/exported_tables.zip"
)
# Verify logging for removed files # Verify logging for removed files
for table_name in table_names: for table_name in table_names:
self.logger_mock.info.assert_any_call(f"Removed temporary file tmp/{table_name}.csv") self.logger_mock.info.assert_any_call(f"Removed {table_name}_1.csv")
@patch("registrar.management.commands.export_tables.getattr") @patch("registrar.management.commands.export_tables.getattr")
def test_export_table_handles_missing_resource_class(self, mock_getattr): def test_export_table_handles_missing_resource_class(self, mock_getattr):
@ -995,8 +985,10 @@ class TestImportTables(TestCase):
@patch("registrar.management.commands.import_tables.logger") @patch("registrar.management.commands.import_tables.logger")
@patch("registrar.management.commands.import_tables.getattr") @patch("registrar.management.commands.import_tables.getattr")
@patch("django.apps.apps.get_model") @patch("django.apps.apps.get_model")
@patch("os.listdir")
def test_handle( def test_handle(
self, self,
mock_listdir,
mock_get_model, mock_get_model,
mock_getattr, mock_getattr,
mock_logger, mock_logger,
@ -1019,6 +1011,24 @@ class TestImportTables(TestCase):
mock_zipfile_instance = mock_zipfile.return_value.__enter__.return_value mock_zipfile_instance = mock_zipfile.return_value.__enter__.return_value
mock_zipfile_instance.extractall.return_value = None mock_zipfile_instance.extractall.return_value = None
# Check that the import_table function was called for each table
table_names = [
"User",
"Contact",
"Domain",
"DomainRequest",
"DomainInformation",
"UserDomainRole",
"DraftDomain",
"Website",
"HostIp",
"Host",
"PublicContact",
]
# Mock directory listing
mock_listdir.side_effect = lambda path: [f"{table}_1.csv" for table in table_names]
# Mock the CSV file content # Mock the CSV file content
csv_content = b"mock_csv_data" csv_content = b"mock_csv_data"
@ -1054,23 +1064,9 @@ class TestImportTables(TestCase):
# Check that extractall was called once to extract the zip file contents # Check that extractall was called once to extract the zip file contents
mock_zipfile_instance.extractall.assert_called_once_with("tmp") mock_zipfile_instance.extractall.assert_called_once_with("tmp")
# Check that the import_table function was called for each table
table_names = [
"User",
"Contact",
"Domain",
"DomainRequest",
"DomainInformation",
"UserDomainRole",
"DraftDomain",
"Website",
"HostIp",
"Host",
"PublicContact",
]
# Check that os.path.exists was called for each table # Check that os.path.exists was called for each table
for table_name in table_names: for table_name in table_names:
mock_path_exists.assert_any_call(f"tmp/{table_name}.csv") mock_path_exists.assert_any_call(f"{table_name}_1.csv")
# Check that clean_tables is called for Contact # Check that clean_tables is called for Contact
mock_get_model.assert_any_call("registrar", "Contact") mock_get_model.assert_any_call("registrar", "Contact")
@ -1079,18 +1075,18 @@ class TestImportTables(TestCase):
# Check that logger.info was called for each successful import # Check that logger.info was called for each successful import
for table_name in table_names: for table_name in table_names:
mock_logger.info.assert_any_call(f"Successfully imported tmp/{table_name}.csv into {table_name}") mock_logger.info.assert_any_call(f"Successfully imported {table_name}_1.csv into {table_name}")
# Check that logger.error was not called for resource class not found # Check that logger.error was not called for resource class not found
mock_logger.error.assert_not_called() mock_logger.error.assert_not_called()
# Check that os.remove was called for each CSV file # Check that os.remove was called for each CSV file
for table_name in table_names: for table_name in table_names:
mock_remove.assert_any_call(f"tmp/{table_name}.csv") mock_remove.assert_any_call(f"{table_name}_1.csv")
# Check that logger.info was called for each CSV file removal # Check that logger.info was called for each CSV file removal
for table_name in table_names: for table_name in table_names:
mock_logger.info.assert_any_call(f"Removed temporary file tmp/{table_name}.csv") mock_logger.info.assert_any_call(f"Removed temporary file {table_name}_1.csv")
@patch("registrar.management.commands.import_tables.logger") @patch("registrar.management.commands.import_tables.logger")
@patch("registrar.management.commands.import_tables.os.makedirs") @patch("registrar.management.commands.import_tables.os.makedirs")