Implement a util class to manage push queues using Cloud Tasks API (#1290)

* Implement a util class to manage push queues using Cloud Tasks API

Push queues were part of App Engine when they debuted. As a result the
Task Queue API were part of the App Engine SDK and can only be used in
App Engine classic runtime. The new Cloud Tasks API can be used in any
runtime but it only supports push queues. In this PR we implement a util
class (CloudTasksUtils) like TaskQueueUtils to handle enqueuing tasks to
push queues using Cloud Tasks. One action (TldFanoutAction) was
converted to use the new API as a demo. Mass migration of other call sites of
the old API will follow in a separate PR.

TESTED=deployed to alpha and verified that tasks are corrected enqueued
and executed.
This commit is contained in:
Lai Jiang 2021-08-24 21:13:54 -04:00 committed by GitHub
parent d3b07f6ab0
commit bc62e13e41
113 changed files with 3822 additions and 1994 deletions

View file

@ -0,0 +1,65 @@
// Copyright 2021 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.config;
import com.google.api.gax.core.FixedCredentialsProvider;
import com.google.cloud.tasks.v2.CloudTasksClient;
import com.google.cloud.tasks.v2.CloudTasksSettings;
import dagger.Module;
import dagger.Provides;
import google.registry.config.CredentialModule.DefaultCredential;
import google.registry.config.RegistryConfig.Config;
import google.registry.util.CloudTasksUtils;
import google.registry.util.GoogleCredentialsBundle;
import google.registry.util.Retrier;
import java.io.IOException;
import javax.inject.Provider;
import javax.inject.Singleton;
/**
* A {@link Module} that provides {@link CloudTasksUtils}.
*
* <p>The class itself cannot be annotated as {@code Inject} because its requested dependencies use
* the {@link Config} qualifier which is not available in the {@code util} package.
*/
@Module
public abstract class CloudTasksUtilsModule {
@Singleton
@Provides
public static CloudTasksUtils provideCloudTasksUtils(
@Config("projectId") String projectId,
@Config("locationId") String locationId,
// Use a provider so that we can use try-with-resources with the client, which implements
// Autocloseable.
Provider<CloudTasksClient> clientProvider,
Retrier retrier) {
return new CloudTasksUtils(retrier, projectId, locationId, clientProvider);
}
@Provides
public static CloudTasksClient provideCloudTasksClient(
@DefaultCredential GoogleCredentialsBundle credentials) {
try {
return CloudTasksClient.create(
CloudTasksSettings.newBuilder()
.setCredentialsProvider(
FixedCredentialsProvider.create(credentials.getGoogleCredentials()))
.build());
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}

View file

@ -105,6 +105,12 @@ public final class RegistryConfig {
return config.appEngine.projectId;
}
@Provides
@Config("locationId")
public static String provideLocationId(RegistryConfigSettings config) {
return config.appEngine.locationId;
}
/**
* The filename of the logo to be displayed in the header of the registrar console.
*

View file

@ -45,6 +45,7 @@ public class RegistryConfigSettings {
/** Configuration options that apply to the entire App Engine project. */
public static class AppEngine {
public String projectId;
public String locationId;
public boolean isLocal;
public String defaultServiceUrl;
public String backendServiceUrl;

View file

@ -8,6 +8,10 @@
appEngine:
# Globally unique App Engine project ID
projectId: registry-project-id
# Location of the App engine project, note that us-central1 and europe-west1 are special in that
# they are used without the trailing number in App Engine commands and Google Cloud Console.
# See: https://cloud.google.com/appengine/docs/locations
locationId: registry-location-id
# whether to use local/test credentials when connecting to the servers
isLocal: true

View file

@ -14,14 +14,10 @@
package google.registry.cron;
import static com.google.appengine.api.taskqueue.QueueFactory.getQueue;
import static com.google.appengine.api.taskqueue.TaskOptions.Builder.withUrl;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Predicates.in;
import static com.google.common.base.Predicates.not;
import static com.google.common.base.Strings.nullToEmpty;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.Iterables.getFirst;
import static com.google.common.collect.Multimaps.filterKeys;
import static com.google.common.net.MediaType.PLAIN_TEXT_UTF_8;
import static google.registry.cron.CronModule.ENDPOINT_PARAM;
@ -36,21 +32,24 @@ import static google.registry.model.tld.Registry.TldType.REAL;
import static google.registry.model.tld.Registry.TldType.TEST;
import static java.util.concurrent.TimeUnit.SECONDS;
import com.google.appengine.api.taskqueue.Queue;
import com.google.appengine.api.taskqueue.TaskHandle;
import com.google.appengine.api.taskqueue.TaskOptions;
import com.google.cloud.tasks.v2.Task;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Multimap;
import com.google.common.collect.Streams;
import com.google.common.flogger.FluentLogger;
import com.google.protobuf.Timestamp;
import google.registry.request.Action;
import google.registry.request.Action.Service;
import google.registry.request.Parameter;
import google.registry.request.ParameterMap;
import google.registry.request.RequestParameters;
import google.registry.request.Response;
import google.registry.request.auth.Auth;
import google.registry.util.TaskQueueUtils;
import google.registry.util.Clock;
import google.registry.util.CloudTasksUtils;
import java.time.Instant;
import java.util.Optional;
import java.util.Random;
import java.util.stream.Stream;
@ -105,7 +104,8 @@ public final class TldFanoutAction implements Runnable {
private static final FluentLogger logger = FluentLogger.forEnclosingClass();
@Inject TaskQueueUtils taskQueueUtils;
@Inject Clock clock;
@Inject CloudTasksUtils cloudTasksUtils;
@Inject Response response;
@Inject @Parameter(ENDPOINT_PARAM) String endpoint;
@Inject @Parameter(QUEUE_PARAM) String queue;
@ -115,7 +115,9 @@ public final class TldFanoutAction implements Runnable {
@Inject @Parameter(EXCLUDE_PARAM) ImmutableSet<String> excludes;
@Inject @Parameter(JITTER_SECONDS_PARAM) Optional<Integer> jitterSeconds;
@Inject @ParameterMap ImmutableListMultimap<String, String> params;
@Inject TldFanoutAction() {}
@Inject
TldFanoutAction() {}
@Override
public void run() {
@ -126,8 +128,7 @@ public final class TldFanoutAction implements Runnable {
runInEmpty || forEachTestTld || forEachRealTld,
"At least one of runInEmpty, forEachTestTld, forEachRealTld must be given");
checkArgument(
!(runInEmpty && !excludes.isEmpty()),
"Can't specify 'exclude' with 'runInEmpty'");
!(runInEmpty && !excludes.isEmpty()), "Can't specify 'exclude' with 'runInEmpty'");
ImmutableSet<String> tlds =
runInEmpty
? ImmutableSet.of("")
@ -137,7 +138,6 @@ public final class TldFanoutAction implements Runnable {
.filter(not(in(excludes)))
.collect(toImmutableSet());
Multimap<String, String> flowThruParams = filterKeys(params, not(in(CONTROL_PARAMS)));
Queue taskQueue = getQueue(queue);
StringBuilder outputPayload =
new StringBuilder(
String.format("OK: Launched the following %d tasks in queue %s\n", tlds.size(), queue));
@ -146,33 +146,41 @@ public final class TldFanoutAction implements Runnable {
logger.atWarning().log("No TLDs to fan-out!");
}
for (String tld : tlds) {
TaskOptions taskOptions = createTaskOptions(tld, flowThruParams);
TaskHandle taskHandle = taskQueueUtils.enqueue(taskQueue, taskOptions);
Task task = createTask(tld, flowThruParams);
Task createdTask = cloudTasksUtils.enqueue(queue, task);
outputPayload.append(
String.format(
"- Task: '%s', tld: '%s', endpoint: '%s'\n",
taskHandle.getName(), tld, taskOptions.getUrl()));
createdTask.getName(), tld, createdTask.getAppEngineHttpRequest().getRelativeUri()));
logger.atInfo().log(
"Task: '%s', tld: '%s', endpoint: '%s'", taskHandle.getName(), tld, taskOptions.getUrl());
"Task: '%s', tld: '%s', endpoint: '%s'",
createdTask.getName(), tld, createdTask.getAppEngineHttpRequest().getRelativeUri());
}
response.setContentType(PLAIN_TEXT_UTF_8);
response.setPayload(outputPayload.toString());
}
private TaskOptions createTaskOptions(String tld, Multimap<String, String> params) {
TaskOptions options =
withUrl(endpoint)
.countdownMillis(
jitterSeconds
.map(seconds -> random.nextInt((int) SECONDS.toMillis(seconds)))
.orElse(0));
private Task createTask(String tld, Multimap<String, String> params) {
if (!tld.isEmpty()) {
options.param(RequestParameters.PARAM_TLD, tld);
params = ArrayListMultimap.create(params);
params.put(RequestParameters.PARAM_TLD, tld);
}
for (String param : params.keySet()) {
// TaskOptions.param() does not accept null values.
options.param(param, nullToEmpty(getFirst(params.get(param), null)));
}
return options;
Instant scheduleTime =
Instant.ofEpochMilli(
clock
.nowUtc()
.plusMillis(
jitterSeconds
.map(seconds -> random.nextInt((int) SECONDS.toMillis(seconds)))
.orElse(0))
.getMillis());
return Task.newBuilder(
CloudTasksUtils.createPostTask(endpoint, Service.BACKEND.toString(), params))
.setScheduleTime(
Timestamp.newBuilder()
.setSeconds(scheduleTime.getEpochSecond())
.setNanos(scheduleTime.getNano())
.build())
.build();
}
}

View file

@ -18,6 +18,7 @@ import com.google.monitoring.metrics.MetricReporter;
import dagger.Component;
import dagger.Lazy;
import google.registry.bigquery.BigqueryModule;
import google.registry.config.CloudTasksUtilsModule;
import google.registry.config.CredentialModule;
import google.registry.config.RegistryConfig.ConfigModule;
import google.registry.dns.writer.VoidDnsWriterModule;
@ -56,6 +57,7 @@ import javax.inject.Singleton;
BackendRequestComponentModule.class,
BigqueryModule.class,
ConfigModule.class,
CloudTasksUtilsModule.class,
CredentialModule.class,
CustomLogicFactoryModule.class,
DatastoreAdminModule.class,

View file

@ -15,27 +15,23 @@
package google.registry.cron;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Iterables.getLast;
import static com.google.common.truth.Truth.assertThat;
import static google.registry.testing.DatabaseHelper.createTlds;
import static google.registry.testing.DatabaseHelper.persistResource;
import static google.registry.testing.TaskQueueHelper.assertNoTasksEnqueued;
import static google.registry.testing.TaskQueueHelper.assertTasksEnqueued;
import static org.junit.jupiter.api.Assertions.assertThrows;
import com.google.appengine.api.taskqueue.dev.QueueStateInfo.TaskStateInfo;
import com.google.appengine.tools.development.testing.LocalTaskQueueTestConfig;
import com.google.common.base.Joiner;
import com.google.cloud.tasks.v2.HttpMethod;
import com.google.cloud.tasks.v2.Task;
import com.google.common.base.Splitter;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableSet;
import google.registry.model.tld.Registry;
import google.registry.model.tld.Registry.TldType;
import google.registry.testing.AppEngineExtension;
import google.registry.testing.CloudTasksHelper;
import google.registry.testing.CloudTasksHelper.TaskMatcher;
import google.registry.testing.FakeClock;
import google.registry.testing.FakeResponse;
import google.registry.testing.TaskQueueHelper.TaskMatcher;
import google.registry.util.Retrier;
import google.registry.util.TaskQueueUtils;
import java.util.List;
import java.util.Optional;
import java.util.stream.Stream;
@ -49,27 +45,14 @@ class TldFanoutActionTest {
private static final String ENDPOINT = "/the/servlet";
private static final String QUEUE = "the-queue";
private final FakeResponse response = new FakeResponse();
private final CloudTasksHelper cloudTasksHelper = new CloudTasksHelper();
@RegisterExtension
final AppEngineExtension appEngine =
AppEngineExtension.builder()
.withDatastoreAndCloudSql()
.withTaskQueue(
Joiner.on('\n')
.join(
"<?xml version=\"1.0\" encoding=\"UTF-8\"?>",
"<queue-entries>",
" <queue>",
" <name>the-queue</name>",
" <rate>1/s</rate>",
" </queue>",
"</queue-entries>"))
.build();
AppEngineExtension.builder().withDatastoreAndCloudSql().build();
private static ImmutableListMultimap<String, String> getParamsMap(String... keysAndValues) {
ImmutableListMultimap.Builder<String, String> params = new ImmutableListMultimap.Builder<>();
params.put("queue", QUEUE);
params.put("endpoint", ENDPOINT);
for (int i = 0; i < keysAndValues.length; i += 2) {
params.put(keysAndValues[i], keysAndValues[i + 1]);
}
@ -78,13 +61,15 @@ class TldFanoutActionTest {
private void run(ImmutableListMultimap<String, String> params) {
TldFanoutAction action = new TldFanoutAction();
action.clock = new FakeClock();
action.params = params;
action.endpoint = getLast(params.get("endpoint"));
action.queue = getLast(params.get("queue"));
action.excludes = params.containsKey("exclude")
? ImmutableSet.copyOf(Splitter.on(',').split(params.get("exclude").get(0)))
: ImmutableSet.of();
action.taskQueueUtils = new TaskQueueUtils(new Retrier(null, 1));
action.endpoint = ENDPOINT;
action.queue = QUEUE;
action.excludes =
params.containsKey("exclude")
? ImmutableSet.copyOf(Splitter.on(',').split(params.get("exclude").get(0)))
: ImmutableSet.of();
action.cloudTasksUtils = cloudTasksHelper.getTestCloudTasksUtils();
action.response = response;
action.runInEmpty = params.containsKey("runInEmpty");
action.forEachRealTld = params.containsKey("forEachRealTld");
@ -99,20 +84,21 @@ class TldFanoutActionTest {
persistResource(Registry.get("example").asBuilder().setTldType(TldType.TEST).build());
}
private static void assertTasks(String... tasks) {
assertTasksEnqueued(
private void assertTasks(String... tasks) {
cloudTasksHelper.assertTasksEnqueued(
QUEUE,
Stream.of(tasks).map(
namespace ->
new TaskMatcher()
.url(ENDPOINT)
.header("content-type", "application/x-www-form-urlencoded")
.param("tld", namespace))
.collect(toImmutableList()));
Stream.of(tasks)
.map(
namespace ->
new TaskMatcher()
.url(ENDPOINT)
.header("content-type", "application/x-www-form-urlencoded")
.param("tld", namespace))
.collect(toImmutableList()));
}
private static void assertTaskWithoutTld() {
assertTasksEnqueued(
private void assertTaskWithoutTld() {
cloudTasksHelper.assertTasksEnqueued(
QUEUE,
new TaskMatcher()
.url(ENDPOINT)
@ -120,9 +106,10 @@ class TldFanoutActionTest {
}
@Test
void testSuccess_methodPostIsDefault() {
void testSuccess_methodPostAndServiceBackendAreDefault() {
run(getParamsMap("runInEmpty", ""));
assertTasksEnqueued(QUEUE, new TaskMatcher().method("POST"));
cloudTasksHelper.assertTasksEnqueued(
QUEUE, new TaskMatcher().method(HttpMethod.POST).service("backend"));
}
@Test
@ -150,9 +137,10 @@ class TldFanoutActionTest {
@Test
void testSuccess_forEachTestTldAndForEachRealTld() {
run(getParamsMap(
"forEachTestTld", "",
"forEachRealTld", ""));
run(
getParamsMap(
"forEachTestTld", "",
"forEachRealTld", ""));
assertTasks("com", "net", "org", "example");
}
@ -164,26 +152,29 @@ class TldFanoutActionTest {
@Test
void testSuccess_excludeRealTlds() {
run(getParamsMap(
"forEachRealTld", "",
"exclude", "com,net"));
run(
getParamsMap(
"forEachRealTld", "",
"exclude", "com,net"));
assertTasks("org");
}
@Test
void testSuccess_excludeTestTlds() {
run(getParamsMap(
"forEachTestTld", "",
"exclude", "example"));
assertNoTasksEnqueued(QUEUE);
run(
getParamsMap(
"forEachTestTld", "",
"exclude", "example"));
cloudTasksHelper.assertNoTasksEnqueued(QUEUE);
}
@Test
void testSuccess_excludeNonexistentTlds() {
run(getParamsMap(
"forEachTestTld", "",
"forEachRealTld", "",
"exclude", "foo"));
run(
getParamsMap(
"forEachTestTld", "",
"forEachRealTld", "",
"exclude", "foo"));
assertTasks("com", "net", "org", "example");
}
@ -223,26 +214,24 @@ class TldFanoutActionTest {
@Test
void testSuccess_additionalArgsFlowThroughToPostParams() {
run(getParamsMap("forEachTestTld", "", "newkey", "newval"));
assertTasksEnqueued(QUEUE,
new TaskMatcher().url("/the/servlet").param("newkey", "newval"));
cloudTasksHelper.assertTasksEnqueued(
QUEUE, new TaskMatcher().url("/the/servlet").param("newkey", "newval"));
}
@Test
void testSuccess_returnHttpResponse() {
run(getParamsMap("forEachRealTld", "", "endpoint", "/the/servlet"));
List<TaskStateInfo> taskList =
LocalTaskQueueTestConfig.getLocalTaskQueue().getQueueStateInfo().get(QUEUE).getTaskInfo();
List<Task> taskList = cloudTasksHelper.getTestTasksFor(QUEUE);
assertThat(taskList).hasSize(3);
String expectedResponse = String.format(
"OK: Launched the following 3 tasks in queue the-queue\n"
+ "- Task: '%s', tld: 'com', endpoint: '/the/servlet'\n"
+ "- Task: '%s', tld: 'net', endpoint: '/the/servlet'\n"
+ "- Task: '%s', tld: 'org', endpoint: '/the/servlet'\n",
taskList.get(0).getTaskName(),
taskList.get(1).getTaskName(),
taskList.get(2).getTaskName());
String expectedResponse =
String.format(
"OK: Launched the following 3 tasks in queue the-queue\n"
+ "- Task: '%s', tld: 'com', endpoint: '/the/servlet'\n"
+ "- Task: '%s', tld: 'net', endpoint: '/the/servlet'\n"
+ "- Task: '%s', tld: 'org', endpoint: '/the/servlet'\n",
taskList.get(0).getName(), taskList.get(1).getName(), taskList.get(2).getName());
assertThat(response.getPayload()).isEqualTo(expectedResponse);
}
@ -250,14 +239,14 @@ class TldFanoutActionTest {
void testSuccess_returnHttpResponse_runInEmpty() {
run(getParamsMap("runInEmpty", "", "endpoint", "/the/servlet"));
List<TaskStateInfo> taskList =
LocalTaskQueueTestConfig.getLocalTaskQueue().getQueueStateInfo().get(QUEUE).getTaskInfo();
List<Task> taskList = cloudTasksHelper.getTestTasksFor(QUEUE);
assertThat(taskList).hasSize(1);
String expectedResponse = String.format(
"OK: Launched the following 1 tasks in queue the-queue\n"
+ "- Task: '%s', tld: '', endpoint: '/the/servlet'\n",
taskList.get(0).getTaskName());
String expectedResponse =
String.format(
"OK: Launched the following 1 tasks in queue the-queue\n"
+ "- Task: '%s', tld: '', endpoint: '/the/servlet'\n",
taskList.get(0).getName());
assertThat(response.getPayload()).isEqualTo(expectedResponse);
}
}

View file

@ -0,0 +1,315 @@
// Copyright 2021 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.testing;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Predicates.in;
import static com.google.common.base.Predicates.not;
import static com.google.common.collect.Multisets.containsOccurrences;
import static com.google.common.truth.Truth.assertThat;
import static com.google.common.truth.Truth.assertWithMessage;
import static google.registry.util.DiffUtils.prettyPrintEntityDeepDiff;
import static java.util.Arrays.asList;
import static java.util.stream.Collectors.joining;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.google.cloud.tasks.v2.CloudTasksClient;
import com.google.cloud.tasks.v2.HttpMethod;
import com.google.cloud.tasks.v2.QueueName;
import com.google.cloud.tasks.v2.Task;
import com.google.common.base.Ascii;
import com.google.common.base.Joiner;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.ImmutableMultiset;
import com.google.common.collect.LinkedListMultimap;
import com.google.common.collect.Maps;
import com.google.common.collect.Multimap;
import com.google.common.net.HttpHeaders;
import com.google.common.net.MediaType;
import com.google.common.truth.Truth8;
import google.registry.model.ImmutableObject;
import google.registry.util.CloudTasksUtils;
import google.registry.util.Retrier;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.function.Function;
import java.util.function.Predicate;
import javax.annotation.Nonnull;
/**
* Static utility functions for testing task queues.
*
* <p>This class is mostly derived from {@link TaskQueueHelper}. It does not implement as many
* helper methods because we have not yet encountered all the use cases with Cloud Tasks. As more
* and more Task Queue API usage is migrated to Cloud Tasks we may replicate more methods from the
* latter.
*/
public class CloudTasksHelper {
private static final String PROJECT_ID = "test-project";
private static final String LOCATION_ID = "test-location";
private final Retrier retrier = new Retrier(new FakeSleeper(new FakeClock()), 1);
private final LinkedListMultimap<String, Task> testTasks = LinkedListMultimap.create();
private final CloudTasksClient mockClient = mock(CloudTasksClient.class);
private final CloudTasksUtils cloudTasksUtils =
new CloudTasksUtils(retrier, PROJECT_ID, LOCATION_ID, () -> mockClient);
public CloudTasksHelper() {
when(mockClient.createTask(any(QueueName.class), any(Task.class)))
.thenAnswer(
invocation -> {
QueueName queue = invocation.getArgument(0);
Task task = invocation.getArgument(1);
if (task.getName().isEmpty()) {
task = task.toBuilder().setName(String.format("test-%d", testTasks.size())).build();
}
testTasks.put(queue.getQueue(), task);
return task;
});
}
public CloudTasksUtils getTestCloudTasksUtils() {
return cloudTasksUtils;
}
public List<Task> getTestTasksFor(String queue) {
return testTasks.get(queue);
}
/**
* Ensures that the tasks in the named queue are exactly those with the expected property values
* after being transformed with the provided property getter function.
*/
public void assertTasksEnqueuedWithProperty(
String queueName, Function<Task, String> propertyGetter, String... expectedTaskProperties) {
// Ordering is irrelevant but duplicates should be considered independently.
Truth8.assertThat(getTestTasksFor(queueName).stream().map(propertyGetter))
.containsExactly((Object[]) expectedTaskProperties);
}
/** Ensures that the tasks in the named queue are exactly those with the expected names. */
public void assertTasksEnqueued(String queueName, String... expectedTaskNames) {
Function<Task, String> nameGetter = Task::getName;
assertTasksEnqueuedWithProperty(queueName, nameGetter, expectedTaskNames);
}
/**
* Ensures that the only tasks in the named queue are exactly those that match the expected
* matchers.
*/
public void assertTasksEnqueued(String queueName, TaskMatcher... taskMatchers) {
assertTasksEnqueued(queueName, asList(taskMatchers));
}
/** Ensures that the named queue contains no tasks. */
public void assertNoTasksEnqueued(String... queueNames) {
for (String queueName : queueNames) {
assertThat(getTestTasksFor(queueName)).isEmpty();
}
}
/**
* Ensures that the only tasks in the named queue are exactly those that match the expected
* matchers.
*/
public void assertTasksEnqueued(String queueName, Collection<TaskMatcher> taskMatchers) {
List<Task> tasks = getTestTasksFor(queueName);
assertThat(tasks.size()).isEqualTo(taskMatchers.size());
for (final TaskMatcher taskMatcher : taskMatchers) {
try {
tasks.remove(tasks.stream().filter(taskMatcher).findFirst().get());
} catch (NoSuchElementException e) {
final Map<String, Object> taskMatcherMap = taskMatcher.expected.toMap();
assertWithMessage(
"Task not found in queue %s:\n\n%s\n\nPotential candidate match diffs:\n\n%s",
queueName,
taskMatcher,
tasks.stream()
.map(
input ->
prettyPrintEntityDeepDiff(
taskMatcherMap,
Maps.filterKeys(
new MatchableTask(input).toMap(), in(taskMatcherMap.keySet()))))
.collect(joining("\n")))
.fail();
}
}
}
/** An adapter to clean up a {@link Task} for ease of matching. */
private static class MatchableTask extends ImmutableObject {
String taskName;
String service;
HttpMethod method;
String url;
Multimap<String, String> headers = ArrayListMultimap.create();
Multimap<String, String> params = ArrayListMultimap.create();
MatchableTask() {}
MatchableTask(Task task) {
URI uri;
try {
// Construct a fake full URI for parsing purpose. The relative URI must start with a slash.
uri =
new URI(
String.format(
"https://nomulus.foo%s", task.getAppEngineHttpRequest().getRelativeUri()));
} catch (java.net.URISyntaxException e) {
throw new IllegalArgumentException(e);
}
taskName = task.getName();
service =
Ascii.toLowerCase(task.getAppEngineHttpRequest().getAppEngineRouting().getService());
method = task.getAppEngineHttpRequest().getHttpMethod();
url = uri.getPath();
ImmutableMultimap.Builder<String, String> headerBuilder = new ImmutableMultimap.Builder<>();
task.getAppEngineHttpRequest()
.getHeadersMap()
.forEach(
(key, value) -> {
// Lowercase header name for comparison since HTTP header names are
// case-insensitive.
headerBuilder.put(Ascii.toLowerCase(key), value);
});
headers = headerBuilder.build();
ImmutableMultimap.Builder<String, String> paramBuilder = new ImmutableMultimap.Builder<>();
String query = null;
if (method == HttpMethod.GET) {
query = uri.getQuery();
} else if (method == HttpMethod.POST) {
assertThat(
headers.containsEntry(
Ascii.toLowerCase(HttpHeaders.CONTENT_TYPE), MediaType.FORM_DATA.toString()))
.isTrue();
query = task.getAppEngineHttpRequest().getBody().toString(StandardCharsets.UTF_8);
}
if (query != null) {
// Note that UriParameters.parse() does not throw an IAE on a bad query string (e.g. one
// where parameters are not properly URL-encoded); it always does a best-effort parse.
paramBuilder.putAll(UriParameters.parse(query));
params = paramBuilder.build();
}
}
public Map<String, Object> toMap() {
Map<String, Object> builder = new HashMap<>();
builder.put("taskName", taskName);
builder.put("method", method);
builder.put("service", service);
builder.put("url", url);
builder.put("headers", headers);
builder.put("params", params);
return Maps.filterValues(builder, not(in(asList(null, "", Collections.EMPTY_MAP))));
}
}
/**
* Matcher to match against the tasks in the task queue. Fields that aren't set are not compared.
*/
public static class TaskMatcher implements Predicate<Task> {
private final MatchableTask expected;
public TaskMatcher() {
expected = new MatchableTask();
}
public TaskMatcher taskName(String taskName) {
expected.taskName = taskName;
return this;
}
public TaskMatcher url(String url) {
expected.url = url;
return this;
}
public TaskMatcher service(String service) {
// Lowercase for case-insensitive comparison.
expected.service = Ascii.toLowerCase(service);
return this;
}
public TaskMatcher method(HttpMethod method) {
expected.method = method;
return this;
}
public TaskMatcher header(String name, String value) {
// Lowercase for case-insensitive comparison.
expected.headers.put(Ascii.toLowerCase(name), value);
return this;
}
public TaskMatcher param(String key, String value) {
checkNotNull(value, "Test error: A param can never have a null value, so don't assert it");
expected.params.put(key, value);
return this;
}
/**
* Returns {@code true} if there are not more occurrences in {@code sub} of each of its entries
* than there are in {@code super}.
*/
private static boolean containsEntries(
Multimap<?, ?> superMultimap, Multimap<?, ?> subMultimap) {
return containsOccurrences(
ImmutableMultiset.copyOf(superMultimap.entries()),
ImmutableMultiset.copyOf(subMultimap.entries()));
}
/**
* Returns true if the fields set on the current object match the given task. This is not quite
* the same contract as {@link #equals}, since it will ignore null fields.
*
* <p>Match fails if any headers or params expected on the TaskMatcher are not found on the
* Task. Note that the inverse is not true (i.e. there may be extra headers on the Task).
*/
@Override
public boolean test(@Nonnull Task task) {
MatchableTask actual = new MatchableTask(task);
return (expected.taskName == null || Objects.equals(expected.taskName, actual.taskName))
&& (expected.url == null || Objects.equals(expected.url, actual.url))
&& (expected.method == null || Objects.equals(expected.method, actual.method))
&& (expected.service == null || Objects.equals(expected.service, actual.service))
&& containsEntries(actual.params, expected.params)
&& containsEntries(actual.headers, expected.headers);
}
@Override
public String toString() {
return Joiner.on('\n')
.withKeyValueSeparator(":\n")
.join(
Maps.transformValues(
expected.toMap(),
input -> "\t" + String.valueOf(input).replaceAll("\n", "\n\t")));
}
}
}