diff --git a/core/src/main/java/google/registry/cron/CommitLogFanoutAction.java b/core/src/main/java/google/registry/cron/CommitLogFanoutAction.java index a57d78000..1e5c26fa6 100644 --- a/core/src/main/java/google/registry/cron/CommitLogFanoutAction.java +++ b/core/src/main/java/google/registry/cron/CommitLogFanoutAction.java @@ -14,18 +14,15 @@ package google.registry.cron; -import static com.google.appengine.api.taskqueue.QueueFactory.getQueue; - -import com.google.appengine.api.taskqueue.Queue; -import com.google.appengine.api.taskqueue.TaskOptions; +import com.google.common.collect.ImmutableMultimap; import google.registry.model.ofy.CommitLogBucket; import google.registry.request.Action; +import google.registry.request.Action.Service; import google.registry.request.Parameter; import google.registry.request.auth.Auth; -import google.registry.util.TaskQueueUtils; -import java.time.Duration; +import google.registry.util.Clock; +import google.registry.util.CloudTasksUtils; import java.util.Optional; -import java.util.Random; import javax.inject.Inject; /** Action for fanning out cron tasks for each commit log bucket. */ @@ -38,25 +35,27 @@ public final class CommitLogFanoutAction implements Runnable { public static final String BUCKET_PARAM = "bucket"; - private static final Random random = new Random(); + @Inject Clock clock; + @Inject CloudTasksUtils cloudTasksUtils; - @Inject TaskQueueUtils taskQueueUtils; @Inject @Parameter("endpoint") String endpoint; @Inject @Parameter("queue") String queue; @Inject @Parameter("jitterSeconds") Optional jitterSeconds; @Inject CommitLogFanoutAction() {} + + @Override public void run() { - Queue taskQueue = getQueue(queue); for (int bucketId : CommitLogBucket.getBucketIds()) { - long delay = - jitterSeconds.map(i -> random.nextInt((int) Duration.ofSeconds(i).toMillis())).orElse(0); - TaskOptions taskOptions = - TaskOptions.Builder.withUrl(endpoint) - .param(BUCKET_PARAM, Integer.toString(bucketId)) - .countdownMillis(delay); - taskQueueUtils.enqueue(taskQueue, taskOptions); + cloudTasksUtils.enqueue( + queue, + CloudTasksUtils.createPostTask( + endpoint, + Service.BACKEND.toString(), + ImmutableMultimap.of(BUCKET_PARAM, Integer.toString(bucketId)), + clock, + jitterSeconds)); } } } diff --git a/core/src/test/java/google/registry/cron/CommitLogFanoutActionTest.java b/core/src/test/java/google/registry/cron/CommitLogFanoutActionTest.java index f079a3a3d..988228cb9 100644 --- a/core/src/test/java/google/registry/cron/CommitLogFanoutActionTest.java +++ b/core/src/test/java/google/registry/cron/CommitLogFanoutActionTest.java @@ -15,14 +15,13 @@ package google.registry.cron; import static google.registry.cron.CommitLogFanoutAction.BUCKET_PARAM; -import static google.registry.testing.TaskQueueHelper.assertTasksEnqueued; import com.google.common.base.Joiner; import google.registry.model.ofy.CommitLogBucket; import google.registry.testing.AppEngineExtension; -import google.registry.testing.TaskQueueHelper.TaskMatcher; -import google.registry.util.Retrier; -import google.registry.util.TaskQueueUtils; +import google.registry.testing.CloudTasksHelper; +import google.registry.testing.CloudTasksHelper.TaskMatcher; +import google.registry.testing.FakeClock; import java.util.ArrayList; import java.util.List; import java.util.Optional; @@ -34,6 +33,7 @@ class CommitLogFanoutActionTest { private static final String ENDPOINT = "/the/servlet"; private static final String QUEUE = "the-queue"; + private final CloudTasksHelper cloudTasksHelper = new CloudTasksHelper(); @RegisterExtension final AppEngineExtension appEngineExtension = @@ -54,15 +54,16 @@ class CommitLogFanoutActionTest { @Test void testSuccess() { CommitLogFanoutAction action = new CommitLogFanoutAction(); - action.taskQueueUtils = new TaskQueueUtils(new Retrier(null, 1)); + action.cloudTasksUtils = cloudTasksHelper.getTestCloudTasksUtils(); action.endpoint = ENDPOINT; action.queue = QUEUE; action.jitterSeconds = Optional.empty(); + action.clock = new FakeClock(); action.run(); List matchers = new ArrayList<>(); for (int bucketId : CommitLogBucket.getBucketIds()) { matchers.add(new TaskMatcher().url(ENDPOINT).param(BUCKET_PARAM, Integer.toString(bucketId))); } - assertTasksEnqueued(QUEUE, matchers); + cloudTasksHelper.assertTasksEnqueued(QUEUE, matchers); } } diff --git a/util/src/main/java/google/registry/util/CloudTasksUtils.java b/util/src/main/java/google/registry/util/CloudTasksUtils.java index 5604f254c..7afed817b 100644 --- a/util/src/main/java/google/registry/util/CloudTasksUtils.java +++ b/util/src/main/java/google/registry/util/CloudTasksUtils.java @@ -16,6 +16,7 @@ package google.registry.util; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.concurrent.TimeUnit.SECONDS; import com.google.api.gax.rpc.ApiException; import com.google.cloud.tasks.v2.AppEngineHttpRequest; @@ -34,9 +35,13 @@ import com.google.common.net.HttpHeaders; import com.google.common.net.MediaType; import com.google.common.net.UrlEscapers; import com.google.protobuf.ByteString; +import com.google.protobuf.Timestamp; import java.io.Serializable; import java.nio.charset.StandardCharsets; +import java.time.Instant; import java.util.Arrays; +import java.util.Optional; +import java.util.Random; import java.util.function.Supplier; /** Utilities for dealing with Cloud Tasks. */ @@ -44,6 +49,7 @@ public class CloudTasksUtils implements Serializable { private static final long serialVersionUID = -7605156291755534069L; private static final FluentLogger logger = FluentLogger.forEnclosingClass(); + private static final Random random = new Random(); private final Retrier retrier; private final String projectId; @@ -88,7 +94,7 @@ public class CloudTasksUtils implements Serializable { * Queue API if no service is specified, the service which enqueues the task will be used to * process the task. Cloud Tasks API does not support this feature so the service will always * needs to be explicitly specified. - * @param params A multi-map of URL query parameters. Duplicate keys are saved as is, and it is up + * @param params a multi-map of URL query parameters. Duplicate keys are saved as is, and it is up * to the server to process the duplicate keys. * @return the enqueued task. * @see Specifyinig + * the worker service + */ + private static Task createTask( + String path, + HttpMethod method, + String service, + Multimap params, + Clock clock, + Optional jitterSeconds) { + if (jitterSeconds.isEmpty() || jitterSeconds.get() <= 0) { + return createTask(path, method, service, params); + } + Instant scheduleTime = + Instant.ofEpochMilli( + clock + .nowUtc() + .plusMillis(random.nextInt((int) SECONDS.toMillis(jitterSeconds.get()))) + .getMillis()); + return Task.newBuilder(createTask(path, method, service, params)) + .setScheduleTime( + Timestamp.newBuilder() + .setSeconds(scheduleTime.getEpochSecond()) + .setNanos(scheduleTime.getNano()) + .build()) + .build(); + } + public static Task createPostTask(String path, String service, Multimap params) { return createTask(path, HttpMethod.POST, service, params); } @@ -137,6 +186,30 @@ public class CloudTasksUtils implements Serializable { return createTask(path, HttpMethod.GET, service, params); } + /** + * Create a {@link Task} via HTTP.POST that will be randomly delayed up to {@code jitterSeconds}. + */ + public static Task createPostTask( + String path, + String service, + Multimap params, + Clock clock, + Optional jitterSeconds) { + return createTask(path, HttpMethod.POST, service, params, clock, jitterSeconds); + } + + /** + * Create a {@link Task} via HTTP.GET that will be randomly delayed up to {@code jitterSeconds}. + */ + public static Task createGetTask( + String path, + String service, + Multimap params, + Clock clock, + Optional jitterSeconds) { + return createTask(path, HttpMethod.GET, service, params, clock, jitterSeconds); + } + public abstract static class SerializableCloudTasksClient implements Serializable { public abstract Task enqueue(String projectId, String locationId, String queueName, Task task); } diff --git a/util/src/test/java/google/registry/util/CloudTasksUtilsTest.java b/util/src/test/java/google/registry/util/CloudTasksUtilsTest.java index 6a0bc6e36..cab58118b 100644 --- a/util/src/test/java/google/registry/util/CloudTasksUtilsTest.java +++ b/util/src/test/java/google/registry/util/CloudTasksUtilsTest.java @@ -30,6 +30,9 @@ import google.registry.testing.FakeClock; import google.registry.testing.FakeSleeper; import google.registry.util.CloudTasksUtils.SerializableCloudTasksClient; import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.util.Optional; +import org.joda.time.DateTime; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -41,6 +44,7 @@ public class CloudTasksUtilsTest { private final CloudTasksUtils cloudTasksUtils = new CloudTasksUtils( new Retrier(new FakeSleeper(new FakeClock()), 1), "project", "location", mockClient); + private final Clock clock = new FakeClock(DateTime.parse("2021-11-08")); @BeforeEach void beforeEach() { @@ -59,6 +63,7 @@ public class CloudTasksUtilsTest { .isEqualTo("/the/path?key1=val1&key2=val2&key1=val3"); assertThat(task.getAppEngineHttpRequest().getAppEngineRouting().getService()) .isEqualTo("myservice"); + assertThat(task.getScheduleTime().getSeconds()).isEqualTo(0); } @Test @@ -72,6 +77,103 @@ public class CloudTasksUtilsTest { .isEqualTo("application/x-www-form-urlencoded"); assertThat(task.getAppEngineHttpRequest().getBody().toString(StandardCharsets.UTF_8)) .isEqualTo("key1=val1&key2=val2&key1=val3"); + assertThat(task.getScheduleTime().getSeconds()).isEqualTo(0); + } + + @SuppressWarnings("ProtoTimestampGetSecondsGetNano") + @Test + void testSuccess_createGetTasks_withJitterSeconds() { + Task task = + CloudTasksUtils.createGetTask("/the/path", "myservice", params, clock, Optional.of(100)); + assertThat(task.getAppEngineHttpRequest().getHttpMethod()).isEqualTo(HttpMethod.GET); + assertThat(task.getAppEngineHttpRequest().getRelativeUri()) + .isEqualTo("/the/path?key1=val1&key2=val2&key1=val3"); + assertThat(task.getAppEngineHttpRequest().getAppEngineRouting().getService()) + .isEqualTo("myservice"); + + Instant scheduleTime = Instant.ofEpochSecond(task.getScheduleTime().getSeconds()); + Instant lowerBoundTime = Instant.ofEpochMilli(clock.nowUtc().getMillis()); + Instant upperBound = Instant.ofEpochMilli(clock.nowUtc().plusSeconds(100).getMillis()); + + assertThat(scheduleTime.isBefore(lowerBoundTime)).isFalse(); + assertThat(upperBound.isBefore(scheduleTime)).isFalse(); + } + + @SuppressWarnings("ProtoTimestampGetSecondsGetNano") + @Test + void testSuccess_createPostTasks_withJitterSeconds() { + Task task = + CloudTasksUtils.createPostTask("/the/path", "myservice", params, clock, Optional.of(1)); + assertThat(task.getAppEngineHttpRequest().getHttpMethod()).isEqualTo(HttpMethod.POST); + assertThat(task.getAppEngineHttpRequest().getRelativeUri()).isEqualTo("/the/path"); + assertThat(task.getAppEngineHttpRequest().getAppEngineRouting().getService()) + .isEqualTo("myservice"); + assertThat(task.getAppEngineHttpRequest().getHeadersMap().get("Content-Type")) + .isEqualTo("application/x-www-form-urlencoded"); + assertThat(task.getAppEngineHttpRequest().getBody().toString(StandardCharsets.UTF_8)) + .isEqualTo("key1=val1&key2=val2&key1=val3"); + assertThat(task.getScheduleTime().getSeconds()).isNotEqualTo(0); + + Instant scheduleTime = Instant.ofEpochSecond(task.getScheduleTime().getSeconds()); + Instant lowerBoundTime = Instant.ofEpochMilli(clock.nowUtc().getMillis()); + Instant upperBound = Instant.ofEpochMilli(clock.nowUtc().plusSeconds(1).getMillis()); + + assertThat(scheduleTime.isBefore(lowerBoundTime)).isFalse(); + assertThat(upperBound.isBefore(scheduleTime)).isFalse(); + } + + @Test + void testSuccess_createPostTasks_withEmptyJitterSeconds() { + Task task = + CloudTasksUtils.createPostTask("/the/path", "myservice", params, clock, Optional.empty()); + assertThat(task.getAppEngineHttpRequest().getHttpMethod()).isEqualTo(HttpMethod.POST); + assertThat(task.getAppEngineHttpRequest().getRelativeUri()).isEqualTo("/the/path"); + assertThat(task.getAppEngineHttpRequest().getAppEngineRouting().getService()) + .isEqualTo("myservice"); + assertThat(task.getAppEngineHttpRequest().getHeadersMap().get("Content-Type")) + .isEqualTo("application/x-www-form-urlencoded"); + assertThat(task.getAppEngineHttpRequest().getBody().toString(StandardCharsets.UTF_8)) + .isEqualTo("key1=val1&key2=val2&key1=val3"); + assertThat(task.getScheduleTime().getSeconds()).isEqualTo(0); + } + + @Test + void testSuccess_createGetTasks_withEmptyJitterSeconds() { + Task task = + CloudTasksUtils.createGetTask("/the/path", "myservice", params, clock, Optional.empty()); + assertThat(task.getAppEngineHttpRequest().getHttpMethod()).isEqualTo(HttpMethod.GET); + assertThat(task.getAppEngineHttpRequest().getRelativeUri()) + .isEqualTo("/the/path?key1=val1&key2=val2&key1=val3"); + assertThat(task.getAppEngineHttpRequest().getAppEngineRouting().getService()) + .isEqualTo("myservice"); + assertThat(task.getScheduleTime().getSeconds()).isEqualTo(0); + } + + @Test + void testSuccess_createPostTasks_withZeroJitterSeconds() { + Task task = + CloudTasksUtils.createPostTask("/the/path", "myservice", params, clock, Optional.of(0)); + assertThat(task.getAppEngineHttpRequest().getHttpMethod()).isEqualTo(HttpMethod.POST); + assertThat(task.getAppEngineHttpRequest().getRelativeUri()).isEqualTo("/the/path"); + assertThat(task.getAppEngineHttpRequest().getAppEngineRouting().getService()) + .isEqualTo("myservice"); + assertThat(task.getAppEngineHttpRequest().getHeadersMap().get("Content-Type")) + .isEqualTo("application/x-www-form-urlencoded"); + assertThat(task.getAppEngineHttpRequest().getBody().toString(StandardCharsets.UTF_8)) + .isEqualTo("key1=val1&key2=val2&key1=val3"); + assertThat(task.getScheduleTime().getSeconds()).isEqualTo(0); + } + + @Test + void testSuccess_createGetTasks_withZeroJitterSeconds() { + Task task = + CloudTasksUtils.createGetTask("/the/path", "myservice", params, clock, Optional.of(0)); + assertThat(task.getAppEngineHttpRequest().getHttpMethod()).isEqualTo(HttpMethod.GET); + assertThat(task.getAppEngineHttpRequest().getRelativeUri()) + .isEqualTo("/the/path?key1=val1&key2=val2&key1=val3"); + assertThat(task.getAppEngineHttpRequest().getAppEngineRouting().getService()) + .isEqualTo("myservice"); + assertThat(task.getScheduleTime().getSeconds()).isEqualTo(0); } @Test