Generate and use an IAP-enabled ID token in the proxy (#1926)

This is only generated and used if "iapClientId" is set in the proxy
config. If so, we use code similar to
https://cloud.google.com/iap/docs/authentication-howto#obtaining_an_oidc_token_for_the_default_service_account
to generate an ID token that is valid for IAP. We set the token on the
Proxy-Authorization header so that we can keep using the pre-existing
access token as well -- IAP allows for us to use either the
Authorization header or the Proxy-Authorization header.
This commit is contained in:
gbrodman 2023-02-09 14:50:35 -05:00 committed by GitHub
parent de1f56393d
commit e26e5adf5c
14 changed files with 272 additions and 51 deletions

View file

@ -16,6 +16,7 @@ package google.registry.proxy;
import static google.registry.util.ResourceUtils.readResourceBytes; import static google.registry.util.ResourceUtils.readResourceBytes;
import com.google.auth.oauth2.GoogleCredentials;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import dagger.Module; import dagger.Module;
import dagger.Provides; import dagger.Provides;
@ -43,6 +44,7 @@ import io.netty.handler.timeout.ReadTimeoutHandler;
import java.io.IOException; import java.io.IOException;
import java.security.PrivateKey; import java.security.PrivateKey;
import java.security.cert.X509Certificate; import java.security.cert.X509Certificate;
import java.util.Optional;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import java.util.function.Supplier; import java.util.function.Supplier;
@ -147,12 +149,18 @@ public final class EppProtocolModule {
@Provides @Provides
static EppServiceHandler provideEppServiceHandler( static EppServiceHandler provideEppServiceHandler(
@Named("accessToken") Supplier<String> accessTokenSupplier, Supplier<GoogleCredentials> refreshedCredentialsSupplier,
@Named("iapClientId") Optional<String> iapClientId,
@Named("hello") byte[] helloBytes, @Named("hello") byte[] helloBytes,
FrontendMetrics metrics, FrontendMetrics metrics,
ProxyConfig config) { ProxyConfig config) {
return new EppServiceHandler( return new EppServiceHandler(
config.epp.relayHost, config.epp.relayPath, accessTokenSupplier, helloBytes, metrics); config.epp.relayHost,
config.epp.relayPath,
refreshedCredentialsSupplier,
iapClientId,
helloBytes,
metrics);
} }
@Singleton @Singleton

View file

@ -40,6 +40,7 @@ public class ProxyConfig {
private static final String CUSTOM_CONFIG_FORMATTER = "config/proxy-config-%s.yaml"; private static final String CUSTOM_CONFIG_FORMATTER = "config/proxy-config-%s.yaml";
public String projectId; public String projectId;
public String iapClientId;
public List<String> gcpScopes; public List<String> gcpScopes;
public int serverCertificateCacheSeconds; public int serverCertificateCacheSeconds;
public Gcs gcs; public Gcs gcs;

View file

@ -157,6 +157,13 @@ public class ProxyModule {
return this; return this;
} }
@Provides
@Named("iapClientId")
@Singleton
Optional<String> provideIapClientId(ProxyConfig config) {
return Optional.ofNullable(config.iapClientId);
}
@Provides @Provides
@WhoisProtocol @WhoisProtocol
int provideWhoisPort(ProxyConfig config) { int provideWhoisPort(ProxyConfig config) {
@ -207,7 +214,7 @@ public class ProxyModule {
@Singleton @Singleton
@Provides @Provides
static GoogleCredentialsBundle provideCredential(ProxyConfig config) { static GoogleCredentialsBundle provideCredentialsBundle(ProxyConfig config) {
try { try {
GoogleCredentials credentials = GoogleCredentials.getApplicationDefault(); GoogleCredentials credentials = GoogleCredentials.getApplicationDefault();
if (credentials.createScopedRequired()) { if (credentials.createScopedRequired()) {
@ -219,19 +226,19 @@ public class ProxyModule {
} }
} }
/** Access token supplier that auto refreshes 1 minute before expiry. */ /** Provides a set of credentials that auto refreshes 1 minute before expiry. */
@Singleton @Singleton
@Provides @Provides
@Named("accessToken") static Supplier<GoogleCredentials> provideRefreshedCredentialsSupplier(
static Supplier<String> provideAccessTokenSupplier(GoogleCredentialsBundle credentialsBundle) { GoogleCredentialsBundle credentialsBundle) {
return () -> { return () -> {
GoogleCredentials credentials = credentialsBundle.getGoogleCredentials(); GoogleCredentials credentials = credentialsBundle.getGoogleCredentials();
try { try {
credentials.refreshIfExpired(); credentials.refreshIfExpired();
} catch (IOException e) { } catch (IOException e) {
throw new RuntimeException("Cannot refresh access token.", e); throw new RuntimeException("Cannot refresh credentials.", e);
} }
return credentials.getAccessToken().getTokenValue(); return credentials;
}; };
} }

View file

@ -14,6 +14,7 @@
package google.registry.proxy; package google.registry.proxy;
import com.google.auth.oauth2.GoogleCredentials;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import dagger.Module; import dagger.Module;
import dagger.Provides; import dagger.Provides;
@ -34,6 +35,7 @@ import google.registry.util.Clock;
import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandler;
import io.netty.handler.codec.LineBasedFrameDecoder; import io.netty.handler.codec.LineBasedFrameDecoder;
import io.netty.handler.timeout.ReadTimeoutHandler; import io.netty.handler.timeout.ReadTimeoutHandler;
import java.util.Optional;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import java.util.function.Supplier; import java.util.function.Supplier;
@ -91,10 +93,15 @@ public class WhoisProtocolModule {
@Provides @Provides
static WhoisServiceHandler provideWhoisServiceHandler( static WhoisServiceHandler provideWhoisServiceHandler(
ProxyConfig config, ProxyConfig config,
@Named("accessToken") Supplier<String> accessTokenSupplier, Supplier<GoogleCredentials> refreshedCredentialsSupplier,
@Named("iapClientId") Optional<String> iapClientId,
FrontendMetrics metrics) { FrontendMetrics metrics) {
return new WhoisServiceHandler( return new WhoisServiceHandler(
config.whois.relayHost, config.whois.relayPath, accessTokenSupplier, metrics); config.whois.relayHost,
config.whois.relayPath,
refreshedCredentialsSupplier,
iapClientId,
metrics);
} }
@Provides @Provides

View file

@ -8,6 +8,9 @@
# GCP project ID # GCP project ID
projectId: your-gcp-project-id projectId: your-gcp-project-id
# IAP client ID, if IAP is enabled for this project
iapClientId: null
# OAuth scope that the GoogleCredential will be constructed with. This list # OAuth scope that the GoogleCredential will be constructed with. This list
# should include all service scopes that the proxy depends on. # should include all service scopes that the proxy depends on.
gcpScopes: gcpScopes:

View file

@ -20,6 +20,7 @@ import static google.registry.networking.handler.SslServerInitializer.CLIENT_CER
import static google.registry.proxy.handler.ProxyProtocolHandler.REMOTE_ADDRESS_KEY; import static google.registry.proxy.handler.ProxyProtocolHandler.REMOTE_ADDRESS_KEY;
import static google.registry.util.X509Utils.getCertificateHash; import static google.registry.util.X509Utils.getCertificateHash;
import com.google.auth.oauth2.GoogleCredentials;
import com.google.common.flogger.FluentLogger; import com.google.common.flogger.FluentLogger;
import google.registry.proxy.metric.FrontendMetrics; import google.registry.proxy.metric.FrontendMetrics;
import google.registry.util.ProxyHttpHeaders; import google.registry.util.ProxyHttpHeaders;
@ -36,6 +37,7 @@ import io.netty.handler.ssl.SslHandshakeCompletionEvent;
import io.netty.util.AttributeKey; import io.netty.util.AttributeKey;
import io.netty.util.concurrent.Promise; import io.netty.util.concurrent.Promise;
import java.security.cert.X509Certificate; import java.security.cert.X509Certificate;
import java.util.Optional;
import java.util.function.Supplier; import java.util.function.Supplier;
/** Handler that processes EPP protocol logic. */ /** Handler that processes EPP protocol logic. */
@ -60,10 +62,11 @@ public class EppServiceHandler extends HttpsRelayServiceHandler {
public EppServiceHandler( public EppServiceHandler(
String relayHost, String relayHost,
String relayPath, String relayPath,
Supplier<String> accessTokenSupplier, Supplier<GoogleCredentials> refreshedCredentialsSupplier,
Optional<String> iapClientId,
byte[] helloBytes, byte[] helloBytes,
FrontendMetrics metrics) { FrontendMetrics metrics) {
super(relayHost, relayPath, accessTokenSupplier, metrics); super(relayHost, relayPath, refreshedCredentialsSupplier, iapClientId, metrics);
this.helloBytes = helloBytes; this.helloBytes = helloBytes;
} }

View file

@ -16,7 +16,12 @@ package google.registry.proxy.handler;
import static java.nio.charset.StandardCharsets.UTF_8; import static java.nio.charset.StandardCharsets.UTF_8;
import com.google.auth.oauth2.GoogleCredentials;
import com.google.auth.oauth2.IdToken;
import com.google.auth.oauth2.IdTokenProvider;
import com.google.auth.oauth2.IdTokenProvider.Option;
import com.google.common.base.Throwables; import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSet;
import com.google.common.flogger.FluentLogger; import com.google.common.flogger.FluentLogger;
import google.registry.proxy.metric.FrontendMetrics; import google.registry.proxy.metric.FrontendMetrics;
@ -37,9 +42,11 @@ import io.netty.handler.codec.http.cookie.ClientCookieDecoder;
import io.netty.handler.codec.http.cookie.ClientCookieEncoder; import io.netty.handler.codec.http.cookie.ClientCookieEncoder;
import io.netty.handler.codec.http.cookie.Cookie; import io.netty.handler.codec.http.cookie.Cookie;
import io.netty.handler.timeout.ReadTimeoutException; import io.netty.handler.timeout.ReadTimeoutException;
import java.io.IOException;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional;
import java.util.function.Supplier; import java.util.function.Supplier;
import javax.net.ssl.SSLHandshakeException; import javax.net.ssl.SSLHandshakeException;
@ -72,18 +79,21 @@ public abstract class HttpsRelayServiceHandler extends ByteToMessageCodec<FullHt
private final Map<String, Cookie> cookieStore = new LinkedHashMap<>(); private final Map<String, Cookie> cookieStore = new LinkedHashMap<>();
private final String relayHost; private final String relayHost;
private final String relayPath; private final String relayPath;
private final Supplier<String> accessTokenSupplier; private final Supplier<GoogleCredentials> refreshedCredentialsSupplier;
private final Optional<String> iapClientId;
protected final FrontendMetrics metrics; protected final FrontendMetrics metrics;
HttpsRelayServiceHandler( HttpsRelayServiceHandler(
String relayHost, String relayHost,
String relayPath, String relayPath,
Supplier<String> accessTokenSupplier, Supplier<GoogleCredentials> refreshedCredentialsSupplier,
Optional<String> iapClientId,
FrontendMetrics metrics) { FrontendMetrics metrics) {
this.relayHost = relayHost; this.relayHost = relayHost;
this.relayPath = relayPath; this.relayPath = relayPath;
this.accessTokenSupplier = accessTokenSupplier; this.refreshedCredentialsSupplier = refreshedCredentialsSupplier;
this.iapClientId = iapClientId;
this.metrics = metrics; this.metrics = metrics;
} }
@ -91,19 +101,37 @@ public abstract class HttpsRelayServiceHandler extends ByteToMessageCodec<FullHt
* Construct the {@link FullHttpRequest}. * Construct the {@link FullHttpRequest}.
* *
* <p>This default method creates a bare-bone {@link FullHttpRequest} that may need to be * <p>This default method creates a bare-bone {@link FullHttpRequest} that may need to be
* modified, e. g. adding headers specific for each protocol. * modified, e.g. adding headers specific for each protocol.
* *
* @param byteBuf inbound message. * @param byteBuf inbound message.
*/ */
protected FullHttpRequest decodeFullHttpRequest(ByteBuf byteBuf) { protected FullHttpRequest decodeFullHttpRequest(ByteBuf byteBuf) {
FullHttpRequest request = FullHttpRequest request =
new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, relayPath); new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, relayPath);
GoogleCredentials credentials = refreshedCredentialsSupplier.get();
request request
.headers() .headers()
.set(HttpHeaderNames.USER_AGENT, "Proxy") .set(HttpHeaderNames.USER_AGENT, "Proxy")
.set(HttpHeaderNames.HOST, relayHost) .set(HttpHeaderNames.HOST, relayHost)
.set(HttpHeaderNames.AUTHORIZATION, "Bearer " + accessTokenSupplier.get()) .set(
HttpHeaderNames.AUTHORIZATION, "Bearer " + credentials.getAccessToken().getTokenValue())
.setInt(HttpHeaderNames.CONTENT_LENGTH, byteBuf.readableBytes()); .setInt(HttpHeaderNames.CONTENT_LENGTH, byteBuf.readableBytes());
// Set the Proxy-Authorization header if using IAP
if (iapClientId.isPresent()) {
IdTokenProvider idTokenProvider = (IdTokenProvider) credentials;
try {
// Note: we use Option.FORMAT_FULL to make sure the JWT we receive contains the email
// address (as is required by IAP)
IdToken idToken =
idTokenProvider.idTokenWithAudience(
iapClientId.get(), ImmutableList.of(Option.FORMAT_FULL));
request
.headers()
.set(HttpHeaderNames.PROXY_AUTHORIZATION, "Bearer " + idToken.getTokenValue());
} catch (IOException e) {
logger.atSevere().withCause(e).log("Error when attempting to retrieve IAP ID token");
}
}
request.content().writeBytes(byteBuf); request.content().writeBytes(byteBuf);
return request; return request;
} }

View file

@ -16,6 +16,7 @@ package google.registry.proxy.handler;
import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkArgument;
import com.google.auth.oauth2.GoogleCredentials;
import google.registry.proxy.metric.FrontendMetrics; import google.registry.proxy.metric.FrontendMetrics;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelFutureListener;
@ -25,6 +26,7 @@ import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaderValues; import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.HttpResponse;
import java.util.Optional;
import java.util.function.Supplier; import java.util.function.Supplier;
/** Handler that processes WHOIS protocol logic. */ /** Handler that processes WHOIS protocol logic. */
@ -33,9 +35,10 @@ public final class WhoisServiceHandler extends HttpsRelayServiceHandler {
public WhoisServiceHandler( public WhoisServiceHandler(
String relayHost, String relayHost,
String relayPath, String relayPath,
Supplier<String> accessTokenSupplier, Supplier<GoogleCredentials> refreshedCredentialsSupplier,
Optional<String> iapClientId,
FrontendMetrics metrics) { FrontendMetrics metrics) {
super(relayHost, relayPath, accessTokenSupplier, metrics); super(relayHost, relayPath, refreshedCredentialsSupplier, iapClientId, metrics);
} }
@Override @Override

View file

@ -36,6 +36,7 @@ import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.cookie.Cookie; import io.netty.handler.codec.http.cookie.Cookie;
import io.netty.handler.codec.http.cookie.DefaultCookie; import io.netty.handler.codec.http.cookie.DefaultCookie;
import io.netty.util.concurrent.Promise; import io.netty.util.concurrent.Promise;
import java.io.IOException;
import java.security.cert.X509Certificate; import java.security.cert.X509Certificate;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@ -96,14 +97,15 @@ class EppProtocolModuleTest extends ProtocolModuleTest {
return buffer; return buffer;
} }
private FullHttpRequest makeEppHttpRequest(byte[] content, Cookie... cookies) { private FullHttpRequest makeEppHttpRequest(byte[] content, Cookie... cookies) throws IOException {
return TestUtils.makeEppHttpRequest( return TestUtils.makeEppHttpRequest(
new String(content, UTF_8), new String(content, UTF_8),
PROXY_CONFIG.epp.relayHost, PROXY_CONFIG.epp.relayHost,
PROXY_CONFIG.epp.relayPath, PROXY_CONFIG.epp.relayPath,
TestModule.provideFakeAccessToken().get(), TestModule.provideFakeCredentials().get(),
getCertificateHash(certificate), getCertificateHash(certificate),
CLIENT_ADDRESS, CLIENT_ADDRESS,
TestModule.provideIapClientId(),
cookies); cookies);
} }

View file

@ -17,7 +17,14 @@ package google.registry.proxy;
import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableList.toImmutableList;
import static google.registry.proxy.ProxyConfig.Environment.LOCAL; import static google.registry.proxy.ProxyConfig.Environment.LOCAL;
import static google.registry.proxy.ProxyConfig.getProxyConfig; import static google.registry.proxy.ProxyConfig.getProxyConfig;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.google.auth.oauth2.AccessToken;
import com.google.auth.oauth2.ComputeEngineCredentials;
import com.google.auth.oauth2.GoogleCredentials;
import com.google.auth.oauth2.IdToken;
import com.google.auth.oauth2.IdTokenProvider.Option;
import com.google.common.base.Suppliers; import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSet;
@ -52,7 +59,9 @@ import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.logging.LoggingHandler; import io.netty.handler.logging.LoggingHandler;
import io.netty.handler.ssl.SslProvider; import io.netty.handler.ssl.SslProvider;
import io.netty.handler.timeout.ReadTimeoutHandler; import io.netty.handler.timeout.ReadTimeoutHandler;
import java.io.IOException;
import java.time.Duration; import java.time.Duration;
import java.util.Optional;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
@ -223,7 +232,7 @@ public abstract class ProtocolModuleTest {
* should be provided in the respective {@code ProtocolModule} instead. * should be provided in the respective {@code ProtocolModule} instead.
*/ */
@Module @Module
static class TestModule { public static class TestModule {
/** /**
* A fake clock that is explicitly provided. Users can construct a module with a controller * A fake clock that is explicitly provided. Users can construct a module with a controller
@ -235,6 +244,12 @@ public abstract class ProtocolModuleTest {
this.fakeClock = fakeClock; this.fakeClock = fakeClock;
} }
@Provides
@Named("iapClientId")
public static Optional<String> provideIapClientId() {
return Optional.of("iapClientId");
}
@Singleton @Singleton
@Provides @Provides
static ProxyConfig provideProxyConfig() { static ProxyConfig provideProxyConfig() {
@ -249,9 +264,19 @@ public abstract class ProtocolModuleTest {
@Singleton @Singleton
@Provides @Provides
@Named("accessToken") static Supplier<GoogleCredentials> provideFakeCredentials() {
static Supplier<String> provideFakeAccessToken() { ComputeEngineCredentials mockCredentials = mock(ComputeEngineCredentials.class);
return Suppliers.ofInstance("fake.test.token"); when(mockCredentials.getAccessToken()).thenReturn(new AccessToken("fake.test.token", null));
IdToken mockIdToken = mock(IdToken.class);
when(mockIdToken.getTokenValue()).thenReturn("fake.test.id.token");
try {
when(mockCredentials.idTokenWithAudience(
"iapClientId", ImmutableList.of(Option.FORMAT_FULL)))
.thenReturn(mockIdToken);
} catch (IOException e) {
throw new RuntimeException(e);
}
return Suppliers.ofInstance(mockCredentials);
} }
@Singleton @Singleton

View file

@ -17,6 +17,10 @@ package google.registry.proxy;
import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertThat;
import static java.nio.charset.StandardCharsets.US_ASCII; import static java.nio.charset.StandardCharsets.US_ASCII;
import com.google.auth.oauth2.GoogleCredentials;
import com.google.auth.oauth2.IdTokenProvider;
import com.google.auth.oauth2.IdTokenProvider.Option;
import com.google.common.collect.ImmutableList;
import google.registry.util.ProxyHttpHeaders; import google.registry.util.ProxyHttpHeaders;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled; import io.netty.buffer.Unpooled;
@ -34,6 +38,8 @@ import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.codec.http.cookie.ClientCookieEncoder; import io.netty.handler.codec.http.cookie.ClientCookieEncoder;
import io.netty.handler.codec.http.cookie.Cookie; import io.netty.handler.codec.http.cookie.Cookie;
import io.netty.handler.codec.http.cookie.ServerCookieEncoder; import io.netty.handler.codec.http.cookie.ServerCookieEncoder;
import java.io.IOException;
import java.util.Optional;
/** Utility class for various helper methods used in testing. */ /** Utility class for various helper methods used in testing. */
public class TestUtils { public class TestUtils {
@ -71,13 +77,19 @@ public class TestUtils {
} }
public static FullHttpRequest makeWhoisHttpRequest( public static FullHttpRequest makeWhoisHttpRequest(
String content, String host, String path, String accessToken) { String content,
String host,
String path,
GoogleCredentials credentials,
Optional<String> iapClientId)
throws IOException {
FullHttpRequest request = makeHttpPostRequest(content, host, path); FullHttpRequest request = makeHttpPostRequest(content, host, path);
request request
.headers() .headers()
.set("authorization", "Bearer " + accessToken) .set("authorization", "Bearer " + credentials.getAccessToken().getTokenValue())
.set(HttpHeaderNames.CONTENT_TYPE, "text/plain") .set(HttpHeaderNames.CONTENT_TYPE, "text/plain")
.set("accept", "text/plain"); .set("accept", "text/plain");
maybeSetProxyAuthForIap(request, credentials, iapClientId);
return request; return request;
} }
@ -85,18 +97,21 @@ public class TestUtils {
String content, String content,
String host, String host,
String path, String path,
String accessToken, GoogleCredentials credentials,
String sslClientCertificateHash, String sslClientCertificateHash,
String clientAddress, String clientAddress,
Cookie... cookies) { Optional<String> iapClientId,
Cookie... cookies)
throws IOException {
FullHttpRequest request = makeHttpPostRequest(content, host, path); FullHttpRequest request = makeHttpPostRequest(content, host, path);
request request
.headers() .headers()
.set("authorization", "Bearer " + accessToken) .set("authorization", "Bearer " + credentials.getAccessToken().getTokenValue())
.set(HttpHeaderNames.CONTENT_TYPE, "application/epp+xml") .set(HttpHeaderNames.CONTENT_TYPE, "application/epp+xml")
.set("accept", "application/epp+xml") .set("accept", "application/epp+xml")
.set(ProxyHttpHeaders.CERTIFICATE_HASH, sslClientCertificateHash) .set(ProxyHttpHeaders.CERTIFICATE_HASH, sslClientCertificateHash)
.set(ProxyHttpHeaders.IP_ADDRESS, clientAddress); .set(ProxyHttpHeaders.IP_ADDRESS, clientAddress);
maybeSetProxyAuthForIap(request, credentials, iapClientId);
if (cookies.length != 0) { if (cookies.length != 0) {
request.headers().set("cookie", ClientCookieEncoder.STRICT.encode(cookies)); request.headers().set("cookie", ClientCookieEncoder.STRICT.encode(cookies));
} }
@ -146,4 +161,16 @@ public class TestUtils {
public static void assertHttpRequestEquivalent(HttpRequest req1, HttpRequest req2) { public static void assertHttpRequestEquivalent(HttpRequest req1, HttpRequest req2) {
assertHttpMessageEquivalent(req1, req2); assertHttpMessageEquivalent(req1, req2);
} }
private static void maybeSetProxyAuthForIap(
FullHttpRequest request, GoogleCredentials credentials, Optional<String> iapClientId)
throws IOException {
if (iapClientId.isPresent()) {
String idTokenValue =
((IdTokenProvider) credentials)
.idTokenWithAudience(iapClientId.get(), ImmutableList.of(Option.FORMAT_FULL))
.getTokenValue();
request.headers().set("proxy-authorization", "Bearer " + idTokenValue);
}
}
} }

View file

@ -41,7 +41,7 @@ class WhoisProtocolModuleTest extends ProtocolModuleTest {
} }
@Test @Test
void testSuccess_singleFrameInboundMessage() { void testSuccess_singleFrameInboundMessage() throws Exception {
String inputString = "test.tld\r\n"; String inputString = "test.tld\r\n";
// Inbound message processed and passed along. // Inbound message processed and passed along.
assertThat(channel.writeInbound(Unpooled.wrappedBuffer(inputString.getBytes(US_ASCII)))) assertThat(channel.writeInbound(Unpooled.wrappedBuffer(inputString.getBytes(US_ASCII))))
@ -53,7 +53,8 @@ class WhoisProtocolModuleTest extends ProtocolModuleTest {
"test.tld", "test.tld",
PROXY_CONFIG.whois.relayHost, PROXY_CONFIG.whois.relayHost,
PROXY_CONFIG.whois.relayPath, PROXY_CONFIG.whois.relayPath,
TestModule.provideFakeAccessToken().get()); TestModule.provideFakeCredentials().get(),
TestModule.provideIapClientId());
assertThat(actualRequest).isEqualTo(expectedRequest); assertThat(actualRequest).isEqualTo(expectedRequest);
assertThat(channel.isActive()).isTrue(); assertThat(channel.isActive()).isTrue();
// Nothing more to read. // Nothing more to read.
@ -70,7 +71,7 @@ class WhoisProtocolModuleTest extends ProtocolModuleTest {
} }
@Test @Test
void testSuccess_multiFrameInboundMessage() { void testSuccess_multiFrameInboundMessage() throws Exception {
String frame1 = "test"; String frame1 = "test";
String frame2 = "1.tld"; String frame2 = "1.tld";
String frame3 = "\r\nte"; String frame3 = "\r\nte";
@ -88,7 +89,8 @@ class WhoisProtocolModuleTest extends ProtocolModuleTest {
"test1.tld", "test1.tld",
PROXY_CONFIG.whois.relayHost, PROXY_CONFIG.whois.relayHost,
PROXY_CONFIG.whois.relayPath, PROXY_CONFIG.whois.relayPath,
TestModule.provideFakeAccessToken().get()); TestModule.provideFakeCredentials().get(),
TestModule.provideIapClientId());
assertThat(actualRequest1).isEqualTo(expectedRequest1); assertThat(actualRequest1).isEqualTo(expectedRequest1);
// No more message at this point. // No more message at this point.
assertThat((Object) channel.readInbound()).isNull(); assertThat((Object) channel.readInbound()).isNull();
@ -102,7 +104,8 @@ class WhoisProtocolModuleTest extends ProtocolModuleTest {
"test2.tld", "test2.tld",
PROXY_CONFIG.whois.relayHost, PROXY_CONFIG.whois.relayHost,
PROXY_CONFIG.whois.relayPath, PROXY_CONFIG.whois.relayPath,
TestModule.provideFakeAccessToken().get()); TestModule.provideFakeCredentials().get(),
TestModule.provideIapClientId());
assertThat(actualRequest2).isEqualTo(expectedRequest2); assertThat(actualRequest2).isEqualTo(expectedRequest2);
// The third message is not complete yet. // The third message is not complete yet.
assertThat(channel.isActive()).isTrue(); assertThat(channel.isActive()).isTrue();

View file

@ -25,8 +25,14 @@ import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import com.google.auth.oauth2.AccessToken;
import com.google.auth.oauth2.ComputeEngineCredentials;
import com.google.auth.oauth2.IdToken;
import com.google.auth.oauth2.IdTokenProvider.Option;
import com.google.common.base.Throwables; import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList;
import google.registry.proxy.TestUtils; import google.registry.proxy.TestUtils;
import google.registry.proxy.handler.HttpsRelayServiceHandler.NonOkHttpResponseException; import google.registry.proxy.handler.HttpsRelayServiceHandler.NonOkHttpResponseException;
import google.registry.proxy.metric.FrontendMetrics; import google.registry.proxy.metric.FrontendMetrics;
@ -44,7 +50,9 @@ import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.cookie.Cookie; import io.netty.handler.codec.http.cookie.Cookie;
import io.netty.handler.codec.http.cookie.DefaultCookie; import io.netty.handler.codec.http.cookie.DefaultCookie;
import io.netty.util.concurrent.Promise; import io.netty.util.concurrent.Promise;
import java.io.IOException;
import java.security.cert.X509Certificate; import java.security.cert.X509Certificate;
import java.util.Optional;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@ -59,9 +67,12 @@ class EppServiceHandlerTest {
private static final String RELAY_HOST = "registry.example.tld"; private static final String RELAY_HOST = "registry.example.tld";
private static final String RELAY_PATH = "/epp"; private static final String RELAY_PATH = "/epp";
private static final String ACCESS_TOKEN = "this.access.token";
private static final String CLIENT_ADDRESS = "epp.client.tld"; private static final String CLIENT_ADDRESS = "epp.client.tld";
private static final String PROTOCOL = "epp"; private static final String PROTOCOL = "epp";
private static final String IAP_CLIENT_ID = "iapClientId";
private static final ComputeEngineCredentials mockCredentials =
mock(ComputeEngineCredentials.class);
private X509Certificate clientCertificate; private X509Certificate clientCertificate;
@ -69,7 +80,12 @@ class EppServiceHandlerTest {
private final EppServiceHandler eppServiceHandler = private final EppServiceHandler eppServiceHandler =
new EppServiceHandler( new EppServiceHandler(
RELAY_HOST, RELAY_PATH, () -> ACCESS_TOKEN, HELLO.getBytes(UTF_8), metrics); RELAY_HOST,
RELAY_PATH,
() -> mockCredentials,
Optional.of(IAP_CLIENT_ID),
HELLO.getBytes(UTF_8),
metrics);
private EmbeddedChannel channel; private EmbeddedChannel channel;
@ -79,7 +95,7 @@ class EppServiceHandlerTest {
channel.attr(CLIENT_CERTIFICATE_PROMISE_KEY).get().setSuccess(certificate); channel.attr(CLIENT_CERTIFICATE_PROMISE_KEY).get().setSuccess(certificate);
} }
private void setHandshakeSuccess() throws Exception { private void setHandshakeSuccess() {
setHandshakeSuccess(channel, clientCertificate); setHandshakeSuccess(channel, clientCertificate);
} }
@ -91,23 +107,29 @@ class EppServiceHandlerTest {
.setFailure(new Exception("Handshake Failure")); .setFailure(new Exception("Handshake Failure"));
} }
private void setHandshakeFailure() throws Exception { private void setHandshakeFailure() {
setHandshakeFailure(channel); setHandshakeFailure(channel);
} }
private FullHttpRequest makeEppHttpRequest(String content, Cookie... cookies) { private FullHttpRequest makeEppHttpRequest(String content, Cookie... cookies) throws IOException {
return TestUtils.makeEppHttpRequest( return TestUtils.makeEppHttpRequest(
content, content,
RELAY_HOST, RELAY_HOST,
RELAY_PATH, RELAY_PATH,
ACCESS_TOKEN, mockCredentials,
getCertificateHash(clientCertificate), getCertificateHash(clientCertificate),
CLIENT_ADDRESS, CLIENT_ADDRESS,
Optional.of(IAP_CLIENT_ID),
cookies); cookies);
} }
@BeforeEach @BeforeEach
void beforeEach() throws Exception { void beforeEach() throws Exception {
when(mockCredentials.getAccessToken()).thenReturn(new AccessToken("this.access.token", null));
IdToken mockIdToken = mock(IdToken.class);
when(mockIdToken.getTokenValue()).thenReturn("fake.test.id.token");
when(mockCredentials.idTokenWithAudience(IAP_CLIENT_ID, ImmutableList.of(Option.FORMAT_FULL)))
.thenReturn(mockIdToken);
clientCertificate = SelfSignedCaCertificate.create().cert(); clientCertificate = SelfSignedCaCertificate.create().cert();
channel = setUpNewChannel(eppServiceHandler); channel = setUpNewChannel(eppServiceHandler);
} }
@ -140,10 +162,15 @@ class EppServiceHandlerTest {
String certHash = getCertificateHash(clientCertificate); String certHash = getCertificateHash(clientCertificate);
assertThat(channel.isActive()).isTrue(); assertThat(channel.isActive()).isTrue();
// Setup the second channel. // Set up the second channel.
EppServiceHandler eppServiceHandler2 = EppServiceHandler eppServiceHandler2 =
new EppServiceHandler( new EppServiceHandler(
RELAY_HOST, RELAY_PATH, () -> ACCESS_TOKEN, HELLO.getBytes(UTF_8), metrics); RELAY_HOST,
RELAY_PATH,
() -> mockCredentials,
Optional.empty(),
HELLO.getBytes(UTF_8),
metrics);
EmbeddedChannel channel2 = setUpNewChannel(eppServiceHandler2); EmbeddedChannel channel2 = setUpNewChannel(eppServiceHandler2);
setHandshakeSuccess(channel2, clientCertificate); setHandshakeSuccess(channel2, clientCertificate);
@ -160,10 +187,15 @@ class EppServiceHandlerTest {
String certHash = getCertificateHash(clientCertificate); String certHash = getCertificateHash(clientCertificate);
assertThat(channel.isActive()).isTrue(); assertThat(channel.isActive()).isTrue();
// Setup the second channel. // Set up the second channel.
EppServiceHandler eppServiceHandler2 = EppServiceHandler eppServiceHandler2 =
new EppServiceHandler( new EppServiceHandler(
RELAY_HOST, RELAY_PATH, () -> ACCESS_TOKEN, HELLO.getBytes(UTF_8), metrics); RELAY_HOST,
RELAY_PATH,
() -> mockCredentials,
Optional.empty(),
HELLO.getBytes(UTF_8),
metrics);
EmbeddedChannel channel2 = setUpNewChannel(eppServiceHandler2); EmbeddedChannel channel2 = setUpNewChannel(eppServiceHandler2);
X509Certificate clientCertificate2 = SelfSignedCaCertificate.create().cert(); X509Certificate clientCertificate2 = SelfSignedCaCertificate.create().cert();
setHandshakeSuccess(channel2, clientCertificate2); setHandshakeSuccess(channel2, clientCertificate2);
@ -326,4 +358,38 @@ class EppServiceHandlerTest {
assertThat((Object) channel.readOutbound()).isNull(); assertThat((Object) channel.readOutbound()).isNull();
assertThat(channel.isActive()).isTrue(); assertThat(channel.isActive()).isTrue();
} }
@Test
void testSuccess_withoutIapClientId() throws Exception {
// Without an IAP client ID configured, we shouldn't include the proxy-authorization header
EppServiceHandler nonIapServiceHandler =
new EppServiceHandler(
RELAY_HOST,
RELAY_PATH,
() -> mockCredentials,
Optional.empty(),
HELLO.getBytes(UTF_8),
metrics);
channel = setUpNewChannel(nonIapServiceHandler);
setHandshakeSuccess();
// First inbound message is hello.
channel.readInbound();
String content = "<epp>stuff</epp>";
channel.writeInbound(Unpooled.wrappedBuffer(content.getBytes(UTF_8)));
FullHttpRequest request = channel.readInbound();
assertThat(request)
.isEqualTo(
TestUtils.makeEppHttpRequest(
content,
RELAY_HOST,
RELAY_PATH,
mockCredentials,
getCertificateHash(clientCertificate),
CLIENT_ADDRESS,
Optional.empty()));
// Nothing further to pass to the next handler.
assertThat((Object) channel.readInbound()).isNull();
assertThat(channel.isActive()).isTrue();
}
} }

View file

@ -22,8 +22,14 @@ import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import com.google.auth.oauth2.AccessToken;
import com.google.auth.oauth2.ComputeEngineCredentials;
import com.google.auth.oauth2.IdToken;
import com.google.auth.oauth2.IdTokenProvider.Option;
import com.google.common.base.Throwables; import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList;
import google.registry.proxy.handler.HttpsRelayServiceHandler.NonOkHttpResponseException; import google.registry.proxy.handler.HttpsRelayServiceHandler.NonOkHttpResponseException;
import google.registry.proxy.metric.FrontendMetrics; import google.registry.proxy.metric.FrontendMetrics;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
@ -34,6 +40,7 @@ import io.netty.handler.codec.EncoderException;
import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http.HttpResponseStatus;
import java.util.Optional;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@ -43,18 +50,26 @@ class WhoisServiceHandlerTest {
private static final String RELAY_HOST = "www.example.tld"; private static final String RELAY_HOST = "www.example.tld";
private static final String RELAY_PATH = "/test"; private static final String RELAY_PATH = "/test";
private static final String QUERY_CONTENT = "test.tld"; private static final String QUERY_CONTENT = "test.tld";
private static final String ACCESS_TOKEN = "this.access.token";
private static final String PROTOCOL = "whois"; private static final String PROTOCOL = "whois";
private static final String CLIENT_HASH = "none"; private static final String CLIENT_HASH = "none";
private static final String IAP_CLIENT_ID = "iapClientId";
private static final ComputeEngineCredentials mockCredentials =
mock(ComputeEngineCredentials.class);
private final FrontendMetrics metrics = mock(FrontendMetrics.class); private final FrontendMetrics metrics = mock(FrontendMetrics.class);
private final WhoisServiceHandler whoisServiceHandler = private final WhoisServiceHandler whoisServiceHandler =
new WhoisServiceHandler(RELAY_HOST, RELAY_PATH, () -> ACCESS_TOKEN, metrics); new WhoisServiceHandler(
RELAY_HOST, RELAY_PATH, () -> mockCredentials, Optional.of(IAP_CLIENT_ID), metrics);
private EmbeddedChannel channel; private EmbeddedChannel channel;
@BeforeEach @BeforeEach
void beforeEach() { void beforeEach() throws Exception {
when(mockCredentials.getAccessToken()).thenReturn(new AccessToken("this.access.token", null));
IdToken mockIdToken = mock(IdToken.class);
when(mockIdToken.getTokenValue()).thenReturn("fake.test.id.token");
when(mockCredentials.idTokenWithAudience(IAP_CLIENT_ID, ImmutableList.of(Option.FORMAT_FULL)))
.thenReturn(mockIdToken);
// Need to reset metrics for each test method, since they are static fields on the class and // Need to reset metrics for each test method, since they are static fields on the class and
// shared between each run. // shared between each run.
channel = new EmbeddedChannel(whoisServiceHandler); channel = new EmbeddedChannel(whoisServiceHandler);
@ -74,7 +89,8 @@ class WhoisServiceHandlerTest {
// Setup second channel. // Setup second channel.
WhoisServiceHandler whoisServiceHandler2 = WhoisServiceHandler whoisServiceHandler2 =
new WhoisServiceHandler(RELAY_HOST, RELAY_PATH, () -> ACCESS_TOKEN, metrics); new WhoisServiceHandler(
RELAY_HOST, RELAY_PATH, () -> mockCredentials, Optional.empty(), metrics);
EmbeddedChannel channel2 = EmbeddedChannel channel2 =
// We need a new channel id so that it has a different hash code. // We need a new channel id so that it has a different hash code.
// This only is needed for EmbeddedChannel because it has a dummy hash code implementation. // This only is needed for EmbeddedChannel because it has a dummy hash code implementation.
@ -85,10 +101,11 @@ class WhoisServiceHandlerTest {
} }
@Test @Test
void testSuccess_fireInboundHttpRequest() { void testSuccess_fireInboundHttpRequest() throws Exception {
ByteBuf inputBuffer = Unpooled.wrappedBuffer(QUERY_CONTENT.getBytes(US_ASCII)); ByteBuf inputBuffer = Unpooled.wrappedBuffer(QUERY_CONTENT.getBytes(US_ASCII));
FullHttpRequest expectedRequest = FullHttpRequest expectedRequest =
makeWhoisHttpRequest(QUERY_CONTENT, RELAY_HOST, RELAY_PATH, ACCESS_TOKEN); makeWhoisHttpRequest(
QUERY_CONTENT, RELAY_HOST, RELAY_PATH, mockCredentials, Optional.of(IAP_CLIENT_ID));
// Input data passed to next handler // Input data passed to next handler
assertThat(channel.writeInbound(inputBuffer)).isTrue(); assertThat(channel.writeInbound(inputBuffer)).isTrue();
FullHttpRequest inputRequest = channel.readInbound(); FullHttpRequest inputRequest = channel.readInbound();
@ -111,6 +128,27 @@ class WhoisServiceHandlerTest {
assertThat(channel.isActive()).isFalse(); assertThat(channel.isActive()).isFalse();
} }
@Test
void testSuccess_withoutIapClientId() throws Exception {
// Without an IAP client ID configured, we shouldn't include the proxy-authorization header
WhoisServiceHandler nonIapHandler =
new WhoisServiceHandler(
RELAY_HOST, RELAY_PATH, () -> mockCredentials, Optional.empty(), metrics);
channel = new EmbeddedChannel(nonIapHandler);
ByteBuf inputBuffer = Unpooled.wrappedBuffer(QUERY_CONTENT.getBytes(US_ASCII));
FullHttpRequest expectedRequest =
makeWhoisHttpRequest(
QUERY_CONTENT, RELAY_HOST, RELAY_PATH, mockCredentials, Optional.empty());
// Input data passed to next handler
assertThat(channel.writeInbound(inputBuffer)).isTrue();
FullHttpRequest inputRequest = channel.readInbound();
assertThat(inputRequest).isEqualTo(expectedRequest);
// The channel is still open, and nothing else is to be read from it.
assertThat((Object) channel.readInbound()).isNull();
assertThat(channel.isActive()).isTrue();
}
@Test @Test
void testFailure_OutboundHttpResponseNotOK() { void testFailure_OutboundHttpResponseNotOK() {
String outputString = "line1\r\nline2\r\n"; String outputString = "line1\r\nline2\r\n";