diff --git a/core/src/main/java/google/registry/beam/initsql/BeamJpaModule.java b/core/src/main/java/google/registry/beam/initsql/BeamJpaModule.java index 31555d19a..a757d4859 100644 --- a/core/src/main/java/google/registry/beam/initsql/BeamJpaModule.java +++ b/core/src/main/java/google/registry/beam/initsql/BeamJpaModule.java @@ -19,7 +19,6 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Strings.isNullOrEmpty; import com.google.common.base.Splitter; -import dagger.Binds; import dagger.Component; import dagger.Lazy; import dagger.Module; @@ -32,10 +31,6 @@ import google.registry.persistence.PersistenceModule; import google.registry.persistence.PersistenceModule.JdbcJpaTm; import google.registry.persistence.PersistenceModule.SocketFactoryJpaTm; import google.registry.persistence.transaction.JpaTransactionManager; -import google.registry.util.Clock; -import google.registry.util.Sleeper; -import google.registry.util.SystemClock; -import google.registry.util.SystemSleeper; import google.registry.util.UtilsModule; import java.io.BufferedReader; import java.io.IOException; @@ -159,19 +154,10 @@ public class BeamJpaModule { @Provides @Config("beamHibernateHikariMaximumPoolSize") static int getBeamHibernateHikariMaximumPoolSize() { + // TODO(weiminyu): make this configurable. Should be equal to number of cores. return 4; } - @Module - interface BindModule { - - @Binds - Sleeper sleeper(SystemSleeper sleeper); - - @Binds - Clock clock(SystemClock clock); - } - @Singleton @Component( modules = { diff --git a/core/src/main/java/google/registry/beam/initsql/Transforms.java b/core/src/main/java/google/registry/beam/initsql/Transforms.java index 1afb034c5..7f302cfb3 100644 --- a/core/src/main/java/google/registry/beam/initsql/Transforms.java +++ b/core/src/main/java/google/registry/beam/initsql/Transforms.java @@ -19,24 +19,37 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; import static google.registry.beam.initsql.BackupPaths.getCommitLogTimestamp; import static google.registry.beam.initsql.BackupPaths.getExportFilePatterns; +import static google.registry.persistence.transaction.TransactionManagerFactory.jpaTm; +import static google.registry.persistence.transaction.TransactionManagerFactory.setJpaTm; import static google.registry.util.DateTimeUtils.START_OF_TIME; import static google.registry.util.DateTimeUtils.isBeforeOrAt; import static java.util.Comparator.comparing; +import static org.apache.beam.sdk.values.TypeDescriptors.integers; import static org.apache.beam.sdk.values.TypeDescriptors.kvs; import static org.apache.beam.sdk.values.TypeDescriptors.strings; import avro.shaded.com.google.common.collect.Iterators; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Throwables; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Streams; +import google.registry.backup.AppEngineEnvironment; import google.registry.backup.CommitLogImports; import google.registry.backup.VersionedEntity; +import google.registry.model.ofy.ObjectifyService; +import google.registry.model.ofy.Ofy; +import google.registry.persistence.transaction.JpaTransactionManager; import google.registry.tools.LevelDbLogReader; +import google.registry.util.SystemSleeper; +import java.io.Serializable; import java.util.Collection; import java.util.Iterator; import java.util.Optional; import java.util.Set; +import java.util.concurrent.ThreadLocalRandom; +import java.util.function.Supplier; +import javax.persistence.OptimisticLockException; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.io.Compression; import org.apache.beam.sdk.io.FileIO; @@ -47,6 +60,7 @@ import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.Flatten; import org.apache.beam.sdk.transforms.GroupByKey; +import org.apache.beam.sdk.transforms.GroupIntoBatches; import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; @@ -56,10 +70,12 @@ import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionList; import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.PDone; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TupleTagList; import org.apache.beam.sdk.values.TypeDescriptor; import org.joda.time.DateTime; +import org.joda.time.Duration; /** * {@link PTransform Pipeline transforms} used in pipelines that load from both Datastore export @@ -245,6 +261,38 @@ public final class Transforms { .iterator())); } + /** + * Returns a {@link PTransform} that writes a {@link PCollection} of entities to a SQL database. + * + * @param transformId a unique ID for an instance of the returned transform + * @param maxWriters the max number of concurrent writes to SQL, which also determines the max + * number of connection pools created + * @param batchSize the number of entities to write in each operation + * @param jpaSupplier supplier of a {@link JpaTransactionManager} + */ + public static PTransform, PDone> writeToSql( + String transformId, + int maxWriters, + int batchSize, + SerializableSupplier jpaSupplier) { + return new PTransform, PDone>() { + @Override + public PDone expand(PCollection input) { + input + .apply( + "Shard data for " + transformId, + MapElements.into(kvs(integers(), TypeDescriptor.of(VersionedEntity.class))) + .via(ve -> KV.of(ThreadLocalRandom.current().nextInt(maxWriters), ve))) + .apply("Batch output by shard " + transformId, GroupIntoBatches.ofSize(batchSize)) + .apply("Write in batch for " + transformId, ParDo.of(new SqlBatchWriter(jpaSupplier))); + return PDone.in(input.getPipeline()); + } + }; + } + + /** Interface for serializable {@link Supplier suppliers}. */ + public interface SerializableSupplier extends Supplier, Serializable {} + /** * Returns a {@link PTransform} that produces a {@link PCollection} containing all elements in the * given {@link Iterable}. @@ -322,4 +370,104 @@ public final class Transforms { } } } + + /** + * Writes a batch of entities to a SQL database. + * + *

Note that an arbitrary number of instances of this class may be created and freed in + * arbitrary order in a single JVM. Due to the tech debt that forced us to use a static variable + * to hold the {@code JpaTransactionManager} instance, we must ensure that JpaTransactionManager + * is not changed or torn down while being used by some instance. + */ + private static class SqlBatchWriter extends DoFn>, Void> { + + private static int instanceCount = 0; + private static JpaTransactionManager originalJpa; + + private final SerializableSupplier jpaSupplier; + + private transient Ofy ofy; + private transient SystemSleeper sleeper; + + SqlBatchWriter(SerializableSupplier jpaSupplier) { + this.jpaSupplier = jpaSupplier; + } + + @Setup + public void setup() { + sleeper = new SystemSleeper(); + + ObjectifyService.initOfy(); + ofy = ObjectifyService.ofy(); + + synchronized (SqlBatchWriter.class) { + if (instanceCount == 0) { + originalJpa = jpaTm(); + setJpaTm(jpaSupplier); + } + instanceCount++; + } + } + + @Teardown + public void teardown() { + synchronized (SqlBatchWriter.class) { + instanceCount--; + if (instanceCount == 0) { + jpaTm().teardown(); + setJpaTm(() -> originalJpa); + } + } + } + + @ProcessElement + public void processElement(@Element KV> kv) { + try (AppEngineEnvironment env = new AppEngineEnvironment()) { + ImmutableList ofyEntities = + Streams.stream(kv.getValue()) + .map(VersionedEntity::getEntity) + .map(Optional::get) + .map(ofy::toPojo) + .collect(ImmutableList.toImmutableList()); + retry(() -> jpaTm().transact(() -> jpaTm().saveNewOrUpdateAll(ofyEntities))); + } + } + + // TODO(b/160632289): Enhance Retrier and use it here. + private void retry(Runnable runnable) { + int maxAttempts = 5; + int initialDelayMillis = 100; + double jitterRatio = 0.2; + + for (int attempt = 0; attempt < maxAttempts; attempt++) { + try { + runnable.run(); + return; + } catch (Throwable throwable) { + throwIfNotCausedBy(throwable, OptimisticLockException.class); + int sleepMillis = (1 << attempt) * initialDelayMillis; + int jitter = + ThreadLocalRandom.current().nextInt((int) (sleepMillis * jitterRatio)) + - (int) (sleepMillis * jitterRatio / 2); + sleeper.sleepUninterruptibly(Duration.millis(sleepMillis + jitter)); + } + } + } + + /** + * Rethrows {@code throwable} if it is not (and does not have a cause of) {@code causeType}; + * otherwise returns with no side effects. + */ + private void throwIfNotCausedBy(Throwable throwable, Class causeType) { + Throwable t = throwable; + while (t != null) { + if (causeType.isInstance(t)) { + return; + } + t = t.getCause(); + } + Throwables.throwIfUnchecked(t); + throw new RuntimeException(t); + } + } } diff --git a/core/src/main/java/google/registry/persistence/transaction/JpaTransactionManager.java b/core/src/main/java/google/registry/persistence/transaction/JpaTransactionManager.java index 0cf189ab0..281482a8a 100644 --- a/core/src/main/java/google/registry/persistence/transaction/JpaTransactionManager.java +++ b/core/src/main/java/google/registry/persistence/transaction/JpaTransactionManager.java @@ -25,4 +25,11 @@ public interface JpaTransactionManager extends TransactionManager { /** Deletes the entity by its id, throws exception if the entity is not deleted. */ public abstract void assertDelete(VKey key); + + /** + * Releases all resources and shuts down. + * + *

The errorprone check forbids injection of {@link java.io.Closeable} resources. + */ + void teardown(); } diff --git a/core/src/main/java/google/registry/persistence/transaction/JpaTransactionManagerImpl.java b/core/src/main/java/google/registry/persistence/transaction/JpaTransactionManagerImpl.java index d3aa5f19c..a1307210d 100644 --- a/core/src/main/java/google/registry/persistence/transaction/JpaTransactionManagerImpl.java +++ b/core/src/main/java/google/registry/persistence/transaction/JpaTransactionManagerImpl.java @@ -62,6 +62,11 @@ public class JpaTransactionManagerImpl implements JpaTransactionManager { this.clock = clock; } + @Override + public void teardown() { + emf.close(); + } + @Override public EntityManager getEntityManager() { if (transactionInfo.get().entityManager == null) { diff --git a/core/src/main/java/google/registry/persistence/transaction/TransactionManagerFactory.java b/core/src/main/java/google/registry/persistence/transaction/TransactionManagerFactory.java index d03548a73..d43f0e1a7 100644 --- a/core/src/main/java/google/registry/persistence/transaction/TransactionManagerFactory.java +++ b/core/src/main/java/google/registry/persistence/transaction/TransactionManagerFactory.java @@ -75,7 +75,12 @@ public class TransactionManagerFactory { return tm; } - /** Returns {@link JpaTransactionManager} instance. */ + /** + * Returns {@link JpaTransactionManager} instance. + * + *

Between invocations of {@link TransactionManagerFactory#setJpaTm} every call to this method + * returns the same instance. + */ public static JpaTransactionManager jpaTm() { return jpaTm.get(); } @@ -93,7 +98,7 @@ public class TransactionManagerFactory { RegistryEnvironment.get().equals(RegistryEnvironment.UNITTEST) || RegistryToolEnvironment.get() != null, "setJpamTm() should only be called by tools and tests."); - jpaTm = jpaTmSupplier; + jpaTm = Suppliers.memoize(jpaTmSupplier::get); } /** Sets the return of {@link #tm()} to the given instance of {@link TransactionManager}. */ diff --git a/core/src/test/java/google/registry/beam/initsql/ExportloadingTransformsTest.java b/core/src/test/java/google/registry/beam/initsql/ExportloadingTransformsTest.java index 8a19ef600..b83a6cdb5 100644 --- a/core/src/test/java/google/registry/beam/initsql/ExportloadingTransformsTest.java +++ b/core/src/test/java/google/registry/beam/initsql/ExportloadingTransformsTest.java @@ -173,6 +173,7 @@ public class ExportloadingTransformsTest implements Serializable { } @Test + @Category(NeedsRunner.class) public void loadDataFromFiles() { PCollection entities = pipeline diff --git a/core/src/test/java/google/registry/beam/initsql/LoadDatastoreSnapshotTest.java b/core/src/test/java/google/registry/beam/initsql/LoadDatastoreSnapshotTest.java index 44599bcd9..8399d8dcd 100644 --- a/core/src/test/java/google/registry/beam/initsql/LoadDatastoreSnapshotTest.java +++ b/core/src/test/java/google/registry/beam/initsql/LoadDatastoreSnapshotTest.java @@ -31,6 +31,7 @@ import google.registry.model.registry.Registry; import google.registry.testing.FakeClock; import google.registry.testing.InjectRule; import java.io.File; +import org.apache.beam.sdk.testing.NeedsRunner; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollectionTuple; @@ -38,6 +39,7 @@ import org.joda.time.DateTime; import org.junit.Before; import org.junit.Rule; import org.junit.Test; +import org.junit.experimental.categories.Category; import org.junit.rules.TemporaryFolder; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -152,6 +154,7 @@ public class LoadDatastoreSnapshotTest { } @Test + @Category(NeedsRunner.class) public void loadDatastoreSnapshot() { PCollectionTuple snapshot = pipeline.apply( diff --git a/core/src/test/java/google/registry/beam/initsql/WriteToSqlTest.java b/core/src/test/java/google/registry/beam/initsql/WriteToSqlTest.java new file mode 100644 index 000000000..0dbb28aef --- /dev/null +++ b/core/src/test/java/google/registry/beam/initsql/WriteToSqlTest.java @@ -0,0 +1,125 @@ +// Copyright 2020 The Nomulus Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package google.registry.beam.initsql; + +import static com.google.common.truth.Truth.assertThat; +import static google.registry.persistence.transaction.TransactionManagerFactory.jpaTm; + +import com.google.appengine.api.datastore.Entity; +import com.google.common.collect.ImmutableList; +import google.registry.backup.VersionedEntity; +import google.registry.model.contact.ContactResource; +import google.registry.model.ofy.Ofy; +import google.registry.model.registrar.Registrar; +import google.registry.persistence.transaction.JpaTestRules; +import google.registry.persistence.transaction.JpaTestRules.JpaIntegrationTestRule; +import google.registry.testing.AppEngineRule; +import google.registry.testing.DatastoreHelper; +import google.registry.testing.FakeClock; +import google.registry.testing.InjectRule; +import java.io.File; +import java.io.PrintStream; +import java.io.Serializable; +import java.util.stream.Collectors; +import org.apache.beam.sdk.testing.NeedsRunner; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.joda.time.DateTime; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit test for {@link Transforms#writeToSql}. */ +@RunWith(JUnit4.class) +public class WriteToSqlTest implements Serializable { + private static final DateTime START_TIME = DateTime.parse("2000-01-01T00:00:00.0Z"); + + private final FakeClock fakeClock = new FakeClock(START_TIME); + + @Rule public final transient InjectRule injectRule = new InjectRule(); + + @Rule + public transient JpaIntegrationTestRule jpaRule = + new JpaTestRules.Builder().withClock(fakeClock).buildIntegrationTestRule(); + + @Rule public transient TemporaryFolder temporaryFolder = new TemporaryFolder(); + + @Rule + public final transient TestPipeline pipeline = + TestPipeline.create().enableAbandonedNodeEnforcement(true); + + private ImmutableList contacts; + + private File credentialFile; + + @Before + public void beforeEach() throws Exception { + try (BackupTestStore store = new BackupTestStore(fakeClock)) { + injectRule.setStaticField(Ofy.class, "clock", fakeClock); + + // Required for contacts created below. + Registrar ofyRegistrar = AppEngineRule.makeRegistrar2(); + store.insertOrUpdate(ofyRegistrar); + jpaTm().transact(() -> jpaTm().saveNewOrUpdate(store.loadAsOfyEntity(ofyRegistrar))); + + ImmutableList.Builder builder = new ImmutableList.Builder<>(); + + for (int i = 0; i < 3; i++) { + ContactResource contact = DatastoreHelper.newContactResource("contact_" + i); + store.insertOrUpdate(contact); + builder.add(store.loadAsDatastoreEntity(contact)); + } + contacts = builder.build(); + } + credentialFile = temporaryFolder.newFile(); + new PrintStream(credentialFile) + .printf( + "%s %s %s", + jpaRule.getDatabaseUrl(), jpaRule.getDatabaseUsername(), jpaRule.getDatabasePassword()) + .close(); + } + + @Test + @Category(NeedsRunner.class) + public void writeToSql_twoWriters() { + pipeline + .apply( + Create.of( + contacts.stream() + .map(InitSqlTestUtils::entityToBytes) + .map(bytes -> VersionedEntity.from(0L, bytes)) + .collect(Collectors.toList()))) + .apply( + Transforms.writeToSql( + "ContactResource", + 2, + 4, + () -> + DaggerBeamJpaModule_JpaTransactionManagerComponent.builder() + .beamJpaModule(new BeamJpaModule(credentialFile.getAbsolutePath())) + .build() + .localDbJpaTransactionManager())); + pipeline.run().waitUntilFinish(); + + ImmutableList sqlContacts = jpaTm().transact(() -> jpaTm().loadAll(ContactResource.class)); + // TODO(weiminyu): compare load entities with originals. Note: lastUpdateTimes won't match by + // design. Need an elegant way to deal with this.bbq + assertThat(sqlContacts).hasSize(3); + } +} diff --git a/core/src/test/java/google/registry/persistence/transaction/JpaTransactionManagerRule.java b/core/src/test/java/google/registry/persistence/transaction/JpaTransactionManagerRule.java index c9972a5e7..33d01d446 100644 --- a/core/src/test/java/google/registry/persistence/transaction/JpaTransactionManagerRule.java +++ b/core/src/test/java/google/registry/persistence/transaction/JpaTransactionManagerRule.java @@ -209,6 +209,18 @@ abstract class JpaTransactionManagerRule extends ExternalResource { cachedTm = null; } + public String getDatabaseUrl() { + return database.getJdbcUrl(); + } + + public String getDatabaseUsername() { + return database.getUsername(); + } + + public String getDatabasePassword() { + return database.getPassword(); + } + private void resetTablesAndSequences() { try (Connection conn = createConnection(); Statement statement = conn.createStatement()) {