Add PersistenceXmlUtility and refactor related code (#357)

* Refactor GenerateSqlSchemaCommand

* Add and throw UncheckedClassNotFoundException
This commit is contained in:
Shicong Huang 2019-11-12 15:08:54 -05:00 committed by GitHub
parent a392100852
commit 09aef04117
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 88 additions and 102 deletions

View file

@ -21,15 +21,12 @@ import com.google.common.collect.Maps;
import java.io.File;
import java.util.EnumSet;
import java.util.Map;
import java.util.Properties;
import java.util.stream.Stream;
import javax.persistence.AttributeConverter;
import org.hibernate.boot.MetadataSources;
import org.hibernate.boot.registry.StandardServiceRegistry;
import org.hibernate.boot.registry.StandardServiceRegistryBuilder;
import org.hibernate.cfg.Environment;
import org.hibernate.jpa.boot.internal.ParsedPersistenceXmlDescriptor;
import org.hibernate.jpa.boot.internal.PersistenceXmlParser;
import org.hibernate.tool.hbm2ddl.SchemaExport;
import org.hibernate.tool.schema.TargetType;
@ -82,25 +79,7 @@ public class HibernateSchemaExporter {
}
private ImmutableList<Class> findAllConverters() {
ParsedPersistenceXmlDescriptor descriptor =
PersistenceXmlParser.locatePersistenceUnits(new Properties()).stream()
.filter(unit -> PersistenceModule.PERSISTENCE_UNIT_NAME.equals(unit.getName()))
.findFirst()
.orElseThrow(
() ->
new IllegalArgumentException(
String.format(
"Could not find persistence unit with name %s",
PersistenceModule.PERSISTENCE_UNIT_NAME)));
return descriptor.getManagedClassNames().stream()
.map(
className -> {
try {
return Class.forName(className);
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
})
return PersistenceXmlUtility.getManagedClasses().stream()
.filter(AttributeConverter.class::isAssignableFrom)
.collect(toImmutableList());
}

View file

@ -0,0 +1,59 @@
// Copyright 2019 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.persistence;
import static com.google.common.collect.ImmutableList.toImmutableList;
import com.google.common.collect.ImmutableList;
import java.util.Properties;
import org.hibernate.jpa.boot.internal.ParsedPersistenceXmlDescriptor;
import org.hibernate.jpa.boot.internal.PersistenceXmlParser;
/** Utility class that provides methods to manipulate persistence.xml file. */
public class PersistenceXmlUtility {
private PersistenceXmlUtility() {}
/**
* Returns the {@link ParsedPersistenceXmlDescriptor} instance constructed from persistence.xml.
*/
public static ParsedPersistenceXmlDescriptor getParsedPersistenceXmlDescriptor() {
return PersistenceXmlParser.locatePersistenceUnits(new Properties()).stream()
.filter(unit -> PersistenceModule.PERSISTENCE_UNIT_NAME.equals(unit.getName()))
.findFirst()
.orElseThrow(
() ->
new IllegalArgumentException(
String.format(
"Could not find persistence unit with name %s",
PersistenceModule.PERSISTENCE_UNIT_NAME)));
}
/** Returns all managed classes defined in persistence.xml. */
public static ImmutableList<Class> getManagedClasses() {
return getParsedPersistenceXmlDescriptor().getManagedClassNames().stream()
.map(
className -> {
try {
return Class.forName(className);
} catch (ClassNotFoundException e) {
throw new IllegalArgumentException(
String.format(
"Could not load class with name %s present in persistence.xml", className),
e);
}
})
.collect(toImmutableList());
}
}

View file

@ -19,23 +19,11 @@ import static java.nio.charset.StandardCharsets.UTF_8;
import com.beust.jcommander.Parameter;
import com.beust.jcommander.Parameters;
import com.google.common.annotations.VisibleForTesting;
import google.registry.persistence.NomulusNamingStrategy;
import google.registry.persistence.NomulusPostgreSQLDialect;
import google.registry.persistence.PersistenceModule;
import google.registry.persistence.HibernateSchemaExporter;
import google.registry.persistence.PersistenceXmlUtility;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.hibernate.boot.MetadataSources;
import org.hibernate.boot.registry.StandardServiceRegistryBuilder;
import org.hibernate.cfg.Environment;
import org.hibernate.jpa.boot.internal.ParsedPersistenceXmlDescriptor;
import org.hibernate.jpa.boot.internal.PersistenceXmlParser;
import org.hibernate.tool.hbm2ddl.SchemaExport;
import org.hibernate.tool.schema.TargetType;
import org.testcontainers.containers.PostgreSQLContainer;
/**
@ -47,6 +35,9 @@ import org.testcontainers.containers.PostgreSQLContainer;
*/
@Parameters(separators = " =", commandDescription = "Generate PostgreSQL schema.")
public class GenerateSqlSchemaCommand implements Command {
private static final String DB_NAME = "postgres";
private static final String DB_USERNAME = "postgres";
private static final String DB_PASSWORD = "domain-registry";
@VisibleForTesting
public static final String DB_OPTIONS_CLASH =
@ -91,10 +82,11 @@ public class GenerateSqlSchemaCommand implements Command {
}
// Start the container and store the address information.
postgresContainer = new PostgreSQLContainer()
.withDatabaseName("postgres")
.withUsername("postgres")
.withPassword("domain-registry");
postgresContainer =
new PostgreSQLContainer()
.withDatabaseName(DB_NAME)
.withUsername(DB_USERNAME)
.withPassword(DB_PASSWORD);
postgresContainer.start();
databaseHost = postgresContainer.getContainerIpAddress();
databasePort = postgresContainer.getMappedPort(POSTGRESQL_PORT);
@ -119,29 +111,7 @@ public class GenerateSqlSchemaCommand implements Command {
}
try {
// Configure Hibernate settings.
Map<String, String> settings = new HashMap<>();
settings.put("hibernate.dialect", NomulusPostgreSQLDialect.class.getName());
settings.put(
"hibernate.connection.url",
"jdbc:postgresql://" + databaseHost + ":" + databasePort + "/postgres?useSSL=false");
settings.put("hibernate.connection.username", "postgres");
settings.put("hibernate.connection.password", "domain-registry");
settings.put("hibernate.hbm2ddl.auto", "none");
settings.put("show_sql", "true");
settings.put(
Environment.PHYSICAL_NAMING_STRATEGY, NomulusNamingStrategy.class.getCanonicalName());
MetadataSources metadata =
new MetadataSources(new StandardServiceRegistryBuilder().applySettings(settings).build());
addAnnotatedClasses(metadata, settings);
SchemaExport schemaExport = new SchemaExport();
schemaExport.setHaltOnError(true);
schemaExport.setFormat(true);
schemaExport.setDelimiter(";");
schemaExport.setOutputFile(outFile);
File outputFile = new File(outFile);
// Generate the copyright header (this file gets checked for copyright). The schema exporter
// appends to the existing file, so this has the additional desired effect of clearing any
@ -161,44 +131,30 @@ public class GenerateSqlSchemaCommand implements Command {
+ "-- See the License for the specific language governing permissions and\n"
+ "-- limitations under the License.\n";
try {
Files.write(Paths.get(outFile), copyright.getBytes(UTF_8));
Files.write(outputFile.toPath(), copyright.getBytes(UTF_8));
} catch (IOException e) {
System.err.println("Error writing sql file: " + e);
e.printStackTrace();
System.exit(1);
}
schemaExport.createOnly(EnumSet.of(TargetType.SCRIPT), metadata.buildMetadata());
HibernateSchemaExporter exporter =
HibernateSchemaExporter.create(
"jdbc:postgresql://"
+ databaseHost
+ ":"
+ databasePort
+ "/"
+ DB_NAME
+ "?useSSL=false",
DB_USERNAME,
DB_PASSWORD);
exporter.export(PersistenceXmlUtility.getManagedClasses(), outputFile);
} finally {
if (postgresContainer != null) {
postgresContainer.stop();
}
}
}
private void addAnnotatedClasses(MetadataSources metadata, Map<String, String> settings) {
ParsedPersistenceXmlDescriptor descriptor =
PersistenceXmlParser.locatePersistenceUnits(settings).stream()
.filter(unit -> PersistenceModule.PERSISTENCE_UNIT_NAME.equals(unit.getName()))
.findFirst()
.orElseThrow(
() ->
new IllegalArgumentException(
String.format(
"Could not find persistence unit with name %s",
PersistenceModule.PERSISTENCE_UNIT_NAME)));
List<String> classNames = descriptor.getManagedClassNames();
for (String className : classNames) {
try {
Class<?> clazz = Class.forName(className);
metadata.addAnnotatedClass(clazz);
} catch (ClassNotFoundException e) {
throw new IllegalArgumentException(
String.format(
"Could not load class with name %s present in persistence.xml", className),
e);
}
}
}
}

View file

@ -26,6 +26,7 @@ import com.google.common.collect.Maps;
import com.google.common.io.Resources;
import google.registry.persistence.HibernateSchemaExporter;
import google.registry.persistence.PersistenceModule;
import google.registry.persistence.PersistenceXmlUtility;
import google.registry.testing.FakeClock;
import java.io.File;
import java.io.IOException;
@ -45,7 +46,6 @@ import java.util.Properties;
import javax.persistence.EntityManagerFactory;
import org.hibernate.cfg.Environment;
import org.hibernate.jpa.boot.internal.ParsedPersistenceXmlDescriptor;
import org.hibernate.jpa.boot.internal.PersistenceXmlParser;
import org.hibernate.jpa.boot.spi.Bootstrap;
import org.joda.time.DateTime;
import org.junit.rules.ExternalResource;
@ -214,15 +214,7 @@ public class JpaTransactionManagerRule extends ExternalResource {
properties.put(Environment.PASS, password);
ParsedPersistenceXmlDescriptor descriptor =
PersistenceXmlParser.locatePersistenceUnits(properties).stream()
.filter(unit -> PersistenceModule.PERSISTENCE_UNIT_NAME.equals(unit.getName()))
.findFirst()
.orElseThrow(
() ->
new IllegalArgumentException(
String.format(
"Could not find persistence unit with name %s",
PersistenceModule.PERSISTENCE_UNIT_NAME)));
PersistenceXmlUtility.getParsedPersistenceXmlDescriptor();
extraEntityClasses.stream().map(Class::getName).forEach(descriptor::addClasses);
return Bootstrap.getEntityManagerFactoryBuilder(descriptor, properties).build();