diff --git a/java/google/registry/dns/DnsQueue.java b/java/google/registry/dns/DnsQueue.java index 10f9eb7d6..f2a8ad5e4 100644 --- a/java/google/registry/dns/DnsQueue.java +++ b/java/google/registry/dns/DnsQueue.java @@ -103,16 +103,16 @@ public class DnsQueue { @VisibleForTesting long leaseTasksBatchSize = QueueConstants.maxLeaseCount(); - /** - * Enqueues the given task type with the given target name to the DNS queue. - */ - private TaskHandle addToQueue(TargetType targetType, String targetName, String tld) { + /** Enqueues the given task type with the given target name to the DNS queue. */ + private TaskHandle addToQueue( + TargetType targetType, String targetName, String tld, Duration countdown) { logger.atInfo().log( "Adding task type=%s, target=%s, tld=%s to pull queue %s (%d tasks currently on queue)", targetType, targetName, tld, DNS_PULL_QUEUE_NAME, queue.fetchStatistics().getNumTasks()); return queue.add( TaskOptions.Builder.withDefaults() .method(Method.PULL) + .countdownMillis(countdown.getMillis()) .param(DNS_TARGET_TYPE_PARAM, targetType.toString()) .param(DNS_TARGET_NAME_PARAM, targetName) .param(DNS_TARGET_CREATE_TIME_PARAM, clock.nowUtc().toString()) @@ -127,20 +127,27 @@ public class DnsQueue { Registries.findTldForName(InternetDomainName.from(fullyQualifiedHostName)); checkArgument(tld.isPresent(), String.format("%s is not a subordinate host to a known tld", fullyQualifiedHostName)); - return addToQueue(TargetType.HOST, fullyQualifiedHostName, tld.get().toString()); + return addToQueue(TargetType.HOST, fullyQualifiedHostName, tld.get().toString(), Duration.ZERO); } - /** Adds a task to the queue to refresh the DNS information for the specified domain. */ + /** Enqueues a task to refresh DNS for the specified domain now. */ public TaskHandle addDomainRefreshTask(String fullyQualifiedDomainName) { + return addDomainRefreshTask(fullyQualifiedDomainName, Duration.ZERO); + } + + /** Enqueues a task to refresh DNS for the specified domain at some point in the future. */ + public TaskHandle addDomainRefreshTask(String fullyQualifiedDomainName, Duration countdown) { return addToQueue( TargetType.DOMAIN, fullyQualifiedDomainName, - assertTldExists(getTldFromDomainName(fullyQualifiedDomainName))); + assertTldExists(getTldFromDomainName(fullyQualifiedDomainName)), + countdown); } /** Adds a task to the queue to refresh the DNS information for the specified zone. */ public TaskHandle addZoneRefreshTask(String fullyQualifiedZoneName) { - return addToQueue(TargetType.ZONE, fullyQualifiedZoneName, fullyQualifiedZoneName); + return addToQueue( + TargetType.ZONE, fullyQualifiedZoneName, fullyQualifiedZoneName, Duration.ZERO); } /** diff --git a/java/google/registry/tools/server/RefreshDnsForAllDomainsAction.java b/java/google/registry/tools/server/RefreshDnsForAllDomainsAction.java index 796a23a94..98021aa12 100644 --- a/java/google/registry/tools/server/RefreshDnsForAllDomainsAction.java +++ b/java/google/registry/tools/server/RefreshDnsForAllDomainsAction.java @@ -14,6 +14,7 @@ package google.registry.tools.server; +import static com.google.common.base.Preconditions.checkArgument; import static google.registry.mapreduce.inputs.EppResourceInputs.createEntityInput; import static google.registry.model.EppResourceUtils.isActive; import static google.registry.model.registry.Registries.assertTldsExist; @@ -32,9 +33,11 @@ import google.registry.request.Parameter; import google.registry.request.Response; import google.registry.request.auth.Auth; import google.registry.util.NonFinalForTesting; +import java.util.Random; import javax.inject.Inject; import org.joda.time.DateTime; import org.joda.time.DateTimeZone; +import org.joda.time.Duration; /** * A mapreduce that enqueues DNS publish tasks on all active domains on the specified TLD(s). @@ -46,6 +49,11 @@ import org.joda.time.DateTimeZone; *

Because there are no auth settings in the {@link Action} annotation, this command can only be * run internally, or by pretending to be internal by setting the X-AppEngine-QueueName header, * which only admin users can do. + * + *

You must pass in a number of smearMinutes as a URL parameter so that the DNS + * queue doesn't get overloaded. A rough rule of thumb for Cloud DNS is 1 minute per every 1,000 + * domains. This smears the updates out over the next N minutes. For small TLDs consisting of fewer + * than 1,000 domains, passing in 1 is fine (which will execute all the updates immediately). */ @Action( service = Action.Service.TOOLS, @@ -57,18 +65,30 @@ public class RefreshDnsForAllDomainsAction implements Runnable { @Inject MapreduceRunner mrRunner; @Inject Response response; - @Inject @Parameter(PARAM_TLDS) ImmutableSet tlds; - @Inject RefreshDnsForAllDomainsAction() {} + + @Inject + @Parameter(PARAM_TLDS) + ImmutableSet tlds; + + @Inject + @Parameter("smearMinutes") + int smearMinutes; + + @Inject Random random; + + @Inject + RefreshDnsForAllDomainsAction() {} @Override public void run() { assertTldsExist(tlds); + checkArgument(smearMinutes > 0, "Must specify a positive number of smear minutes"); mrRunner .setJobName("Refresh DNS for all domains") .setModuleName("tools") .setDefaultMapShards(10) .runMapOnly( - new RefreshDnsForAllDomainsActionMapper(tlds), + new RefreshDnsForAllDomainsActionMapper(tlds, smearMinutes, random), ImmutableList.of(createEntityInput(DomainBase.class))) .sendLinkToMapreduceConsole(response); } @@ -77,14 +97,19 @@ public class RefreshDnsForAllDomainsAction implements Runnable { public static class RefreshDnsForAllDomainsActionMapper extends Mapper { - private static final long serialVersionUID = 1455544013508953083L; + private static final long serialVersionUID = -5103865047156795489L; @NonFinalForTesting private static DnsQueue dnsQueue = DnsQueue.create(); private final ImmutableSet tlds; + private final int smearMinutes; + private final Random random; - RefreshDnsForAllDomainsActionMapper(ImmutableSet tlds) { + RefreshDnsForAllDomainsActionMapper( + ImmutableSet tlds, int smearMinutes, Random random) { this.tlds = tlds; + this.smearMinutes = smearMinutes; + this.random = random; } @Override @@ -93,7 +118,9 @@ public class RefreshDnsForAllDomainsAction implements Runnable { if (tlds.contains(domain.getTld())) { if (isActive(domain, DateTime.now(DateTimeZone.UTC))) { try { - dnsQueue.addDomainRefreshTask(domainName); + // Smear the task execution time over the next N minutes. + dnsQueue.addDomainRefreshTask( + domainName, Duration.standardMinutes(random.nextInt(smearMinutes))); getContext().incrementCounter("active domains refreshed"); } catch (Throwable t) { logger.atSevere().withCause(t).log( diff --git a/java/google/registry/tools/server/ToolsServerModule.java b/java/google/registry/tools/server/ToolsServerModule.java index 4c42dd668..778bea0fa 100644 --- a/java/google/registry/tools/server/ToolsServerModule.java +++ b/java/google/registry/tools/server/ToolsServerModule.java @@ -108,4 +108,10 @@ public class ToolsServerModule { String provideJobId(HttpServletRequest req) { return extractRequiredParameter(req, "jobId"); } + + @Provides + @Parameter("smearMinutes") + static int provideSmearMinutes(HttpServletRequest req) { + return extractIntParameter(req, "smearMinutes"); + } } diff --git a/java/google/registry/util/UtilsModule.java b/java/google/registry/util/UtilsModule.java index 61bb14128..78d10b9b1 100644 --- a/java/google/registry/util/UtilsModule.java +++ b/java/google/registry/util/UtilsModule.java @@ -22,6 +22,7 @@ import dagger.Provides; import java.security.NoSuchAlgorithmException; import java.security.ProviderException; import java.security.SecureRandom; +import java.util.Random; import javax.inject.Named; import javax.inject.Singleton; @@ -48,7 +49,6 @@ public abstract class UtilsModule { abstract AppEngineServiceUtils provideAppEngineServiceUtils( AppEngineServiceUtilsImpl appEngineServiceUtilsImpl); - @Singleton @Provides public static SecureRandom provideSecureRandom() { @@ -59,6 +59,10 @@ public abstract class UtilsModule { } } + @Binds + @Singleton + abstract Random provideSecureRandomAsRandom(SecureRandom random); + @Singleton @Provides @Named("base58StringGenerator") diff --git a/javatests/google/registry/tools/server/RefreshDnsForAllDomainsActionTest.java b/javatests/google/registry/tools/server/RefreshDnsForAllDomainsActionTest.java index c9d2964f0..aa0ee7558 100644 --- a/javatests/google/registry/tools/server/RefreshDnsForAllDomainsActionTest.java +++ b/javatests/google/registry/tools/server/RefreshDnsForAllDomainsActionTest.java @@ -16,10 +16,12 @@ package google.registry.tools.server; import static com.google.common.truth.Truth.assertThat; import static google.registry.testing.DatastoreHelper.createTld; -import static google.registry.testing.DatastoreHelper.createTlds; import static google.registry.testing.DatastoreHelper.persistActiveDomain; import static google.registry.testing.DatastoreHelper.persistDeletedDomain; +import static google.registry.testing.JUnitBackports.assertThrows; import static org.joda.time.DateTimeZone.UTC; +import static org.joda.time.Duration.standardMinutes; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; @@ -30,13 +32,16 @@ import google.registry.testing.FakeResponse; import google.registry.testing.InjectRule; import google.registry.testing.mapreduce.MapreduceTestCase; import google.registry.tools.server.RefreshDnsForAllDomainsAction.RefreshDnsForAllDomainsActionMapper; +import java.util.Random; import org.joda.time.DateTime; +import org.joda.time.Duration; import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; /** Unit tests for {@link RefreshDnsForAllDomainsAction}. */ @RunWith(JUnit4.class) @@ -53,8 +58,13 @@ public class RefreshDnsForAllDomainsActionTest origDnsQueue = RefreshDnsForAllDomainsActionMapper.setDnsQueueForTest(dnsQueue); action = new RefreshDnsForAllDomainsAction(); + action.smearMinutes = 1; + action.random = new Random(); + action.random.setSeed(123L); action.mrRunner = makeDefaultRunner(); action.response = new FakeResponse(); + + createTld("bar"); } @After @@ -70,36 +80,58 @@ public class RefreshDnsForAllDomainsActionTest @Test public void test_runAction_successfullyEnqueuesDnsRefreshes() throws Exception { - createTld("bar"); persistActiveDomain("foo.bar"); persistActiveDomain("low.bar"); action.tlds = ImmutableSet.of("bar"); runMapreduce(); - verify(dnsQueue).addDomainRefreshTask("foo.bar"); - verify(dnsQueue).addDomainRefreshTask("low.bar"); + verify(dnsQueue).addDomainRefreshTask("foo.bar", Duration.ZERO); + verify(dnsQueue).addDomainRefreshTask("low.bar", Duration.ZERO); + } + + @Test + public void test_runAction_smearsOutDnsRefreshes() throws Exception { + persistActiveDomain("foo.bar"); + persistActiveDomain("low.bar"); + action.tlds = ImmutableSet.of("bar"); + action.smearMinutes = 1000; + runMapreduce(); + ArgumentCaptor captor = ArgumentCaptor.forClass(Duration.class); + verify(dnsQueue).addDomainRefreshTask(eq("foo.bar"), captor.capture()); + verify(dnsQueue).addDomainRefreshTask(eq("low.bar"), captor.capture()); + assertThat(captor.getAllValues()).containsExactly(standardMinutes(450), standardMinutes(782)); } @Test public void test_runAction_doesntRefreshDeletedDomain() throws Exception { - createTld("bar"); persistActiveDomain("foo.bar"); persistDeletedDomain("deleted.bar", DateTime.now(UTC).minusYears(1)); action.tlds = ImmutableSet.of("bar"); runMapreduce(); - verify(dnsQueue).addDomainRefreshTask("foo.bar"); - verify(dnsQueue, never()).addDomainRefreshTask("deleted.bar"); + verify(dnsQueue).addDomainRefreshTask("foo.bar", Duration.ZERO); + verify(dnsQueue, never()).addDomainRefreshTask("deleted.bar", Duration.ZERO); } @Test public void test_runAction_ignoresDomainsOnOtherTlds() throws Exception { - createTlds("bar", "baz"); + createTld("baz"); persistActiveDomain("foo.bar"); persistActiveDomain("low.bar"); persistActiveDomain("ignore.baz"); action.tlds = ImmutableSet.of("bar"); runMapreduce(); - verify(dnsQueue).addDomainRefreshTask("foo.bar"); - verify(dnsQueue).addDomainRefreshTask("low.bar"); - verify(dnsQueue, never()).addDomainRefreshTask("ignore.baz"); + verify(dnsQueue).addDomainRefreshTask("foo.bar", Duration.ZERO); + verify(dnsQueue).addDomainRefreshTask("low.bar", Duration.ZERO); + verify(dnsQueue, never()).addDomainRefreshTask("ignore.baz", Duration.ZERO); + } + + @Test + public void test_smearMinutesMustBeSpecified() { + action.tlds = ImmutableSet.of("bar"); + action.smearMinutes = 0; + IllegalArgumentException thrown = + assertThrows(IllegalArgumentException.class, () -> action.run()); + assertThat(thrown) + .hasMessageThat() + .isEqualTo("Must specify a positive number of smear minutes"); } }