Migrate ReadDnsQueueAction to use CloudTasksUtils (#1669)

* Migrate ReadDnsQueueAction to use CloudTasksUtils

Also marked TaskQueueUtils as deprecated and fixed a few linter errors.

Note that DNS pull queue still requires the use of the GAE Task Queue API.

* Fix a test failure

* Remove TaskQueueUtils from VKeyTest

* Remove the @error exception that was inadvertently pulled in
This commit is contained in:
Lai Jiang 2022-06-15 13:48:28 -04:00 committed by GitHub
parent 471205ad77
commit fbc37485f5
9 changed files with 178 additions and 245 deletions

View file

@ -34,7 +34,6 @@ import com.google.common.collect.ImmutableSortedMap;
import dagger.Module; import dagger.Module;
import dagger.Provides; import dagger.Provides;
import google.registry.persistence.transaction.JpaTransactionManager; import google.registry.persistence.transaction.JpaTransactionManager;
import google.registry.util.TaskQueueUtils;
import google.registry.util.YamlUtils; import google.registry.util.YamlUtils;
import java.lang.annotation.Documented; import java.lang.annotation.Documented;
import java.lang.annotation.Retention; import java.lang.annotation.Retention;
@ -952,7 +951,7 @@ public final class RegistryConfig {
* <p>Note that this uses {@code @Named} instead of {@code @Config} so that it can be used from * <p>Note that this uses {@code @Named} instead of {@code @Config} so that it can be used from
* the low-level util package, which cannot have a dependency on the config package. * the low-level util package, which cannot have a dependency on the config package.
* *
* @see TaskQueueUtils * @see google.registry.util.CloudTasksUtils
*/ */
@Provides @Provides
@Named("transientFailureRetries") @Named("transientFailureRetries")

View file

@ -30,14 +30,13 @@ import static google.registry.dns.DnsModule.PARAM_REFRESH_REQUEST_CREATED;
import static google.registry.request.RequestParameters.PARAM_TLD; import static google.registry.request.RequestParameters.PARAM_TLD;
import static google.registry.util.DomainNameUtils.getSecondLevelDomain; import static google.registry.util.DomainNameUtils.getSecondLevelDomain;
import static java.nio.charset.StandardCharsets.UTF_8; import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.concurrent.TimeUnit.SECONDS;
import com.google.appengine.api.taskqueue.Queue;
import com.google.appengine.api.taskqueue.TaskHandle; import com.google.appengine.api.taskqueue.TaskHandle;
import com.google.appengine.api.taskqueue.TaskOptions;
import com.google.auto.value.AutoValue; import com.google.auto.value.AutoValue;
import com.google.cloud.tasks.v2.Task;
import com.google.common.collect.ComparisonChain; import com.google.common.collect.ComparisonChain;
import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSetMultimap; import com.google.common.collect.ImmutableSetMultimap;
import com.google.common.collect.Iterables; import com.google.common.collect.Iterables;
@ -50,20 +49,19 @@ import google.registry.dns.DnsConstants.TargetType;
import google.registry.model.tld.Registries; import google.registry.model.tld.Registries;
import google.registry.model.tld.Registry; import google.registry.model.tld.Registry;
import google.registry.request.Action; import google.registry.request.Action;
import google.registry.request.Action.Service;
import google.registry.request.Parameter; import google.registry.request.Parameter;
import google.registry.request.auth.Auth; import google.registry.request.auth.Auth;
import google.registry.util.Clock; import google.registry.util.Clock;
import google.registry.util.TaskQueueUtils; import google.registry.util.CloudTasksUtils;
import java.io.UnsupportedEncodingException; import java.io.UnsupportedEncodingException;
import java.util.Collection; import java.util.Collection;
import java.util.Comparator; import java.util.Comparator;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.Random;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import javax.inject.Inject; import javax.inject.Inject;
import javax.inject.Named;
import org.joda.time.DateTime; import org.joda.time.DateTime;
import org.joda.time.Duration; import org.joda.time.Duration;
@ -84,7 +82,6 @@ import org.joda.time.Duration;
public final class ReadDnsQueueAction implements Runnable { public final class ReadDnsQueueAction implements Runnable {
private static final String PARAM_JITTER_SECONDS = "jitterSeconds"; private static final String PARAM_JITTER_SECONDS = "jitterSeconds";
private static final Random random = new Random();
private static final FluentLogger logger = FluentLogger.forEnclosingClass(); private static final FluentLogger logger = FluentLogger.forEnclosingClass();
/** /**
@ -101,15 +98,31 @@ public final class ReadDnsQueueAction implements Runnable {
*/ */
private static final Duration LEASE_PADDING = Duration.standardMinutes(1); private static final Duration LEASE_PADDING = Duration.standardMinutes(1);
@Inject @Config("dnsTldUpdateBatchSize") int tldUpdateBatchSize; private final int tldUpdateBatchSize;
@Inject @Config("readDnsQueueActionRuntime") Duration requestedMaximumDuration; private final Duration requestedMaximumDuration;
@Inject @Named(DNS_PUBLISH_PUSH_QUEUE_NAME) Queue dnsPublishPushQueue; private final Optional<Integer> jitterSeconds;
@Inject @Parameter(PARAM_JITTER_SECONDS) Optional<Integer> jitterSeconds; private final Clock clock;
@Inject Clock clock; private final DnsQueue dnsQueue;
@Inject DnsQueue dnsQueue; private final HashFunction hashFunction;
@Inject HashFunction hashFunction; private final CloudTasksUtils cloudTasksUtils;
@Inject TaskQueueUtils taskQueueUtils;
@Inject ReadDnsQueueAction() {} @Inject
ReadDnsQueueAction(
@Config("dnsTldUpdateBatchSize") int tldUpdateBatchSize,
@Config("readDnsQueueActionRuntime") Duration requestedMaximumDuration,
@Parameter(PARAM_JITTER_SECONDS) Optional<Integer> jitterSeconds,
Clock clock,
DnsQueue dnsQueue,
HashFunction hashFunction,
CloudTasksUtils cloudTasksUtils) {
this.tldUpdateBatchSize = tldUpdateBatchSize;
this.requestedMaximumDuration = requestedMaximumDuration;
this.jitterSeconds = jitterSeconds;
this.clock = clock;
this.dnsQueue = dnsQueue;
this.hashFunction = hashFunction;
this.cloudTasksUtils = cloudTasksUtils;
}
/** Container for items we pull out of the DNS pull queue and process for fanout. */ /** Container for items we pull out of the DNS pull queue and process for fanout. */
@AutoValue @AutoValue
@ -322,17 +335,13 @@ public final class ReadDnsQueueAction implements Runnable {
if (numPublishLocks <= 1) { if (numPublishLocks <= 1) {
enqueueUpdates(tld, 1, 1, tldRefreshItemsEntry.getValue()); enqueueUpdates(tld, 1, 1, tldRefreshItemsEntry.getValue());
} else { } else {
tldRefreshItemsEntry tldRefreshItemsEntry.getValue().stream()
.getValue()
.stream()
.collect( .collect(
toImmutableSetMultimap( toImmutableSetMultimap(
refreshItem -> getLockIndex(tld, numPublishLocks, refreshItem), refreshItem -> getLockIndex(tld, numPublishLocks, refreshItem),
refreshItem -> refreshItem)) refreshItem -> refreshItem))
.asMap() .asMap()
.entrySet() .forEach((key, value) -> enqueueUpdates(tld, key, numPublishLocks, value));
.forEach(
entry -> enqueueUpdates(tld, entry.getKey(), numPublishLocks, entry.getValue()));
} }
} }
} }
@ -340,10 +349,10 @@ public final class ReadDnsQueueAction implements Runnable {
/** /**
* Returns the lock index for a given refreshItem. * Returns the lock index for a given refreshItem.
* *
* <p>We hash the second level domain domain for all records, to group in-balliwick hosts (the * <p>We hash the second level domain for all records, to group in-bailiwick hosts (the only ones
* only ones we refresh DNS for) with their superordinate domains. We use consistent hashing to * we refresh DNS for) with their superordinate domains. We use consistent hashing to determine
* determine the lock index because it gives us [0,N) bucketing properties out of the box, then * the lock index because it gives us [0,N) bucketing properties out of the box, then add 1 to
* add 1 to make indexes within [1,N]. * make indexes within [1,N].
*/ */
private int getLockIndex(String tld, int numPublishLocks, RefreshItem refreshItem) { private int getLockIndex(String tld, int numPublishLocks, RefreshItem refreshItem) {
String domain = getSecondLevelDomain(refreshItem.name(), tld); String domain = getSecondLevelDomain(refreshItem.name(), tld);
@ -360,33 +369,32 @@ public final class ReadDnsQueueAction implements Runnable {
DateTime earliestCreateTime = DateTime earliestCreateTime =
chunk.stream().map(RefreshItem::creationTime).min(Comparator.naturalOrder()).get(); chunk.stream().map(RefreshItem::creationTime).min(Comparator.naturalOrder()).get();
for (String dnsWriter : Registry.get(tld).getDnsWriters()) { for (String dnsWriter : Registry.get(tld).getDnsWriters()) {
taskQueueUtils.enqueue( Task task =
dnsPublishPushQueue, cloudTasksUtils.createPostTaskWithJitter(
TaskOptions.Builder.withUrl(PublishDnsUpdatesAction.PATH) PublishDnsUpdatesAction.PATH,
.countdownMillis( Service.BACKEND.toString(),
jitterSeconds ImmutableMultimap.<String, String>builder()
.map(seconds -> random.nextInt((int) SECONDS.toMillis(seconds))) .put(PARAM_TLD, tld)
.orElse(0)) .put(PARAM_DNS_WRITER, dnsWriter)
.param(PARAM_TLD, tld) .put(PARAM_LOCK_INDEX, Integer.toString(lockIndex))
.param(PARAM_DNS_WRITER, dnsWriter) .put(PARAM_NUM_PUBLISH_LOCKS, Integer.toString(numPublishLocks))
.param(PARAM_LOCK_INDEX, Integer.toString(lockIndex)) .put(PARAM_PUBLISH_TASK_ENQUEUED, clock.nowUtc().toString())
.param(PARAM_NUM_PUBLISH_LOCKS, Integer.toString(numPublishLocks)) .put(PARAM_REFRESH_REQUEST_CREATED, earliestCreateTime.toString())
.param(PARAM_PUBLISH_TASK_ENQUEUED, clock.nowUtc().toString()) .put(
.param(PARAM_REFRESH_REQUEST_CREATED, earliestCreateTime.toString())
.param(
PARAM_DOMAINS, PARAM_DOMAINS,
chunk chunk.stream()
.stream()
.filter(item -> item.type() == TargetType.DOMAIN) .filter(item -> item.type() == TargetType.DOMAIN)
.map(RefreshItem::name) .map(RefreshItem::name)
.collect(Collectors.joining(","))) .collect(Collectors.joining(",")))
.param( .put(
PARAM_HOSTS, PARAM_HOSTS,
chunk chunk.stream()
.stream()
.filter(item -> item.type() == TargetType.HOST) .filter(item -> item.type() == TargetType.HOST)
.map(RefreshItem::name) .map(RefreshItem::name)
.collect(Collectors.joining(",")))); .collect(Collectors.joining(",")))
.build(),
jitterSeconds);
cloudTasksUtils.enqueue(DNS_PUBLISH_PUSH_QUEUE_NAME, task);
} }
} }
} }

View file

@ -48,7 +48,6 @@ import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSortedSet; import com.google.common.collect.ImmutableSortedSet;
import com.google.common.collect.Ordering; import com.google.common.collect.Ordering;
import com.google.common.collect.Sets; import com.google.common.collect.Sets;
import com.google.common.flogger.FluentLogger;
import com.google.common.net.InternetDomainName; import com.google.common.net.InternetDomainName;
import google.registry.dns.DnsQueue; import google.registry.dns.DnsQueue;
import google.registry.flows.EppException; import google.registry.flows.EppException;
@ -146,8 +145,6 @@ public final class DomainUpdateFlow implements TransactionalFlow {
private static final ImmutableSet<StatusValue> UPDATE_DISALLOWED_STATUSES = private static final ImmutableSet<StatusValue> UPDATE_DISALLOWED_STATUSES =
ImmutableSet.of(StatusValue.PENDING_DELETE, StatusValue.SERVER_UPDATE_PROHIBITED); ImmutableSet.of(StatusValue.PENDING_DELETE, StatusValue.SERVER_UPDATE_PROHIBITED);
private static final FluentLogger logger = FluentLogger.forEnclosingClass();
@Inject ResourceCommand resourceCommand; @Inject ResourceCommand resourceCommand;
@Inject ExtensionManager extensionManager; @Inject ExtensionManager extensionManager;
@Inject EppInput eppInput; @Inject EppInput eppInput;

View file

@ -28,6 +28,7 @@ import google.registry.model.ofy.Ofy;
import google.registry.request.HttpException.NotFoundException; import google.registry.request.HttpException.NotFoundException;
import google.registry.request.RequestModule; import google.registry.request.RequestModule;
import google.registry.testing.AppEngineExtension; import google.registry.testing.AppEngineExtension;
import google.registry.testing.CloudTasksHelper.CloudTasksHelperModule;
import google.registry.testing.FakeClock; import google.registry.testing.FakeClock;
import google.registry.testing.InjectExtension; import google.registry.testing.InjectExtension;
import java.io.PrintWriter; import java.io.PrintWriter;
@ -59,8 +60,10 @@ public final class DnsInjectionTest {
void beforeEach() throws Exception { void beforeEach() throws Exception {
inject.setStaticField(Ofy.class, "clock", clock); inject.setStaticField(Ofy.class, "clock", clock);
when(rsp.getWriter()).thenReturn(new PrintWriter(httpOutput)); when(rsp.getWriter()).thenReturn(new PrintWriter(httpOutput));
component = DaggerDnsTestComponent.builder() component =
DaggerDnsTestComponent.builder()
.requestModule(new RequestModule(req, rsp)) .requestModule(new RequestModule(req, rsp))
.cloudTasksHelperModule(new CloudTasksHelperModule(clock))
.build(); .build();
dnsQueue = component.dnsQueue(); dnsQueue = component.dnsQueue();
createTld("lol"); createTld("lol");

View file

@ -19,12 +19,14 @@ import google.registry.config.RegistryConfig.ConfigModule;
import google.registry.cron.CronModule; import google.registry.cron.CronModule;
import google.registry.dns.writer.VoidDnsWriterModule; import google.registry.dns.writer.VoidDnsWriterModule;
import google.registry.request.RequestModule; import google.registry.request.RequestModule;
import google.registry.testing.CloudTasksHelper.CloudTasksHelperModule;
import google.registry.util.UtilsModule; import google.registry.util.UtilsModule;
import javax.inject.Singleton; import javax.inject.Singleton;
@Singleton @Singleton
@Component( @Component(
modules = { modules = {
CloudTasksHelperModule.class,
ConfigModule.class, ConfigModule.class,
CronModule.class, CronModule.class,
DnsModule.class, DnsModule.class,

View file

@ -17,7 +17,6 @@ package google.registry.dns;
import static com.google.appengine.api.taskqueue.QueueFactory.getQueue; import static com.google.appengine.api.taskqueue.QueueFactory.getQueue;
import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Lists.transform; import static com.google.common.collect.Lists.transform;
import static com.google.common.collect.MoreCollectors.onlyElement;
import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertThat;
import static com.google.common.truth.Truth8.assertThat; import static com.google.common.truth.Truth8.assertThat;
import static google.registry.dns.DnsConstants.DNS_PUBLISH_PUSH_QUEUE_NAME; import static google.registry.dns.DnsConstants.DNS_PUBLISH_PUSH_QUEUE_NAME;
@ -28,13 +27,11 @@ import static google.registry.dns.DnsConstants.DNS_TARGET_TYPE_PARAM;
import static google.registry.request.RequestParameters.PARAM_TLD; import static google.registry.request.RequestParameters.PARAM_TLD;
import static google.registry.testing.DatabaseHelper.createTlds; import static google.registry.testing.DatabaseHelper.createTlds;
import static google.registry.testing.DatabaseHelper.persistResource; import static google.registry.testing.DatabaseHelper.persistResource;
import static google.registry.testing.TaskQueueHelper.assertNoTasksEnqueued;
import static google.registry.testing.TaskQueueHelper.assertTasksEnqueued;
import static google.registry.testing.TaskQueueHelper.getQueuedParams;
import com.google.appengine.api.taskqueue.QueueFactory; import com.google.appengine.api.taskqueue.QueueFactory;
import com.google.appengine.api.taskqueue.TaskOptions; import com.google.appengine.api.taskqueue.TaskOptions;
import com.google.appengine.api.taskqueue.TaskOptions.Method; import com.google.cloud.tasks.v2.HttpMethod;
import com.google.cloud.tasks.v2.Task;
import com.google.common.base.Joiner; import com.google.common.base.Joiner;
import com.google.common.base.Splitter; import com.google.common.base.Splitter;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
@ -46,10 +43,12 @@ import google.registry.dns.DnsConstants.TargetType;
import google.registry.model.tld.Registry; import google.registry.model.tld.Registry;
import google.registry.model.tld.Registry.TldType; import google.registry.model.tld.Registry.TldType;
import google.registry.testing.AppEngineExtension; import google.registry.testing.AppEngineExtension;
import google.registry.testing.CloudTasksHelper;
import google.registry.testing.CloudTasksHelper.TaskMatcher;
import google.registry.testing.FakeClock; import google.registry.testing.FakeClock;
import google.registry.testing.TaskQueueHelper.TaskMatcher; import google.registry.testing.TaskQueueHelper;
import google.registry.util.Retrier; import google.registry.testing.UriParameters;
import google.registry.util.TaskQueueUtils; import java.nio.charset.StandardCharsets;
import java.util.Map.Entry; import java.util.Map.Entry;
import java.util.Optional; import java.util.Optional;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -67,7 +66,8 @@ public class ReadDnsQueueActionTest {
private DnsQueue dnsQueue; private DnsQueue dnsQueue;
// Because of a bug in the queue test environment - b/73372999 - we must set the fake date of the // Because of a bug in the queue test environment - b/73372999 - we must set the fake date of the
// test in the future. Set to year 3000 so it'll remain in the future for a very long time. // test in the future. Set to year 3000 so it'll remain in the future for a very long time.
private FakeClock clock = new FakeClock(DateTime.parse("3000-01-01TZ")); private final FakeClock clock = new FakeClock(DateTime.parse("3000-01-01TZ"));
private final CloudTasksHelper cloudTasksHelper = new CloudTasksHelper(clock);
@RegisterExtension @RegisterExtension
public final AppEngineExtension appEngine = public final AppEngineExtension appEngine =
@ -79,10 +79,6 @@ public class ReadDnsQueueActionTest {
"<?xml version=\"1.0\" encoding=\"UTF-8\"?>", "<?xml version=\"1.0\" encoding=\"UTF-8\"?>",
"<queue-entries>", "<queue-entries>",
" <queue>", " <queue>",
" <name>dns-publish</name>",
" <rate>1/s</rate>",
" </queue>",
" <queue>",
" <name>dns-pull</name>", " <name>dns-pull</name>",
" <mode>pull</mode>", " <mode>pull</mode>",
" </queue>", " </queue>",
@ -116,24 +112,23 @@ public class ReadDnsQueueActionTest {
} }
private void run() { private void run() {
ReadDnsQueueAction action = new ReadDnsQueueAction(); ReadDnsQueueAction action =
action.tldUpdateBatchSize = TEST_TLD_UPDATE_BATCH_SIZE; new ReadDnsQueueAction(
action.requestedMaximumDuration = Duration.standardSeconds(10); TEST_TLD_UPDATE_BATCH_SIZE,
action.clock = clock; Duration.standardSeconds(10),
action.dnsQueue = dnsQueue; Optional.empty(),
action.dnsPublishPushQueue = QueueFactory.getQueue(DNS_PUBLISH_PUSH_QUEUE_NAME); clock,
action.hashFunction = Hashing.murmur3_32(); dnsQueue,
action.taskQueueUtils = new TaskQueueUtils(new Retrier(null, 1)); Hashing.murmur3_32(),
action.jitterSeconds = Optional.empty(); cloudTasksHelper.getTestCloudTasksUtils());
// Advance the time a little, to ensure that leaseTasks() returns all tasks. // Advance the time a little, to ensure that leaseTasks() returns all tasks.
clock.advanceBy(Duration.standardHours(1)); clock.advanceBy(Duration.standardHours(1));
action.run(); action.run();
} }
private static TaskOptions createRefreshTask(String name, TargetType type) { private static TaskOptions createRefreshTask(String name, TargetType type) {
TaskOptions options = TaskOptions options =
TaskOptions.Builder.withMethod(Method.PULL) TaskOptions.Builder.withMethod(TaskOptions.Method.PULL)
.param(DNS_TARGET_TYPE_PARAM, type.toString()) .param(DNS_TARGET_TYPE_PARAM, type.toString())
.param(DNS_TARGET_NAME_PARAM, name) .param(DNS_TARGET_NAME_PARAM, name)
.param(DNS_TARGET_CREATE_TIME_PARAM, "3000-01-01TZ"); .param(DNS_TARGET_CREATE_TIME_PARAM, "3000-01-01TZ");
@ -141,8 +136,8 @@ public class ReadDnsQueueActionTest {
return options.param("tld", tld); return options.param("tld", tld);
} }
private static TaskMatcher createDomainRefreshTaskMatcher(String name) { private static TaskQueueHelper.TaskMatcher createDomainRefreshTaskMatcher(String name) {
return new TaskMatcher() return new TaskQueueHelper.TaskMatcher()
.param(DNS_TARGET_NAME_PARAM, name) .param(DNS_TARGET_NAME_PARAM, name)
.param(DNS_TARGET_TYPE_PARAM, TargetType.DOMAIN.toString()); .param(DNS_TARGET_TYPE_PARAM, TargetType.DOMAIN.toString());
} }
@ -150,7 +145,7 @@ public class ReadDnsQueueActionTest {
private void assertTldsEnqueuedInPushQueue(ImmutableMultimap<String, String> tldsToDnsWriters) { private void assertTldsEnqueuedInPushQueue(ImmutableMultimap<String, String> tldsToDnsWriters) {
// By default, the publishDnsUpdates tasks will be enqueued one hour after the update items were // By default, the publishDnsUpdates tasks will be enqueued one hour after the update items were
// created in the pull queue. This is because of the clock.advanceBy in run() // created in the pull queue. This is because of the clock.advanceBy in run()
assertTasksEnqueued( cloudTasksHelper.assertTasksEnqueued(
DNS_PUBLISH_PUSH_QUEUE_NAME, DNS_PUBLISH_PUSH_QUEUE_NAME,
transform( transform(
tldsToDnsWriters.entries().asList(), tldsToDnsWriters.entries().asList(),
@ -175,12 +170,12 @@ public class ReadDnsQueueActionTest {
run(); run();
assertNoTasksEnqueued(DNS_PULL_QUEUE_NAME); TaskQueueHelper.assertNoTasksEnqueued(DNS_PULL_QUEUE_NAME);
assertTasksEnqueued( cloudTasksHelper.assertTasksEnqueued(
DNS_PUBLISH_PUSH_QUEUE_NAME, DNS_PUBLISH_PUSH_QUEUE_NAME,
new TaskMatcher().method("POST"), new TaskMatcher().method(HttpMethod.POST),
new TaskMatcher().method("POST"), new TaskMatcher().method(HttpMethod.POST),
new TaskMatcher().method("POST")); new TaskMatcher().method(HttpMethod.POST));
} }
@RetryingTest(4) @RetryingTest(4)
@ -191,7 +186,7 @@ public class ReadDnsQueueActionTest {
run(); run();
assertNoTasksEnqueued(DNS_PULL_QUEUE_NAME); TaskQueueHelper.assertNoTasksEnqueued(DNS_PULL_QUEUE_NAME);
assertTldsEnqueuedInPushQueue( assertTldsEnqueuedInPushQueue(
ImmutableMultimap.of("com", "comWriter", "net", "netWriter", "example", "exampleWriter")); ImmutableMultimap.of("com", "comWriter", "net", "netWriter", "example", "exampleWriter"));
} }
@ -208,17 +203,24 @@ public class ReadDnsQueueActionTest {
run(); run();
assertNoTasksEnqueued(DNS_PULL_QUEUE_NAME); TaskQueueHelper.assertNoTasksEnqueued(DNS_PULL_QUEUE_NAME);
ImmutableList<ImmutableMultimap<String, String>> queuedParams = ImmutableList<Task> queuedTasks =
getQueuedParams(DNS_PUBLISH_PUSH_QUEUE_NAME); ImmutableList.copyOf(cloudTasksHelper.getTestTasksFor(DNS_PUBLISH_PUSH_QUEUE_NAME));
// ReadDnsQueueAction batches items per TLD in batches of size 100. // ReadDnsQueueAction batches items per TLD in batches of size 100.
// So for 1500 items in the DNS queue, we expect 15 items in the push queue // So for 1500 items in the DNS queue, we expect 15 items in the push queue
assertThat(queuedParams).hasSize(15); assertThat(queuedTasks).hasSize(15);
// Check all the expected domains are indeed enqueued // Check all the expected domains are indeed enqueued
assertThat( assertThat(
queuedParams.stream() queuedTasks.stream()
.map(params -> params.get("domains").stream().collect(onlyElement())) .flatMap(
.flatMap(values -> Splitter.on(',').splitToList(values).stream())) task ->
UriParameters.parse(
task.getAppEngineHttpRequest()
.getBody()
.toString(StandardCharsets.UTF_8))
.get("domains")
.stream())
.flatMap(values -> Splitter.on(',').splitToStream(values)))
.containsExactlyElementsIn(domains); .containsExactlyElementsIn(domains);
} }
@ -233,7 +235,7 @@ public class ReadDnsQueueActionTest {
run(); run();
assertNoTasksEnqueued(DNS_PULL_QUEUE_NAME); TaskQueueHelper.assertNoTasksEnqueued(DNS_PULL_QUEUE_NAME);
assertTldsEnqueuedInPushQueue(ImmutableMultimap.of("com", "comWriter", "com", "otherWriter")); assertTldsEnqueuedInPushQueue(ImmutableMultimap.of("com", "comWriter", "com", "otherWriter"));
} }
@ -248,18 +250,18 @@ public class ReadDnsQueueActionTest {
run(); run();
assertNoTasksEnqueued(DNS_PULL_QUEUE_NAME); TaskQueueHelper.assertNoTasksEnqueued(DNS_PULL_QUEUE_NAME);
assertThat(getQueuedParams(DNS_PUBLISH_PUSH_QUEUE_NAME)).hasSize(1); cloudTasksHelper.assertTasksEnqueued(
assertThat(getQueuedParams(DNS_PUBLISH_PUSH_QUEUE_NAME).get(0)) DNS_PUBLISH_PUSH_QUEUE_NAME,
.containsExactly( new TaskMatcher()
"enqueued", "3000-02-05T01:00:00.000Z", .param("enqueued", "3000-02-05T01:00:00.000Z")
"itemsCreated", "3000-02-03T00:00:00.000Z", .param("itemsCreated", "3000-02-03T00:00:00.000Z")
"tld", "com", .param("tld", "com")
"dnsWriter", "comWriter", .param("dnsWriter", "comWriter")
"domains", "domain1.com,domain2.com,domain3.com", .param("domains", "domain1.com,domain2.com,domain3.com")
"hosts", "", .param("hosts", "")
"lockIndex", "1", .param("lockIndex", "1")
"numPublishLocks", "1"); .param("numPublishLocks", "1"));
} }
@RetryingTest(4) @RetryingTest(4)
@ -271,7 +273,8 @@ public class ReadDnsQueueActionTest {
run(); run();
assertTasksEnqueued(DNS_PULL_QUEUE_NAME, createDomainRefreshTaskMatcher("domain.net")); TaskQueueHelper.assertTasksEnqueued(
DNS_PULL_QUEUE_NAME, createDomainRefreshTaskMatcher("domain.net"));
assertTldsEnqueuedInPushQueue( assertTldsEnqueuedInPushQueue(
ImmutableMultimap.of("com", "comWriter", "example", "exampleWriter")); ImmutableMultimap.of("com", "comWriter", "example", "exampleWriter"));
} }
@ -283,7 +286,7 @@ public class ReadDnsQueueActionTest {
QueueFactory.getQueue(DNS_PULL_QUEUE_NAME) QueueFactory.getQueue(DNS_PULL_QUEUE_NAME)
.add( .add(
TaskOptions.Builder.withDefaults() TaskOptions.Builder.withDefaults()
.method(Method.PULL) .method(TaskOptions.Method.PULL)
.param(DNS_TARGET_TYPE_PARAM, TargetType.DOMAIN.toString()) .param(DNS_TARGET_TYPE_PARAM, TargetType.DOMAIN.toString())
.param(DNS_TARGET_NAME_PARAM, "domain.unknown") .param(DNS_TARGET_NAME_PARAM, "domain.unknown")
.param(DNS_TARGET_CREATE_TIME_PARAM, "3000-01-01TZ") .param(DNS_TARGET_CREATE_TIME_PARAM, "3000-01-01TZ")
@ -291,7 +294,8 @@ public class ReadDnsQueueActionTest {
run(); run();
assertTasksEnqueued(DNS_PULL_QUEUE_NAME, createDomainRefreshTaskMatcher("domain.unknown")); TaskQueueHelper.assertTasksEnqueued(
DNS_PULL_QUEUE_NAME, createDomainRefreshTaskMatcher("domain.unknown"));
assertTldsEnqueuedInPushQueue( assertTldsEnqueuedInPushQueue(
ImmutableMultimap.of("com", "comWriter", "example", "exampleWriter")); ImmutableMultimap.of("com", "comWriter", "example", "exampleWriter"));
} }
@ -304,7 +308,7 @@ public class ReadDnsQueueActionTest {
QueueFactory.getQueue(DNS_PULL_QUEUE_NAME) QueueFactory.getQueue(DNS_PULL_QUEUE_NAME)
.add( .add(
TaskOptions.Builder.withDefaults() TaskOptions.Builder.withDefaults()
.method(Method.PULL) .method(TaskOptions.Method.PULL)
.param(DNS_TARGET_TYPE_PARAM, TargetType.DOMAIN.toString()) .param(DNS_TARGET_TYPE_PARAM, TargetType.DOMAIN.toString())
.param(DNS_TARGET_NAME_PARAM, "domain.wrongtld") .param(DNS_TARGET_NAME_PARAM, "domain.wrongtld")
.param(DNS_TARGET_CREATE_TIME_PARAM, "3000-01-01TZ") .param(DNS_TARGET_CREATE_TIME_PARAM, "3000-01-01TZ")
@ -312,7 +316,7 @@ public class ReadDnsQueueActionTest {
run(); run();
assertNoTasksEnqueued(DNS_PULL_QUEUE_NAME); TaskQueueHelper.assertNoTasksEnqueued(DNS_PULL_QUEUE_NAME);
assertTldsEnqueuedInPushQueue( assertTldsEnqueuedInPushQueue(
ImmutableMultimap.of("com", "comWriter", "example", "exampleWriter", "net", "netWriter")); ImmutableMultimap.of("com", "comWriter", "example", "exampleWriter", "net", "netWriter"));
} }
@ -324,14 +328,14 @@ public class ReadDnsQueueActionTest {
QueueFactory.getQueue(DNS_PULL_QUEUE_NAME) QueueFactory.getQueue(DNS_PULL_QUEUE_NAME)
.add( .add(
TaskOptions.Builder.withDefaults() TaskOptions.Builder.withDefaults()
.method(Method.PULL) .method(TaskOptions.Method.PULL)
.param(DNS_TARGET_TYPE_PARAM, TargetType.DOMAIN.toString()) .param(DNS_TARGET_TYPE_PARAM, TargetType.DOMAIN.toString())
.param(DNS_TARGET_NAME_PARAM, "domain.net")); .param(DNS_TARGET_NAME_PARAM, "domain.net"));
run(); run();
// The corrupt task isn't in the pull queue, but also isn't in the push queue // The corrupt task isn't in the pull queue, but also isn't in the push queue
assertNoTasksEnqueued(DNS_PULL_QUEUE_NAME); TaskQueueHelper.assertNoTasksEnqueued(DNS_PULL_QUEUE_NAME);
assertTldsEnqueuedInPushQueue( assertTldsEnqueuedInPushQueue(
ImmutableMultimap.of("com", "comWriter", "example", "exampleWriter")); ImmutableMultimap.of("com", "comWriter", "example", "exampleWriter"));
} }
@ -343,14 +347,14 @@ public class ReadDnsQueueActionTest {
QueueFactory.getQueue(DNS_PULL_QUEUE_NAME) QueueFactory.getQueue(DNS_PULL_QUEUE_NAME)
.add( .add(
TaskOptions.Builder.withDefaults() TaskOptions.Builder.withDefaults()
.method(Method.PULL) .method(TaskOptions.Method.PULL)
.param(DNS_TARGET_TYPE_PARAM, TargetType.DOMAIN.toString()) .param(DNS_TARGET_TYPE_PARAM, TargetType.DOMAIN.toString())
.param(PARAM_TLD, "net")); .param(PARAM_TLD, "net"));
run(); run();
// The corrupt task isn't in the pull queue, but also isn't in the push queue // The corrupt task isn't in the pull queue, but also isn't in the push queue
assertNoTasksEnqueued(DNS_PULL_QUEUE_NAME); TaskQueueHelper.assertNoTasksEnqueued(DNS_PULL_QUEUE_NAME);
assertTldsEnqueuedInPushQueue( assertTldsEnqueuedInPushQueue(
ImmutableMultimap.of("com", "comWriter", "example", "exampleWriter")); ImmutableMultimap.of("com", "comWriter", "example", "exampleWriter"));
} }
@ -362,14 +366,14 @@ public class ReadDnsQueueActionTest {
QueueFactory.getQueue(DNS_PULL_QUEUE_NAME) QueueFactory.getQueue(DNS_PULL_QUEUE_NAME)
.add( .add(
TaskOptions.Builder.withDefaults() TaskOptions.Builder.withDefaults()
.method(Method.PULL) .method(TaskOptions.Method.PULL)
.param(DNS_TARGET_NAME_PARAM, "domain.net") .param(DNS_TARGET_NAME_PARAM, "domain.net")
.param(PARAM_TLD, "net")); .param(PARAM_TLD, "net"));
run(); run();
// The corrupt task isn't in the pull queue, but also isn't in the push queue // The corrupt task isn't in the pull queue, but also isn't in the push queue
assertNoTasksEnqueued(DNS_PULL_QUEUE_NAME); TaskQueueHelper.assertNoTasksEnqueued(DNS_PULL_QUEUE_NAME);
assertTldsEnqueuedInPushQueue( assertTldsEnqueuedInPushQueue(
ImmutableMultimap.of("com", "comWriter", "example", "exampleWriter")); ImmutableMultimap.of("com", "comWriter", "example", "exampleWriter"));
} }
@ -381,7 +385,7 @@ public class ReadDnsQueueActionTest {
QueueFactory.getQueue(DNS_PULL_QUEUE_NAME) QueueFactory.getQueue(DNS_PULL_QUEUE_NAME)
.add( .add(
TaskOptions.Builder.withDefaults() TaskOptions.Builder.withDefaults()
.method(Method.PULL) .method(TaskOptions.Method.PULL)
.param(DNS_TARGET_TYPE_PARAM, "Wrong type") .param(DNS_TARGET_TYPE_PARAM, "Wrong type")
.param(DNS_TARGET_NAME_PARAM, "domain.net") .param(DNS_TARGET_NAME_PARAM, "domain.net")
.param(PARAM_TLD, "net")); .param(PARAM_TLD, "net"));
@ -389,7 +393,7 @@ public class ReadDnsQueueActionTest {
run(); run();
// The corrupt task isn't in the pull queue, but also isn't in the push queue // The corrupt task isn't in the pull queue, but also isn't in the push queue
assertNoTasksEnqueued(DNS_PULL_QUEUE_NAME); TaskQueueHelper.assertNoTasksEnqueued(DNS_PULL_QUEUE_NAME);
assertTldsEnqueuedInPushQueue( assertTldsEnqueuedInPushQueue(
ImmutableMultimap.of("com", "comWriter", "example", "exampleWriter")); ImmutableMultimap.of("com", "comWriter", "example", "exampleWriter"));
} }
@ -402,8 +406,8 @@ public class ReadDnsQueueActionTest {
run(); run();
assertNoTasksEnqueued(DNS_PULL_QUEUE_NAME); TaskQueueHelper.assertNoTasksEnqueued(DNS_PULL_QUEUE_NAME);
assertTasksEnqueued( cloudTasksHelper.assertTasksEnqueued(
DNS_PUBLISH_PUSH_QUEUE_NAME, DNS_PUBLISH_PUSH_QUEUE_NAME,
new TaskMatcher().url(PublishDnsUpdatesAction.PATH).param("domains", "domain.net"), new TaskMatcher().url(PublishDnsUpdatesAction.PATH).param("domains", "domain.net"),
new TaskMatcher().url(PublishDnsUpdatesAction.PATH).param("hosts", "ns1.domain.com")); new TaskMatcher().url(PublishDnsUpdatesAction.PATH).param("hosts", "ns1.domain.com"));
@ -441,8 +445,8 @@ public class ReadDnsQueueActionTest {
run(); run();
assertNoTasksEnqueued(DNS_PULL_QUEUE_NAME); TaskQueueHelper.assertNoTasksEnqueued(DNS_PULL_QUEUE_NAME);
assertTasksEnqueued( cloudTasksHelper.assertTasksEnqueued(
DNS_PUBLISH_PUSH_QUEUE_NAME, DNS_PUBLISH_PUSH_QUEUE_NAME,
new TaskMatcher() new TaskMatcher()
.url(PublishDnsUpdatesAction.PATH) .url(PublishDnsUpdatesAction.PATH)
@ -497,9 +501,9 @@ public class ReadDnsQueueActionTest {
run(); run();
assertNoTasksEnqueued(DNS_PULL_QUEUE_NAME); TaskQueueHelper.assertNoTasksEnqueued(DNS_PULL_QUEUE_NAME);
// Expect two different groups; in-balliwick hosts are locked with their superordinate domains. // Expect two different groups; in-balliwick hosts are locked with their superordinate domains.
assertTasksEnqueued( cloudTasksHelper.assertTasksEnqueued(
DNS_PUBLISH_PUSH_QUEUE_NAME, DNS_PUBLISH_PUSH_QUEUE_NAME,
new TaskMatcher() new TaskMatcher()
.url(PublishDnsUpdatesAction.PATH) .url(PublishDnsUpdatesAction.PATH)

View file

@ -13,17 +13,12 @@
// limitations under the License. // limitations under the License.
package google.registry.persistence; package google.registry.persistence;
import static com.google.appengine.api.taskqueue.QueueFactory.getQueue;
import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertThat;
import static com.google.common.truth.Truth8.assertThat; import static com.google.common.truth.Truth8.assertThat;
import static google.registry.testing.DatabaseHelper.newDomainBase; import static google.registry.testing.DatabaseHelper.newDomainBase;
import static google.registry.testing.DatabaseHelper.persistActiveContact; import static google.registry.testing.DatabaseHelper.persistActiveContact;
import static google.registry.testing.TaskQueueHelper.assertTasksEnqueued;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
import com.google.appengine.api.taskqueue.TaskOptions;
import com.google.common.base.Joiner;
import com.google.common.collect.ImmutableMap;
import com.googlecode.objectify.Key; import com.googlecode.objectify.Key;
import com.googlecode.objectify.annotation.Entity; import com.googlecode.objectify.annotation.Entity;
import google.registry.model.billing.BillingEvent.OneTime; import google.registry.model.billing.BillingEvent.OneTime;
@ -32,10 +27,7 @@ import google.registry.model.domain.DomainBase;
import google.registry.model.host.HostResource; import google.registry.model.host.HostResource;
import google.registry.model.registrar.RegistrarContact; import google.registry.model.registrar.RegistrarContact;
import google.registry.testing.AppEngineExtension; import google.registry.testing.AppEngineExtension;
import google.registry.testing.TaskQueueHelper.TaskMatcher;
import google.registry.testing.TestObject; import google.registry.testing.TestObject;
import google.registry.util.Retrier;
import google.registry.util.TaskQueueUtils;
import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.api.extension.RegisterExtension;
@ -48,20 +40,8 @@ class VKeyTest {
AppEngineExtension.builder() AppEngineExtension.builder()
.withDatastoreAndCloudSql() .withDatastoreAndCloudSql()
.withOfyTestEntities(TestObject.class) .withOfyTestEntities(TestObject.class)
.withTaskQueue(
Joiner.on('\n')
.join(
"<?xml version=\"1.0\" encoding=\"UTF-8\"?>",
"<queue-entries>",
" <queue>",
" <name>test-queue-for-vkey</name>",
" <rate>1/s</rate>",
" </queue>",
"</queue-entries>"))
.build(); .build();
private final TaskQueueUtils taskQueueUtils = new TaskQueueUtils(new Retrier(null, 1));
@BeforeAll @BeforeAll
static void beforeAll() { static void beforeAll() {
ClassPathManager.addTestEntityClass(TestObject.class); ClassPathManager.addTestEntityClass(TestObject.class);
@ -347,97 +327,6 @@ class VKeyTest {
assertThat(VKey.create(vkey.stringify())).isEqualTo(vkey); assertThat(VKey.create(vkey.stringify())).isEqualTo(vkey);
} }
/**
* Verifies a complete key can go into task queue and comes out unscathed.
*
* <p>TaskOption objects are being used here instead of Task objects, despite that we are in the
* process of migrating to using Cloud Tasks API, the stringify() and create() were written with
* the intention to handle all types of vkeys, inlcuding ofy only vkeys. The purpose of the
* following test cases is to make sure we don't deploy the system with parameters that don't work
* in the current implementation. Once migration is done, the following test cases with TaskOption
* or TaskHandle will go away.
*/
@Test
void testStringifyThenCreate_ofyOnlyVKeyIntaskQueue_success() throws Exception {
VKey<TestObject> vkey =
VKey.createOfy(TestObject.class, Key.create(TestObject.class, "tmpKey"));
String vkeyStringFromQueue =
ImmutableMap.copyOf(
taskQueueUtils
.enqueue(
getQueue("test-queue-for-vkey"),
TaskOptions.Builder.withUrl("/the/path").param("vkey", vkey.stringify()))
.extractParams())
.get("vkey");
assertTasksEnqueued(
"test-queue-for-vkey", new TaskMatcher().url("/the/path").param("vkey", vkey.stringify()));
assertThat(vkeyStringFromQueue).isEqualTo(vkey.stringify());
assertThat(VKey.create(vkeyStringFromQueue)).isEqualTo(vkey);
}
@Test
void testStringifyThenCreate_sqlOnlyVKeyIntaskQueue_success() throws Exception {
VKey<TestObject> vkey = VKey.createSql(TestObject.class, "sqlKey");
String vkeyStringFromQueue =
ImmutableMap.copyOf(
taskQueueUtils
.enqueue(
getQueue("test-queue-for-vkey"),
TaskOptions.Builder.withUrl("/the/path").param("vkey", vkey.stringify()))
.extractParams())
.get("vkey");
assertTasksEnqueued(
"test-queue-for-vkey", new TaskMatcher().url("/the/path").param("vkey", vkey.stringify()));
assertThat(vkeyStringFromQueue).isEqualTo(vkey.stringify());
assertThat(VKey.create(vkeyStringFromQueue)).isEqualTo(vkey);
}
@Test
void testStringifyThenCreate_generalVKeyIntaskQueue_success() throws Exception {
VKey<TestObject> vkey =
VKey.create(TestObject.class, "12345", Key.create(TestObject.class, "12345"));
String vkeyStringFromQueue =
ImmutableMap.copyOf(
taskQueueUtils
.enqueue(
getQueue("test-queue-for-vkey"),
TaskOptions.Builder.withUrl("/the/path").param("vkey", vkey.stringify()))
.extractParams())
.get("vkey");
assertTasksEnqueued(
"test-queue-for-vkey", new TaskMatcher().url("/the/path").param("vkey", vkey.stringify()));
assertThat(vkeyStringFromQueue).isEqualTo(vkey.stringify());
assertThat(VKey.create(vkeyStringFromQueue)).isEqualTo(vkey);
}
@Test
void testStringifyThenCreate_vkeyFromWebsafeStringIntaskQueue_success() throws Exception {
VKey<DomainBase> vkey =
VKey.fromWebsafeKey(
Key.create(newDomainBase("example.com", "ROID-1", persistActiveContact("contact-1")))
.getString());
String vkeyStringFromQueue =
ImmutableMap.copyOf(
taskQueueUtils
.enqueue(
getQueue("test-queue-for-vkey"),
TaskOptions.Builder.withUrl("/the/path").param("vkey", vkey.stringify()))
.extractParams())
.get("vkey");
assertTasksEnqueued(
"test-queue-for-vkey", new TaskMatcher().url("/the/path").param("vkey", vkey.stringify()));
assertThat(vkeyStringFromQueue).isEqualTo(vkey.stringify());
assertThat(VKey.create(vkeyStringFromQueue)).isEqualTo(vkey);
}
@Test @Test
void testToString_sqlOnlyVKey() { void testToString_sqlOnlyVKey() {
assertThat(VKey.createSql(TestObject.class, "testId").toString()) assertThat(VKey.createSql(TestObject.class, "testId").toString())

View file

@ -41,6 +41,8 @@ import com.google.common.net.MediaType;
import com.google.common.truth.Truth8; import com.google.common.truth.Truth8;
import com.google.protobuf.Timestamp; import com.google.protobuf.Timestamp;
import com.google.protobuf.util.Timestamps; import com.google.protobuf.util.Timestamps;
import dagger.Module;
import dagger.Provides;
import google.registry.model.ImmutableObject; import google.registry.model.ImmutableObject;
import google.registry.util.CloudTasksUtils; import google.registry.util.CloudTasksUtils;
import google.registry.util.Retrier; import google.registry.util.Retrier;
@ -61,6 +63,7 @@ import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function; import java.util.function.Function;
import java.util.function.Predicate; import java.util.function.Predicate;
import javax.annotation.Nonnull; import javax.annotation.Nonnull;
import javax.inject.Singleton;
import org.joda.time.DateTime; import org.joda.time.DateTime;
/** /**
@ -181,6 +184,28 @@ public class CloudTasksHelper implements Serializable {
} }
} }
@Module
public static class CloudTasksHelperModule {
private final FakeClock clock;
public CloudTasksHelperModule(FakeClock clock) {
this.clock = clock;
}
@Singleton
@Provides
CloudTasksUtils provideCloudTasksUtils(CloudTasksHelper cloudTasksHelper) {
return cloudTasksHelper.getTestCloudTasksUtils();
}
@Singleton
@Provides
CloudTasksHelper provideCloudTasksHelper() {
return new CloudTasksHelper(clock);
}
}
private class FakeCloudTasksClient extends CloudTasksUtils.SerializableCloudTasksClient { private class FakeCloudTasksClient extends CloudTasksUtils.SerializableCloudTasksClient {
private static final long serialVersionUID = 6661964844791720639L; private static final long serialVersionUID = 6661964844791720639L;

View file

@ -26,7 +26,13 @@ import java.io.Serializable;
import java.util.List; import java.util.List;
import javax.inject.Inject; import javax.inject.Inject;
/** Utilities for dealing with App Engine task queues. */ /**
* Utilities for dealing with App Engine task queues.
*
* <p>Use {@link CloudTasksUtils} to interact with push queues (Cloud Task queues). Pull queues will
* be implemented separately in SQL and you can continue using this class for that for now.
*/
@Deprecated
public class TaskQueueUtils implements Serializable { public class TaskQueueUtils implements Serializable {
private static final long serialVersionUID = 7893211200220508362L; private static final long serialVersionUID = 7893211200220508362L;