diff --git a/core/src/main/java/google/registry/persistence/HibernateSchemaExporter.java b/core/src/main/java/google/registry/persistence/HibernateSchemaExporter.java index 50ad9485e..0753e8073 100644 --- a/core/src/main/java/google/registry/persistence/HibernateSchemaExporter.java +++ b/core/src/main/java/google/registry/persistence/HibernateSchemaExporter.java @@ -21,15 +21,12 @@ import com.google.common.collect.Maps; import java.io.File; import java.util.EnumSet; import java.util.Map; -import java.util.Properties; import java.util.stream.Stream; import javax.persistence.AttributeConverter; import org.hibernate.boot.MetadataSources; import org.hibernate.boot.registry.StandardServiceRegistry; import org.hibernate.boot.registry.StandardServiceRegistryBuilder; import org.hibernate.cfg.Environment; -import org.hibernate.jpa.boot.internal.ParsedPersistenceXmlDescriptor; -import org.hibernate.jpa.boot.internal.PersistenceXmlParser; import org.hibernate.tool.hbm2ddl.SchemaExport; import org.hibernate.tool.schema.TargetType; @@ -82,25 +79,7 @@ public class HibernateSchemaExporter { } private ImmutableList findAllConverters() { - ParsedPersistenceXmlDescriptor descriptor = - PersistenceXmlParser.locatePersistenceUnits(new Properties()).stream() - .filter(unit -> PersistenceModule.PERSISTENCE_UNIT_NAME.equals(unit.getName())) - .findFirst() - .orElseThrow( - () -> - new IllegalArgumentException( - String.format( - "Could not find persistence unit with name %s", - PersistenceModule.PERSISTENCE_UNIT_NAME))); - return descriptor.getManagedClassNames().stream() - .map( - className -> { - try { - return Class.forName(className); - } catch (ClassNotFoundException e) { - throw new RuntimeException(e); - } - }) + return PersistenceXmlUtility.getManagedClasses().stream() .filter(AttributeConverter.class::isAssignableFrom) .collect(toImmutableList()); } diff --git a/core/src/main/java/google/registry/persistence/PersistenceXmlUtility.java b/core/src/main/java/google/registry/persistence/PersistenceXmlUtility.java new file mode 100644 index 000000000..8bd3f90d7 --- /dev/null +++ b/core/src/main/java/google/registry/persistence/PersistenceXmlUtility.java @@ -0,0 +1,59 @@ +// Copyright 2019 The Nomulus Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package google.registry.persistence; + +import static com.google.common.collect.ImmutableList.toImmutableList; + +import com.google.common.collect.ImmutableList; +import java.util.Properties; +import org.hibernate.jpa.boot.internal.ParsedPersistenceXmlDescriptor; +import org.hibernate.jpa.boot.internal.PersistenceXmlParser; + +/** Utility class that provides methods to manipulate persistence.xml file. */ +public class PersistenceXmlUtility { + private PersistenceXmlUtility() {} + + /** + * Returns the {@link ParsedPersistenceXmlDescriptor} instance constructed from persistence.xml. + */ + public static ParsedPersistenceXmlDescriptor getParsedPersistenceXmlDescriptor() { + return PersistenceXmlParser.locatePersistenceUnits(new Properties()).stream() + .filter(unit -> PersistenceModule.PERSISTENCE_UNIT_NAME.equals(unit.getName())) + .findFirst() + .orElseThrow( + () -> + new IllegalArgumentException( + String.format( + "Could not find persistence unit with name %s", + PersistenceModule.PERSISTENCE_UNIT_NAME))); + } + + /** Returns all managed classes defined in persistence.xml. */ + public static ImmutableList getManagedClasses() { + return getParsedPersistenceXmlDescriptor().getManagedClassNames().stream() + .map( + className -> { + try { + return Class.forName(className); + } catch (ClassNotFoundException e) { + throw new IllegalArgumentException( + String.format( + "Could not load class with name %s present in persistence.xml", className), + e); + } + }) + .collect(toImmutableList()); + } +} diff --git a/core/src/main/java/google/registry/tools/GenerateSqlSchemaCommand.java b/core/src/main/java/google/registry/tools/GenerateSqlSchemaCommand.java index 6e39fd5e5..04a9beebc 100644 --- a/core/src/main/java/google/registry/tools/GenerateSqlSchemaCommand.java +++ b/core/src/main/java/google/registry/tools/GenerateSqlSchemaCommand.java @@ -19,23 +19,11 @@ import static java.nio.charset.StandardCharsets.UTF_8; import com.beust.jcommander.Parameter; import com.beust.jcommander.Parameters; import com.google.common.annotations.VisibleForTesting; -import google.registry.persistence.NomulusNamingStrategy; -import google.registry.persistence.NomulusPostgreSQLDialect; -import google.registry.persistence.PersistenceModule; +import google.registry.persistence.HibernateSchemaExporter; +import google.registry.persistence.PersistenceXmlUtility; +import java.io.File; import java.io.IOException; import java.nio.file.Files; -import java.nio.file.Paths; -import java.util.EnumSet; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import org.hibernate.boot.MetadataSources; -import org.hibernate.boot.registry.StandardServiceRegistryBuilder; -import org.hibernate.cfg.Environment; -import org.hibernate.jpa.boot.internal.ParsedPersistenceXmlDescriptor; -import org.hibernate.jpa.boot.internal.PersistenceXmlParser; -import org.hibernate.tool.hbm2ddl.SchemaExport; -import org.hibernate.tool.schema.TargetType; import org.testcontainers.containers.PostgreSQLContainer; /** @@ -47,6 +35,9 @@ import org.testcontainers.containers.PostgreSQLContainer; */ @Parameters(separators = " =", commandDescription = "Generate PostgreSQL schema.") public class GenerateSqlSchemaCommand implements Command { + private static final String DB_NAME = "postgres"; + private static final String DB_USERNAME = "postgres"; + private static final String DB_PASSWORD = "domain-registry"; @VisibleForTesting public static final String DB_OPTIONS_CLASH = @@ -91,10 +82,11 @@ public class GenerateSqlSchemaCommand implements Command { } // Start the container and store the address information. - postgresContainer = new PostgreSQLContainer() - .withDatabaseName("postgres") - .withUsername("postgres") - .withPassword("domain-registry"); + postgresContainer = + new PostgreSQLContainer() + .withDatabaseName(DB_NAME) + .withUsername(DB_USERNAME) + .withPassword(DB_PASSWORD); postgresContainer.start(); databaseHost = postgresContainer.getContainerIpAddress(); databasePort = postgresContainer.getMappedPort(POSTGRESQL_PORT); @@ -119,29 +111,7 @@ public class GenerateSqlSchemaCommand implements Command { } try { - // Configure Hibernate settings. - Map settings = new HashMap<>(); - settings.put("hibernate.dialect", NomulusPostgreSQLDialect.class.getName()); - settings.put( - "hibernate.connection.url", - "jdbc:postgresql://" + databaseHost + ":" + databasePort + "/postgres?useSSL=false"); - settings.put("hibernate.connection.username", "postgres"); - settings.put("hibernate.connection.password", "domain-registry"); - settings.put("hibernate.hbm2ddl.auto", "none"); - settings.put("show_sql", "true"); - settings.put( - Environment.PHYSICAL_NAMING_STRATEGY, NomulusNamingStrategy.class.getCanonicalName()); - - MetadataSources metadata = - new MetadataSources(new StandardServiceRegistryBuilder().applySettings(settings).build()); - - addAnnotatedClasses(metadata, settings); - - SchemaExport schemaExport = new SchemaExport(); - schemaExport.setHaltOnError(true); - schemaExport.setFormat(true); - schemaExport.setDelimiter(";"); - schemaExport.setOutputFile(outFile); + File outputFile = new File(outFile); // Generate the copyright header (this file gets checked for copyright). The schema exporter // appends to the existing file, so this has the additional desired effect of clearing any @@ -161,44 +131,30 @@ public class GenerateSqlSchemaCommand implements Command { + "-- See the License for the specific language governing permissions and\n" + "-- limitations under the License.\n"; try { - Files.write(Paths.get(outFile), copyright.getBytes(UTF_8)); + Files.write(outputFile.toPath(), copyright.getBytes(UTF_8)); } catch (IOException e) { System.err.println("Error writing sql file: " + e); e.printStackTrace(); System.exit(1); } - schemaExport.createOnly(EnumSet.of(TargetType.SCRIPT), metadata.buildMetadata()); + HibernateSchemaExporter exporter = + HibernateSchemaExporter.create( + "jdbc:postgresql://" + + databaseHost + + ":" + + databasePort + + "/" + + DB_NAME + + "?useSSL=false", + DB_USERNAME, + DB_PASSWORD); + exporter.export(PersistenceXmlUtility.getManagedClasses(), outputFile); + } finally { if (postgresContainer != null) { postgresContainer.stop(); } } } - - private void addAnnotatedClasses(MetadataSources metadata, Map settings) { - ParsedPersistenceXmlDescriptor descriptor = - PersistenceXmlParser.locatePersistenceUnits(settings).stream() - .filter(unit -> PersistenceModule.PERSISTENCE_UNIT_NAME.equals(unit.getName())) - .findFirst() - .orElseThrow( - () -> - new IllegalArgumentException( - String.format( - "Could not find persistence unit with name %s", - PersistenceModule.PERSISTENCE_UNIT_NAME))); - - List classNames = descriptor.getManagedClassNames(); - for (String className : classNames) { - try { - Class clazz = Class.forName(className); - metadata.addAnnotatedClass(clazz); - } catch (ClassNotFoundException e) { - throw new IllegalArgumentException( - String.format( - "Could not load class with name %s present in persistence.xml", className), - e); - } - } - } } diff --git a/core/src/test/java/google/registry/model/transaction/JpaTransactionManagerRule.java b/core/src/test/java/google/registry/model/transaction/JpaTransactionManagerRule.java index 3562adb41..72691ee04 100644 --- a/core/src/test/java/google/registry/model/transaction/JpaTransactionManagerRule.java +++ b/core/src/test/java/google/registry/model/transaction/JpaTransactionManagerRule.java @@ -26,6 +26,7 @@ import com.google.common.collect.Maps; import com.google.common.io.Resources; import google.registry.persistence.HibernateSchemaExporter; import google.registry.persistence.PersistenceModule; +import google.registry.persistence.PersistenceXmlUtility; import google.registry.testing.FakeClock; import java.io.File; import java.io.IOException; @@ -45,7 +46,6 @@ import java.util.Properties; import javax.persistence.EntityManagerFactory; import org.hibernate.cfg.Environment; import org.hibernate.jpa.boot.internal.ParsedPersistenceXmlDescriptor; -import org.hibernate.jpa.boot.internal.PersistenceXmlParser; import org.hibernate.jpa.boot.spi.Bootstrap; import org.joda.time.DateTime; import org.junit.rules.ExternalResource; @@ -214,15 +214,7 @@ public class JpaTransactionManagerRule extends ExternalResource { properties.put(Environment.PASS, password); ParsedPersistenceXmlDescriptor descriptor = - PersistenceXmlParser.locatePersistenceUnits(properties).stream() - .filter(unit -> PersistenceModule.PERSISTENCE_UNIT_NAME.equals(unit.getName())) - .findFirst() - .orElseThrow( - () -> - new IllegalArgumentException( - String.format( - "Could not find persistence unit with name %s", - PersistenceModule.PERSISTENCE_UNIT_NAME))); + PersistenceXmlUtility.getParsedPersistenceXmlDescriptor(); extraEntityClasses.stream().map(Class::getName).forEach(descriptor::addClasses); return Bootstrap.getEntityManagerFactoryBuilder(descriptor, properties).build();