diff --git a/core/src/test/java/google/registry/persistence/transaction/JpaTransactionManagerRule.java b/core/src/test/java/google/registry/persistence/transaction/JpaTransactionManagerRule.java index ff52095e2..8b575e124 100644 --- a/core/src/test/java/google/registry/persistence/transaction/JpaTransactionManagerRule.java +++ b/core/src/test/java/google/registry/persistence/transaction/JpaTransactionManagerRule.java @@ -14,13 +14,16 @@ package google.registry.persistence.transaction; +import static com.google.common.base.Preconditions.checkState; import static com.google.common.truth.Truth.assertWithMessage; import static org.testcontainers.containers.PostgreSQLContainer.POSTGRESQL_PORT; import com.google.common.base.Charsets; +import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Maps; +import com.google.common.collect.Streams; import com.google.common.io.Resources; import google.registry.persistence.HibernateSchemaExporter; import google.registry.persistence.NomulusPostgreSql; @@ -40,6 +43,8 @@ import java.sql.Statement; import java.util.HashMap; import java.util.Optional; import java.util.Properties; +import java.util.stream.Collectors; +import java.util.stream.Stream; import javax.persistence.EntityManagerFactory; import org.hibernate.cfg.Environment; import org.hibernate.jpa.boot.internal.ParsedPersistenceXmlDescriptor; @@ -77,8 +82,15 @@ abstract class JpaTransactionManagerRule extends ExternalResource { private static final HibernateSchemaExporter exporter = HibernateSchemaExporter.create( database.getJdbcUrl(), database.getUsername(), database.getPassword()); - private EntityManagerFactory emf; + // The EntityManagerFactory for the current schema in the test db. This instance may be + // reused between test methods if the requested schema remains the same. + private static EntityManagerFactory emf; + // Hash of the ORM entity names in the current schema in the test db. + private static int emfEntityHash; + private JpaTransactionManager cachedTm; + // Hash of the ORM entity names requested by this rule instance. + private int entityHash; protected JpaTransactionManagerRule( Clock clock, @@ -89,6 +101,7 @@ abstract class JpaTransactionManagerRule extends ExternalResource { this.initScriptPath = initScriptPath; this.extraEntityClasses = extraEntityClasses; this.userProperties = userProperties; + this.entityHash = getOrmEntityHash(initScriptPath, extraEntityClasses); } private static JdbcDatabaseContainer create() { @@ -99,16 +112,34 @@ abstract class JpaTransactionManagerRule extends ExternalResource { return container; } - @Override - public void before() throws Exception { - executeSql(POSTGRES_DB_NAME, readSqlInClassPath(DB_CLEANUP_SQL_PATH)); - initScriptPath.ifPresent(path -> executeSql(POSTGRES_DB_NAME, readSqlInClassPath(path))); + private static int getOrmEntityHash( + Optional initScriptPath, ImmutableList extraEntityClasses) { + return Streams.concat( + Stream.of(initScriptPath.orElse("")), + extraEntityClasses.stream().map(Class::getCanonicalName)) + .sorted() + .collect(Collectors.toList()) + .hashCode(); + } + + /** + * Drops and recreates the 'public' schema and all tables, then creates a new {@link + * EntityManagerFactory} and save it in {@link #emf}. + */ + private void recreateSchema() throws Exception { + if (emf != null) { + emf.close(); + emf = null; + emfEntityHash = 0; + assertReasonableNumDbConnections(); + } + executeSql(readSqlInClassPath(DB_CLEANUP_SQL_PATH)); + initScriptPath.ifPresent(path -> executeSql(readSqlInClassPath(path))); if (!extraEntityClasses.isEmpty()) { File tempSqlFile = File.createTempFile("tempSqlFile", ".sql"); tempSqlFile.deleteOnExit(); exporter.export(extraEntityClasses, tempSqlFile); executeSql( - POSTGRES_DB_NAME, new String(Files.readAllBytes(tempSqlFile.toPath()), StandardCharsets.UTF_8)); } @@ -125,11 +156,22 @@ abstract class JpaTransactionManagerRule extends ExternalResource { assertReasonableNumDbConnections(); emf = createEntityManagerFactory( - getJdbcUrlFor(POSTGRES_DB_NAME), + getJdbcUrl(), database.getUsername(), database.getPassword(), properties, extraEntityClasses); + emfEntityHash = entityHash; + } + + @Override + public void before() throws Exception { + if (entityHash == emfEntityHash) { + checkState(emf != null, "Missing EntityManagerFactory."); + resetTablesAndSequences(); + } else { + recreateSchema(); + } JpaTransactionManagerImpl txnManager = new JpaTransactionManagerImpl(emf, clock); cachedTm = TransactionManagerFactory.jpaTm(); TransactionManagerFactory.setJpaTm(txnManager); @@ -138,12 +180,26 @@ abstract class JpaTransactionManagerRule extends ExternalResource { @Override public void after() { TransactionManagerFactory.setJpaTm(cachedTm); - if (emf != null) { - emf.close(); - emf = null; - } cachedTm = null; - assertReasonableNumDbConnections(); + } + + private void resetTablesAndSequences() { + try (Connection conn = createConnection(); + Statement statement = conn.createStatement()) { + ResultSet rs = + statement.executeQuery( + "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';"); + ImmutableList.Builder tableNames = new ImmutableList.Builder<>(); + while (rs.next()) { + tableNames.add('"' + rs.getString(1) + '"'); + } + String sql = + String.format( + "TRUNCATE %s RESTART IDENTITY CASCADE", Joiner.on(',').join(tableNames.build())); + executeSql(sql); + } catch (Exception e) { + throw new RuntimeException(e); + } } /** @@ -155,7 +211,7 @@ abstract class JpaTransactionManagerRule extends ExternalResource { * is less than 5 to reduce flakiness. */ private void assertReasonableNumDbConnections() { - try (Connection conn = createConnection(POSTGRES_DB_NAME); + try (Connection conn = createConnection(); Statement statement = conn.createStatement()) { // Note: Since we use the admin user (returned by container's getUserName() method) // in tests, we need to filter connections by database name and/or backend type to filter out @@ -184,8 +240,8 @@ abstract class JpaTransactionManagerRule extends ExternalResource { } } - private void executeSql(String dbName, String sqlScript) { - try (Connection conn = createConnection(dbName); + private static void executeSql(String sqlScript) { + try (Connection conn = createConnection(); Statement statement = conn.createStatement()) { statement.execute(sqlScript); } catch (Exception e) { @@ -193,24 +249,24 @@ abstract class JpaTransactionManagerRule extends ExternalResource { } } - private static String getJdbcUrlFor(String dbName) { + private static String getJdbcUrl() { // Disable Postgres driver use of java.util.logging to reduce noise at startup time return "jdbc:postgresql://" + database.getContainerIpAddress() + ":" + database.getMappedPort(POSTGRESQL_PORT) + "/" - + dbName + + POSTGRES_DB_NAME + "?loggerLevel=OFF"; } - private static Connection createConnection(String dbName) { + private static Connection createConnection() { final Properties info = new Properties(); info.put("user", database.getUsername()); info.put("password", database.getPassword()); final Driver jdbcDriverInstance = database.getJdbcDriverInstance(); try { - return jdbcDriverInstance.connect(getJdbcUrlFor(dbName), info); + return jdbcDriverInstance.connect(getJdbcUrl(), info); } catch (SQLException e) { throw new RuntimeException(e); } @@ -227,6 +283,8 @@ abstract class JpaTransactionManagerRule extends ExternalResource { properties.put(Environment.URL, jdbcUrl); properties.put(Environment.USER, username); properties.put(Environment.PASS, password); + // Tell Postgresql JDBC driver to expect out-of-band schema change. + properties.put("hibernate.hikari.dataSource.autosave", "conservative"); ParsedPersistenceXmlDescriptor descriptor = PersistenceXmlUtility.getParsedPersistenceXmlDescriptor();