Parameterize the serialization of objects being written to SQL (#892)

* Parameterize the serialization of objects being written to SQL

We shouldn't require that objects written to SQL during a Beam pipeline
be VersionedEntity objects -- they may be non-Objectify entities. As a
result, we should allow the user to specify what the objects are that
should be written to SQL.

Note: we will need to clean up the Spec11PipelineTest more but that can
be out of the scope of this PR.

* Overload the method and add a bit of javadoc

* Actually use the overloaded function
This commit is contained in:
gbrodman 2020-12-09 15:52:56 -05:00 committed by GitHub
parent a181d6a720
commit 2c6ee6dae9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 74 additions and 43 deletions

View file

@ -20,6 +20,7 @@ import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Throwables.throwIfUnchecked;
import static google.registry.beam.initsql.BackupPaths.getCommitLogTimestamp;
import static google.registry.beam.initsql.BackupPaths.getExportFilePatterns;
import static google.registry.model.ofy.ObjectifyService.ofy;
import static google.registry.persistence.JpaRetries.isFailedTxnRetriable;
import static google.registry.persistence.transaction.TransactionManagerFactory.jpaTm;
import static google.registry.persistence.transaction.TransactionManagerFactory.setJpaTm;
@ -68,6 +69,7 @@ import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.ProcessFunction;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
@ -264,9 +266,9 @@ public final class Transforms {
}
/**
* Returns a {@link PTransform} that writes a {@link PCollection} of entities to a SQL database.
* and outputs an empty {@code PCollection<Void>}. This allows other operations to {@link
* org.apache.beam.sdk.transforms.Wait wait} for the completion of this transform.
* Returns a {@link PTransform} that writes a {@link PCollection} of {@link VersionedEntity}s to a
* SQL database. and outputs an empty {@code PCollection<Void>}. This allows other operations to
* {@link org.apache.beam.sdk.transforms.Wait wait} for the completion of this transform.
*
* <p>Errors are handled according to the pipeline runner's default policy. As part of a one-time
* job, we will not add features unless proven necessary.
@ -282,16 +284,53 @@ public final class Transforms {
int maxWriters,
int batchSize,
SerializableSupplier<JpaTransactionManager> jpaSupplier) {
return new PTransform<PCollection<VersionedEntity>, PCollection<Void>>() {
return writeToSql(
transformId,
maxWriters,
batchSize,
jpaSupplier,
(e) -> ofy().toPojo(e.getEntity().get()),
TypeDescriptor.of(VersionedEntity.class));
}
/**
* Returns a {@link PTransform} that writes a {@link PCollection} of entities to a SQL database.
* and outputs an empty {@code PCollection<Void>}. This allows other operations to {@link
* org.apache.beam.sdk.transforms.Wait wait} for the completion of this transform.
*
* <p>The converter and type descriptor are generics so that we can convert any type of entity to
* an object to be placed in SQL.
*
* <p>Errors are handled according to the pipeline runner's default policy. As part of a one-time
* job, we will not add features unless proven necessary.
*
* @param transformId a unique ID for an instance of the returned transform
* @param maxWriters the max number of concurrent writes to SQL, which also determines the max
* number of connection pools created
* @param batchSize the number of entities to write in each operation
* @param jpaSupplier supplier of a {@link JpaTransactionManager}
* @param jpaConverter the function that converts the input object to a JPA entity
* @param objectDescriptor the type descriptor of the input object
*/
public static <T> PTransform<PCollection<T>, PCollection<Void>> writeToSql(
String transformId,
int maxWriters,
int batchSize,
SerializableSupplier<JpaTransactionManager> jpaSupplier,
SerializableFunction<T, Object> jpaConverter,
TypeDescriptor<T> objectDescriptor) {
return new PTransform<PCollection<T>, PCollection<Void>>() {
@Override
public PCollection<Void> expand(PCollection<VersionedEntity> input) {
public PCollection<Void> expand(PCollection<T> input) {
return input
.apply(
"Shard data for " + transformId,
MapElements.into(kvs(integers(), TypeDescriptor.of(VersionedEntity.class)))
MapElements.into(kvs(integers(), objectDescriptor))
.via(ve -> KV.of(ThreadLocalRandom.current().nextInt(maxWriters), ve)))
.apply("Batch output by shard " + transformId, GroupIntoBatches.ofSize(batchSize))
.apply("Write in batch for " + transformId, ParDo.of(new SqlBatchWriter(jpaSupplier)));
.apply(
"Write in batch for " + transformId,
ParDo.of(new SqlBatchWriter<T>(jpaSupplier, jpaConverter)));
}
};
}
@ -385,18 +424,22 @@ public final class Transforms {
* to hold the {@code JpaTransactionManager} instance, we must ensure that JpaTransactionManager
* is not changed or torn down while being used by some instance.
*/
private static class SqlBatchWriter extends DoFn<KV<Integer, Iterable<VersionedEntity>>, Void> {
private static class SqlBatchWriter<T> extends DoFn<KV<Integer, Iterable<T>>, Void> {
private static int instanceCount = 0;
private static JpaTransactionManager originalJpa;
private final SerializableSupplier<JpaTransactionManager> jpaSupplier;
private final SerializableFunction<T, Object> jpaConverter;
private transient Ofy ofy;
private transient SystemSleeper sleeper;
SqlBatchWriter(SerializableSupplier<JpaTransactionManager> jpaSupplier) {
SqlBatchWriter(
SerializableSupplier<JpaTransactionManager> jpaSupplier,
SerializableFunction<T, Object> jpaConverter) {
this.jpaSupplier = jpaSupplier;
this.jpaConverter = jpaConverter;
}
@Setup
@ -429,13 +472,11 @@ public final class Transforms {
}
@ProcessElement
public void processElement(@Element KV<Integer, Iterable<VersionedEntity>> kv) {
public void processElement(@Element KV<Integer, Iterable<T>> kv) {
try (AppEngineEnvironment env = new AppEngineEnvironment()) {
ImmutableList<Object> ofyEntities =
Streams.stream(kv.getValue())
.map(VersionedEntity::getEntity)
.map(Optional::get)
.map(ofy::toPojo)
.map(this.jpaConverter::apply)
.collect(ImmutableList.toImmutableList());
retry(() -> jpaTm().transact(() -> jpaTm().putAll(ofyEntities)));
}

View file

@ -20,7 +20,7 @@ import static google.registry.beam.BeamUtils.getQueryFromFile;
import com.google.auth.oauth2.GoogleCredentials;
import com.google.auto.value.AutoValue;
import com.google.common.collect.ImmutableSet;
import google.registry.backup.AppEngineEnvironment;
import google.registry.beam.initsql.Transforms;
import google.registry.beam.initsql.Transforms.SerializableSupplier;
import google.registry.beam.spec11.SafeBrowsingTransforms.EvaluateSafeBrowsingFn;
import google.registry.config.CredentialModule.LocalCredential;
@ -43,7 +43,6 @@ import org.apache.beam.sdk.options.Description;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.options.ValueProvider;
import org.apache.beam.sdk.options.ValueProvider.NestedValueProvider;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.ParDo;
@ -191,34 +190,27 @@ public class Spec11Pipeline implements Serializable {
PCollection<Subdomain> domains,
EvaluateSafeBrowsingFn evaluateSafeBrowsingFn,
ValueProvider<String> dateProvider) {
PCollection<KV<Subdomain, ThreatMatch>> subdomainsSql =
domains.apply("Run through SafeBrowsing API", ParDo.of(evaluateSafeBrowsingFn));
/* Store ThreatMatch objects in SQL. */
TypeDescriptor<KV<Subdomain, ThreatMatch>> descriptor =
new TypeDescriptor<KV<Subdomain, ThreatMatch>>() {};
subdomainsSql.apply(
ParDo.of(
new DoFn<KV<Subdomain, ThreatMatch>, Void>() {
@ProcessElement
public void processElement(ProcessContext context) {
// create the Spec11ThreatMatch from Subdomain and ThreatMatch
try (AppEngineEnvironment env = new AppEngineEnvironment()) {
Subdomain subdomain = context.element().getKey();
Spec11ThreatMatch threatMatch =
new Spec11ThreatMatch.Builder()
.setThreatTypes(
ImmutableSet.of(
ThreatType.valueOf(context.element().getValue().threatType())))
.setCheckDate(
LocalDate.parse(dateProvider.get(), ISODateTimeFormat.date()))
.setDomainName(subdomain.domainName())
.setDomainRepoId(subdomain.domainRepoId())
.setRegistrarId(subdomain.registrarId())
.build();
JpaTransactionManager jpaTransactionManager = jpaSupplierFactory.get();
jpaTransactionManager.transact(() -> jpaTransactionManager.insert(threatMatch));
}
}
}));
Transforms.writeToSql(
"Spec11ThreatMatch",
4,
4,
jpaSupplierFactory,
(kv) -> {
Subdomain subdomain = kv.getKey();
return new Spec11ThreatMatch.Builder()
.setThreatTypes(ImmutableSet.of(ThreatType.valueOf(kv.getValue().threatType())))
.setCheckDate(LocalDate.parse(dateProvider.get(), ISODateTimeFormat.date()))
.setDomainName(subdomain.domainName())
.setDomainRepoId(subdomain.domainRepoId())
.setRegistrarId(subdomain.registrarId())
.build();
},
descriptor));
/* Store ThreatMatch objects in JSON. */
PCollection<KV<Subdomain, ThreatMatch>> subdomainsJson =

View file

@ -20,7 +20,6 @@ import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.withSettings;
@ -285,8 +284,7 @@ class Spec11PipelineTest {
.build();
verify(mockJpaTm).transact(any(Runnable.class));
verify(mockJpaTm).insert(expected);
verifyNoMoreInteractions(mockJpaTm);
verify(mockJpaTm).putAll(ImmutableList.of(expected));
}
/**