Refine tests in GCP proxy

Previously the ssl initializer tests always uses JDK, which is not really testing what happens in production when we take advantage of the OpenSSL provider. Now the tests will run with all providers that are available (through JUnit parameterization). Some bugs that may cause flakiness are fixed in the process.

Change how SNI is verified in tests. It turns out that the old method (only verifying the SSL parameters in the SSL engine) does not actually ensure that the SNI address is sent to the peer, but only that the SSL engine is configured to send it (this value exists even before a handshake is performed). Also there's likely a bug in Netty's SSL engine that does not set this parameter when created with a peer host.

Lastly HTTP test utils are changed so that they do not use pre-defined constants for header names and values. We want the test to confirm that these constants are what we expect they are. Using string literals makes these tests also more explicit.

-------------
Created by MOE: https://github.com/google/moe
MOE_MIGRATED_REVID=207930282
This commit is contained in:
jianglai 2018-08-08 12:44:36 -07:00
parent d80f431e21
commit 9eec70729f
8 changed files with 120 additions and 136 deletions

View file

@ -17,6 +17,9 @@ java_library(
"resources/*", "resources/*",
"config/*.yaml", "config/*.yaml",
]), ]),
runtime_deps = [
"@io_netty_tcnative",
],
deps = [ deps = [
"//java/google/registry/config", "//java/google/registry/config",
"//java/google/registry/util", "//java/google/registry/util",
@ -53,7 +56,6 @@ java_binary(
main_class = "google.registry.proxy.ProxyServer", main_class = "google.registry.proxy.ProxyServer",
runtime_deps = [ runtime_deps = [
":proxy", ":proxy",
"@io_netty_tcnative",
], ],
) )

View file

@ -39,6 +39,7 @@ import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.function.Supplier; import java.util.function.Supplier;
import javax.net.ssl.SSLHandshakeException;
/** /**
* Handler that relays a single (framed) ByteBuf message to an HTTPS server. * Handler that relays a single (framed) ByteBuf message to an HTTPS server.
@ -168,7 +169,11 @@ abstract class HttpsRelayServiceHandler extends ByteToMessageCodec<FullHttpRespo
// IllegalArgumentException is thrown by the checkArgument in the #encode command, it just means // IllegalArgumentException is thrown by the checkArgument in the #encode command, it just means
// that GAE returns a non-200 response and the connection should be killed. The request is still // that GAE returns a non-200 response and the connection should be killed. The request is still
// processed by GAE, so this is not an unexpected behavior. // processed by GAE, so this is not an unexpected behavior.
if (cause instanceof ReadTimeoutException || cause instanceof IllegalArgumentException) { // SslHandshakeException is caused by the client not able to complete the handshake, we should
// not log it at error as we do not control client behavior.
if (cause instanceof ReadTimeoutException
|| cause instanceof IllegalArgumentException
|| cause instanceof SSLHandshakeException) {
logger.atWarning().withCause(cause).log( logger.atWarning().withCause(cause).log(
"Inbound exception caught for channel %s", ctx.channel()); "Inbound exception caught for channel %s", ctx.channel());
} else { } else {

View file

@ -17,8 +17,8 @@ package google.registry.proxy.handler;
import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkNotNull;
import static google.registry.proxy.Protocol.PROTOCOL_KEY; import static google.registry.proxy.Protocol.PROTOCOL_KEY;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.flogger.FluentLogger; import com.google.common.flogger.FluentLogger;
import google.registry.proxy.HttpsRelayProtocolModule.HttpsRelayProtocol;
import google.registry.proxy.Protocol.BackendProtocol; import google.registry.proxy.Protocol.BackendProtocol;
import io.netty.channel.Channel; import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler.Sharable; import io.netty.channel.ChannelHandler.Sharable;
@ -28,7 +28,6 @@ import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslHandler; import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.SslProvider; import io.netty.handler.ssl.SslProvider;
import java.security.cert.X509Certificate; import java.security.cert.X509Certificate;
import javax.annotation.Nullable;
import javax.inject.Inject; import javax.inject.Inject;
import javax.inject.Singleton; import javax.inject.Singleton;
import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngine;
@ -46,12 +45,18 @@ import javax.net.ssl.SSLParameters;
public class SslClientInitializer<C extends Channel> extends ChannelInitializer<C> { public class SslClientInitializer<C extends Channel> extends ChannelInitializer<C> {
private static final FluentLogger logger = FluentLogger.forEnclosingClass(); private static final FluentLogger logger = FluentLogger.forEnclosingClass();
private final SslProvider sslProvider; private final SslProvider sslProvider;
private final X509Certificate[] trustedCertificates; private final X509Certificate[] trustedCertificates;
@Inject @Inject
SslClientInitializer( public SslClientInitializer(SslProvider sslProvider) {
SslProvider sslProvider, @Nullable @HttpsRelayProtocol X509Certificate... trustCertificates) { // null uses the system default trust store.
this(sslProvider, null);
}
@VisibleForTesting
SslClientInitializer(SslProvider sslProvider, X509Certificate[] trustCertificates) {
logger.atInfo().log("Client SSL Provider: %s", sslProvider); logger.atInfo().log("Client SSL Provider: %s", sslProvider);
this.sslProvider = sslProvider; this.sslProvider = sslProvider;
this.trustedCertificates = trustCertificates; this.trustedCertificates = trustCertificates;

View file

@ -15,10 +15,6 @@
package google.registry.proxy; package google.registry.proxy;
import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertThat;
import static google.registry.proxy.handler.EppServiceHandler.EPP_CONTENT_TYPE;
import static google.registry.proxy.handler.EppServiceHandler.FORWARDED_FOR_FIELD;
import static google.registry.proxy.handler.EppServiceHandler.REQUESTED_SERVERNAME_VIA_SNI_FIELD;
import static google.registry.proxy.handler.EppServiceHandler.SSL_CLIENT_CERTIFICATE_HASH_FIELD;
import static java.nio.charset.StandardCharsets.US_ASCII; import static java.nio.charset.StandardCharsets.US_ASCII;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
@ -28,8 +24,6 @@ import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.FullHttpMessage; import io.netty.handler.codec.http.FullHttpMessage;
import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.HttpMessage; import io.netty.handler.codec.http.HttpMessage;
import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.HttpRequest;
@ -48,29 +42,29 @@ public class TestUtils {
new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, path, buf); new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, path, buf);
request request
.headers() .headers()
.set(HttpHeaderNames.USER_AGENT, "Proxy") .set("user-agent", "Proxy")
.set(HttpHeaderNames.HOST, host) .set("host", host)
.setInt(HttpHeaderNames.CONTENT_LENGTH, buf.readableBytes()); .setInt("content-length", buf.readableBytes());
return request; return request;
} }
public static FullHttpRequest makeHttpGetRequest(String host, String path) { public static FullHttpRequest makeHttpGetRequest(String host, String path) {
FullHttpRequest request = FullHttpRequest request =
new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, path); new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, path);
request.headers().set(HttpHeaderNames.HOST, host).setInt(HttpHeaderNames.CONTENT_LENGTH, 0); request.headers().set("host", host).setInt("content-length", 0);
return request; return request;
} }
public static FullHttpResponse makeHttpResponse(String content, HttpResponseStatus status) { public static FullHttpResponse makeHttpResponse(String content, HttpResponseStatus status) {
ByteBuf buf = Unpooled.wrappedBuffer(content.getBytes(US_ASCII)); ByteBuf buf = Unpooled.wrappedBuffer(content.getBytes(US_ASCII));
FullHttpResponse response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status, buf); FullHttpResponse response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status, buf);
response.headers().setInt(HttpHeaderNames.CONTENT_LENGTH, buf.readableBytes()); response.headers().setInt("content-length", buf.readableBytes());
return response; return response;
} }
public static FullHttpResponse makeHttpResponse(HttpResponseStatus status) { public static FullHttpResponse makeHttpResponse(HttpResponseStatus status) {
FullHttpResponse response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status); FullHttpResponse response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status);
response.headers().setInt(HttpHeaderNames.CONTENT_LENGTH, 0); response.headers().setInt("content-length", 0);
return response; return response;
} }
@ -79,9 +73,9 @@ public class TestUtils {
FullHttpRequest request = makeHttpPostRequest(content, host, path); FullHttpRequest request = makeHttpPostRequest(content, host, path);
request request
.headers() .headers()
.set(HttpHeaderNames.AUTHORIZATION, "Bearer " + accessToken) .set("authorization", "Bearer " + accessToken)
.set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.TEXT_PLAIN) .set("content-type", "text/plain")
.set(HttpHeaderNames.ACCEPT, HttpHeaderValues.TEXT_PLAIN); .set("accept", "text/plain");
return request; return request;
} }
@ -97,30 +91,30 @@ public class TestUtils {
FullHttpRequest request = makeHttpPostRequest(content, host, path); FullHttpRequest request = makeHttpPostRequest(content, host, path);
request request
.headers() .headers()
.set(HttpHeaderNames.AUTHORIZATION, "Bearer " + accessToken) .set("authorization", "Bearer " + accessToken)
.set(HttpHeaderNames.CONTENT_TYPE, EPP_CONTENT_TYPE) .set("content-type", "application/epp+xml")
.set(HttpHeaderNames.ACCEPT, EPP_CONTENT_TYPE) .set("accept", "application/epp+xml")
.set(SSL_CLIENT_CERTIFICATE_HASH_FIELD, sslClientCertificateHash) .set("X-SSL-Certificate", sslClientCertificateHash)
.set(REQUESTED_SERVERNAME_VIA_SNI_FIELD, serverHostname) .set("X-Requested-Servername-SNI", serverHostname)
.set(FORWARDED_FOR_FIELD, clientAddress); .set("X-Forwarded-For", clientAddress);
if (cookies.length != 0) { if (cookies.length != 0) {
request.headers().set(HttpHeaderNames.COOKIE, ClientCookieEncoder.STRICT.encode(cookies)); request.headers().set("cookie", ClientCookieEncoder.STRICT.encode(cookies));
} }
return request; return request;
} }
public static FullHttpResponse makeWhoisHttpResponse(String content, HttpResponseStatus status) { public static FullHttpResponse makeWhoisHttpResponse(String content, HttpResponseStatus status) {
FullHttpResponse response = makeHttpResponse(content, status); FullHttpResponse response = makeHttpResponse(content, status);
response.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.TEXT_PLAIN); response.headers().set("content-type", "text/plain");
return response; return response;
} }
public static FullHttpResponse makeEppHttpResponse( public static FullHttpResponse makeEppHttpResponse(
String content, HttpResponseStatus status, Cookie... cookies) { String content, HttpResponseStatus status, Cookie... cookies) {
FullHttpResponse response = makeHttpResponse(content, status); FullHttpResponse response = makeHttpResponse(content, status);
response.headers().set(HttpHeaderNames.CONTENT_TYPE, EPP_CONTENT_TYPE); response.headers().set("content-type", "application/epp+xml");
for (Cookie cookie : cookies) { for (Cookie cookie : cookies) {
response.headers().add(HttpHeaderNames.SET_COOKIE, ServerCookieEncoder.STRICT.encode(cookie)); response.headers().add("set-cookie", ServerCookieEncoder.STRICT.encode(cookie));
} }
return response; return response;
} }

View file

@ -22,6 +22,7 @@ import static google.registry.proxy.handler.SslInitializerTestUtils.setUpServer;
import static google.registry.proxy.handler.SslInitializerTestUtils.signKeyPair; import static google.registry.proxy.handler.SslInitializerTestUtils.signKeyPair;
import static google.registry.proxy.handler.SslInitializerTestUtils.verifySslChannel; import static google.registry.proxy.handler.SslInitializerTestUtils.verifySslChannel;
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import google.registry.proxy.Protocol; import google.registry.proxy.Protocol;
import google.registry.proxy.Protocol.BackendProtocol; import google.registry.proxy.Protocol.BackendProtocol;
@ -37,7 +38,8 @@ import io.netty.channel.EventLoopGroup;
import io.netty.channel.embedded.EmbeddedChannel; import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.local.LocalAddress; import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel; import io.netty.channel.local.LocalChannel;
import io.netty.handler.codec.DecoderException; import io.netty.handler.ssl.OpenSsl;
import io.netty.handler.ssl.SniHandler;
import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder; import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslHandler; import io.netty.handler.ssl.SslHandler;
@ -51,10 +53,12 @@ import java.security.cert.X509Certificate;
import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock; import java.util.concurrent.locks.ReentrantLock;
import javax.net.ssl.SSLException; import javax.net.ssl.SSLException;
import javax.net.ssl.SSLHandshakeException;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.junit.runners.JUnit4; import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameter;
import org.junit.runners.Parameterized.Parameters;
import sun.security.provider.certpath.SunCertPathBuilderException;
/** /**
* Unit tests for {@link SslClientInitializer}. * Unit tests for {@link SslClientInitializer}.
@ -67,7 +71,7 @@ import org.junit.runners.JUnit4;
* <p>The local addresses used in each test method must to be different, otherwise tests run in * <p>The local addresses used in each test method must to be different, otherwise tests run in
* parallel may interfere with each other. * parallel may interfere with each other.
*/ */
@RunWith(JUnit4.class) @RunWith(Parameterized.class)
public class SslClientInitializerTest { public class SslClientInitializerTest {
/** Fake host to test if the SSL engine gets the correct peer host. */ /** Fake host to test if the SSL engine gets the correct peer host. */
@ -76,6 +80,18 @@ public class SslClientInitializerTest {
/** Fake port to test if the SSL engine gets the correct peer port. */ /** Fake port to test if the SSL engine gets the correct peer port. */
private static final int SSL_PORT = 12345; private static final int SSL_PORT = 12345;
@Parameter(0)
public SslProvider sslProvider;
@Parameters(name = "{0}")
public static SslProvider[] data() {
return OpenSsl.isAvailable()
? new SslProvider[] {SslProvider.JDK, SslProvider.OPENSSL}
: new SslProvider[] {SslProvider.JDK};
}
private String sniHostReceived;
/** Fake protocol saved in channel attribute. */ /** Fake protocol saved in channel attribute. */
private static final BackendProtocol PROTOCOL = private static final BackendProtocol PROTOCOL =
Protocol.backendBuilder() Protocol.backendBuilder()
@ -97,7 +113,12 @@ public class SslClientInitializerTest {
protected void initChannel(LocalChannel ch) throws Exception { protected void initChannel(LocalChannel ch) throws Exception {
ch.pipeline() ch.pipeline()
.addLast( .addLast(
sslContext.newHandler(ch.alloc()), new EchoHandler(serverLock, serverException)); new SniHandler(
hostname -> {
sniHostReceived = hostname;
return sslContext;
}),
new EchoHandler(serverLock, serverException));
} }
}; };
} }
@ -119,7 +140,7 @@ public class SslClientInitializerTest {
@Test @Test
public void testSuccess_swappedInitializerWithSslHandler() throws Exception { public void testSuccess_swappedInitializerWithSslHandler() throws Exception {
SslClientInitializer<EmbeddedChannel> sslClientInitializer = SslClientInitializer<EmbeddedChannel> sslClientInitializer =
new SslClientInitializer<>(SslProvider.JDK, (X509Certificate[]) null); new SslClientInitializer<>(sslProvider);
EmbeddedChannel channel = new EmbeddedChannel(); EmbeddedChannel channel = new EmbeddedChannel();
channel.attr(PROTOCOL_KEY).set(PROTOCOL); channel.attr(PROTOCOL_KEY).set(PROTOCOL);
ChannelPipeline pipeline = channel.pipeline(); ChannelPipeline pipeline = channel.pipeline();
@ -135,7 +156,7 @@ public class SslClientInitializerTest {
@Test @Test
public void testSuccess_protocolAttributeNotSet() { public void testSuccess_protocolAttributeNotSet() {
SslClientInitializer<EmbeddedChannel> sslClientInitializer = SslClientInitializer<EmbeddedChannel> sslClientInitializer =
new SslClientInitializer<>(SslProvider.JDK, (X509Certificate[]) null); new SslClientInitializer<>(sslProvider);
EmbeddedChannel channel = new EmbeddedChannel(); EmbeddedChannel channel = new EmbeddedChannel();
ChannelPipeline pipeline = channel.pipeline(); ChannelPipeline pipeline = channel.pipeline();
pipeline.addLast(sslClientInitializer); pipeline.addLast(sslClientInitializer);
@ -146,7 +167,8 @@ public class SslClientInitializerTest {
@Test @Test
public void testFailure_defaultTrustManager_rejectSelfSignedCert() throws Exception { public void testFailure_defaultTrustManager_rejectSelfSignedCert() throws Exception {
SelfSignedCertificate ssc = new SelfSignedCertificate(SSL_HOST); SelfSignedCertificate ssc = new SelfSignedCertificate(SSL_HOST);
LocalAddress localAddress = new LocalAddress("DEFAULT_TRUST_MANAGER_REJECT_SELF_SIGNED_CERT"); LocalAddress localAddress =
new LocalAddress("DEFAULT_TRUST_MANAGER_REJECT_SELF_SIGNED_CERT_" + sslProvider);
Lock clientLock = new ReentrantLock(); Lock clientLock = new ReentrantLock();
Lock serverLock = new ReentrantLock(); Lock serverLock = new ReentrantLock();
ByteBuf buffer = Unpooled.buffer(); ByteBuf buffer = Unpooled.buffer();
@ -156,7 +178,7 @@ public class SslClientInitializerTest {
setUpServer( setUpServer(
getServerInitializer(ssc.key(), ssc.cert(), serverLock, serverException), localAddress); getServerInitializer(ssc.key(), ssc.cert(), serverLock, serverException), localAddress);
SslClientInitializer<LocalChannel> sslClientInitializer = SslClientInitializer<LocalChannel> sslClientInitializer =
new SslClientInitializer<>(SslProvider.JDK, (X509Certificate[]) null); new SslClientInitializer<>(sslProvider);
Channel channel = Channel channel =
setUpClient( setUpClient(
eventLoopGroup, eventLoopGroup,
@ -169,13 +191,9 @@ public class SslClientInitializerTest {
// The connection is now terminated, both the client side and the server side should get // The connection is now terminated, both the client side and the server side should get
// exceptions (caught in the caughtException method in EchoHandler and DumpHandler, // exceptions (caught in the caughtException method in EchoHandler and DumpHandler,
// respectively). // respectively).
assertThat(clientException).hasCauseThat().isInstanceOf(DecoderException.class); assertThat(Throwables.getRootCause(clientException))
assertThat(clientException) .isInstanceOf(SunCertPathBuilderException.class);
.hasCauseThat() assertThat(Throwables.getRootCause(serverException)).isInstanceOf(SSLException.class);
.hasCauseThat()
.isInstanceOf(SSLHandshakeException.class);
assertThat(serverException).hasCauseThat().isInstanceOf(DecoderException.class);
assertThat(serverException).hasCauseThat().hasCauseThat().isInstanceOf(SSLException.class);
assertThat(channel.isActive()).isFalse(); assertThat(channel.isActive()).isFalse();
Future<?> unusedFuture = eventLoopGroup.shutdownGracefully().syncUninterruptibly(); Future<?> unusedFuture = eventLoopGroup.shutdownGracefully().syncUninterruptibly();
@ -184,7 +202,7 @@ public class SslClientInitializerTest {
@Test @Test
public void testSuccess_customTrustManager_acceptCertSignedByTrustedCa() throws Exception { public void testSuccess_customTrustManager_acceptCertSignedByTrustedCa() throws Exception {
LocalAddress localAddress = LocalAddress localAddress =
new LocalAddress("CUSTOM_TRUST_MANAGER_ACCEPT_CERT_SIGNED_BY_TRUSTED_CA"); new LocalAddress("CUSTOM_TRUST_MANAGER_ACCEPT_CERT_SIGNED_BY_TRUSTED_CA_" + sslProvider);
Lock clientLock = new ReentrantLock(); Lock clientLock = new ReentrantLock();
Lock serverLock = new ReentrantLock(); Lock serverLock = new ReentrantLock();
ByteBuf buffer = Unpooled.buffer(); ByteBuf buffer = Unpooled.buffer();
@ -206,7 +224,7 @@ public class SslClientInitializerTest {
// Set up the client to trust the self signed cert used to sign the cert that server provides. // Set up the client to trust the self signed cert used to sign the cert that server provides.
SslClientInitializer<LocalChannel> sslClientInitializer = SslClientInitializer<LocalChannel> sslClientInitializer =
new SslClientInitializer<>(SslProvider.JDK, ssc.cert()); new SslClientInitializer<>(sslProvider, new X509Certificate[] {ssc.cert()});
Channel channel = Channel channel =
setUpClient( setUpClient(
eventLoopGroup, eventLoopGroup,
@ -214,14 +232,18 @@ public class SslClientInitializerTest {
localAddress, localAddress,
PROTOCOL); PROTOCOL);
verifySslChannel(channel, ImmutableList.of(cert), clientLock, serverLock, buffer, SSL_HOST); verifySslChannel(channel, ImmutableList.of(cert), clientLock, serverLock, buffer);
// Verify that the SNI extension is sent during handshake.
assertThat(sniHostReceived).isEqualTo(SSL_HOST);
Future<?> unusedFuture = eventLoopGroup.shutdownGracefully().syncUninterruptibly(); Future<?> unusedFuture = eventLoopGroup.shutdownGracefully().syncUninterruptibly();
} }
@Test @Test
public void testFailure_customTrustManager_wrongHostnameInCertificate() throws Exception { public void testFailure_customTrustManager_wrongHostnameInCertificate() throws Exception {
LocalAddress localAddress = new LocalAddress("CUSTOM_TRUST_MANAGER_WRONG_HOSTNAME"); LocalAddress localAddress =
new LocalAddress("CUSTOM_TRUST_MANAGER_WRONG_HOSTNAME_" + sslProvider);
Lock clientLock = new ReentrantLock(); Lock clientLock = new ReentrantLock();
Lock serverLock = new ReentrantLock(); Lock serverLock = new ReentrantLock();
ByteBuf buffer = Unpooled.buffer(); ByteBuf buffer = Unpooled.buffer();
@ -243,7 +265,7 @@ public class SslClientInitializerTest {
// Set up the client to trust the self signed cert used to sign the cert that server provides. // Set up the client to trust the self signed cert used to sign the cert that server provides.
SslClientInitializer<LocalChannel> sslClientInitializer = SslClientInitializer<LocalChannel> sslClientInitializer =
new SslClientInitializer<>(SslProvider.JDK, ssc.cert()); new SslClientInitializer<>(sslProvider, new X509Certificate[] {ssc.cert()});
Channel channel = Channel channel =
setUpClient( setUpClient(
eventLoopGroup, eventLoopGroup,
@ -256,31 +278,10 @@ public class SslClientInitializerTest {
// When the client rejects the server cert due to wrong hostname, the client error is wrapped // When the client rejects the server cert due to wrong hostname, the client error is wrapped
// several layers in the exception. The server also throws an exception. // several layers in the exception. The server also throws an exception.
assertThat(clientException).hasCauseThat().isInstanceOf(DecoderException.class); Throwable rootCause = Throwables.getRootCause(clientException);
assertThat(clientException) assertThat(rootCause).isInstanceOf(CertificateException.class);
.hasCauseThat() assertThat(rootCause).hasMessageThat().contains(SSL_HOST);
.hasCauseThat() assertThat(Throwables.getRootCause(serverException)).isInstanceOf(SSLException.class);
.isInstanceOf(SSLHandshakeException.class);
assertThat(clientException)
.hasCauseThat()
.hasCauseThat()
.hasCauseThat()
.isInstanceOf(SSLHandshakeException.class);
assertThat(clientException)
.hasCauseThat()
.hasCauseThat()
.hasCauseThat()
.hasCauseThat()
.isInstanceOf(CertificateException.class);
assertThat(clientException)
.hasCauseThat()
.hasCauseThat()
.hasCauseThat()
.hasCauseThat()
.hasMessageThat()
.contains(SSL_HOST);
assertThat(serverException).hasCauseThat().isInstanceOf(DecoderException.class);
assertThat(serverException).hasCauseThat().hasCauseThat().isInstanceOf(SSLException.class);
assertThat(channel.isActive()).isFalse(); assertThat(channel.isActive()).isFalse();
Future<?> unusedFuture = eventLoopGroup.shutdownGracefully().syncUninterruptibly(); Future<?> unusedFuture = eventLoopGroup.shutdownGracefully().syncUninterruptibly();

View file

@ -16,7 +16,6 @@ package google.registry.proxy.handler;
import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertThat;
import static google.registry.proxy.Protocol.PROTOCOL_KEY; import static google.registry.proxy.Protocol.PROTOCOL_KEY;
import static java.nio.charset.StandardCharsets.UTF_8;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import google.registry.proxy.Protocol.BackendProtocol; import google.registry.proxy.Protocol.BackendProtocol;
@ -240,8 +239,7 @@ public class SslInitializerTestUtils {
ImmutableList<X509Certificate> certs, ImmutableList<X509Certificate> certs,
Lock clientLock, Lock clientLock,
Lock serverLock, Lock serverLock,
ByteBuf buffer, ByteBuf buffer)
String sniHostname)
throws Exception { throws Exception {
SslHandler sslHandler = channel.pipeline().get(SslHandler.class); SslHandler sslHandler = channel.pipeline().get(SslHandler.class);
// Wait till the handshake is complete. // Wait till the handshake is complete.
@ -253,10 +251,6 @@ public class SslInitializerTestUtils {
assertThat(sslHandler.engine().getSession().getPeerCertificates()) assertThat(sslHandler.engine().getSession().getPeerCertificates())
.asList() .asList()
.containsExactly(certs.toArray()); .containsExactly(certs.toArray());
// Verify that the client sent expected SNI name during handshake.
assertThat(sslHandler.engine().getSSLParameters().getServerNames()).hasSize(1);
assertThat(sslHandler.engine().getSSLParameters().getServerNames().get(0).getEncoded())
.isEqualTo(sniHostname.getBytes(UTF_8));
// Test that message can go through, bound inbound and outbound. // Test that message can go through, bound inbound and outbound.
String inputString = "Hello, world!"; String inputString = "Hello, world!";

View file

@ -22,6 +22,7 @@ import static google.registry.proxy.handler.SslInitializerTestUtils.signKeyPair;
import static google.registry.proxy.handler.SslInitializerTestUtils.verifySslChannel; import static google.registry.proxy.handler.SslInitializerTestUtils.verifySslChannel;
import com.google.common.base.Suppliers; import com.google.common.base.Suppliers;
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import google.registry.proxy.Protocol; import google.registry.proxy.Protocol;
import google.registry.proxy.Protocol.BackendProtocol; import google.registry.proxy.Protocol.BackendProtocol;
@ -37,7 +38,7 @@ import io.netty.channel.EventLoopGroup;
import io.netty.channel.embedded.EmbeddedChannel; import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.local.LocalAddress; import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel; import io.netty.channel.local.LocalChannel;
import io.netty.handler.codec.DecoderException; import io.netty.handler.ssl.OpenSsl;
import io.netty.handler.ssl.SslContextBuilder; import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslHandler; import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.SslProvider; import io.netty.handler.ssl.SslProvider;
@ -56,7 +57,9 @@ import javax.net.ssl.SSLParameters;
import javax.net.ssl.SSLSession; import javax.net.ssl.SSLSession;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.junit.runners.JUnit4; import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameter;
import org.junit.runners.Parameterized.Parameters;
/** /**
* Unit tests for {@link SslServerInitializer}. * Unit tests for {@link SslServerInitializer}.
@ -69,7 +72,7 @@ import org.junit.runners.JUnit4;
* <p>The local addresses used in each test method must to be different, otherwise tests run in * <p>The local addresses used in each test method must to be different, otherwise tests run in
* parallel may interfere with each other. * parallel may interfere with each other.
*/ */
@RunWith(JUnit4.class) @RunWith(Parameterized.class)
public class SslServerInitializerTest { public class SslServerInitializerTest {
/** Fake host to test if the SSL engine gets the correct peer host. */ /** Fake host to test if the SSL engine gets the correct peer host. */
@ -87,6 +90,16 @@ public class SslServerInitializerTest {
.handlerProviders(ImmutableList.of()) .handlerProviders(ImmutableList.of())
.build(); .build();
@Parameter(0)
public SslProvider sslProvider;
@Parameters(name = "{0}")
public static SslProvider[] data() {
return OpenSsl.isAvailable()
? new SslProvider[] {SslProvider.OPENSSL, SslProvider.JDK}
: new SslProvider[] {SslProvider.JDK};
}
private ChannelInitializer<LocalChannel> getServerInitializer( private ChannelInitializer<LocalChannel> getServerInitializer(
boolean requireClientCert, boolean requireClientCert,
Lock serverLock, Lock serverLock,
@ -101,7 +114,7 @@ public class SslServerInitializerTest {
.addLast( .addLast(
new SslServerInitializer<LocalChannel>( new SslServerInitializer<LocalChannel>(
requireClientCert, requireClientCert,
SslProvider.JDK, sslProvider,
Suppliers.ofInstance(privateKey), Suppliers.ofInstance(privateKey),
Suppliers.ofInstance(certificates)), Suppliers.ofInstance(certificates)),
new EchoHandler(serverLock, serverException)); new EchoHandler(serverLock, serverException));
@ -129,7 +142,7 @@ public class SslServerInitializerTest {
@Override @Override
protected void initChannel(LocalChannel ch) throws Exception { protected void initChannel(LocalChannel ch) throws Exception {
SslContextBuilder sslContextBuilder = SslContextBuilder sslContextBuilder =
SslContextBuilder.forClient().trustManager(trustedCertificate); SslContextBuilder.forClient().trustManager(trustedCertificate).sslProvider(sslProvider);
if (privateKey != null && certificate != null) { if (privateKey != null && certificate != null) {
sslContextBuilder.keyManager(privateKey, certificate); sslContextBuilder.keyManager(privateKey, certificate);
} }
@ -154,7 +167,7 @@ public class SslServerInitializerTest {
SslServerInitializer<EmbeddedChannel> sslServerInitializer = SslServerInitializer<EmbeddedChannel> sslServerInitializer =
new SslServerInitializer<>( new SslServerInitializer<>(
true, true,
SslProvider.JDK, sslProvider,
Suppliers.ofInstance(ssc.key()), Suppliers.ofInstance(ssc.key()),
Suppliers.ofInstance(new X509Certificate[] {ssc.cert()})); Suppliers.ofInstance(new X509Certificate[] {ssc.cert()}));
EmbeddedChannel channel = new EmbeddedChannel(); EmbeddedChannel channel = new EmbeddedChannel();
@ -170,7 +183,7 @@ public class SslServerInitializerTest {
@Test @Test
public void testSuccess_trustAnyClientCert() throws Exception { public void testSuccess_trustAnyClientCert() throws Exception {
SelfSignedCertificate serverSsc = new SelfSignedCertificate(SSL_HOST); SelfSignedCertificate serverSsc = new SelfSignedCertificate(SSL_HOST);
LocalAddress localAddress = new LocalAddress("TRUST_ANY_CLIENT_CERT"); LocalAddress localAddress = new LocalAddress("TRUST_ANY_CLIENT_CERT_" + sslProvider);
Lock clientLock = new ReentrantLock(); Lock clientLock = new ReentrantLock();
Lock serverLock = new ReentrantLock(); Lock serverLock = new ReentrantLock();
ByteBuf buffer = Unpooled.buffer(); ByteBuf buffer = Unpooled.buffer();
@ -196,7 +209,7 @@ public class SslServerInitializerTest {
SSLSession sslSession = SSLSession sslSession =
verifySslChannel( verifySslChannel(
channel, ImmutableList.of(serverSsc.cert()), clientLock, serverLock, buffer, SSL_HOST); channel, ImmutableList.of(serverSsc.cert()), clientLock, serverLock, buffer);
// Verify that the SSL session gets the client cert. Note that this SslSession is for the client // Verify that the SSL session gets the client cert. Note that this SslSession is for the client
// channel, therefore its local certificates are the remote certificates of the SslSession for // channel, therefore its local certificates are the remote certificates of the SslSession for
// the server channel, and vice versa. // the server channel, and vice versa.
@ -209,7 +222,7 @@ public class SslServerInitializerTest {
@Test @Test
public void testSuccess_doesNotRequireClientCert() throws Exception { public void testSuccess_doesNotRequireClientCert() throws Exception {
SelfSignedCertificate serverSsc = new SelfSignedCertificate(SSL_HOST); SelfSignedCertificate serverSsc = new SelfSignedCertificate(SSL_HOST);
LocalAddress localAddress = new LocalAddress("DOES_NOT_REQUIRE_CLIENT_CERT"); LocalAddress localAddress = new LocalAddress("DOES_NOT_REQUIRE_CLIENT_CERT_" + sslProvider);
Lock clientLock = new ReentrantLock(); Lock clientLock = new ReentrantLock();
Lock serverLock = new ReentrantLock(); Lock serverLock = new ReentrantLock();
ByteBuf buffer = Unpooled.buffer(); ByteBuf buffer = Unpooled.buffer();
@ -229,7 +242,7 @@ public class SslServerInitializerTest {
SSLSession sslSession = SSLSession sslSession =
verifySslChannel( verifySslChannel(
channel, ImmutableList.of(serverSsc.cert()), clientLock, serverLock, buffer, SSL_HOST); channel, ImmutableList.of(serverSsc.cert()), clientLock, serverLock, buffer);
// Verify that the SSL session does not contain any client cert. Note that this SslSession is // Verify that the SSL session does not contain any client cert. Note that this SslSession is
// for the client channel, therefore its local certificates are the remote certificates of the // for the client channel, therefore its local certificates are the remote certificates of the
// SslSession for the server channel, and vice versa. // SslSession for the server channel, and vice versa.
@ -245,7 +258,7 @@ public class SslServerInitializerTest {
SelfSignedCertificate caSsc = new SelfSignedCertificate(); SelfSignedCertificate caSsc = new SelfSignedCertificate();
KeyPair keyPair = getKeyPair(); KeyPair keyPair = getKeyPair();
X509Certificate serverCert = signKeyPair(caSsc, keyPair, SSL_HOST); X509Certificate serverCert = signKeyPair(caSsc, keyPair, SSL_HOST);
LocalAddress localAddress = new LocalAddress("CERT_SIGNED_BY_OTHER_CA"); LocalAddress localAddress = new LocalAddress("CERT_SIGNED_BY_OTHER_CA_" + sslProvider);
Lock clientLock = new ReentrantLock(); Lock clientLock = new ReentrantLock();
Lock serverLock = new ReentrantLock(); Lock serverLock = new ReentrantLock();
ByteBuf buffer = Unpooled.buffer(); ByteBuf buffer = Unpooled.buffer();
@ -278,12 +291,7 @@ public class SslServerInitializerTest {
SSLSession sslSession = SSLSession sslSession =
verifySslChannel( verifySslChannel(
channel, channel, ImmutableList.of(serverCert, caSsc.cert()), clientLock, serverLock, buffer);
ImmutableList.of(serverCert, caSsc.cert()),
clientLock,
serverLock,
buffer,
SSL_HOST);
assertThat(sslSession.getLocalCertificates()).asList().containsExactly(clientSsc.cert()); assertThat(sslSession.getLocalCertificates()).asList().containsExactly(clientSsc.cert());
assertThat(sslSession.getPeerCertificates()) assertThat(sslSession.getPeerCertificates())
@ -297,7 +305,7 @@ public class SslServerInitializerTest {
@Test @Test
public void testFailure_requireClientCertificate() throws Exception { public void testFailure_requireClientCertificate() throws Exception {
SelfSignedCertificate serverSsc = new SelfSignedCertificate(SSL_HOST); SelfSignedCertificate serverSsc = new SelfSignedCertificate(SSL_HOST);
LocalAddress localAddress = new LocalAddress("REQUIRE_CLIENT_CERT"); LocalAddress localAddress = new LocalAddress("REQUIRE_CLIENT_CERT_" + sslProvider);
Lock clientLock = new ReentrantLock(); Lock clientLock = new ReentrantLock();
Lock serverLock = new ReentrantLock(); Lock serverLock = new ReentrantLock();
ByteBuf buffer = Unpooled.buffer(); ByteBuf buffer = Unpooled.buffer();
@ -322,14 +330,11 @@ public class SslServerInitializerTest {
PROTOCOL); PROTOCOL);
serverLock.lock(); serverLock.lock();
clientLock.lock();
// When the server rejects the client during handshake due to lack of client certificate, only // When the server rejects the client during handshake due to lack of client certificate, only
// the server throws an exception. // the server throws an exception.
assertThat(serverException).hasCauseThat().isInstanceOf(DecoderException.class); assertThat(Throwables.getRootCause(serverException)).isInstanceOf(SSLHandshakeException.class);
assertThat(serverException)
.hasCauseThat()
.hasCauseThat()
.isInstanceOf(SSLHandshakeException.class);
assertThat(channel.isActive()).isFalse(); assertThat(channel.isActive()).isFalse();
Future<?> unusedFuture = eventLoopGroup.shutdownGracefully().syncUninterruptibly(); Future<?> unusedFuture = eventLoopGroup.shutdownGracefully().syncUninterruptibly();
@ -338,7 +343,7 @@ public class SslServerInitializerTest {
@Test @Test
public void testFailure_wrongHostnameInCertificate() throws Exception { public void testFailure_wrongHostnameInCertificate() throws Exception {
SelfSignedCertificate serverSsc = new SelfSignedCertificate("wrong.com"); SelfSignedCertificate serverSsc = new SelfSignedCertificate("wrong.com");
LocalAddress localAddress = new LocalAddress("REQUIRE_CLIENT_CERT"); LocalAddress localAddress = new LocalAddress("WRONG_HOSTNAME_" + sslProvider);
Lock clientLock = new ReentrantLock(); Lock clientLock = new ReentrantLock();
Lock serverLock = new ReentrantLock(); Lock serverLock = new ReentrantLock();
ByteBuf buffer = Unpooled.buffer(); ByteBuf buffer = Unpooled.buffer();
@ -367,31 +372,10 @@ public class SslServerInitializerTest {
// When the client rejects the server cert due to wrong hostname, the client error is wrapped // When the client rejects the server cert due to wrong hostname, the client error is wrapped
// several layers in the exception. The server also throws an exception. // several layers in the exception. The server also throws an exception.
assertThat(clientException).hasCauseThat().isInstanceOf(DecoderException.class); Throwable rootCause = Throwables.getRootCause(clientException);
assertThat(clientException) assertThat(rootCause).isInstanceOf(CertificateException.class);
.hasCauseThat() assertThat(rootCause).hasMessageThat().contains(SSL_HOST);
.hasCauseThat() assertThat(Throwables.getRootCause(serverException)).isInstanceOf(SSLException.class);
.isInstanceOf(SSLHandshakeException.class);
assertThat(clientException)
.hasCauseThat()
.hasCauseThat()
.hasCauseThat()
.isInstanceOf(SSLHandshakeException.class);
assertThat(clientException)
.hasCauseThat()
.hasCauseThat()
.hasCauseThat()
.hasCauseThat()
.isInstanceOf(CertificateException.class);
assertThat(clientException)
.hasCauseThat()
.hasCauseThat()
.hasCauseThat()
.hasCauseThat()
.hasMessageThat()
.contains(SSL_HOST);
assertThat(serverException).hasCauseThat().isInstanceOf(DecoderException.class);
assertThat(serverException).hasCauseThat().hasCauseThat().isInstanceOf(SSLException.class);
assertThat(channel.isActive()).isFalse(); assertThat(channel.isActive()).isFalse();
Future<?> unusedFuture = eventLoopGroup.shutdownGracefully().syncUninterruptibly(); Future<?> unusedFuture = eventLoopGroup.shutdownGracefully().syncUninterruptibly();

View file

@ -19,7 +19,6 @@ import static google.registry.proxy.TestUtils.assertHttpResponseEquivalent;
import static google.registry.proxy.TestUtils.makeHttpGetRequest; import static google.registry.proxy.TestUtils.makeHttpGetRequest;
import static google.registry.proxy.TestUtils.makeHttpPostRequest; import static google.registry.proxy.TestUtils.makeHttpPostRequest;
import static google.registry.proxy.TestUtils.makeHttpResponse; import static google.registry.proxy.TestUtils.makeHttpResponse;
import static io.netty.handler.codec.http.HttpHeaderNames.HOST;
import io.netty.channel.embedded.EmbeddedChannel; import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpRequest;
@ -90,7 +89,7 @@ public class WebWhoisRedirectHandlerTest {
public void testSuccess_http_noHost() { public void testSuccess_http_noHost() {
setupChannel(false); setupChannel(false);
request = makeHttpGetRequest("", "/"); request = makeHttpGetRequest("", "/");
request.headers().remove(HOST); request.headers().remove("host");
// No inbound message passed to the next handler. // No inbound message passed to the next handler.
assertThat(channel.writeInbound(request)).isFalse(); assertThat(channel.writeInbound(request)).isFalse();
response = channel.readOutbound(); response = channel.readOutbound();
@ -184,7 +183,7 @@ public class WebWhoisRedirectHandlerTest {
public void testSuccess_https_noHost() { public void testSuccess_https_noHost() {
setupChannel(true); setupChannel(true);
request = makeHttpGetRequest("", "/"); request = makeHttpGetRequest("", "/");
request.headers().remove(HOST); request.headers().remove("host");
// No inbound message passed to the next handler. // No inbound message passed to the next handler.
assertThat(channel.writeInbound(request)).isFalse(); assertThat(channel.writeInbound(request)).isFalse();
response = channel.readOutbound(); response = channel.readOutbound();