diff --git a/javatests/google/registry/proxy/handler/NettyRule.java b/javatests/google/registry/proxy/handler/NettyRule.java new file mode 100644 index 000000000..c0fbdae28 --- /dev/null +++ b/javatests/google/registry/proxy/handler/NettyRule.java @@ -0,0 +1,223 @@ +// Copyright 2018 The Nomulus Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package google.registry.proxy.handler; + +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.truth.Truth.assertThat; +import static google.registry.proxy.Protocol.PROTOCOL_KEY; +import static google.registry.testing.JUnitBackports.assertThrows; +import static java.nio.charset.StandardCharsets.US_ASCII; +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.google.common.base.Throwables; +import com.google.common.truth.ThrowableSubject; +import google.registry.proxy.Protocol.BackendProtocol; +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalServerChannel; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.util.ReferenceCountUtil; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import org.junit.rules.ExternalResource; + +/** + * Helper for setting up and testing client / server connection with netty. + * + *

Used in {@link SslClientInitializerTest} and {@link SslServerInitializerTest}. + */ +final class NettyRule extends ExternalResource { + + // All I/O operations are done inside the single thread within this event loop group, which is + // different from the main test thread. Therefore synchronizations are required to make sure that + // certain I/O activities are finished when assertions are performed. + private final EventLoopGroup eventLoopGroup = new NioEventLoopGroup(1); + + // Handler attached to server's channel to record the request received. + private EchoHandler echoHandler; + + // Handler attached to client's channel to record the response received. + private DumpHandler dumpHandler; + + private Channel channel; + + /** Sets up a server channel bound to the given local address. */ + void setUpServer(LocalAddress localAddress, ChannelHandler handler) { + checkState(echoHandler == null, "Can't call setUpServer twice"); + echoHandler = new EchoHandler(); + ChannelInitializer serverInitializer = + new ChannelInitializer() { + @Override + protected void initChannel(LocalChannel ch) { + // Add the given handler + ch.pipeline().addLast(handler); + // Add the "echoHandler" last to log the incoming message and send it back + ch.pipeline().addLast(echoHandler); + } + }; + ServerBootstrap sb = + new ServerBootstrap() + .group(eventLoopGroup) + .channel(LocalServerChannel.class) + .childHandler(serverInitializer); + ChannelFuture unusedFuture = sb.bind(localAddress).syncUninterruptibly(); + } + + /** Sets up a client channel connecting to the give local address. */ + void setUpClient( + LocalAddress localAddress, + BackendProtocol protocol, + ChannelHandler handler) { + checkState(echoHandler != null, "Must call setUpServer before setUpClient"); + checkState(dumpHandler == null, "Can't call setUpClient twice"); + dumpHandler = new DumpHandler(); + ChannelInitializer clientInitializer = + new ChannelInitializer() { + @Override + protected void initChannel(LocalChannel ch) throws Exception { + // Add the given handler + ch.pipeline().addLast(handler); + // Add the "dumpHandler" last to log the incoming message + ch.pipeline().addLast(dumpHandler); + } + }; + Bootstrap b = + new Bootstrap() + .group(eventLoopGroup) + .channel(LocalChannel.class) + .handler(clientInitializer) + .attr(PROTOCOL_KEY, protocol); + channel = b.connect(localAddress).syncUninterruptibly().channel(); + } + + void checkReady() { + checkState(channel != null, "Must call setUpClient to finish NettyRule setup"); + } + + /** + * Test that a message can go through, both inbound and outbound. + * + *

The client writes the message to the server, which echos it back and saves the string in its + * promise. The client receives the echo and saves it in its promise. All these activities happens + * in the I/O thread, and this call itself returns immediately. + */ + void assertThatMessagesWork() throws Exception { + checkReady(); + assertThat(channel.isActive()).isTrue(); + + writeToChannelAndFlush(channel, "Hello, world!"); + assertThat(echoHandler.getRequestFuture().get()).isEqualTo("Hello, world!"); + assertThat(dumpHandler.getResponseFuture().get()).isEqualTo("Hello, world!"); + } + + Channel getChannel() { + checkReady(); + return channel; + } + + ThrowableSubject assertThatServerRootCause() { + checkReady(); + return assertThat( + Throwables.getRootCause( + assertThrows(ExecutionException.class, () -> echoHandler.getRequestFuture().get()))); + } + + ThrowableSubject assertThatClientRootCause() { + checkReady(); + return assertThat( + Throwables.getRootCause( + assertThrows(ExecutionException.class, () -> dumpHandler.getResponseFuture().get()))); + } + + /** + * A handler that echoes back its inbound message. The message is also saved in a promise for + * inspection later. + */ + private static class EchoHandler extends ChannelInboundHandlerAdapter { + + private final CompletableFuture requestFuture = new CompletableFuture<>(); + + Future getRequestFuture() { + return requestFuture; + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + // In the test we only send messages of type ByteBuf. + assertThat(msg).isInstanceOf(ByteBuf.class); + String request = ((ByteBuf) msg).toString(UTF_8); + // After the message is written back to the client, fulfill the promise. + ChannelFuture unusedFuture = + ctx.writeAndFlush(msg).addListener(f -> requestFuture.complete(request)); + } + + /** Saves any inbound error as the cause of the promise failure. */ + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + ChannelFuture unusedFuture = + ctx.channel().closeFuture().addListener(f -> requestFuture.completeExceptionally(cause)); + } + } + + /** A handler that dumps its inbound message to a promise that can be inspected later. */ + private static class DumpHandler extends ChannelInboundHandlerAdapter { + + private final CompletableFuture responseFuture = new CompletableFuture<>(); + + Future getResponseFuture() { + return responseFuture; + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + // In the test we only send messages of type ByteBuf. + assertThat(msg).isInstanceOf(ByteBuf.class); + String response = ((ByteBuf) msg).toString(UTF_8); + // There is no more use of this message, we should release its reference count so that it + // can be more effectively garbage collected by Netty. + ReferenceCountUtil.release(msg); + // Save the string in the promise and make it as complete. + responseFuture.complete(response); + } + + /** Saves any inbound error into the failure cause of the promise. */ + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + ctx.channel().closeFuture().addListener(f -> responseFuture.completeExceptionally(cause)); + } + } + + @Override + protected void after() { + Future unusedFuture = eventLoopGroup.shutdownGracefully(); + } + + private static void writeToChannelAndFlush(Channel channel, String data) { + ChannelFuture unusedFuture = + channel.writeAndFlush(Unpooled.wrappedBuffer(data.getBytes(US_ASCII))); + } +} diff --git a/javatests/google/registry/proxy/handler/SslClientInitializerTest.java b/javatests/google/registry/proxy/handler/SslClientInitializerTest.java index 6ca921745..34b2b720e 100644 --- a/javatests/google/registry/proxy/handler/SslClientInitializerTest.java +++ b/javatests/google/registry/proxy/handler/SslClientInitializerTest.java @@ -17,26 +17,17 @@ package google.registry.proxy.handler; import static com.google.common.truth.Truth.assertThat; import static google.registry.proxy.Protocol.PROTOCOL_KEY; import static google.registry.proxy.handler.SslInitializerTestUtils.getKeyPair; -import static google.registry.proxy.handler.SslInitializerTestUtils.setUpClient; -import static google.registry.proxy.handler.SslInitializerTestUtils.setUpServer; +import static google.registry.proxy.handler.SslInitializerTestUtils.setUpSslChannel; import static google.registry.proxy.handler.SslInitializerTestUtils.signKeyPair; -import static google.registry.proxy.handler.SslInitializerTestUtils.verifySslChannel; -import com.google.common.base.Throwables; import com.google.common.collect.ImmutableList; import google.registry.proxy.Protocol; import google.registry.proxy.Protocol.BackendProtocol; -import google.registry.proxy.handler.SslInitializerTestUtils.DumpHandler; -import google.registry.proxy.handler.SslInitializerTestUtils.EchoHandler; -import io.netty.channel.Channel; import io.netty.channel.ChannelHandler; -import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelPipeline; -import io.netty.channel.EventLoopGroup; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.channel.local.LocalAddress; import io.netty.channel.local.LocalChannel; -import io.netty.channel.nio.NioEventLoopGroup; import io.netty.handler.ssl.OpenSsl; import io.netty.handler.ssl.SniHandler; import io.netty.handler.ssl.SslContext; @@ -44,13 +35,12 @@ import io.netty.handler.ssl.SslContextBuilder; import io.netty.handler.ssl.SslHandler; import io.netty.handler.ssl.SslProvider; import io.netty.handler.ssl.util.SelfSignedCertificate; -import io.netty.util.concurrent.Future; import java.security.KeyPair; import java.security.PrivateKey; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; import javax.net.ssl.SSLException; -import org.junit.After; +import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -78,6 +68,9 @@ public class SslClientInitializerTest { /** Fake port to test if the SSL engine gets the correct peer port. */ private static final int SSL_PORT = 12345; + @Rule + public NettyRule nettyRule = new NettyRule(); + @Parameter(0) public SslProvider sslProvider; @@ -92,17 +85,6 @@ public class SslClientInitializerTest { /** Saves the SNI hostname received by the server, if sent by the client. */ private String sniHostReceived; - // All I/O operations are done inside the single thread within this event loop group, which is - // different from the main test thread. Therefore synchronizations are required to make sure that - // certain I/O activities are finished when assertions are performed. - private final EventLoopGroup eventLoopGroup = new NioEventLoopGroup(1); - - // Handler attached to server's channel to record the request received. - private final EchoHandler echoHandler = new EchoHandler(); - - // Handler attached to client's channel to record the response received. - private final DumpHandler dumpHandler = new DumpHandler(); - /** Fake protocol saved in channel attribute. */ private static final BackendProtocol PROTOCOL = Protocol.backendBuilder() @@ -112,37 +94,14 @@ public class SslClientInitializerTest { .handlerProviders(ImmutableList.of()) .build(); - @After - public void shutDown() { - Future unusedFuture = eventLoopGroup.shutdownGracefully(); - } - - private ChannelInitializer getServerInitializer( - PrivateKey privateKey, X509Certificate certificate) throws Exception { + private ChannelHandler getServerHandler(PrivateKey privateKey, X509Certificate certificate) + throws Exception { SslContext sslContext = SslContextBuilder.forServer(privateKey, certificate).build(); - return new ChannelInitializer() { - @Override - protected void initChannel(LocalChannel ch) { - ch.pipeline() - .addLast( - new SniHandler( - hostname -> { - sniHostReceived = hostname; - return sslContext; - }), - echoHandler); - } - }; - } - - private ChannelInitializer getClientInitializer( - SslClientInitializer sslClientInitializer) { - return new ChannelInitializer() { - @Override - protected void initChannel(LocalChannel ch) { - ch.pipeline().addLast(sslClientInitializer, dumpHandler); - } - }; + return new SniHandler( + hostname -> { + sniHostReceived = hostname; + return sslContext; + }); } @Test @@ -177,21 +136,15 @@ public class SslClientInitializerTest { SelfSignedCertificate ssc = new SelfSignedCertificate(SSL_HOST); LocalAddress localAddress = new LocalAddress("DEFAULT_TRUST_MANAGER_REJECT_SELF_SIGNED_CERT_" + sslProvider); - setUpServer(eventLoopGroup, getServerInitializer(ssc.key(), ssc.cert()), localAddress); + nettyRule.setUpServer(localAddress, getServerHandler(ssc.key(), ssc.cert())); SslClientInitializer sslClientInitializer = new SslClientInitializer<>(sslProvider); - Channel channel = - setUpClient( - eventLoopGroup, getClientInitializer(sslClientInitializer), localAddress, PROTOCOL); - // Wait for handshake exception to throw. - echoHandler.waitTillReady(); - dumpHandler.waitTillReady(); + nettyRule.setUpClient(localAddress, PROTOCOL, sslClientInitializer); // The connection is now terminated, both the client side and the server side should get // exceptions. - assertThat(Throwables.getRootCause(dumpHandler.getCause())) - .isInstanceOf(SunCertPathBuilderException.class); - assertThat(Throwables.getRootCause(echoHandler.getCause())).isInstanceOf(SSLException.class); - assertThat(channel.isActive()).isFalse(); + nettyRule.assertThatClientRootCause().isInstanceOf(SunCertPathBuilderException.class); + nettyRule.assertThatServerRootCause().isInstanceOf(SSLException.class); + assertThat(nettyRule.getChannel().isActive()).isFalse(); } @Test @@ -208,16 +161,15 @@ public class SslClientInitializerTest { // Set up the server to use the signed cert and private key to perform handshake; PrivateKey privateKey = keyPair.getPrivate(); - setUpServer(eventLoopGroup, getServerInitializer(privateKey, cert), localAddress); + nettyRule.setUpServer(localAddress, getServerHandler(privateKey, cert)); // Set up the client to trust the self signed cert used to sign the cert that server provides. SslClientInitializer sslClientInitializer = new SslClientInitializer<>(sslProvider, new X509Certificate[] {ssc.cert()}); - Channel channel = - setUpClient( - eventLoopGroup, getClientInitializer(sslClientInitializer), localAddress, PROTOCOL); + nettyRule.setUpClient(localAddress, PROTOCOL, sslClientInitializer); - verifySslChannel(channel, ImmutableList.of(cert), echoHandler, dumpHandler); + setUpSslChannel(nettyRule.getChannel(), cert); + nettyRule.assertThatMessagesWork(); // Verify that the SNI extension is sent during handshake. assertThat(sniHostReceived).isEqualTo(SSL_HOST); @@ -237,24 +189,18 @@ public class SslClientInitializerTest { // Set up the server to use the signed cert and private key to perform handshake; PrivateKey privateKey = keyPair.getPrivate(); - setUpServer(eventLoopGroup, getServerInitializer(privateKey, cert), localAddress); + nettyRule.setUpServer(localAddress, getServerHandler(privateKey, cert)); // Set up the client to trust the self signed cert used to sign the cert that server provides. SslClientInitializer sslClientInitializer = new SslClientInitializer<>(sslProvider, new X509Certificate[] {ssc.cert()}); - Channel channel = - setUpClient( - eventLoopGroup, getClientInitializer(sslClientInitializer), localAddress, PROTOCOL); - - echoHandler.waitTillReady(); - dumpHandler.waitTillReady(); + nettyRule.setUpClient(localAddress, PROTOCOL, sslClientInitializer); // When the client rejects the server cert due to wrong hostname, both the client and server // should throw exceptions. - Throwable rootCause = Throwables.getRootCause(dumpHandler.getCause()); - assertThat(rootCause).isInstanceOf(CertificateException.class); - assertThat(rootCause).hasMessageThat().contains(SSL_HOST); - assertThat(Throwables.getRootCause(echoHandler.getCause())).isInstanceOf(SSLException.class); - assertThat(channel.isActive()).isFalse(); + nettyRule.assertThatClientRootCause().isInstanceOf(CertificateException.class); + nettyRule.assertThatClientRootCause().hasMessageThat().contains(SSL_HOST); + nettyRule.assertThatServerRootCause().isInstanceOf(SSLException.class); + assertThat(nettyRule.getChannel().isActive()).isFalse(); } } diff --git a/javatests/google/registry/proxy/handler/SslInitializerTestUtils.java b/javatests/google/registry/proxy/handler/SslInitializerTestUtils.java index 83e63958e..a992f39a2 100644 --- a/javatests/google/registry/proxy/handler/SslInitializerTestUtils.java +++ b/javatests/google/registry/proxy/handler/SslInitializerTestUtils.java @@ -15,29 +15,11 @@ package google.registry.proxy.handler; import static com.google.common.truth.Truth.assertThat; -import static google.registry.proxy.Protocol.PROTOCOL_KEY; -import static java.nio.charset.StandardCharsets.UTF_8; -import com.google.common.collect.ImmutableList; -import google.registry.proxy.Protocol.BackendProtocol; -import io.netty.bootstrap.Bootstrap; -import io.netty.bootstrap.ServerBootstrap; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; import io.netty.channel.Channel; -import io.netty.channel.ChannelFuture; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandlerAdapter; -import io.netty.channel.ChannelInitializer; -import io.netty.channel.EventLoopGroup; -import io.netty.channel.local.LocalAddress; -import io.netty.channel.local.LocalChannel; -import io.netty.channel.local.LocalServerChannel; import io.netty.handler.ssl.SslHandler; import io.netty.handler.ssl.util.SelfSignedCertificate; -import io.netty.util.ReferenceCountUtil; import java.math.BigInteger; -import java.nio.charset.StandardCharsets; import java.security.KeyPair; import java.security.KeyPairGenerator; import java.security.SecureRandom; @@ -46,7 +28,6 @@ import java.security.cert.X509Certificate; import java.time.Duration; import java.time.Instant; import java.util.Date; -import java.util.concurrent.CountDownLatch; import javax.net.ssl.SSLSession; import javax.security.auth.x500.X500Principal; import org.bouncycastle.jce.provider.BouncyCastleProvider; @@ -62,133 +43,6 @@ public class SslInitializerTestUtils { Security.addProvider(new BouncyCastleProvider()); } - /** Sets up a server channel bound to the given local address. */ - static void setUpServer( - EventLoopGroup eventLoopGroup, - ChannelInitializer serverInitializer, - LocalAddress localAddress) { - ServerBootstrap sb = - new ServerBootstrap() - .group(eventLoopGroup) - .channel(LocalServerChannel.class) - .childHandler(serverInitializer); - ChannelFuture unusedFuture = sb.bind(localAddress).syncUninterruptibly(); - } - - /** Sets up a client channel connecting to the give local address. */ - static Channel setUpClient( - EventLoopGroup eventLoopGroup, - ChannelInitializer clientInitializer, - LocalAddress localAddress, - BackendProtocol protocol) { - Bootstrap b = - new Bootstrap() - .group(eventLoopGroup) - .channel(LocalChannel.class) - .handler(clientInitializer) - .attr(PROTOCOL_KEY, protocol); - return b.connect(localAddress).syncUninterruptibly().channel(); - } - - /** - * A handler that echoes back its inbound message. The message is also saved in a promise for - * inspection later. - */ - static class EchoHandler extends ChannelInboundHandlerAdapter { - - private final CountDownLatch latch = new CountDownLatch(1); - private String request; - private Throwable cause; - - void waitTillReady() throws InterruptedException { - latch.await(); - } - - String getRequest() { - return request; - } - - Throwable getCause() { - return cause; - } - - @Override - public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { - // In the test we only send messages of type ByteBuf. - assertThat(msg).isInstanceOf(ByteBuf.class); - request = ((ByteBuf) msg).toString(UTF_8); - // After the message is written back to the client, fulfill the promise. - ChannelFuture unusedFuture = ctx.writeAndFlush(msg).addListener(f -> latch.countDown()); - } - - /** Saves any inbound error as the cause of the promise failure. */ - @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { - this.cause = cause; - ChannelFuture unusedFuture = - ctx.channel() - .closeFuture() - .addListener( - // Apparently the JDK SSL provider will call #exceptionCaught twice with the same - // exception when the handshake fails. In this case the second listener should not - // set the promise again. - f -> { - if (latch.getCount() == 1) { - latch.countDown(); - } - }); - } - } - - /** A handler that dumps its inbound message to a promise that can be inspected later. */ - static class DumpHandler extends ChannelInboundHandlerAdapter { - - private final CountDownLatch latch = new CountDownLatch(1); - private String response; - private Throwable cause; - - void waitTillReady() throws InterruptedException { - latch.await(); - } - - String getResponse() { - return response; - } - - Throwable getCause() { - return cause; - } - - @Override - public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { - // In the test we only send messages of type ByteBuf. - assertThat(msg).isInstanceOf(ByteBuf.class); - response = ((ByteBuf) msg).toString(UTF_8); - // There is no more use of this message, we should release its reference count so that it - // can be more effectively garbage collected by Netty. - ReferenceCountUtil.release(msg); - // Save the string in the promise and make it as complete. - latch.countDown(); - } - - /** Saves any inbound error into the failure cause of the promise. */ - @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { - this.cause = cause; - ctx.channel() - .closeFuture() - .addListener( - f -> { - // Apparently the JDK SSL provider will call #exceptionCaught twice with the same - // exception when the handshake fails. In this case the second listener should not - // set the promise again. - if (latch.getCount() == 1) { - latch.countDown(); - } - }); - } - } - public static KeyPair getKeyPair() throws Exception { KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA", "BC"); keyPairGenerator.initialize(2048, new SecureRandom()); @@ -221,11 +75,9 @@ public class SslInitializerTestUtils { * @param certs The certificate that the server should provide. * @return The SSL session in current channel, can be used for further validation. */ - static SSLSession verifySslChannel( + static SSLSession setUpSslChannel( Channel channel, - ImmutableList certs, - EchoHandler echoHandler, - DumpHandler dumpHandler) + X509Certificate... certs) throws Exception { SslHandler sslHandler = channel.pipeline().get(SslHandler.class); // Wait till the handshake is complete. @@ -236,27 +88,7 @@ public class SslInitializerTestUtils { assertThat(sslHandler.engine().getSession().isValid()).isTrue(); assertThat(sslHandler.engine().getSession().getPeerCertificates()) .asList() - .containsExactly(certs.toArray()); - - // Test that message can go through, bound inbound and outbound. - String inputString = "Hello, world!"; - // The client writes the message to the server, which echos it back and saves the string in its - // promise. The client receives the echo and saves it in its promise. All these activities - // happens in the I/O thread, and this call itself returns immediately. - ChannelFuture unusedFuture = - channel.writeAndFlush( - Unpooled.wrappedBuffer(inputString.getBytes(StandardCharsets.US_ASCII))); - - // Wait for both the server and the client to finish processing. - echoHandler.waitTillReady(); - dumpHandler.waitTillReady(); - - // Checks that the message is transmitted faithfully. - String requestReceived = echoHandler.getRequest(); - String responseReceived = dumpHandler.getResponse(); - assertThat(inputString).isEqualTo(requestReceived); - assertThat(inputString).isEqualTo(responseReceived); - + .containsExactlyElementsIn(certs); // Returns the SSL session for further assertion. return sslHandler.engine().getSession(); } diff --git a/javatests/google/registry/proxy/handler/SslServerInitializerTest.java b/javatests/google/registry/proxy/handler/SslServerInitializerTest.java index d52913f87..bfc33bee9 100644 --- a/javatests/google/registry/proxy/handler/SslServerInitializerTest.java +++ b/javatests/google/registry/proxy/handler/SslServerInitializerTest.java @@ -16,27 +16,19 @@ package google.registry.proxy.handler; import static com.google.common.truth.Truth.assertThat; import static google.registry.proxy.handler.SslInitializerTestUtils.getKeyPair; -import static google.registry.proxy.handler.SslInitializerTestUtils.setUpClient; -import static google.registry.proxy.handler.SslInitializerTestUtils.setUpServer; +import static google.registry.proxy.handler.SslInitializerTestUtils.setUpSslChannel; import static google.registry.proxy.handler.SslInitializerTestUtils.signKeyPair; -import static google.registry.proxy.handler.SslInitializerTestUtils.verifySslChannel; import com.google.common.base.Suppliers; -import com.google.common.base.Throwables; import com.google.common.collect.ImmutableList; import google.registry.proxy.Protocol; import google.registry.proxy.Protocol.BackendProtocol; -import google.registry.proxy.handler.SslInitializerTestUtils.DumpHandler; -import google.registry.proxy.handler.SslInitializerTestUtils.EchoHandler; -import io.netty.channel.Channel; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelPipeline; -import io.netty.channel.EventLoopGroup; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.channel.local.LocalAddress; import io.netty.channel.local.LocalChannel; -import io.netty.channel.nio.NioEventLoopGroup; import io.netty.handler.ssl.OpenSsl; import io.netty.handler.ssl.SslContextBuilder; import io.netty.handler.ssl.SslHandler; @@ -51,7 +43,7 @@ import javax.net.ssl.SSLException; import javax.net.ssl.SSLHandshakeException; import javax.net.ssl.SSLParameters; import javax.net.ssl.SSLSession; -import org.junit.After; +import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -87,6 +79,9 @@ public class SslServerInitializerTest { .handlerProviders(ImmutableList.of()) .build(); + @Rule + public NettyRule nettyRule = new NettyRule(); + @Parameter(0) public SslProvider sslProvider; @@ -98,64 +93,40 @@ public class SslServerInitializerTest { : new SslProvider[] {SslProvider.JDK}; } - // All I/O operations are done inside the single thread within this event loop group, which is - // different from the main test thread. Therefore synchronizations are required to make sure that - // certain I/O activities are finished when assertions are performed. - private final EventLoopGroup eventLoopGroup = new NioEventLoopGroup(1); - - // Handler attached to server's channel to record the request received. - private final EchoHandler echoHandler = new EchoHandler(); - - // Handler attached to client's channel to record the response received. - private final DumpHandler dumpHandler = new DumpHandler(); - - @After - public void shutDown() { - eventLoopGroup.shutdownGracefully().getNow(); - } - - private ChannelInitializer getServerInitializer( + private ChannelHandler getServerHandler( boolean requireClientCert, PrivateKey privateKey, X509Certificate... certificates) { - return new ChannelInitializer() { - @Override - protected void initChannel(LocalChannel ch) { - ch.pipeline() - .addLast( - new SslServerInitializer( - requireClientCert, - sslProvider, - Suppliers.ofInstance(privateKey), - Suppliers.ofInstance(certificates)), - echoHandler); - } - }; + return new SslServerInitializer( + requireClientCert, + sslProvider, + Suppliers.ofInstance(privateKey), + Suppliers.ofInstance(certificates)); } - private ChannelInitializer getServerInitializer( - PrivateKey privateKey, X509Certificate... certificates) { - return getServerInitializer(true, privateKey, certificates); + private ChannelHandler getServerHandler(PrivateKey privateKey, X509Certificate... certificates) { + return getServerHandler(true, privateKey, certificates); } - private ChannelInitializer getClientInitializer( - X509Certificate trustedCertificate, PrivateKey privateKey, X509Certificate certificate) { + private ChannelHandler getClientHandler( + X509Certificate trustedCertificate, + PrivateKey privateKey, + X509Certificate certificate) { return new ChannelInitializer() { @Override protected void initChannel(LocalChannel ch) throws Exception { - SslContextBuilder sslContextBuilder = - SslContextBuilder.forClient().trustManager(trustedCertificate).sslProvider(sslProvider); - if (privateKey != null && certificate != null) { - sslContextBuilder.keyManager(privateKey, certificate); - } - SslHandler sslHandler = - sslContextBuilder.build().newHandler(ch.alloc(), SSL_HOST, SSL_PORT); + SslContextBuilder sslContextBuilder = + SslContextBuilder.forClient().trustManager(trustedCertificate).sslProvider(sslProvider); + if (privateKey != null && certificate != null) { + sslContextBuilder.keyManager(privateKey, certificate); + } + SslHandler sslHandler = sslContextBuilder.build().newHandler(ch.alloc(), SSL_HOST, SSL_PORT); - // Enable hostname verification. - SSLEngine sslEngine = sslHandler.engine(); - SSLParameters sslParameters = sslEngine.getSSLParameters(); - sslParameters.setEndpointIdentificationAlgorithm("HTTPS"); - sslEngine.setSSLParameters(sslParameters); + // Enable hostname verification. + SSLEngine sslEngine = sslHandler.engine(); + SSLParameters sslParameters = sslEngine.getSSLParameters(); + sslParameters.setEndpointIdentificationAlgorithm("HTTPS"); + sslEngine.setSSLParameters(sslParameters); - ch.pipeline().addLast(sslHandler, dumpHandler); + ch.pipeline().addLast(sslHandler); } }; } @@ -184,18 +155,16 @@ public class SslServerInitializerTest { SelfSignedCertificate serverSsc = new SelfSignedCertificate(SSL_HOST); LocalAddress localAddress = new LocalAddress("TRUST_ANY_CLIENT_CERT_" + sslProvider); - setUpServer( - eventLoopGroup, getServerInitializer(serverSsc.key(), serverSsc.cert()), localAddress); + nettyRule.setUpServer(localAddress, getServerHandler(serverSsc.key(), serverSsc.cert())); SelfSignedCertificate clientSsc = new SelfSignedCertificate(); - Channel channel = - setUpClient( - eventLoopGroup, - getClientInitializer(serverSsc.cert(), clientSsc.key(), clientSsc.cert()), - localAddress, - PROTOCOL); + nettyRule.setUpClient( + localAddress, + PROTOCOL, + getClientHandler(serverSsc.cert(), clientSsc.key(), clientSsc.cert())); + + SSLSession sslSession = setUpSslChannel(nettyRule.getChannel(), serverSsc.cert()); + nettyRule.assertThatMessagesWork(); - SSLSession sslSession = - verifySslChannel(channel, ImmutableList.of(serverSsc.cert()), echoHandler, dumpHandler); // Verify that the SSL session gets the client cert. Note that this SslSession is for the client // channel, therefore its local certificates are the remote certificates of the SslSession for // the server channel, and vice versa. @@ -208,19 +177,15 @@ public class SslServerInitializerTest { SelfSignedCertificate serverSsc = new SelfSignedCertificate(SSL_HOST); LocalAddress localAddress = new LocalAddress("DOES_NOT_REQUIRE_CLIENT_CERT_" + sslProvider); - setUpServer( - eventLoopGroup, - getServerInitializer(false, serverSsc.key(), serverSsc.cert()), - localAddress); - Channel channel = - setUpClient( - eventLoopGroup, - getClientInitializer(serverSsc.cert(), null, null), - localAddress, - PROTOCOL); + nettyRule.setUpServer( + localAddress, + getServerHandler(false, serverSsc.key(), serverSsc.cert())); + nettyRule.setUpClient( + localAddress, PROTOCOL, getClientHandler(serverSsc.cert(), null, null)); + + SSLSession sslSession = setUpSslChannel(nettyRule.getChannel(), serverSsc.cert()); + nettyRule.assertThatMessagesWork(); - SSLSession sslSession = - verifySslChannel(channel, ImmutableList.of(serverSsc.cert()), echoHandler, dumpHandler); // Verify that the SSL session does not contain any client cert. Note that this SslSession is // for the client channel, therefore its local certificates are the remote certificates of the // SslSession for the server channel, and vice versa. @@ -236,27 +201,23 @@ public class SslServerInitializerTest { X509Certificate serverCert = signKeyPair(caSsc, keyPair, SSL_HOST); LocalAddress localAddress = new LocalAddress("CERT_SIGNED_BY_OTHER_CA_" + sslProvider); - setUpServer( - eventLoopGroup, - getServerInitializer( + nettyRule.setUpServer( + localAddress, + getServerHandler( keyPair.getPrivate(), // Serving both the server cert, and the CA cert serverCert, - caSsc.cert()), - localAddress); + caSsc.cert())); SelfSignedCertificate clientSsc = new SelfSignedCertificate(); - Channel channel = - setUpClient( - eventLoopGroup, - getClientInitializer( + nettyRule.setUpClient( + localAddress, + PROTOCOL, + getClientHandler( // Client trusts the CA cert - caSsc.cert(), clientSsc.key(), clientSsc.cert()), - localAddress, - PROTOCOL); + caSsc.cert(), clientSsc.key(), clientSsc.cert())); - SSLSession sslSession = - verifySslChannel( - channel, ImmutableList.of(serverCert, caSsc.cert()), echoHandler, dumpHandler); + SSLSession sslSession = setUpSslChannel(nettyRule.getChannel(), serverCert, caSsc.cert()); + nettyRule.assertThatMessagesWork(); assertThat(sslSession.getLocalCertificates()).asList().containsExactly(clientSsc.cert()); assertThat(sslSession.getPeerCertificates()) @@ -270,28 +231,21 @@ public class SslServerInitializerTest { SelfSignedCertificate serverSsc = new SelfSignedCertificate(SSL_HOST); LocalAddress localAddress = new LocalAddress("REQUIRE_CLIENT_CERT_" + sslProvider); - setUpServer( - eventLoopGroup, getServerInitializer(serverSsc.key(), serverSsc.cert()), localAddress); - Channel channel = - setUpClient( - eventLoopGroup, - getClientInitializer( - serverSsc.cert(), - // No client cert/private key used. - null, - null), - localAddress, - PROTOCOL); - - echoHandler.waitTillReady(); - dumpHandler.waitTillReady(); + nettyRule.setUpServer(localAddress, getServerHandler(serverSsc.key(), serverSsc.cert())); + nettyRule.setUpClient( + localAddress, + PROTOCOL, + getClientHandler( + serverSsc.cert(), + // No client cert/private key used. + null, + null)); // When the server rejects the client during handshake due to lack of client certificate, both // should throw exceptions. - assertThat(Throwables.getRootCause(echoHandler.getCause())) - .isInstanceOf(SSLHandshakeException.class); - assertThat(Throwables.getRootCause(dumpHandler.getCause())).isInstanceOf(SSLException.class); - assertThat(channel.isActive()).isFalse(); + nettyRule.assertThatServerRootCause().isInstanceOf(SSLHandshakeException.class); + nettyRule.assertThatClientRootCause().isInstanceOf(SSLException.class); + assertThat(nettyRule.getChannel().isActive()).isFalse(); } @Test @@ -299,25 +253,18 @@ public class SslServerInitializerTest { SelfSignedCertificate serverSsc = new SelfSignedCertificate("wrong.com"); LocalAddress localAddress = new LocalAddress("WRONG_HOSTNAME_" + sslProvider); - setUpServer( - eventLoopGroup, getServerInitializer(serverSsc.key(), serverSsc.cert()), localAddress); + nettyRule.setUpServer(localAddress, getServerHandler(serverSsc.key(), serverSsc.cert())); SelfSignedCertificate clientSsc = new SelfSignedCertificate(); - Channel channel = - setUpClient( - eventLoopGroup, - getClientInitializer(serverSsc.cert(), clientSsc.key(), clientSsc.cert()), - localAddress, - PROTOCOL); - - echoHandler.waitTillReady(); - dumpHandler.waitTillReady(); + nettyRule.setUpClient( + localAddress, + PROTOCOL, + getClientHandler(serverSsc.cert(), clientSsc.key(), clientSsc.cert())); // When the client rejects the server cert due to wrong hostname, both the server and the client // throw exceptions. - Throwable rootCause = Throwables.getRootCause(dumpHandler.getCause()); - assertThat(rootCause).isInstanceOf(CertificateException.class); - assertThat(rootCause).hasMessageThat().contains(SSL_HOST); - assertThat(Throwables.getRootCause(echoHandler.getCause())).isInstanceOf(SSLException.class); - assertThat(channel.isActive()).isFalse(); + nettyRule.assertThatClientRootCause().isInstanceOf(CertificateException.class); + nettyRule.assertThatClientRootCause().hasMessageThat().contains(SSL_HOST); + nettyRule.assertThatServerRootCause().isInstanceOf(SSLException.class); + assertThat(nettyRule.getChannel().isActive()).isFalse(); } }