From a90ef39a405c1ce171ead191bd65be9a10ea348e Mon Sep 17 00:00:00 2001 From: gbrodman Date: Tue, 9 May 2023 16:02:12 -0400 Subject: [PATCH] Allow usage of standard HTTP requests in CloudTasksUtils (#2013) This adds a possible configuration point "defaultServiceAccount" (which in GAE will be the standard GAE service account). If this is configured, CloudTasksUtils can create tasks with standard HTTP requests with an OIDC token corresponding to that service account, as opposed to using the AppEngine-specific request methods. This also works with IAP, in that if IAP is on and we specify the IAP client ID in the config, CloudTasksUtils will use the IAP client ID as the token audience and the request will successfully be passed through the IAP layer. Tetsted in QA. --- .../registry/batch/CloudTasksUtils.java | 122 ++++++-- .../registry/config/RegistryConfig.java | 18 +- .../config/RegistryConfigSettings.java | 1 + .../registry/config/files/default-config.yaml | 3 + .../google/registry/cron/TldFanoutAction.java | 26 +- .../registry/tools/ServiceConnection.java | 4 +- .../registry/batch/CloudTasksUtilsTest.java | 262 +++++++++++++++++- .../registry/testing/CloudTasksHelper.java | 3 + 8 files changed, 395 insertions(+), 44 deletions(-) diff --git a/core/src/main/java/google/registry/batch/CloudTasksUtils.java b/core/src/main/java/google/registry/batch/CloudTasksUtils.java index cb1c62814..55305032f 100644 --- a/core/src/main/java/google/registry/batch/CloudTasksUtils.java +++ b/core/src/main/java/google/registry/batch/CloudTasksUtils.java @@ -16,6 +16,7 @@ package google.registry.batch; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; +import static google.registry.tools.ServiceConnection.getServer; import static java.util.concurrent.TimeUnit.SECONDS; import com.google.api.gax.rpc.ApiException; @@ -23,6 +24,8 @@ import com.google.cloud.tasks.v2.AppEngineHttpRequest; import com.google.cloud.tasks.v2.AppEngineRouting; import com.google.cloud.tasks.v2.CloudTasksClient; import com.google.cloud.tasks.v2.HttpMethod; +import com.google.cloud.tasks.v2.HttpRequest; +import com.google.cloud.tasks.v2.OidcToken; import com.google.cloud.tasks.v2.QueueName; import com.google.cloud.tasks.v2.Task; import com.google.common.base.Joiner; @@ -46,7 +49,10 @@ import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.Optional; import java.util.Random; +import java.util.function.BiConsumer; +import java.util.function.Consumer; import java.util.function.Supplier; +import javax.annotation.Nullable; import javax.inject.Inject; import org.joda.time.Duration; @@ -61,6 +67,9 @@ public class CloudTasksUtils implements Serializable { private final Clock clock; private final String projectId; private final String locationId; + // defaultServiceAccount and iapClientId are nullable because Optional isn't serializable + @Nullable private final String defaultServiceAccount; + @Nullable private final String iapClientId; private final SerializableCloudTasksClient client; @Inject @@ -69,11 +78,15 @@ public class CloudTasksUtils implements Serializable { Clock clock, @Config("projectId") String projectId, @Config("locationId") String locationId, + @Config("defaultServiceAccount") Optional defaultServiceAccount, + @Config("iapClientId") Optional iapClientId, SerializableCloudTasksClient client) { this.retrier = retrier; this.clock = clock; this.projectId = projectId; this.locationId = locationId; + this.defaultServiceAccount = defaultServiceAccount.orElse(null); + this.iapClientId = iapClientId.orElse(null); this.client = client; } @@ -98,6 +111,74 @@ public class CloudTasksUtils implements Serializable { return enqueue(queue, Arrays.asList(tasks)); } + /** + * Converts a (possible) set of params into an HTTP request via the appropriate method. + * + *

For GET requests we add them on to the URL, and for POST requests we add them in the body of + * the request. + * + *

The parameters {@code putHeadersFunction} and {@code setBodyFunction} are used so that this + * method can be called with either an AppEngine HTTP request or a standard non-AppEngine HTTP + * request. The two objects do not have the same methods, but both have ways of setting headers / + * body. + * + * @return the resulting path (unchanged for POST requests, with params added for GET requests) + */ + private String processRequestParameters( + String path, + HttpMethod method, + Multimap params, + BiConsumer putHeadersFunction, + Consumer setBodyFunction) { + if (CollectionUtils.isNullOrEmpty(params)) { + return path; + } + Escaper escaper = UrlEscapers.urlPathSegmentEscaper(); + String encodedParams = + Joiner.on("&") + .join( + params.entries().stream() + .map( + entry -> + String.format( + "%s=%s", + escaper.escape(entry.getKey()), escaper.escape(entry.getValue()))) + .collect(toImmutableList())); + if (method.equals(HttpMethod.GET)) { + return String.format("%s?%s", path, encodedParams); + } + putHeadersFunction.accept(HttpHeaders.CONTENT_TYPE, MediaType.FORM_DATA.toString()); + setBodyFunction.accept(ByteString.copyFrom(encodedParams, StandardCharsets.UTF_8)); + return path; + } + + /** + * Creates a {@link Task} that does not use AppEngine for submission. + * + *

This uses the standard Cloud Tasks auth format to create and send an OIDC ID token set to + * the default service account. That account must have permission to submit tasks to Cloud Tasks. + */ + private Task createNonAppEngineTask( + String path, HttpMethod method, Service service, Multimap params) { + HttpRequest.Builder requestBuilder = HttpRequest.newBuilder().setHttpMethod(method); + path = + processRequestParameters( + path, method, params, requestBuilder::putHeaders, requestBuilder::setBody); + OidcToken.Builder oidcTokenBuilder = + OidcToken.newBuilder().setServiceAccountEmail(defaultServiceAccount); + // If the service is using IAP, add that as the audience for the token so the request can be + // appropriately authed. Otherwise, use the project name. + if (iapClientId != null) { + oidcTokenBuilder.setAudience(iapClientId); + } else { + oidcTokenBuilder.setAudience(projectId); + } + requestBuilder.setOidcToken(oidcTokenBuilder.build()); + String totalPath = String.format("%s%s", getServer(service), path); + requestBuilder.setUrl(totalPath); + return Task.newBuilder().setHttpRequest(requestBuilder.build()).build(); + } + /** * Create a {@link Task} to be enqueued. * @@ -123,34 +204,21 @@ public class CloudTasksUtils implements Serializable { method.equals(HttpMethod.GET) || method.equals(HttpMethod.POST), "HTTP method %s is used. Only GET and POST are allowed.", method); - AppEngineHttpRequest.Builder requestBuilder = - AppEngineHttpRequest.newBuilder() - .setHttpMethod(method) - .setAppEngineRouting( - AppEngineRouting.newBuilder().setService(service.toString()).build()); - - if (!CollectionUtils.isNullOrEmpty(params)) { - Escaper escaper = UrlEscapers.urlPathSegmentEscaper(); - String encodedParams = - Joiner.on("&") - .join( - params.entries().stream() - .map( - entry -> - String.format( - "%s=%s", - escaper.escape(entry.getKey()), escaper.escape(entry.getValue()))) - .collect(toImmutableList())); - if (method == HttpMethod.GET) { - path = String.format("%s?%s", path, encodedParams); - } else { - requestBuilder - .putHeaders(HttpHeaders.CONTENT_TYPE, MediaType.FORM_DATA.toString()) - .setBody(ByteString.copyFrom(encodedParams, StandardCharsets.UTF_8)); - } + // If the default service account is configured, send a standard non-AppEngine HTTP request + if (defaultServiceAccount != null) { + return createNonAppEngineTask(path, method, service, params); + } else { + AppEngineHttpRequest.Builder requestBuilder = + AppEngineHttpRequest.newBuilder() + .setHttpMethod(method) + .setAppEngineRouting( + AppEngineRouting.newBuilder().setService(service.toString()).build()); + path = + processRequestParameters( + path, method, params, requestBuilder::putHeaders, requestBuilder::setBody); + requestBuilder.setRelativeUri(path); + return Task.newBuilder().setAppEngineHttpRequest(requestBuilder.build()).build(); } - requestBuilder.setRelativeUri(path); - return Task.newBuilder().setAppEngineHttpRequest(requestBuilder.build()).build(); } /** diff --git a/core/src/main/java/google/registry/config/RegistryConfig.java b/core/src/main/java/google/registry/config/RegistryConfig.java index 60c151ca5..958b15153 100644 --- a/core/src/main/java/google/registry/config/RegistryConfig.java +++ b/core/src/main/java/google/registry/config/RegistryConfig.java @@ -108,12 +108,6 @@ public final class RegistryConfig { return config.gcpProject.projectId; } - @Provides - @Config("serviceAccountEmails") - public static ImmutableList provideServiceAccountEmails(RegistryConfigSettings config) { - return ImmutableList.copyOf(config.gcpProject.serviceAccountEmails); - } - @Provides @Config("projectIdNumber") public static long provideProjectIdNumber(RegistryConfigSettings config) { @@ -126,6 +120,18 @@ public final class RegistryConfig { return config.gcpProject.locationId; } + @Provides + @Config("serviceAccountEmails") + public static ImmutableList provideServiceAccountEmails(RegistryConfigSettings config) { + return ImmutableList.copyOf(config.gcpProject.serviceAccountEmails); + } + + @Provides + @Config("defaultServiceAccount") + public static Optional provideDefaultServiceAccount(RegistryConfigSettings config) { + return Optional.ofNullable(config.gcpProject.defaultServiceAccount); + } + /** * The filename of the logo to be displayed in the header of the registrar console. * diff --git a/core/src/main/java/google/registry/config/RegistryConfigSettings.java b/core/src/main/java/google/registry/config/RegistryConfigSettings.java index 741e7897f..d191e307a 100644 --- a/core/src/main/java/google/registry/config/RegistryConfigSettings.java +++ b/core/src/main/java/google/registry/config/RegistryConfigSettings.java @@ -55,6 +55,7 @@ public class RegistryConfigSettings { public String toolsServiceUrl; public String pubapiServiceUrl; public List serviceAccountEmails; + public String defaultServiceAccount; } /** Configuration options for OAuth settings for authenticating users. */ diff --git a/core/src/main/java/google/registry/config/files/default-config.yaml b/core/src/main/java/google/registry/config/files/default-config.yaml index 6f7f19924..3d5107b11 100644 --- a/core/src/main/java/google/registry/config/files/default-config.yaml +++ b/core/src/main/java/google/registry/config/files/default-config.yaml @@ -27,6 +27,9 @@ gcpProject: serviceAccountEmails: - default-service-account-email@email.com - cloud-scheduler-email@email.com + # The default service account with which the service is running. For example, + # on GAE this would be {project-id}@appspot.gserviceaccount.com + defaultServiceAccount: null gSuite: # Publicly accessible domain name of the running G Suite instance. diff --git a/core/src/main/java/google/registry/cron/TldFanoutAction.java b/core/src/main/java/google/registry/cron/TldFanoutAction.java index 46d0d7f55..709e76663 100644 --- a/core/src/main/java/google/registry/cron/TldFanoutAction.java +++ b/core/src/main/java/google/registry/cron/TldFanoutAction.java @@ -140,13 +140,25 @@ public final class TldFanoutAction implements Runnable { for (String tld : tlds) { Task task = createTask(tld, flowThruParams); Task createdTask = cloudTasksUtils.enqueue(queue, task); - outputPayload.append( - String.format( - "- Task: '%s', tld: '%s', endpoint: '%s'\n", - createdTask.getName(), tld, createdTask.getAppEngineHttpRequest().getRelativeUri())); - logger.atInfo().log( - "Task: '%s', tld: '%s', endpoint: '%s'.", - createdTask.getName(), tld, createdTask.getAppEngineHttpRequest().getRelativeUri()); + if (createdTask.hasAppEngineHttpRequest()) { + outputPayload.append( + String.format( + "- Task: '%s', tld: '%s', endpoint: '%s'\n", + createdTask.getName(), + tld, + createdTask.getAppEngineHttpRequest().getRelativeUri())); + logger.atInfo().log( + "Task: '%s', tld: '%s', endpoint: '%s'.", + createdTask.getName(), tld, createdTask.getAppEngineHttpRequest().getRelativeUri()); + } else { + outputPayload.append( + String.format( + "- Task: '%s', tld: '%s', endpoint: '%s'\n", + createdTask.getName(), tld, createdTask.getHttpRequest().getUrl())); + logger.atInfo().log( + "Task: '%s', tld: '%s', endpoint: '%s'.", + createdTask.getName(), tld, createdTask.getHttpRequest().getUrl()); + } } response.setContentType(PLAIN_TEXT_UTF_8); response.setPayload(outputPayload.toString()); diff --git a/core/src/main/java/google/registry/tools/ServiceConnection.java b/core/src/main/java/google/registry/tools/ServiceConnection.java index 8c19c8fdd..e35f3f12a 100644 --- a/core/src/main/java/google/registry/tools/ServiceConnection.java +++ b/core/src/main/java/google/registry/tools/ServiceConnection.java @@ -85,7 +85,7 @@ public class ServiceConnection { private String internalSend( String endpoint, Map params, MediaType contentType, @Nullable byte[] payload) throws IOException { - GenericUrl url = new GenericUrl(String.format("%s%s", getServer(), endpoint)); + GenericUrl url = new GenericUrl(String.format("%s%s", getServer(service), endpoint)); url.putAll(params); HttpRequest request = (payload != null) @@ -141,7 +141,7 @@ public class ServiceConnection { return (Map) JSONValue.parse(response.substring(JSON_SAFETY_PREFIX.length())); } - public URL getServer() { + public static URL getServer(Service service) { switch (service) { case DEFAULT: return RegistryConfig.getDefaultServer(); diff --git a/core/src/test/java/google/registry/batch/CloudTasksUtilsTest.java b/core/src/test/java/google/registry/batch/CloudTasksUtilsTest.java index 77b92aa20..ef0c06070 100644 --- a/core/src/test/java/google/registry/batch/CloudTasksUtilsTest.java +++ b/core/src/test/java/google/registry/batch/CloudTasksUtilsTest.java @@ -23,6 +23,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import com.google.cloud.tasks.v2.HttpMethod; +import com.google.cloud.tasks.v2.OidcToken; import com.google.cloud.tasks.v2.Task; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMultimap; @@ -46,9 +47,15 @@ public class CloudTasksUtilsTest { private final LinkedListMultimap params = LinkedListMultimap.create(); private final SerializableCloudTasksClient mockClient = mock(SerializableCloudTasksClient.class); private final FakeClock clock = new FakeClock(DateTime.parse("2021-11-08")); - private final CloudTasksUtils cloudTasksUtils = + private CloudTasksUtils cloudTasksUtils = new CloudTasksUtils( - new Retrier(new FakeSleeper(clock), 1), clock, "project", "location", mockClient); + new Retrier(new FakeSleeper(clock), 1), + clock, + "project", + "location", + Optional.empty(), + Optional.empty(), + mockClient); @BeforeEach void beforeEach() { @@ -348,4 +355,255 @@ public class CloudTasksUtilsTest { verify(mockClient).enqueue("project", "location", "test-queue", task1); verify(mockClient).enqueue("project", "location", "test-queue", task2); } + + @Test + void testSuccess_nonAppEngine_createGetTasks() { + createOidcTasksUtils(); + Task task = cloudTasksUtils.createGetTask("/the/path", Service.BACKEND, params); + assertThat(task.getHttpRequest().getHttpMethod()).isEqualTo(HttpMethod.GET); + assertThat(task.getHttpRequest().getUrl()) + .isEqualTo("https://localhost/the/path?key1=val1&key2=val2&key1=val3"); + verifyOidcToken(task); + assertThat(task.getScheduleTime().getSeconds()).isEqualTo(0); + } + + @Test + void testSuccess_nonAppEngine_createPostTasks() { + createOidcTasksUtils(); + Task task = cloudTasksUtils.createPostTask("/the/path", Service.BACKEND, params); + assertThat(task.getHttpRequest().getHttpMethod()).isEqualTo(HttpMethod.POST); + assertThat(task.getHttpRequest().getUrl()).isEqualTo("https://localhost/the/path"); + assertThat(task.getHttpRequest().getHeadersMap().get("Content-Type")) + .isEqualTo("application/x-www-form-urlencoded"); + assertThat(task.getHttpRequest().getBody().toString(StandardCharsets.UTF_8)) + .isEqualTo("key1=val1&key2=val2&key1=val3"); + verifyOidcToken(task); + assertThat(task.getScheduleTime().getSeconds()).isEqualTo(0); + } + + @Test + void testSuccess_nonAppEngine_createGetTasks_withNullParams() { + createOidcTasksUtils(); + Task task = cloudTasksUtils.createGetTask("/the/path", Service.BACKEND, null); + assertThat(task.getHttpRequest().getHttpMethod()).isEqualTo(HttpMethod.GET); + assertThat(task.getHttpRequest().getUrl()).isEqualTo("https://localhost/the/path"); + verifyOidcToken(task); + assertThat(task.getScheduleTime().getSeconds()).isEqualTo(0); + } + + @Test + void testSuccess_nonAppEngine_createPostTasks_withNullParams() { + createOidcTasksUtils(); + Task task = cloudTasksUtils.createPostTask("/the/path", Service.BACKEND, null); + assertThat(task.getHttpRequest().getHttpMethod()).isEqualTo(HttpMethod.POST); + assertThat(task.getHttpRequest().getUrl()).isEqualTo("https://localhost/the/path"); + assertThat(task.getHttpRequest().getBody().toString(StandardCharsets.UTF_8)).isEmpty(); + verifyOidcToken(task); + assertThat(task.getScheduleTime().getSeconds()).isEqualTo(0); + } + + @Test + void testSuccess_nonAppEngine_createGetTasks_withEmptyParams() { + createOidcTasksUtils(); + Task task = cloudTasksUtils.createGetTask("/the/path", Service.BACKEND, ImmutableMultimap.of()); + assertThat(task.getHttpRequest().getHttpMethod()).isEqualTo(HttpMethod.GET); + assertThat(task.getHttpRequest().getUrl()).isEqualTo("https://localhost/the/path"); + verifyOidcToken(task); + assertThat(task.getScheduleTime().getSeconds()).isEqualTo(0); + } + + @Test + void testSuccess_nonAppEngine_createPostTasks_withEmptyParams() { + createOidcTasksUtils(); + Task task = + cloudTasksUtils.createPostTask("/the/path", Service.BACKEND, ImmutableMultimap.of()); + assertThat(task.getHttpRequest().getHttpMethod()).isEqualTo(HttpMethod.POST); + assertThat(task.getHttpRequest().getUrl()).isEqualTo("https://localhost/the/path"); + assertThat(task.getHttpRequest().getBody().toString(StandardCharsets.UTF_8)).isEmpty(); + verifyOidcToken(task); + assertThat(task.getScheduleTime().getSeconds()).isEqualTo(0); + } + + @SuppressWarnings("ProtoTimestampGetSecondsGetNano") + @Test + void testSuccess_nonAppEngine_createGetTasks_withJitterSeconds() { + createOidcTasksUtils(); + Task task = + cloudTasksUtils.createGetTaskWithJitter( + "/the/path", Service.BACKEND, params, Optional.of(100)); + assertThat(task.getHttpRequest().getHttpMethod()).isEqualTo(HttpMethod.GET); + assertThat(task.getHttpRequest().getUrl()) + .isEqualTo("https://localhost/the/path?key1=val1&key2=val2&key1=val3"); + verifyOidcToken(task); + + Instant scheduleTime = Instant.ofEpochSecond(task.getScheduleTime().getSeconds()); + Instant lowerBoundTime = Instant.ofEpochMilli(clock.nowUtc().getMillis()); + Instant upperBound = Instant.ofEpochMilli(clock.nowUtc().plusSeconds(100).getMillis()); + + assertThat(scheduleTime.isBefore(lowerBoundTime)).isFalse(); + assertThat(upperBound.isBefore(scheduleTime)).isFalse(); + } + + @SuppressWarnings("ProtoTimestampGetSecondsGetNano") + @Test + void testSuccess_nonAppEngine_createPostTasks_withJitterSeconds() { + createOidcTasksUtils(); + Task task = + cloudTasksUtils.createPostTaskWithJitter( + "/the/path", Service.BACKEND, params, Optional.of(1)); + assertThat(task.getHttpRequest().getHttpMethod()).isEqualTo(HttpMethod.POST); + assertThat(task.getHttpRequest().getUrl()).isEqualTo("https://localhost/the/path"); + assertThat(task.getHttpRequest().getHeadersMap().get("Content-Type")) + .isEqualTo("application/x-www-form-urlencoded"); + assertThat(task.getHttpRequest().getBody().toString(StandardCharsets.UTF_8)) + .isEqualTo("key1=val1&key2=val2&key1=val3"); + verifyOidcToken(task); + assertThat(task.getScheduleTime().getSeconds()).isNotEqualTo(0); + + Instant scheduleTime = Instant.ofEpochSecond(task.getScheduleTime().getSeconds()); + Instant lowerBoundTime = Instant.ofEpochMilli(clock.nowUtc().getMillis()); + Instant upperBound = Instant.ofEpochMilli(clock.nowUtc().plusSeconds(1).getMillis()); + + assertThat(scheduleTime.isBefore(lowerBoundTime)).isFalse(); + assertThat(upperBound.isBefore(scheduleTime)).isFalse(); + } + + @Test + void testSuccess_nonAppEngine_createPostTasks_withEmptyJitterSeconds() { + createOidcTasksUtils(); + Task task = + cloudTasksUtils.createPostTaskWithJitter( + "/the/path", Service.BACKEND, params, Optional.empty()); + assertThat(task.getHttpRequest().getHttpMethod()).isEqualTo(HttpMethod.POST); + assertThat(task.getHttpRequest().getUrl()).isEqualTo("https://localhost/the/path"); + assertThat(task.getHttpRequest().getHeadersMap().get("Content-Type")) + .isEqualTo("application/x-www-form-urlencoded"); + assertThat(task.getHttpRequest().getBody().toString(StandardCharsets.UTF_8)) + .isEqualTo("key1=val1&key2=val2&key1=val3"); + verifyOidcToken(task); + assertThat(task.getScheduleTime().getSeconds()).isEqualTo(0); + } + + @Test + void testSuccess_nonAppEngine_createGetTasks_withEmptyJitterSeconds() { + createOidcTasksUtils(); + Task task = + cloudTasksUtils.createGetTaskWithJitter( + "/the/path", Service.BACKEND, params, Optional.empty()); + assertThat(task.getHttpRequest().getHttpMethod()).isEqualTo(HttpMethod.GET); + assertThat(task.getHttpRequest().getUrl()) + .isEqualTo("https://localhost/the/path?key1=val1&key2=val2&key1=val3"); + verifyOidcToken(task); + assertThat(task.getScheduleTime().getSeconds()).isEqualTo(0); + } + + @Test + void testSuccess_nonAppEngine_createPostTasks_withZeroJitterSeconds() { + createOidcTasksUtils(); + Task task = + cloudTasksUtils.createPostTaskWithJitter( + "/the/path", Service.BACKEND, params, Optional.of(0)); + assertThat(task.getHttpRequest().getHttpMethod()).isEqualTo(HttpMethod.POST); + assertThat(task.getHttpRequest().getUrl()).isEqualTo("https://localhost/the/path"); + assertThat(task.getHttpRequest().getHeadersMap().get("Content-Type")) + .isEqualTo("application/x-www-form-urlencoded"); + assertThat(task.getHttpRequest().getBody().toString(StandardCharsets.UTF_8)) + .isEqualTo("key1=val1&key2=val2&key1=val3"); + verifyOidcToken(task); + assertThat(task.getScheduleTime().getSeconds()).isEqualTo(0); + } + + @Test + void testSuccess_nonAppEngine_createGetTasks_withZeroJitterSeconds() { + createOidcTasksUtils(); + Task task = + cloudTasksUtils.createGetTaskWithJitter( + "/the/path", Service.BACKEND, params, Optional.of(0)); + assertThat(task.getHttpRequest().getHttpMethod()).isEqualTo(HttpMethod.GET); + assertThat(task.getHttpRequest().getUrl()) + .isEqualTo("https://localhost/the/path?key1=val1&key2=val2&key1=val3"); + verifyOidcToken(task); + assertThat(task.getScheduleTime().getSeconds()).isEqualTo(0); + } + + @Test + void testSuccess_nonAppEngine_createGetTasks_withDelay() { + createOidcTasksUtils(); + Task task = + cloudTasksUtils.createGetTaskWithDelay( + "/the/path", Service.BACKEND, params, Duration.standardMinutes(10)); + assertThat(task.getHttpRequest().getHttpMethod()).isEqualTo(HttpMethod.GET); + assertThat(task.getHttpRequest().getUrl()) + .isEqualTo("https://localhost/the/path?key1=val1&key2=val2&key1=val3"); + verifyOidcToken(task); + assertThat(Instant.ofEpochSecond(task.getScheduleTime().getSeconds())) + .isEqualTo(Instant.ofEpochMilli(clock.nowUtc().plusMinutes(10).getMillis())); + } + + @Test + void testSuccess_nonAppEngine_createPostTasks_withDelay() { + createOidcTasksUtils(); + Task task = + cloudTasksUtils.createPostTaskWithDelay( + "/the/path", Service.BACKEND, params, Duration.standardMinutes(10)); + assertThat(task.getHttpRequest().getHttpMethod()).isEqualTo(HttpMethod.POST); + assertThat(task.getHttpRequest().getUrl()).isEqualTo("https://localhost/the/path"); + assertThat(task.getHttpRequest().getHeadersMap().get("Content-Type")) + .isEqualTo("application/x-www-form-urlencoded"); + assertThat(task.getHttpRequest().getBody().toString(StandardCharsets.UTF_8)) + .isEqualTo("key1=val1&key2=val2&key1=val3"); + verifyOidcToken(task); + assertThat(task.getScheduleTime().getSeconds()).isNotEqualTo(0); + assertThat(Instant.ofEpochSecond(task.getScheduleTime().getSeconds())) + .isEqualTo(Instant.ofEpochMilli(clock.nowUtc().plusMinutes(10).getMillis())); + } + + @Test + void testSuccess_nonAppEngine_createPostTasks_withZeroDelay() { + createOidcTasksUtils(); + Task task = + cloudTasksUtils.createPostTaskWithDelay( + "/the/path", Service.BACKEND, params, Duration.ZERO); + assertThat(task.getHttpRequest().getHttpMethod()).isEqualTo(HttpMethod.POST); + assertThat(task.getHttpRequest().getUrl()).isEqualTo("https://localhost/the/path"); + assertThat(task.getHttpRequest().getHeadersMap().get("Content-Type")) + .isEqualTo("application/x-www-form-urlencoded"); + assertThat(task.getHttpRequest().getBody().toString(StandardCharsets.UTF_8)) + .isEqualTo("key1=val1&key2=val2&key1=val3"); + verifyOidcToken(task); + assertThat(task.getScheduleTime().getSeconds()).isEqualTo(0); + } + + @Test + void testSuccess_nonAppEngine_createGetTasks_withZeroDelay() { + createOidcTasksUtils(); + Task task = + cloudTasksUtils.createGetTaskWithDelay("/the/path", Service.BACKEND, params, Duration.ZERO); + assertThat(task.getHttpRequest().getHttpMethod()).isEqualTo(HttpMethod.GET); + assertThat(task.getHttpRequest().getUrl()) + .isEqualTo("https://localhost/the/path?key1=val1&key2=val2&key1=val3"); + verifyOidcToken(task); + assertThat(task.getScheduleTime().getSeconds()).isEqualTo(0); + } + + private void createOidcTasksUtils() { + cloudTasksUtils = + new CloudTasksUtils( + new Retrier(new FakeSleeper(clock), 1), + clock, + "project", + "location", + Optional.of("defaultServiceAccount"), + Optional.of("iapClientId"), + mockClient); + } + + private void verifyOidcToken(Task task) { + assertThat(task.getHttpRequest().getOidcToken()) + .isEqualTo( + OidcToken.newBuilder() + .setServiceAccountEmail("defaultServiceAccount") + .setAudience("iapClientId") + .build()); + } } diff --git a/core/src/test/java/google/registry/testing/CloudTasksHelper.java b/core/src/test/java/google/registry/testing/CloudTasksHelper.java index 45f0fbb55..50f2e8b43 100644 --- a/core/src/test/java/google/registry/testing/CloudTasksHelper.java +++ b/core/src/test/java/google/registry/testing/CloudTasksHelper.java @@ -58,6 +58,7 @@ import java.util.List; import java.util.Map; import java.util.NoSuchElementException; import java.util.Objects; +import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicInteger; @@ -103,6 +104,8 @@ public class CloudTasksHelper implements Serializable { clock, PROJECT_ID, LOCATION_ID, + Optional.empty(), + Optional.empty(), new FakeCloudTasksClient()); testTasks.put(instanceId, Multimaps.synchronizedListMultimap(LinkedListMultimap.create())); }