From 74c0836fe626825a578f14135be0292d550717f8 Mon Sep 17 00:00:00 2001 From: Ben McIlwain Date: Thu, 19 May 2022 09:13:37 -0400 Subject: [PATCH] Add batching to ExpandRecurringBillingEventsAction (#1636) * Add batching to ExpandRecurringBillingEventsAction It's OOMing on trying to load every single BillingRecurrence that needs to be expanded simultaneously (which is to be expected). So this processes them in transactional batches of 50. --- .../ExpandRecurringBillingEventsAction.java | 170 ++++++++++++------ .../registry/config/RegistryConfig.java | 8 +- .../config/RegistryConfigSettings.java | 2 +- .../persistence/PersistenceModule.java | 2 +- ...xpandRecurringBillingEventsActionTest.java | 94 +++++++--- 5 files changed, 200 insertions(+), 76 deletions(-) diff --git a/core/src/main/java/google/registry/batch/ExpandRecurringBillingEventsAction.java b/core/src/main/java/google/registry/batch/ExpandRecurringBillingEventsAction.java index 26db4b414..863b57231 100644 --- a/core/src/main/java/google/registry/batch/ExpandRecurringBillingEventsAction.java +++ b/core/src/main/java/google/registry/batch/ExpandRecurringBillingEventsAction.java @@ -17,6 +17,7 @@ package google.registry.batch; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Sets.difference; +import static com.google.common.collect.Sets.newHashSet; import static google.registry.mapreduce.MapreduceRunner.PARAM_DRY_RUN; import static google.registry.mapreduce.inputs.EppResourceInputs.createChildEntityInput; import static google.registry.model.common.Cursor.CursorType.RECURRING_BILLING; @@ -36,11 +37,13 @@ import static google.registry.util.DomainNameUtils.getTldFromDomainName; import com.google.appengine.tools.mapreduce.Mapper; import com.google.appengine.tools.mapreduce.Reducer; import com.google.appengine.tools.mapreduce.ReducerInput; +import com.google.auto.value.AutoValue; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Range; import com.google.common.collect.Streams; import com.google.common.flogger.FluentLogger; +import google.registry.config.RegistryConfig.Config; import google.registry.mapreduce.MapreduceRunner; import google.registry.mapreduce.inputs.NullInput; import google.registry.model.ImmutableObject; @@ -61,6 +64,7 @@ import google.registry.request.Parameter; import google.registry.request.Response; import google.registry.request.auth.Auth; import google.registry.util.Clock; +import java.util.List; import java.util.Optional; import java.util.Set; import javax.inject.Inject; @@ -86,6 +90,11 @@ public class ExpandRecurringBillingEventsAction implements Runnable { @Inject Clock clock; @Inject MapreduceRunner mrRunner; + + @Inject + @Config("jdbcBatchSize") + int batchSize; + @Inject @Parameter(PARAM_DRY_RUN) boolean isDryRun; @Inject @Parameter(PARAM_CURSOR_TIME) Optional cursorTimeParam; @Inject Response response; @@ -120,61 +129,116 @@ public class ExpandRecurringBillingEventsAction implements Runnable { ImmutableSet.of(DomainBase.class), ImmutableSet.of(Recurring.class)))) .sendLinkToMapreduceConsole(response); } else { - int numBillingEventsSaved = - jpaTm() - .transact( - () -> - jpaTm() - .query( - "FROM BillingRecurrence " - + "WHERE eventTime <= :executeTime " - + "AND eventTime < recurrenceEndTime " - + "ORDER BY id ASC", - Recurring.class) - .setParameter("executeTime", executeTime) - // Need to get a list from the transaction and then convert it to a stream - // for further processing. If we get a stream directly, each elements gets - // processed downstream eagerly but Hibernate returns a - // ScrollableResultsIterator that cannot be advanced outside the - // transaction, resulting in an exception. - .getResultList()) - .stream() - .map( - recurring -> - jpaTm() - .transact( - () -> - expandBillingEvent(recurring, executeTime, cursorTime, isDryRun))) - .reduce(0, Integer::sum); - - if (!isDryRun) { - logger.atInfo().log("Saved OneTime billing events.", numBillingEventsSaved); - } else { - logger.atInfo().log("Generated OneTime billing events (dry run).", numBillingEventsSaved); - } - logger.atInfo().log( - "Recurring event expansion %s complete for billing event range [%s, %s).", - isDryRun ? "(dry run) " : "", cursorTime, executeTime); - tm().transact( - () -> { - // Check for the unlikely scenario where the cursor has been altered during the - // expansion. - DateTime currentCursorTime = - tm().loadByKeyIfPresent(Cursor.createGlobalVKey(RECURRING_BILLING)) - .orElse(Cursor.createGlobal(RECURRING_BILLING, START_OF_TIME)) - .getCursorTime(); - if (!currentCursorTime.equals(persistedCursorTime)) { - throw new IllegalStateException( - String.format( - "Current cursor position %s does not match persisted cursor position %s.", - currentCursorTime, persistedCursorTime)); - } - if (!isDryRun) { - tm().put(Cursor.createGlobal(RECURRING_BILLING, executeTime)); - } - }); + expandSqlBillingEventsInBatches(executeTime, cursorTime, persistedCursorTime); } } + + private void expandSqlBillingEventsInBatches( + DateTime executeTime, DateTime cursorTime, DateTime persistedCursorTime) { + int totalBillingEventsSaved = 0; + long maxProcessedRecurrenceId = 0; + SqlBatchResults sqlBatchResults; + + do { + final long prevMaxProcessedRecurrenceId = maxProcessedRecurrenceId; + sqlBatchResults = + jpaTm() + .transact( + () -> { + Set expandedDomains = newHashSet(); + int batchBillingEventsSaved = 0; + long maxRecurrenceId = prevMaxProcessedRecurrenceId; + List recurrings = + jpaTm() + .query( + "FROM BillingRecurrence " + + "WHERE eventTime <= :executeTime " + + "AND eventTime < recurrenceEndTime " + + "AND id > :maxProcessedRecurrenceId " + + "ORDER BY id ASC", + Recurring.class) + .setParameter("executeTime", executeTime) + .setParameter("maxProcessedRecurrenceId", prevMaxProcessedRecurrenceId) + .setMaxResults(batchSize) + .getResultList(); + for (Recurring recurring : recurrings) { + if (expandedDomains.contains(recurring.getTargetId())) { + // On the off chance this batch contains multiple recurrences for the same + // domain (which is actually possible if a given domain is quickly renewed + // multiple times in a row), then short-circuit after the first one is + // processed that involves actually expanding a billing event. This is + // necessary because otherwise we get an "Inserted/updated object reloaded" + // error from Hibernate when those billing events would be loaded + // inside a transaction where they were already written. Note, there is no + // actual further work to be done in this case anyway, not unless it has + // somehow been over a year since this action last ran successfully (and if + // that were somehow true, the remaining billing events would still be + // expanded on subsequent runs). + continue; + } + int billingEventsSaved = + expandBillingEvent(recurring, executeTime, cursorTime, isDryRun); + batchBillingEventsSaved += billingEventsSaved; + if (billingEventsSaved > 0) { + expandedDomains.add(recurring.getTargetId()); + } + maxRecurrenceId = Math.max(maxRecurrenceId, recurring.getId()); + } + return SqlBatchResults.create( + batchBillingEventsSaved, + maxRecurrenceId, + maxRecurrenceId > prevMaxProcessedRecurrenceId); + }); + totalBillingEventsSaved += sqlBatchResults.batchBillingEventsSaved(); + maxProcessedRecurrenceId = sqlBatchResults.maxProcessedRecurrenceId(); + logger.atInfo().log( + "Saved %d billing events in batch with max recurrence id %d.", + sqlBatchResults.batchBillingEventsSaved(), maxProcessedRecurrenceId); + } while (sqlBatchResults.shouldContinue()); + + if (!isDryRun) { + logger.atInfo().log("Saved OneTime billing events.", totalBillingEventsSaved); + } else { + logger.atInfo().log("Generated OneTime billing events (dry run).", totalBillingEventsSaved); + } + logger.atInfo().log( + "Recurring event expansion %s complete for billing event range [%s, %s).", + isDryRun ? "(dry run) " : "", cursorTime, executeTime); + tm().transact( + () -> { + // Check for the unlikely scenario where the cursor has been altered during the + // expansion. + DateTime currentCursorTime = + tm().loadByKeyIfPresent(Cursor.createGlobalVKey(RECURRING_BILLING)) + .orElse(Cursor.createGlobal(RECURRING_BILLING, START_OF_TIME)) + .getCursorTime(); + if (!currentCursorTime.equals(persistedCursorTime)) { + throw new IllegalStateException( + String.format( + "Current cursor position %s does not match persisted cursor position %s.", + currentCursorTime, persistedCursorTime)); + } + if (!isDryRun) { + tm().put(Cursor.createGlobal(RECURRING_BILLING, executeTime)); + } + }); + } + + @AutoValue + abstract static class SqlBatchResults { + abstract int batchBillingEventsSaved(); + + abstract long maxProcessedRecurrenceId(); + + abstract boolean shouldContinue(); + + static SqlBatchResults create( + int batchBillingEventsSaved, long maxProcessedRecurrenceId, boolean shouldContinue) { + return new AutoValue_ExpandRecurringBillingEventsAction_SqlBatchResults( + batchBillingEventsSaved, maxProcessedRecurrenceId, shouldContinue); + } + } + /** Mapper to expand {@link Recurring} billing events into synthetic {@link OneTime} events. */ public static class ExpandRecurringBillingEventsMapper extends Mapper { diff --git a/core/src/main/java/google/registry/config/RegistryConfig.java b/core/src/main/java/google/registry/config/RegistryConfig.java index 39020954b..bfa4f09bf 100644 --- a/core/src/main/java/google/registry/config/RegistryConfig.java +++ b/core/src/main/java/google/registry/config/RegistryConfig.java @@ -1351,6 +1351,12 @@ public final class RegistryConfig { public static int provideWipeOutQueryBatchSize(RegistryConfigSettings config) { return config.contactHistory.wipeOutQueryBatchSize; } + + @Provides + @Config("jdbcBatchSize") + public static int provideHibernateJdbcBatchSize(RegistryConfigSettings config) { + return config.hibernate.jdbcBatchSize; + } } /** Returns the App Engine project ID, which is based off the environment name. */ @@ -1555,7 +1561,7 @@ public final class RegistryConfig { * https://docs.jboss.org/hibernate/orm/5.6/userguide/html_single/Hibernate_User_Guide.html, * recommend between 10 and 50. */ - public static String getHibernateJdbcBatchSize() { + public static int getHibernateJdbcBatchSize() { return CONFIG_SETTINGS.get().hibernate.jdbcBatchSize; } diff --git a/core/src/main/java/google/registry/config/RegistryConfigSettings.java b/core/src/main/java/google/registry/config/RegistryConfigSettings.java index 098250707..80cc66c2a 100644 --- a/core/src/main/java/google/registry/config/RegistryConfigSettings.java +++ b/core/src/main/java/google/registry/config/RegistryConfigSettings.java @@ -120,7 +120,7 @@ public class RegistryConfigSettings { public String hikariMinimumIdle; public String hikariMaximumPoolSize; public String hikariIdleTimeout; - public String jdbcBatchSize; + public int jdbcBatchSize; public String jdbcFetchSize; } diff --git a/core/src/main/java/google/registry/persistence/PersistenceModule.java b/core/src/main/java/google/registry/persistence/PersistenceModule.java index 8c25dbd70..ac42f5310 100644 --- a/core/src/main/java/google/registry/persistence/PersistenceModule.java +++ b/core/src/main/java/google/registry/persistence/PersistenceModule.java @@ -107,7 +107,7 @@ public abstract class PersistenceModule { properties.put(HIKARI_MAXIMUM_POOL_SIZE, getHibernateHikariMaximumPoolSize()); properties.put(HIKARI_IDLE_TIMEOUT, getHibernateHikariIdleTimeout()); properties.put(Environment.DIALECT, NomulusPostgreSQLDialect.class.getName()); - properties.put(JDBC_BATCH_SIZE, getHibernateJdbcBatchSize()); + properties.put(JDBC_BATCH_SIZE, Integer.toString(getHibernateJdbcBatchSize())); properties.put(JDBC_FETCH_SIZE, getHibernateJdbcFetchSize()); return properties.build(); } diff --git a/core/src/test/java/google/registry/batch/ExpandRecurringBillingEventsActionTest.java b/core/src/test/java/google/registry/batch/ExpandRecurringBillingEventsActionTest.java index 29ac23f75..220be52e3 100644 --- a/core/src/test/java/google/registry/batch/ExpandRecurringBillingEventsActionTest.java +++ b/core/src/test/java/google/registry/batch/ExpandRecurringBillingEventsActionTest.java @@ -90,6 +90,7 @@ public class ExpandRecurringBillingEventsActionTest action.mrRunner = makeDefaultRunner(); action.clock = clock; action.cursorTimeParam = Optional.empty(); + action.batchSize = 2; createTld("tld"); domain = persistResource( @@ -279,11 +280,12 @@ public class ExpandRecurringBillingEventsActionTest assertHistoryEntryMatches( domain, persistedEntry, "TheRegistrar", DateTime.parse("2000-02-19T00:00:00Z"), true); BillingEvent.OneTime expected = defaultOneTimeBuilder().setParent(persistedEntry).build(); - // Persist an otherwise identical billing event that differs only in billing time. + // Persist an otherwise identical billing event that differs only in billing time (and ID). BillingEvent.OneTime persisted = persistResource( expected .asBuilder() + .setId(15891L) .setBillingTime(DateTime.parse("1999-02-19T00:00:00Z")) .setEventTime(DateTime.parse("1999-01-05T00:00:00Z")) .build()); @@ -639,43 +641,95 @@ public class ExpandRecurringBillingEventsActionTest @TestOfyAndSql void testSuccess_expandMultipleEvents() throws Exception { persistResource(recurring); + DomainBase domain2 = + persistResource( + newDomainBase("example2.tld") + .asBuilder() + .setCreationTimeForTest(DateTime.parse("1999-04-05T00:00:00Z")) + .build()); + DomainHistory historyEntry2 = + persistResource( + new DomainHistory.Builder() + .setRegistrarId(domain2.getCreationRegistrarId()) + .setType(HistoryEntry.Type.DOMAIN_CREATE) + .setModificationTime(DateTime.parse("1999-04-05T00:00:00Z")) + .setDomain(domain2) + .build()); BillingEvent.Recurring recurring2 = persistResource( - recurring + new BillingEvent.Recurring.Builder() + .setParent(historyEntry2) + .setRegistrarId(domain2.getCreationRegistrarId()) + .setEventTime(DateTime.parse("2000-04-05T00:00:00Z")) + .setFlags(ImmutableSet.of(Flag.AUTO_RENEW)) + .setReason(Reason.RENEW) + .setRecurrenceEndTime(END_OF_TIME) + .setTargetId(domain2.getDomainName()) + .build()); + DomainBase domain3 = + persistResource( + newDomainBase("example3.tld") .asBuilder() - .setEventTime(recurring.getEventTime().plusMonths(3)) - .setId(3L) + .setCreationTimeForTest(DateTime.parse("1999-06-05T00:00:00Z")) + .build()); + DomainHistory historyEntry3 = + persistResource( + new DomainHistory.Builder() + .setRegistrarId(domain3.getCreationRegistrarId()) + .setType(HistoryEntry.Type.DOMAIN_CREATE) + .setModificationTime(DateTime.parse("1999-06-05T00:00:00Z")) + .setDomain(domain3) + .build()); + BillingEvent.Recurring recurring3 = + persistResource( + new BillingEvent.Recurring.Builder() + .setParent(historyEntry3) + .setRegistrarId(domain3.getCreationRegistrarId()) + .setEventTime(DateTime.parse("2000-06-05T00:00:00Z")) + .setFlags(ImmutableSet.of(Flag.AUTO_RENEW)) + .setReason(Reason.RENEW) + .setRecurrenceEndTime(END_OF_TIME) + .setTargetId(domain3.getDomainName()) .build()); action.cursorTimeParam = Optional.of(START_OF_TIME); runAction(); - List persistedEntries = - getHistoryEntriesOfType(domain, DOMAIN_AUTORENEW, DomainHistory.class); - assertThat(persistedEntries).hasSize(2); + + DomainHistory persistedHistory1 = + getOnlyHistoryEntryOfType(domain, DOMAIN_AUTORENEW, DomainHistory.class); assertHistoryEntryMatches( - domain, - persistedEntries.get(0), - "TheRegistrar", - DateTime.parse("2000-02-19T00:00:00Z"), - true); + domain, persistedHistory1, "TheRegistrar", DateTime.parse("2000-02-19T00:00:00Z"), true); BillingEvent.OneTime expected = defaultOneTimeBuilder() - .setParent(persistedEntries.get(0)) + .setParent(persistedHistory1) .setCancellationMatchingBillingEvent(recurring.createVKey()) .build(); + DomainHistory persistedHistory2 = + getOnlyHistoryEntryOfType(domain2, DOMAIN_AUTORENEW, DomainHistory.class); assertHistoryEntryMatches( - domain, - persistedEntries.get(1), - "TheRegistrar", - DateTime.parse("2000-05-20T00:00:00Z"), - true); + domain2, persistedHistory2, "TheRegistrar", DateTime.parse("2000-05-20T00:00:00Z"), true); BillingEvent.OneTime expected2 = defaultOneTimeBuilder() .setBillingTime(DateTime.parse("2000-05-20T00:00:00Z")) .setEventTime(DateTime.parse("2000-04-05T00:00:00Z")) - .setParent(persistedEntries.get(1)) + .setParent(persistedHistory2) + .setTargetId(domain2.getDomainName()) .setCancellationMatchingBillingEvent(recurring2.createVKey()) .build(); - assertBillingEventsForResource(domain, expected, expected2, recurring, recurring2); + DomainHistory persistedHistory3 = + getOnlyHistoryEntryOfType(domain3, DOMAIN_AUTORENEW, DomainHistory.class); + assertHistoryEntryMatches( + domain3, persistedHistory3, "TheRegistrar", DateTime.parse("2000-07-20T00:00:00Z"), true); + BillingEvent.OneTime expected3 = + defaultOneTimeBuilder() + .setBillingTime(DateTime.parse("2000-07-20T00:00:00Z")) + .setEventTime(DateTime.parse("2000-06-05T00:00:00Z")) + .setTargetId(domain3.getDomainName()) + .setParent(persistedHistory3) + .setCancellationMatchingBillingEvent(recurring3.createVKey()) + .build(); + assertBillingEventsForResource(domain, expected, recurring); + assertBillingEventsForResource(domain2, expected2, recurring2); + assertBillingEventsForResource(domain3, expected3, recurring3); assertCursorAt(currentTestTime); }