From 2c6ee6dae90769a836748a4bffee454a5df54428 Mon Sep 17 00:00:00 2001 From: gbrodman Date: Wed, 9 Dec 2020 15:52:56 -0500 Subject: [PATCH] Parameterize the serialization of objects being written to SQL (#892) * Parameterize the serialization of objects being written to SQL We shouldn't require that objects written to SQL during a Beam pipeline be VersionedEntity objects -- they may be non-Objectify entities. As a result, we should allow the user to specify what the objects are that should be written to SQL. Note: we will need to clean up the Spec11PipelineTest more but that can be out of the scope of this PR. * Overload the method and add a bit of javadoc * Actually use the overloaded function --- .../registry/beam/initsql/Transforms.java | 67 +++++++++++++++---- .../registry/beam/spec11/Spec11Pipeline.java | 46 ++++++------- .../beam/spec11/Spec11PipelineTest.java | 4 +- 3 files changed, 74 insertions(+), 43 deletions(-) diff --git a/core/src/main/java/google/registry/beam/initsql/Transforms.java b/core/src/main/java/google/registry/beam/initsql/Transforms.java index 78e61cf55..0a98d9abb 100644 --- a/core/src/main/java/google/registry/beam/initsql/Transforms.java +++ b/core/src/main/java/google/registry/beam/initsql/Transforms.java @@ -20,6 +20,7 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Throwables.throwIfUnchecked; import static google.registry.beam.initsql.BackupPaths.getCommitLogTimestamp; import static google.registry.beam.initsql.BackupPaths.getExportFilePatterns; +import static google.registry.model.ofy.ObjectifyService.ofy; import static google.registry.persistence.JpaRetries.isFailedTxnRetriable; import static google.registry.persistence.transaction.TransactionManagerFactory.jpaTm; import static google.registry.persistence.transaction.TransactionManagerFactory.setJpaTm; @@ -68,6 +69,7 @@ import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.ProcessFunction; +import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; @@ -264,9 +266,9 @@ public final class Transforms { } /** - * Returns a {@link PTransform} that writes a {@link PCollection} of entities to a SQL database. - * and outputs an empty {@code PCollection}. This allows other operations to {@link - * org.apache.beam.sdk.transforms.Wait wait} for the completion of this transform. + * Returns a {@link PTransform} that writes a {@link PCollection} of {@link VersionedEntity}s to a + * SQL database. and outputs an empty {@code PCollection}. This allows other operations to + * {@link org.apache.beam.sdk.transforms.Wait wait} for the completion of this transform. * *

Errors are handled according to the pipeline runner's default policy. As part of a one-time * job, we will not add features unless proven necessary. @@ -282,16 +284,53 @@ public final class Transforms { int maxWriters, int batchSize, SerializableSupplier jpaSupplier) { - return new PTransform, PCollection>() { + return writeToSql( + transformId, + maxWriters, + batchSize, + jpaSupplier, + (e) -> ofy().toPojo(e.getEntity().get()), + TypeDescriptor.of(VersionedEntity.class)); + } + + /** + * Returns a {@link PTransform} that writes a {@link PCollection} of entities to a SQL database. + * and outputs an empty {@code PCollection}. This allows other operations to {@link + * org.apache.beam.sdk.transforms.Wait wait} for the completion of this transform. + * + *

The converter and type descriptor are generics so that we can convert any type of entity to + * an object to be placed in SQL. + * + *

Errors are handled according to the pipeline runner's default policy. As part of a one-time + * job, we will not add features unless proven necessary. + * + * @param transformId a unique ID for an instance of the returned transform + * @param maxWriters the max number of concurrent writes to SQL, which also determines the max + * number of connection pools created + * @param batchSize the number of entities to write in each operation + * @param jpaSupplier supplier of a {@link JpaTransactionManager} + * @param jpaConverter the function that converts the input object to a JPA entity + * @param objectDescriptor the type descriptor of the input object + */ + public static PTransform, PCollection> writeToSql( + String transformId, + int maxWriters, + int batchSize, + SerializableSupplier jpaSupplier, + SerializableFunction jpaConverter, + TypeDescriptor objectDescriptor) { + return new PTransform, PCollection>() { @Override - public PCollection expand(PCollection input) { + public PCollection expand(PCollection input) { return input .apply( "Shard data for " + transformId, - MapElements.into(kvs(integers(), TypeDescriptor.of(VersionedEntity.class))) + MapElements.into(kvs(integers(), objectDescriptor)) .via(ve -> KV.of(ThreadLocalRandom.current().nextInt(maxWriters), ve))) .apply("Batch output by shard " + transformId, GroupIntoBatches.ofSize(batchSize)) - .apply("Write in batch for " + transformId, ParDo.of(new SqlBatchWriter(jpaSupplier))); + .apply( + "Write in batch for " + transformId, + ParDo.of(new SqlBatchWriter(jpaSupplier, jpaConverter))); } }; } @@ -385,18 +424,22 @@ public final class Transforms { * to hold the {@code JpaTransactionManager} instance, we must ensure that JpaTransactionManager * is not changed or torn down while being used by some instance. */ - private static class SqlBatchWriter extends DoFn>, Void> { + private static class SqlBatchWriter extends DoFn>, Void> { private static int instanceCount = 0; private static JpaTransactionManager originalJpa; private final SerializableSupplier jpaSupplier; + private final SerializableFunction jpaConverter; private transient Ofy ofy; private transient SystemSleeper sleeper; - SqlBatchWriter(SerializableSupplier jpaSupplier) { + SqlBatchWriter( + SerializableSupplier jpaSupplier, + SerializableFunction jpaConverter) { this.jpaSupplier = jpaSupplier; + this.jpaConverter = jpaConverter; } @Setup @@ -429,13 +472,11 @@ public final class Transforms { } @ProcessElement - public void processElement(@Element KV> kv) { + public void processElement(@Element KV> kv) { try (AppEngineEnvironment env = new AppEngineEnvironment()) { ImmutableList ofyEntities = Streams.stream(kv.getValue()) - .map(VersionedEntity::getEntity) - .map(Optional::get) - .map(ofy::toPojo) + .map(this.jpaConverter::apply) .collect(ImmutableList.toImmutableList()); retry(() -> jpaTm().transact(() -> jpaTm().putAll(ofyEntities))); } diff --git a/core/src/main/java/google/registry/beam/spec11/Spec11Pipeline.java b/core/src/main/java/google/registry/beam/spec11/Spec11Pipeline.java index 7f06d82e8..1b061ad24 100644 --- a/core/src/main/java/google/registry/beam/spec11/Spec11Pipeline.java +++ b/core/src/main/java/google/registry/beam/spec11/Spec11Pipeline.java @@ -20,7 +20,7 @@ import static google.registry.beam.BeamUtils.getQueryFromFile; import com.google.auth.oauth2.GoogleCredentials; import com.google.auto.value.AutoValue; import com.google.common.collect.ImmutableSet; -import google.registry.backup.AppEngineEnvironment; +import google.registry.beam.initsql.Transforms; import google.registry.beam.initsql.Transforms.SerializableSupplier; import google.registry.beam.spec11.SafeBrowsingTransforms.EvaluateSafeBrowsingFn; import google.registry.config.CredentialModule.LocalCredential; @@ -43,7 +43,6 @@ import org.apache.beam.sdk.options.Description; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.options.ValueProvider; import org.apache.beam.sdk.options.ValueProvider.NestedValueProvider; -import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.ParDo; @@ -191,34 +190,27 @@ public class Spec11Pipeline implements Serializable { PCollection domains, EvaluateSafeBrowsingFn evaluateSafeBrowsingFn, ValueProvider dateProvider) { - PCollection> subdomainsSql = domains.apply("Run through SafeBrowsing API", ParDo.of(evaluateSafeBrowsingFn)); - /* Store ThreatMatch objects in SQL. */ + TypeDescriptor> descriptor = + new TypeDescriptor>() {}; subdomainsSql.apply( - ParDo.of( - new DoFn, Void>() { - @ProcessElement - public void processElement(ProcessContext context) { - // create the Spec11ThreatMatch from Subdomain and ThreatMatch - try (AppEngineEnvironment env = new AppEngineEnvironment()) { - Subdomain subdomain = context.element().getKey(); - Spec11ThreatMatch threatMatch = - new Spec11ThreatMatch.Builder() - .setThreatTypes( - ImmutableSet.of( - ThreatType.valueOf(context.element().getValue().threatType()))) - .setCheckDate( - LocalDate.parse(dateProvider.get(), ISODateTimeFormat.date())) - .setDomainName(subdomain.domainName()) - .setDomainRepoId(subdomain.domainRepoId()) - .setRegistrarId(subdomain.registrarId()) - .build(); - JpaTransactionManager jpaTransactionManager = jpaSupplierFactory.get(); - jpaTransactionManager.transact(() -> jpaTransactionManager.insert(threatMatch)); - } - } - })); + Transforms.writeToSql( + "Spec11ThreatMatch", + 4, + 4, + jpaSupplierFactory, + (kv) -> { + Subdomain subdomain = kv.getKey(); + return new Spec11ThreatMatch.Builder() + .setThreatTypes(ImmutableSet.of(ThreatType.valueOf(kv.getValue().threatType()))) + .setCheckDate(LocalDate.parse(dateProvider.get(), ISODateTimeFormat.date())) + .setDomainName(subdomain.domainName()) + .setDomainRepoId(subdomain.domainRepoId()) + .setRegistrarId(subdomain.registrarId()) + .build(); + }, + descriptor)); /* Store ThreatMatch objects in JSON. */ PCollection> subdomainsJson = diff --git a/core/src/test/java/google/registry/beam/spec11/Spec11PipelineTest.java b/core/src/test/java/google/registry/beam/spec11/Spec11PipelineTest.java index 068581a63..20db45935 100644 --- a/core/src/test/java/google/registry/beam/spec11/Spec11PipelineTest.java +++ b/core/src/test/java/google/registry/beam/spec11/Spec11PipelineTest.java @@ -20,7 +20,6 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; import static org.mockito.Mockito.withSettings; @@ -285,8 +284,7 @@ class Spec11PipelineTest { .build(); verify(mockJpaTm).transact(any(Runnable.class)); - verify(mockJpaTm).insert(expected); - verifyNoMoreInteractions(mockJpaTm); + verify(mockJpaTm).putAll(ImmutableList.of(expected)); } /**