diff --git a/java/google/registry/dns/DnsQueue.java b/java/google/registry/dns/DnsQueue.java
index 10f9eb7d6..f2a8ad5e4 100644
--- a/java/google/registry/dns/DnsQueue.java
+++ b/java/google/registry/dns/DnsQueue.java
@@ -103,16 +103,16 @@ public class DnsQueue {
@VisibleForTesting
long leaseTasksBatchSize = QueueConstants.maxLeaseCount();
- /**
- * Enqueues the given task type with the given target name to the DNS queue.
- */
- private TaskHandle addToQueue(TargetType targetType, String targetName, String tld) {
+ /** Enqueues the given task type with the given target name to the DNS queue. */
+ private TaskHandle addToQueue(
+ TargetType targetType, String targetName, String tld, Duration countdown) {
logger.atInfo().log(
"Adding task type=%s, target=%s, tld=%s to pull queue %s (%d tasks currently on queue)",
targetType, targetName, tld, DNS_PULL_QUEUE_NAME, queue.fetchStatistics().getNumTasks());
return queue.add(
TaskOptions.Builder.withDefaults()
.method(Method.PULL)
+ .countdownMillis(countdown.getMillis())
.param(DNS_TARGET_TYPE_PARAM, targetType.toString())
.param(DNS_TARGET_NAME_PARAM, targetName)
.param(DNS_TARGET_CREATE_TIME_PARAM, clock.nowUtc().toString())
@@ -127,20 +127,27 @@ public class DnsQueue {
Registries.findTldForName(InternetDomainName.from(fullyQualifiedHostName));
checkArgument(tld.isPresent(),
String.format("%s is not a subordinate host to a known tld", fullyQualifiedHostName));
- return addToQueue(TargetType.HOST, fullyQualifiedHostName, tld.get().toString());
+ return addToQueue(TargetType.HOST, fullyQualifiedHostName, tld.get().toString(), Duration.ZERO);
}
- /** Adds a task to the queue to refresh the DNS information for the specified domain. */
+ /** Enqueues a task to refresh DNS for the specified domain now. */
public TaskHandle addDomainRefreshTask(String fullyQualifiedDomainName) {
+ return addDomainRefreshTask(fullyQualifiedDomainName, Duration.ZERO);
+ }
+
+ /** Enqueues a task to refresh DNS for the specified domain at some point in the future. */
+ public TaskHandle addDomainRefreshTask(String fullyQualifiedDomainName, Duration countdown) {
return addToQueue(
TargetType.DOMAIN,
fullyQualifiedDomainName,
- assertTldExists(getTldFromDomainName(fullyQualifiedDomainName)));
+ assertTldExists(getTldFromDomainName(fullyQualifiedDomainName)),
+ countdown);
}
/** Adds a task to the queue to refresh the DNS information for the specified zone. */
public TaskHandle addZoneRefreshTask(String fullyQualifiedZoneName) {
- return addToQueue(TargetType.ZONE, fullyQualifiedZoneName, fullyQualifiedZoneName);
+ return addToQueue(
+ TargetType.ZONE, fullyQualifiedZoneName, fullyQualifiedZoneName, Duration.ZERO);
}
/**
diff --git a/java/google/registry/tools/server/RefreshDnsForAllDomainsAction.java b/java/google/registry/tools/server/RefreshDnsForAllDomainsAction.java
index 796a23a94..98021aa12 100644
--- a/java/google/registry/tools/server/RefreshDnsForAllDomainsAction.java
+++ b/java/google/registry/tools/server/RefreshDnsForAllDomainsAction.java
@@ -14,6 +14,7 @@
package google.registry.tools.server;
+import static com.google.common.base.Preconditions.checkArgument;
import static google.registry.mapreduce.inputs.EppResourceInputs.createEntityInput;
import static google.registry.model.EppResourceUtils.isActive;
import static google.registry.model.registry.Registries.assertTldsExist;
@@ -32,9 +33,11 @@ import google.registry.request.Parameter;
import google.registry.request.Response;
import google.registry.request.auth.Auth;
import google.registry.util.NonFinalForTesting;
+import java.util.Random;
import javax.inject.Inject;
import org.joda.time.DateTime;
import org.joda.time.DateTimeZone;
+import org.joda.time.Duration;
/**
* A mapreduce that enqueues DNS publish tasks on all active domains on the specified TLD(s).
@@ -46,6 +49,11 @@ import org.joda.time.DateTimeZone;
*
Because there are no auth settings in the {@link Action} annotation, this command can only be
* run internally, or by pretending to be internal by setting the X-AppEngine-QueueName header,
* which only admin users can do.
+ *
+ *
You must pass in a number of smearMinutes
as a URL parameter so that the DNS
+ * queue doesn't get overloaded. A rough rule of thumb for Cloud DNS is 1 minute per every 1,000
+ * domains. This smears the updates out over the next N minutes. For small TLDs consisting of fewer
+ * than 1,000 domains, passing in 1 is fine (which will execute all the updates immediately).
*/
@Action(
service = Action.Service.TOOLS,
@@ -57,18 +65,30 @@ public class RefreshDnsForAllDomainsAction implements Runnable {
@Inject MapreduceRunner mrRunner;
@Inject Response response;
- @Inject @Parameter(PARAM_TLDS) ImmutableSet tlds;
- @Inject RefreshDnsForAllDomainsAction() {}
+
+ @Inject
+ @Parameter(PARAM_TLDS)
+ ImmutableSet tlds;
+
+ @Inject
+ @Parameter("smearMinutes")
+ int smearMinutes;
+
+ @Inject Random random;
+
+ @Inject
+ RefreshDnsForAllDomainsAction() {}
@Override
public void run() {
assertTldsExist(tlds);
+ checkArgument(smearMinutes > 0, "Must specify a positive number of smear minutes");
mrRunner
.setJobName("Refresh DNS for all domains")
.setModuleName("tools")
.setDefaultMapShards(10)
.runMapOnly(
- new RefreshDnsForAllDomainsActionMapper(tlds),
+ new RefreshDnsForAllDomainsActionMapper(tlds, smearMinutes, random),
ImmutableList.of(createEntityInput(DomainBase.class)))
.sendLinkToMapreduceConsole(response);
}
@@ -77,14 +97,19 @@ public class RefreshDnsForAllDomainsAction implements Runnable {
public static class RefreshDnsForAllDomainsActionMapper
extends Mapper {
- private static final long serialVersionUID = 1455544013508953083L;
+ private static final long serialVersionUID = -5103865047156795489L;
@NonFinalForTesting private static DnsQueue dnsQueue = DnsQueue.create();
private final ImmutableSet tlds;
+ private final int smearMinutes;
+ private final Random random;
- RefreshDnsForAllDomainsActionMapper(ImmutableSet tlds) {
+ RefreshDnsForAllDomainsActionMapper(
+ ImmutableSet tlds, int smearMinutes, Random random) {
this.tlds = tlds;
+ this.smearMinutes = smearMinutes;
+ this.random = random;
}
@Override
@@ -93,7 +118,9 @@ public class RefreshDnsForAllDomainsAction implements Runnable {
if (tlds.contains(domain.getTld())) {
if (isActive(domain, DateTime.now(DateTimeZone.UTC))) {
try {
- dnsQueue.addDomainRefreshTask(domainName);
+ // Smear the task execution time over the next N minutes.
+ dnsQueue.addDomainRefreshTask(
+ domainName, Duration.standardMinutes(random.nextInt(smearMinutes)));
getContext().incrementCounter("active domains refreshed");
} catch (Throwable t) {
logger.atSevere().withCause(t).log(
diff --git a/java/google/registry/tools/server/ToolsServerModule.java b/java/google/registry/tools/server/ToolsServerModule.java
index 4c42dd668..778bea0fa 100644
--- a/java/google/registry/tools/server/ToolsServerModule.java
+++ b/java/google/registry/tools/server/ToolsServerModule.java
@@ -108,4 +108,10 @@ public class ToolsServerModule {
String provideJobId(HttpServletRequest req) {
return extractRequiredParameter(req, "jobId");
}
+
+ @Provides
+ @Parameter("smearMinutes")
+ static int provideSmearMinutes(HttpServletRequest req) {
+ return extractIntParameter(req, "smearMinutes");
+ }
}
diff --git a/java/google/registry/util/UtilsModule.java b/java/google/registry/util/UtilsModule.java
index 61bb14128..78d10b9b1 100644
--- a/java/google/registry/util/UtilsModule.java
+++ b/java/google/registry/util/UtilsModule.java
@@ -22,6 +22,7 @@ import dagger.Provides;
import java.security.NoSuchAlgorithmException;
import java.security.ProviderException;
import java.security.SecureRandom;
+import java.util.Random;
import javax.inject.Named;
import javax.inject.Singleton;
@@ -48,7 +49,6 @@ public abstract class UtilsModule {
abstract AppEngineServiceUtils provideAppEngineServiceUtils(
AppEngineServiceUtilsImpl appEngineServiceUtilsImpl);
-
@Singleton
@Provides
public static SecureRandom provideSecureRandom() {
@@ -59,6 +59,10 @@ public abstract class UtilsModule {
}
}
+ @Binds
+ @Singleton
+ abstract Random provideSecureRandomAsRandom(SecureRandom random);
+
@Singleton
@Provides
@Named("base58StringGenerator")
diff --git a/javatests/google/registry/tools/server/RefreshDnsForAllDomainsActionTest.java b/javatests/google/registry/tools/server/RefreshDnsForAllDomainsActionTest.java
index c9d2964f0..aa0ee7558 100644
--- a/javatests/google/registry/tools/server/RefreshDnsForAllDomainsActionTest.java
+++ b/javatests/google/registry/tools/server/RefreshDnsForAllDomainsActionTest.java
@@ -16,10 +16,12 @@ package google.registry.tools.server;
import static com.google.common.truth.Truth.assertThat;
import static google.registry.testing.DatastoreHelper.createTld;
-import static google.registry.testing.DatastoreHelper.createTlds;
import static google.registry.testing.DatastoreHelper.persistActiveDomain;
import static google.registry.testing.DatastoreHelper.persistDeletedDomain;
+import static google.registry.testing.JUnitBackports.assertThrows;
import static org.joda.time.DateTimeZone.UTC;
+import static org.joda.time.Duration.standardMinutes;
+import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
@@ -30,13 +32,16 @@ import google.registry.testing.FakeResponse;
import google.registry.testing.InjectRule;
import google.registry.testing.mapreduce.MapreduceTestCase;
import google.registry.tools.server.RefreshDnsForAllDomainsAction.RefreshDnsForAllDomainsActionMapper;
+import java.util.Random;
import org.joda.time.DateTime;
+import org.joda.time.Duration;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
+import org.mockito.ArgumentCaptor;
/** Unit tests for {@link RefreshDnsForAllDomainsAction}. */
@RunWith(JUnit4.class)
@@ -53,8 +58,13 @@ public class RefreshDnsForAllDomainsActionTest
origDnsQueue = RefreshDnsForAllDomainsActionMapper.setDnsQueueForTest(dnsQueue);
action = new RefreshDnsForAllDomainsAction();
+ action.smearMinutes = 1;
+ action.random = new Random();
+ action.random.setSeed(123L);
action.mrRunner = makeDefaultRunner();
action.response = new FakeResponse();
+
+ createTld("bar");
}
@After
@@ -70,36 +80,58 @@ public class RefreshDnsForAllDomainsActionTest
@Test
public void test_runAction_successfullyEnqueuesDnsRefreshes() throws Exception {
- createTld("bar");
persistActiveDomain("foo.bar");
persistActiveDomain("low.bar");
action.tlds = ImmutableSet.of("bar");
runMapreduce();
- verify(dnsQueue).addDomainRefreshTask("foo.bar");
- verify(dnsQueue).addDomainRefreshTask("low.bar");
+ verify(dnsQueue).addDomainRefreshTask("foo.bar", Duration.ZERO);
+ verify(dnsQueue).addDomainRefreshTask("low.bar", Duration.ZERO);
+ }
+
+ @Test
+ public void test_runAction_smearsOutDnsRefreshes() throws Exception {
+ persistActiveDomain("foo.bar");
+ persistActiveDomain("low.bar");
+ action.tlds = ImmutableSet.of("bar");
+ action.smearMinutes = 1000;
+ runMapreduce();
+ ArgumentCaptor captor = ArgumentCaptor.forClass(Duration.class);
+ verify(dnsQueue).addDomainRefreshTask(eq("foo.bar"), captor.capture());
+ verify(dnsQueue).addDomainRefreshTask(eq("low.bar"), captor.capture());
+ assertThat(captor.getAllValues()).containsExactly(standardMinutes(450), standardMinutes(782));
}
@Test
public void test_runAction_doesntRefreshDeletedDomain() throws Exception {
- createTld("bar");
persistActiveDomain("foo.bar");
persistDeletedDomain("deleted.bar", DateTime.now(UTC).minusYears(1));
action.tlds = ImmutableSet.of("bar");
runMapreduce();
- verify(dnsQueue).addDomainRefreshTask("foo.bar");
- verify(dnsQueue, never()).addDomainRefreshTask("deleted.bar");
+ verify(dnsQueue).addDomainRefreshTask("foo.bar", Duration.ZERO);
+ verify(dnsQueue, never()).addDomainRefreshTask("deleted.bar", Duration.ZERO);
}
@Test
public void test_runAction_ignoresDomainsOnOtherTlds() throws Exception {
- createTlds("bar", "baz");
+ createTld("baz");
persistActiveDomain("foo.bar");
persistActiveDomain("low.bar");
persistActiveDomain("ignore.baz");
action.tlds = ImmutableSet.of("bar");
runMapreduce();
- verify(dnsQueue).addDomainRefreshTask("foo.bar");
- verify(dnsQueue).addDomainRefreshTask("low.bar");
- verify(dnsQueue, never()).addDomainRefreshTask("ignore.baz");
+ verify(dnsQueue).addDomainRefreshTask("foo.bar", Duration.ZERO);
+ verify(dnsQueue).addDomainRefreshTask("low.bar", Duration.ZERO);
+ verify(dnsQueue, never()).addDomainRefreshTask("ignore.baz", Duration.ZERO);
+ }
+
+ @Test
+ public void test_smearMinutesMustBeSpecified() {
+ action.tlds = ImmutableSet.of("bar");
+ action.smearMinutes = 0;
+ IllegalArgumentException thrown =
+ assertThrows(IllegalArgumentException.class, () -> action.run());
+ assertThat(thrown)
+ .hasMessageThat()
+ .isEqualTo("Must specify a positive number of smear minutes");
}
}