diff --git a/core/src/main/java/google/registry/beam/resave/ResaveAllEppResourcesPipeline.java b/core/src/main/java/google/registry/beam/resave/ResaveAllEppResourcesPipeline.java index 54529a4b7..d67fc64a3 100644 --- a/core/src/main/java/google/registry/beam/resave/ResaveAllEppResourcesPipeline.java +++ b/core/src/main/java/google/registry/beam/resave/ResaveAllEppResourcesPipeline.java @@ -14,11 +14,14 @@ package google.registry.beam.resave; +import static com.google.common.collect.ImmutableList.toImmutableList; import static google.registry.persistence.transaction.TransactionManagerFactory.jpaTm; import static org.apache.beam.sdk.values.TypeDescriptors.integers; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Streams; import google.registry.beam.common.RegistryJpaIO; import google.registry.beam.common.RegistryJpaIO.Read; import google.registry.model.EppResource; @@ -27,7 +30,7 @@ import google.registry.model.domain.Domain; import google.registry.model.domain.DomainBase; import google.registry.model.host.Host; import google.registry.persistence.PersistenceModule.TransactionIsolationLevel; -import google.registry.persistence.transaction.CriteriaQueryBuilder; +import google.registry.persistence.VKey; import google.registry.util.DateTimeUtils; import java.io.Serializable; import java.util.concurrent.ThreadLocalRandom; @@ -69,7 +72,7 @@ public class ResaveAllEppResourcesPipeline implements Serializable { * multiple times, and to avoid projecting and resaving the same domain multiple times. */ private static final String DOMAINS_TO_PROJECT_QUERY = - "FROM Domain d WHERE (d.transferData.transferStatus = 'PENDING' AND" + "SELECT repoId FROM Domain d WHERE (d.transferData.transferStatus = 'PENDING' AND" + " d.transferData.pendingTransferExpirationTime < current_timestamp()) OR" + " (d.registrationExpirationTime < current_timestamp() AND d.deletionTime =" + " (:END_OF_TIME)) OR (EXISTS (SELECT 1 FROM GracePeriod gp WHERE gp.domainRepoId =" @@ -99,13 +102,13 @@ public class ResaveAllEppResourcesPipeline implements Serializable { /** Projects to the current time and saves any contacts with expired transfers. */ private void fastResaveContacts(Pipeline pipeline) { - Read read = + Read repoIdRead = RegistryJpaIO.read( - "FROM Contact WHERE transferData.transferStatus = 'PENDING' AND" + "SELECT repoId FROM Contact WHERE transferData.transferStatus = 'PENDING' AND" + " transferData.pendingTransferExpirationTime < current_timestamp()", - Contact.class, - c -> c); - projectAndResaveResources(pipeline, Contact.class, read); + String.class, + r -> r); + projectAndResaveResources(pipeline, Contact.class, repoIdRead); } /** @@ -116,61 +119,72 @@ public class ResaveAllEppResourcesPipeline implements Serializable { * DomainBase#cloneProjectedAtTime(DateTime)}. */ private void fastResaveDomains(Pipeline pipeline) { - Read read = + Read repoIdRead = RegistryJpaIO.read( DOMAINS_TO_PROJECT_QUERY, ImmutableMap.of("END_OF_TIME", DateTimeUtils.END_OF_TIME), - Domain.class, - d -> d); - projectAndResaveResources(pipeline, Domain.class, read); + String.class, + r -> r); + projectAndResaveResources(pipeline, Domain.class, repoIdRead); } /** Projects all resources to the current time and saves them. */ private void forceResaveAllResources(Pipeline pipeline, Class clazz) { - Read read = RegistryJpaIO.read(() -> CriteriaQueryBuilder.create(clazz).build()); - projectAndResaveResources(pipeline, clazz, read); + Read repoIdRead = + RegistryJpaIO.read( + // Note: cannot use SQL parameters for the table name + String.format("SELECT repoId FROM %s", clazz.getSimpleName()), String.class, r -> r); + projectAndResaveResources(pipeline, clazz, repoIdRead); } - /** Projects and re-saves the result of the provided {@link Read}. */ + /** Projects and re-saves all resources with repo IDs provided by the {@link Read}. */ private void projectAndResaveResources( - Pipeline pipeline, Class clazz, Read read) { + Pipeline pipeline, Class clazz, Read repoIdRead) { int numShards = options.getSqlWriteShards(); int batchSize = options.getSqlWriteBatchSize(); String className = clazz.getSimpleName(); pipeline - .apply("Read " + className, read) + .apply("Read " + className, repoIdRead) .apply( "Shard data for class" + className, - WithKeys.of(e -> ThreadLocalRandom.current().nextInt(numShards)) + WithKeys.of(e -> ThreadLocalRandom.current().nextInt(numShards)) .withKeyType(integers())) .apply( "Group into batches for class" + className, - GroupIntoBatches.ofSize(batchSize).withShardedKey()) - .apply("Map " + className + " to now", ParDo.of(new BatchedProjectionFunction<>())) + GroupIntoBatches.ofSize(batchSize).withShardedKey()) .apply( - "Write transformed " + className, - RegistryJpaIO.write() - .withName("Write transformed " + className) - .withBatchSize(batchSize) - .withShards(numShards)); + "Load, map, and save " + className, + ParDo.of(new BatchedLoadProjectAndSaveFunction(clazz))); } - private static class BatchedProjectionFunction - extends DoFn, Iterable>, EppResource> { + /** Function that loads, projects, and saves resources all in the same transaction. */ + private static class BatchedLoadProjectAndSaveFunction + extends DoFn, Iterable>, Void> { + + private final Class clazz; + + private BatchedLoadProjectAndSaveFunction(Class clazz) { + this.clazz = clazz; + } @ProcessElement public void processElement( - @Element KV, Iterable> element, - OutputReceiver outputReceiver) { + @Element KV, Iterable> element, + OutputReceiver outputReceiver) { jpaTm() .transact( - () -> - element - .getValue() - .forEach( - resource -> - outputReceiver.output( - resource.cloneProjectedAtTime(jpaTm().getTransactionTime())))); + () -> { + DateTime now = jpaTm().getTransactionTime(); + ImmutableList> keys = + Streams.stream(element.getValue()) + .map(repoId -> VKey.create(clazz, repoId)) + .collect(toImmutableList()); + ImmutableList mappedResources = + jpaTm().loadByKeys(keys).values().stream() + .map(r -> r.cloneProjectedAtTime(now)) + .collect(toImmutableList()); + jpaTm().putAll(mappedResources); + }); } }