Write one PCollection to SQL (#664)

* Write one PCollection to SQL

Defined a transform that writes a PCollection of entities to SQL using
JPA. Allows configuring parallelism level and batch size.
This commit is contained in:
Weimin Yu 2020-07-13 13:34:01 -04:00 committed by GitHub
parent 58618a274e
commit ba1915e271
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 309 additions and 17 deletions

View file

@ -19,7 +19,6 @@ import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Strings.isNullOrEmpty; import static com.google.common.base.Strings.isNullOrEmpty;
import com.google.common.base.Splitter; import com.google.common.base.Splitter;
import dagger.Binds;
import dagger.Component; import dagger.Component;
import dagger.Lazy; import dagger.Lazy;
import dagger.Module; import dagger.Module;
@ -32,10 +31,6 @@ import google.registry.persistence.PersistenceModule;
import google.registry.persistence.PersistenceModule.JdbcJpaTm; import google.registry.persistence.PersistenceModule.JdbcJpaTm;
import google.registry.persistence.PersistenceModule.SocketFactoryJpaTm; import google.registry.persistence.PersistenceModule.SocketFactoryJpaTm;
import google.registry.persistence.transaction.JpaTransactionManager; import google.registry.persistence.transaction.JpaTransactionManager;
import google.registry.util.Clock;
import google.registry.util.Sleeper;
import google.registry.util.SystemClock;
import google.registry.util.SystemSleeper;
import google.registry.util.UtilsModule; import google.registry.util.UtilsModule;
import java.io.BufferedReader; import java.io.BufferedReader;
import java.io.IOException; import java.io.IOException;
@ -159,19 +154,10 @@ public class BeamJpaModule {
@Provides @Provides
@Config("beamHibernateHikariMaximumPoolSize") @Config("beamHibernateHikariMaximumPoolSize")
static int getBeamHibernateHikariMaximumPoolSize() { static int getBeamHibernateHikariMaximumPoolSize() {
// TODO(weiminyu): make this configurable. Should be equal to number of cores.
return 4; return 4;
} }
@Module
interface BindModule {
@Binds
Sleeper sleeper(SystemSleeper sleeper);
@Binds
Clock clock(SystemClock clock);
}
@Singleton @Singleton
@Component( @Component(
modules = { modules = {

View file

@ -19,24 +19,37 @@ import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Preconditions.checkState;
import static google.registry.beam.initsql.BackupPaths.getCommitLogTimestamp; import static google.registry.beam.initsql.BackupPaths.getCommitLogTimestamp;
import static google.registry.beam.initsql.BackupPaths.getExportFilePatterns; import static google.registry.beam.initsql.BackupPaths.getExportFilePatterns;
import static google.registry.persistence.transaction.TransactionManagerFactory.jpaTm;
import static google.registry.persistence.transaction.TransactionManagerFactory.setJpaTm;
import static google.registry.util.DateTimeUtils.START_OF_TIME; import static google.registry.util.DateTimeUtils.START_OF_TIME;
import static google.registry.util.DateTimeUtils.isBeforeOrAt; import static google.registry.util.DateTimeUtils.isBeforeOrAt;
import static java.util.Comparator.comparing; import static java.util.Comparator.comparing;
import static org.apache.beam.sdk.values.TypeDescriptors.integers;
import static org.apache.beam.sdk.values.TypeDescriptors.kvs; import static org.apache.beam.sdk.values.TypeDescriptors.kvs;
import static org.apache.beam.sdk.values.TypeDescriptors.strings; import static org.apache.beam.sdk.values.TypeDescriptors.strings;
import avro.shaded.com.google.common.collect.Iterators; import avro.shaded.com.google.common.collect.Iterators;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Streams; import com.google.common.collect.Streams;
import google.registry.backup.AppEngineEnvironment;
import google.registry.backup.CommitLogImports; import google.registry.backup.CommitLogImports;
import google.registry.backup.VersionedEntity; import google.registry.backup.VersionedEntity;
import google.registry.model.ofy.ObjectifyService;
import google.registry.model.ofy.Ofy;
import google.registry.persistence.transaction.JpaTransactionManager;
import google.registry.tools.LevelDbLogReader; import google.registry.tools.LevelDbLogReader;
import google.registry.util.SystemSleeper;
import java.io.Serializable;
import java.util.Collection; import java.util.Collection;
import java.util.Iterator; import java.util.Iterator;
import java.util.Optional; import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.Supplier;
import javax.persistence.OptimisticLockException;
import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.io.Compression; import org.apache.beam.sdk.io.Compression;
import org.apache.beam.sdk.io.FileIO; import org.apache.beam.sdk.io.FileIO;
@ -47,6 +60,7 @@ import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.Flatten; import org.apache.beam.sdk.transforms.Flatten;
import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.GroupIntoBatches;
import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.ParDo;
@ -56,10 +70,12 @@ import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionList; import org.apache.beam.sdk.values.PCollectionList;
import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.PDone;
import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TupleTagList; import org.apache.beam.sdk.values.TupleTagList;
import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.sdk.values.TypeDescriptor;
import org.joda.time.DateTime; import org.joda.time.DateTime;
import org.joda.time.Duration;
/** /**
* {@link PTransform Pipeline transforms} used in pipelines that load from both Datastore export * {@link PTransform Pipeline transforms} used in pipelines that load from both Datastore export
@ -245,6 +261,38 @@ public final class Transforms {
.iterator())); .iterator()));
} }
/**
* Returns a {@link PTransform} that writes a {@link PCollection} of entities to a SQL database.
*
* @param transformId a unique ID for an instance of the returned transform
* @param maxWriters the max number of concurrent writes to SQL, which also determines the max
* number of connection pools created
* @param batchSize the number of entities to write in each operation
* @param jpaSupplier supplier of a {@link JpaTransactionManager}
*/
public static PTransform<PCollection<VersionedEntity>, PDone> writeToSql(
String transformId,
int maxWriters,
int batchSize,
SerializableSupplier<JpaTransactionManager> jpaSupplier) {
return new PTransform<PCollection<VersionedEntity>, PDone>() {
@Override
public PDone expand(PCollection<VersionedEntity> input) {
input
.apply(
"Shard data for " + transformId,
MapElements.into(kvs(integers(), TypeDescriptor.of(VersionedEntity.class)))
.via(ve -> KV.of(ThreadLocalRandom.current().nextInt(maxWriters), ve)))
.apply("Batch output by shard " + transformId, GroupIntoBatches.ofSize(batchSize))
.apply("Write in batch for " + transformId, ParDo.of(new SqlBatchWriter(jpaSupplier)));
return PDone.in(input.getPipeline());
}
};
}
/** Interface for serializable {@link Supplier suppliers}. */
public interface SerializableSupplier<T> extends Supplier<T>, Serializable {}
/** /**
* Returns a {@link PTransform} that produces a {@link PCollection} containing all elements in the * Returns a {@link PTransform} that produces a {@link PCollection} containing all elements in the
* given {@link Iterable}. * given {@link Iterable}.
@ -322,4 +370,104 @@ public final class Transforms {
} }
} }
} }
/**
* Writes a batch of entities to a SQL database.
*
* <p>Note that an arbitrary number of instances of this class may be created and freed in
* arbitrary order in a single JVM. Due to the tech debt that forced us to use a static variable
* to hold the {@code JpaTransactionManager} instance, we must ensure that JpaTransactionManager
* is not changed or torn down while being used by some instance.
*/
private static class SqlBatchWriter extends DoFn<KV<Integer, Iterable<VersionedEntity>>, Void> {
private static int instanceCount = 0;
private static JpaTransactionManager originalJpa;
private final SerializableSupplier<JpaTransactionManager> jpaSupplier;
private transient Ofy ofy;
private transient SystemSleeper sleeper;
SqlBatchWriter(SerializableSupplier<JpaTransactionManager> jpaSupplier) {
this.jpaSupplier = jpaSupplier;
}
@Setup
public void setup() {
sleeper = new SystemSleeper();
ObjectifyService.initOfy();
ofy = ObjectifyService.ofy();
synchronized (SqlBatchWriter.class) {
if (instanceCount == 0) {
originalJpa = jpaTm();
setJpaTm(jpaSupplier);
}
instanceCount++;
}
}
@Teardown
public void teardown() {
synchronized (SqlBatchWriter.class) {
instanceCount--;
if (instanceCount == 0) {
jpaTm().teardown();
setJpaTm(() -> originalJpa);
}
}
}
@ProcessElement
public void processElement(@Element KV<Integer, Iterable<VersionedEntity>> kv) {
try (AppEngineEnvironment env = new AppEngineEnvironment()) {
ImmutableList<Object> ofyEntities =
Streams.stream(kv.getValue())
.map(VersionedEntity::getEntity)
.map(Optional::get)
.map(ofy::toPojo)
.collect(ImmutableList.toImmutableList());
retry(() -> jpaTm().transact(() -> jpaTm().saveNewOrUpdateAll(ofyEntities)));
}
}
// TODO(b/160632289): Enhance Retrier and use it here.
private void retry(Runnable runnable) {
int maxAttempts = 5;
int initialDelayMillis = 100;
double jitterRatio = 0.2;
for (int attempt = 0; attempt < maxAttempts; attempt++) {
try {
runnable.run();
return;
} catch (Throwable throwable) {
throwIfNotCausedBy(throwable, OptimisticLockException.class);
int sleepMillis = (1 << attempt) * initialDelayMillis;
int jitter =
ThreadLocalRandom.current().nextInt((int) (sleepMillis * jitterRatio))
- (int) (sleepMillis * jitterRatio / 2);
sleeper.sleepUninterruptibly(Duration.millis(sleepMillis + jitter));
}
}
}
/**
* Rethrows {@code throwable} if it is not (and does not have a cause of) {@code causeType};
* otherwise returns with no side effects.
*/
private void throwIfNotCausedBy(Throwable throwable, Class<? extends Throwable> causeType) {
Throwable t = throwable;
while (t != null) {
if (causeType.isInstance(t)) {
return;
}
t = t.getCause();
}
Throwables.throwIfUnchecked(t);
throw new RuntimeException(t);
}
}
} }

View file

@ -25,4 +25,11 @@ public interface JpaTransactionManager extends TransactionManager {
/** Deletes the entity by its id, throws exception if the entity is not deleted. */ /** Deletes the entity by its id, throws exception if the entity is not deleted. */
public abstract <T> void assertDelete(VKey<T> key); public abstract <T> void assertDelete(VKey<T> key);
/**
* Releases all resources and shuts down.
*
* <p>The errorprone check forbids injection of {@link java.io.Closeable} resources.
*/
void teardown();
} }

View file

@ -62,6 +62,11 @@ public class JpaTransactionManagerImpl implements JpaTransactionManager {
this.clock = clock; this.clock = clock;
} }
@Override
public void teardown() {
emf.close();
}
@Override @Override
public EntityManager getEntityManager() { public EntityManager getEntityManager() {
if (transactionInfo.get().entityManager == null) { if (transactionInfo.get().entityManager == null) {

View file

@ -75,7 +75,12 @@ public class TransactionManagerFactory {
return tm; return tm;
} }
/** Returns {@link JpaTransactionManager} instance. */ /**
* Returns {@link JpaTransactionManager} instance.
*
* <p>Between invocations of {@link TransactionManagerFactory#setJpaTm} every call to this method
* returns the same instance.
*/
public static JpaTransactionManager jpaTm() { public static JpaTransactionManager jpaTm() {
return jpaTm.get(); return jpaTm.get();
} }
@ -93,7 +98,7 @@ public class TransactionManagerFactory {
RegistryEnvironment.get().equals(RegistryEnvironment.UNITTEST) RegistryEnvironment.get().equals(RegistryEnvironment.UNITTEST)
|| RegistryToolEnvironment.get() != null, || RegistryToolEnvironment.get() != null,
"setJpamTm() should only be called by tools and tests."); "setJpamTm() should only be called by tools and tests.");
jpaTm = jpaTmSupplier; jpaTm = Suppliers.memoize(jpaTmSupplier::get);
} }
/** Sets the return of {@link #tm()} to the given instance of {@link TransactionManager}. */ /** Sets the return of {@link #tm()} to the given instance of {@link TransactionManager}. */

View file

@ -173,6 +173,7 @@ public class ExportloadingTransformsTest implements Serializable {
} }
@Test @Test
@Category(NeedsRunner.class)
public void loadDataFromFiles() { public void loadDataFromFiles() {
PCollection<VersionedEntity> entities = PCollection<VersionedEntity> entities =
pipeline pipeline

View file

@ -31,6 +31,7 @@ import google.registry.model.registry.Registry;
import google.registry.testing.FakeClock; import google.registry.testing.FakeClock;
import google.registry.testing.InjectRule; import google.registry.testing.InjectRule;
import java.io.File; import java.io.File;
import org.apache.beam.sdk.testing.NeedsRunner;
import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.PCollectionTuple;
@ -38,6 +39,7 @@ import org.joda.time.DateTime;
import org.junit.Before; import org.junit.Before;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.junit.rules.TemporaryFolder; import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.junit.runners.JUnit4; import org.junit.runners.JUnit4;
@ -152,6 +154,7 @@ public class LoadDatastoreSnapshotTest {
} }
@Test @Test
@Category(NeedsRunner.class)
public void loadDatastoreSnapshot() { public void loadDatastoreSnapshot() {
PCollectionTuple snapshot = PCollectionTuple snapshot =
pipeline.apply( pipeline.apply(

View file

@ -0,0 +1,125 @@
// Copyright 2020 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.beam.initsql;
import static com.google.common.truth.Truth.assertThat;
import static google.registry.persistence.transaction.TransactionManagerFactory.jpaTm;
import com.google.appengine.api.datastore.Entity;
import com.google.common.collect.ImmutableList;
import google.registry.backup.VersionedEntity;
import google.registry.model.contact.ContactResource;
import google.registry.model.ofy.Ofy;
import google.registry.model.registrar.Registrar;
import google.registry.persistence.transaction.JpaTestRules;
import google.registry.persistence.transaction.JpaTestRules.JpaIntegrationTestRule;
import google.registry.testing.AppEngineRule;
import google.registry.testing.DatastoreHelper;
import google.registry.testing.FakeClock;
import google.registry.testing.InjectRule;
import java.io.File;
import java.io.PrintStream;
import java.io.Serializable;
import java.util.stream.Collectors;
import org.apache.beam.sdk.testing.NeedsRunner;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Create;
import org.joda.time.DateTime;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit test for {@link Transforms#writeToSql}. */
@RunWith(JUnit4.class)
public class WriteToSqlTest implements Serializable {
private static final DateTime START_TIME = DateTime.parse("2000-01-01T00:00:00.0Z");
private final FakeClock fakeClock = new FakeClock(START_TIME);
@Rule public final transient InjectRule injectRule = new InjectRule();
@Rule
public transient JpaIntegrationTestRule jpaRule =
new JpaTestRules.Builder().withClock(fakeClock).buildIntegrationTestRule();
@Rule public transient TemporaryFolder temporaryFolder = new TemporaryFolder();
@Rule
public final transient TestPipeline pipeline =
TestPipeline.create().enableAbandonedNodeEnforcement(true);
private ImmutableList<Entity> contacts;
private File credentialFile;
@Before
public void beforeEach() throws Exception {
try (BackupTestStore store = new BackupTestStore(fakeClock)) {
injectRule.setStaticField(Ofy.class, "clock", fakeClock);
// Required for contacts created below.
Registrar ofyRegistrar = AppEngineRule.makeRegistrar2();
store.insertOrUpdate(ofyRegistrar);
jpaTm().transact(() -> jpaTm().saveNewOrUpdate(store.loadAsOfyEntity(ofyRegistrar)));
ImmutableList.Builder<Entity> builder = new ImmutableList.Builder<>();
for (int i = 0; i < 3; i++) {
ContactResource contact = DatastoreHelper.newContactResource("contact_" + i);
store.insertOrUpdate(contact);
builder.add(store.loadAsDatastoreEntity(contact));
}
contacts = builder.build();
}
credentialFile = temporaryFolder.newFile();
new PrintStream(credentialFile)
.printf(
"%s %s %s",
jpaRule.getDatabaseUrl(), jpaRule.getDatabaseUsername(), jpaRule.getDatabasePassword())
.close();
}
@Test
@Category(NeedsRunner.class)
public void writeToSql_twoWriters() {
pipeline
.apply(
Create.of(
contacts.stream()
.map(InitSqlTestUtils::entityToBytes)
.map(bytes -> VersionedEntity.from(0L, bytes))
.collect(Collectors.toList())))
.apply(
Transforms.writeToSql(
"ContactResource",
2,
4,
() ->
DaggerBeamJpaModule_JpaTransactionManagerComponent.builder()
.beamJpaModule(new BeamJpaModule(credentialFile.getAbsolutePath()))
.build()
.localDbJpaTransactionManager()));
pipeline.run().waitUntilFinish();
ImmutableList<?> sqlContacts = jpaTm().transact(() -> jpaTm().loadAll(ContactResource.class));
// TODO(weiminyu): compare load entities with originals. Note: lastUpdateTimes won't match by
// design. Need an elegant way to deal with this.bbq
assertThat(sqlContacts).hasSize(3);
}
}

View file

@ -209,6 +209,18 @@ abstract class JpaTransactionManagerRule extends ExternalResource {
cachedTm = null; cachedTm = null;
} }
public String getDatabaseUrl() {
return database.getJdbcUrl();
}
public String getDatabaseUsername() {
return database.getUsername();
}
public String getDatabasePassword() {
return database.getPassword();
}
private void resetTablesAndSequences() { private void resetTablesAndSequences() {
try (Connection conn = createConnection(); try (Connection conn = createConnection();
Statement statement = conn.createStatement()) { Statement statement = conn.createStatement()) {