Remove SSL initializer from the prober (#378)

The prober now uses the common SSL initializer in the networking
subproject.

Also changed both initializers to take an ImmutableList of certificates
other than an array of those, for better immutability.

I have no idea where these lockfile changes are coming from. They seem
to be pure noise as far as code review is concerned.
This commit is contained in:
Lai Jiang 2019-11-22 17:46:06 -05:00 committed by GitHub
parent e318f47fc6
commit 05d56fe1a2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
27 changed files with 257 additions and 770 deletions

View file

@ -17,16 +17,20 @@ package google.registry.networking.handler;
import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.flogger.FluentLogger;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler.Sharable;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.SslProvider;
import java.security.PrivateKey;
import java.security.cert.X509Certificate;
import java.util.function.Function;
import java.util.function.Supplier;
import javax.inject.Singleton;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLParameters;
@ -47,14 +51,33 @@ public class SslClientInitializer<C extends Channel> extends ChannelInitializer<
private final Function<Channel, String> hostProvider;
private final Function<Channel, Integer> portProvider;
private final SslProvider sslProvider;
private final X509Certificate[] trustedCertificates;
private final ImmutableList<X509Certificate> trustedCertificates;
// The following two suppliers only need be none-null when client authentication is required.
private final Supplier<PrivateKey> privateKeySupplier;
private final Supplier<ImmutableList<X509Certificate>> certificateChainSupplier;
public SslClientInitializer(
SslProvider sslProvider,
Function<Channel, String> hostProvider,
Function<Channel, Integer> portProvider) {
// null uses the system default trust store.
this(sslProvider, hostProvider, portProvider, null);
public static SslClientInitializer<NioSocketChannel>
createSslClientInitializerWithSystemTrustStore(
SslProvider sslProvider,
Function<Channel, String> hostProvider,
Function<Channel, Integer> portProvider) {
return new SslClientInitializer<>(sslProvider, hostProvider, portProvider, null, null, null);
}
public static SslClientInitializer<NioSocketChannel>
createSslClientInitializerWithSystemTrustStoreAndClientAuthentication(
SslProvider sslProvider,
Function<Channel, String> hostProvider,
Function<Channel, Integer> portProvider,
Supplier<PrivateKey> privateKeySupplier,
Supplier<ImmutableList<X509Certificate>> certificateChainSupplier) {
return new SslClientInitializer<>(
sslProvider,
hostProvider,
portProvider,
ImmutableList.of(),
privateKeySupplier,
certificateChainSupplier);
}
@VisibleForTesting
@ -62,22 +85,38 @@ public class SslClientInitializer<C extends Channel> extends ChannelInitializer<
SslProvider sslProvider,
Function<Channel, String> hostProvider,
Function<Channel, Integer> portProvider,
X509Certificate[] trustCertificates) {
ImmutableList<X509Certificate> trustedCertificates,
Supplier<PrivateKey> privateKeySupplier,
Supplier<ImmutableList<X509Certificate>> certificateChainSupplier) {
logger.atInfo().log("Client SSL Provider: %s", sslProvider);
this.sslProvider = sslProvider;
this.hostProvider = hostProvider;
this.portProvider = portProvider;
this.trustedCertificates = trustCertificates;
this.trustedCertificates = trustedCertificates;
this.privateKeySupplier = privateKeySupplier;
this.certificateChainSupplier = certificateChainSupplier;
}
@Override
protected void initChannel(C channel) throws Exception {
checkNotNull(hostProvider.apply(channel), "Cannot obtain SSL host for channel: %s", channel);
checkNotNull(portProvider.apply(channel), "Cannot obtain SSL port for channel: %s", channel);
SslHandler sslHandler =
SslContextBuilder sslContextBuilder =
SslContextBuilder.forClient()
.sslProvider(sslProvider)
.trustManager(trustedCertificates)
.trustManager(
trustedCertificates.isEmpty()
? null
: trustedCertificates.toArray(new X509Certificate[0]));
if (privateKeySupplier != null && certificateChainSupplier != null) {
sslContextBuilder.keyManager(
privateKeySupplier.get(), certificateChainSupplier.get().toArray(new X509Certificate[0]));
}
SslHandler sslHandler =
sslContextBuilder
.build()
.newHandler(channel.alloc(), hostProvider.apply(channel), portProvider.apply(channel));

View file

@ -14,6 +14,7 @@
package google.registry.networking.handler;
import com.google.common.collect.ImmutableList;
import com.google.common.flogger.FluentLogger;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler.Sharable;
@ -39,7 +40,7 @@ import java.util.function.Supplier;
* come before this handler. The type parameter {@code C} is needed so that unit tests can construct
* this handler that works with {@link EmbeddedChannel};
*
* <p>The ssl handler added requires client authentication, but it uses an {@link
* <p>The ssl handler added can require client authentication, but it uses an {@link
* InsecureTrustManagerFactory}, which accepts any ssl certificate presented by the client, as long
* as the client uses the corresponding private key to establish SSL handshake. The client
* certificate hash will be passed along to GAE as an HTTP header for verification (not handled by
@ -58,14 +59,16 @@ public class SslServerInitializer<C extends Channel> extends ChannelInitializer<
private static final FluentLogger logger = FluentLogger.forEnclosingClass();
private final boolean requireClientCert;
private final SslProvider sslProvider;
// We use suppliers for the key/cert pair because they are fetched and cached from GCS, and can
// change when the artifacts on GCS changes.
private final Supplier<PrivateKey> privateKeySupplier;
private final Supplier<X509Certificate[]> certificatesSupplier;
private final Supplier<ImmutableList<X509Certificate>> certificatesSupplier;
public SslServerInitializer(
boolean requireClientCert,
SslProvider sslProvider,
Supplier<PrivateKey> privateKeySupplier,
Supplier<X509Certificate[]> certificatesSupplier) {
Supplier<ImmutableList<X509Certificate>> certificatesSupplier) {
logger.atInfo().log("Server SSL Provider: %s", sslProvider);
this.requireClientCert = requireClientCert;
this.sslProvider = sslProvider;
@ -76,7 +79,9 @@ public class SslServerInitializer<C extends Channel> extends ChannelInitializer<
@Override
protected void initChannel(C channel) throws Exception {
SslHandler sslHandler =
SslContextBuilder.forServer(privateKeySupplier.get(), certificatesSupplier.get())
SslContextBuilder.forServer(
privateKeySupplier.get(),
certificatesSupplier.get().toArray(new X509Certificate[0]))
.sslProvider(sslProvider)
.trustManager(InsecureTrustManagerFactory.INSTANCE)
.clientAuth(requireClientCert ? ClientAuth.REQUIRE : ClientAuth.NONE)

View file

@ -48,7 +48,7 @@ import org.junit.rules.ExternalResource;
*
* <p>Used in {@link SslClientInitializerTest} and {@link SslServerInitializerTest}.
*/
final class NettyRule extends ExternalResource {
public final class NettyRule extends ExternalResource {
// All I/O operations are done inside the single thread within this event loop group, which is
// different from the main test thread. Therefore synchronizations are required to make sure that
@ -63,8 +63,12 @@ final class NettyRule extends ExternalResource {
private Channel channel;
public EventLoopGroup getEventLoopGroup() {
return eventLoopGroup;
}
/** Sets up a server channel bound to the given local address. */
void setUpServer(LocalAddress localAddress, ChannelHandler handler) {
public void setUpServer(LocalAddress localAddress, ChannelHandler... handlers) {
checkState(echoHandler == null, "Can't call setUpServer twice");
echoHandler = new EchoHandler();
ChannelInitializer<LocalChannel> serverInitializer =
@ -72,7 +76,7 @@ final class NettyRule extends ExternalResource {
@Override
protected void initChannel(LocalChannel ch) {
// Add the given handler
ch.pipeline().addLast(handler);
ch.pipeline().addLast(handlers);
// Add the "echoHandler" last to log the incoming message and send it back
ch.pipeline().addLast(echoHandler);
}
@ -147,6 +151,11 @@ final class NettyRule extends ExternalResource {
assertThrows(ExecutionException.class, () -> dumpHandler.getResponseFuture().get())));
}
// TODO(jianglai): find a way to remove this helper method.
public void assertReceivedMessage(String message) throws Exception {
assertThat(echoHandler.getRequestFuture().get()).isEqualTo(message);
}
/**
* A handler that echoes back its inbound message. The message is also saved in a promise for
* inspection later.

View file

@ -19,18 +19,21 @@ import static google.registry.networking.handler.SslInitializerTestUtils.getKeyP
import static google.registry.networking.handler.SslInitializerTestUtils.setUpSslChannel;
import static google.registry.networking.handler.SslInitializerTestUtils.signKeyPair;
import com.google.common.collect.ImmutableList;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.handler.ssl.ClientAuth;
import io.netty.handler.ssl.OpenSsl;
import io.netty.handler.ssl.SniHandler;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.SslProvider;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import java.security.KeyPair;
import java.security.PrivateKey;
@ -39,6 +42,7 @@ import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.function.Function;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLSession;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
@ -86,9 +90,14 @@ public class SslClientInitializerTest {
/** Saves the SNI hostname received by the server, if sent by the client. */
private String sniHostReceived;
private ChannelHandler getServerHandler(PrivateKey privateKey, X509Certificate certificate)
private ChannelHandler getServerHandler(
boolean requireClientCert, PrivateKey privateKey, X509Certificate certificate)
throws Exception {
SslContext sslContext = SslContextBuilder.forServer(privateKey, certificate).build();
SslContext sslContext =
SslContextBuilder.forServer(privateKey, certificate)
.trustManager(InsecureTrustManagerFactory.INSTANCE)
.clientAuth(requireClientCert ? ClientAuth.REQUIRE : ClientAuth.NONE)
.build();
return new SniHandler(
hostname -> {
sniHostReceived = hostname;
@ -99,7 +108,8 @@ public class SslClientInitializerTest {
@Test
public void testSuccess_swappedInitializerWithSslHandler() throws Exception {
SslClientInitializer<EmbeddedChannel> sslClientInitializer =
new SslClientInitializer<>(sslProvider, hostProvider, portProvider);
new SslClientInitializer<>(
sslProvider, hostProvider, portProvider, ImmutableList.of(), null, null);
EmbeddedChannel channel = new EmbeddedChannel();
ChannelPipeline pipeline = channel.pipeline();
pipeline.addLast(sslClientInitializer);
@ -114,7 +124,8 @@ public class SslClientInitializerTest {
@Test
public void testSuccess_nullHost() {
SslClientInitializer<EmbeddedChannel> sslClientInitializer =
new SslClientInitializer<>(sslProvider, channel -> null, portProvider);
new SslClientInitializer<>(
sslProvider, channel -> null, portProvider, ImmutableList.of(), null, null);
EmbeddedChannel channel = new EmbeddedChannel();
ChannelPipeline pipeline = channel.pipeline();
pipeline.addLast(sslClientInitializer);
@ -125,7 +136,8 @@ public class SslClientInitializerTest {
@Test
public void testSuccess_nullPort() {
SslClientInitializer<EmbeddedChannel> sslClientInitializer =
new SslClientInitializer<>(sslProvider, hostProvider, channel -> null);
new SslClientInitializer<>(
sslProvider, hostProvider, channel -> null, ImmutableList.of(), null, null);
EmbeddedChannel channel = new EmbeddedChannel();
ChannelPipeline pipeline = channel.pipeline();
pipeline.addLast(sslClientInitializer);
@ -138,9 +150,10 @@ public class SslClientInitializerTest {
SelfSignedCertificate ssc = new SelfSignedCertificate(SSL_HOST);
LocalAddress localAddress =
new LocalAddress("DEFAULT_TRUST_MANAGER_REJECT_SELF_SIGNED_CERT_" + sslProvider);
nettyRule.setUpServer(localAddress, getServerHandler(ssc.key(), ssc.cert()));
nettyRule.setUpServer(localAddress, getServerHandler(false, ssc.key(), ssc.cert()));
SslClientInitializer<LocalChannel> sslClientInitializer =
new SslClientInitializer<>(sslProvider, hostProvider, portProvider);
new SslClientInitializer<>(
sslProvider, hostProvider, portProvider, ImmutableList.of(), null, null);
nettyRule.setUpClient(localAddress, sslClientInitializer);
// The connection is now terminated, both the client side and the server side should get
// exceptions.
@ -163,12 +176,12 @@ public class SslClientInitializerTest {
// Set up the server to use the signed cert and private key to perform handshake;
PrivateKey privateKey = keyPair.getPrivate();
nettyRule.setUpServer(localAddress, getServerHandler(privateKey, cert));
nettyRule.setUpServer(localAddress, getServerHandler(false, privateKey, cert));
// Set up the client to trust the self signed cert used to sign the cert that server provides.
SslClientInitializer<LocalChannel> sslClientInitializer =
new SslClientInitializer<>(
sslProvider, hostProvider, portProvider, new X509Certificate[] {ssc.cert()});
sslProvider, hostProvider, portProvider, ImmutableList.of(ssc.cert()), null, null);
nettyRule.setUpClient(localAddress, sslClientInitializer);
setUpSslChannel(nettyRule.getChannel(), cert);
@ -178,6 +191,43 @@ public class SslClientInitializerTest {
assertThat(sniHostReceived).isEqualTo(SSL_HOST);
}
@Test
public void testSuccess_customTrustManager_acceptSelfSignedCert_clientCertRequired()
throws Exception {
LocalAddress localAddress =
new LocalAddress(
"CUSTOM_TRUST_MANAGER_ACCEPT_SELF_SIGNED_CERT_CLIENT_CERT_REQUIRED_" + sslProvider);
SelfSignedCertificate serverSsc = new SelfSignedCertificate(SSL_HOST);
SelfSignedCertificate clientSsc = new SelfSignedCertificate();
// Set up the server to require client certificate.
nettyRule.setUpServer(localAddress, getServerHandler(true, serverSsc.key(), serverSsc.cert()));
// Set up the client to trust the server certificate and use the client certificate.
SslClientInitializer<LocalChannel> sslClientInitializer =
new SslClientInitializer<>(
sslProvider,
hostProvider,
portProvider,
ImmutableList.of(serverSsc.cert()),
() -> clientSsc.key(),
() -> ImmutableList.of(clientSsc.cert()));
nettyRule.setUpClient(localAddress, sslClientInitializer);
SSLSession sslSession = setUpSslChannel(nettyRule.getChannel(), serverSsc.cert());
nettyRule.assertThatMessagesWork();
// Verify that the SNI extension is sent during handshake.
assertThat(sniHostReceived).isEqualTo(SSL_HOST);
// Verify that the SSL session gets the client cert. Note that this SslSession is for the client
// channel, therefore its local certificates are the remote certificates of the SslSession for
// the server channel, and vice versa.
assertThat(sslSession.getLocalCertificates()).asList().containsExactly(clientSsc.cert());
assertThat(sslSession.getPeerCertificates()).asList().containsExactly(serverSsc.cert());
}
@Test
public void testFailure_customTrustManager_wrongHostnameInCertificate() throws Exception {
LocalAddress localAddress =
@ -192,12 +242,12 @@ public class SslClientInitializerTest {
// Set up the server to use the signed cert and private key to perform handshake;
PrivateKey privateKey = keyPair.getPrivate();
nettyRule.setUpServer(localAddress, getServerHandler(privateKey, cert));
nettyRule.setUpServer(localAddress, getServerHandler(false, privateKey, cert));
// Set up the client to trust the self signed cert used to sign the cert that server provides.
SslClientInitializer<LocalChannel> sslClientInitializer =
new SslClientInitializer<>(
sslProvider, hostProvider, portProvider, new X509Certificate[] {ssc.cert()});
sslProvider, hostProvider, portProvider, ImmutableList.of(ssc.cert()), null, null);
nettyRule.setUpClient(localAddress, sslClientInitializer);
// When the client rejects the server cert due to wrong hostname, both the client and server

View file

@ -29,15 +29,23 @@ import java.time.Duration;
import java.time.Instant;
import java.util.Date;
import javax.net.ssl.SSLSession;
import javax.security.auth.x500.X500Principal;
import org.bouncycastle.asn1.x500.X500Name;
import org.bouncycastle.asn1.x509.AlgorithmIdentifier;
import org.bouncycastle.asn1.x509.SubjectPublicKeyInfo;
import org.bouncycastle.cert.X509CertificateHolder;
import org.bouncycastle.cert.X509v3CertificateBuilder;
import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter;
import org.bouncycastle.crypto.util.PrivateKeyFactory;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.bouncycastle.x509.X509V3CertificateGenerator;
import org.bouncycastle.operator.ContentSigner;
import org.bouncycastle.operator.DefaultDigestAlgorithmIdentifierFinder;
import org.bouncycastle.operator.DefaultSignatureAlgorithmIdentifierFinder;
import org.bouncycastle.operator.bc.BcRSAContentSignerBuilder;
/**
* Utility class that provides methods used by {@link SslClientInitializerTest} and {@link
* SslServerInitializerTest}.
*/
@SuppressWarnings("deprecation")
public final class SslInitializerTestUtils {
static {
@ -59,16 +67,26 @@ public final class SslInitializerTestUtils {
*/
public static X509Certificate signKeyPair(
SelfSignedCertificate ssc, KeyPair keyPair, String hostname) throws Exception {
X509V3CertificateGenerator certGen = new X509V3CertificateGenerator();
X500Principal dnName = new X500Principal("CN=" + hostname);
certGen.setSerialNumber(BigInteger.valueOf(System.currentTimeMillis()));
certGen.setSubjectDN(dnName);
certGen.setIssuerDN(ssc.cert().getSubjectX500Principal());
certGen.setNotBefore(Date.from(Instant.now().minus(Duration.ofDays(1))));
certGen.setNotAfter(Date.from(Instant.now().plus(Duration.ofDays(1))));
certGen.setPublicKey(keyPair.getPublic());
certGen.setSignatureAlgorithm("SHA256WithRSAEncryption");
return certGen.generate(ssc.key(), "BC");
X500Name subjectDnName = new X500Name("CN=" + hostname);
BigInteger serialNumber = BigInteger.valueOf(System.currentTimeMillis());
X500Name issuerDnName = new X500Name(ssc.cert().getIssuerDN().getName());
Date from = Date.from(Instant.now().minus(Duration.ofDays(1)));
Date to = Date.from(Instant.now().plus(Duration.ofDays(1)));
SubjectPublicKeyInfo subPubKeyInfo =
SubjectPublicKeyInfo.getInstance(keyPair.getPublic().getEncoded());
AlgorithmIdentifier sigAlgId =
new DefaultSignatureAlgorithmIdentifierFinder().find("SHA256WithRSAEncryption");
AlgorithmIdentifier digAlgId = new DefaultDigestAlgorithmIdentifierFinder().find(sigAlgId);
ContentSigner sigGen =
new BcRSAContentSignerBuilder(sigAlgId, digAlgId)
.build(PrivateKeyFactory.createKey(ssc.key().getEncoded()));
X509v3CertificateBuilder v3CertGen =
new X509v3CertificateBuilder(
issuerDnName, serialNumber, from, to, subjectDnName, subPubKeyInfo);
X509CertificateHolder certificateHolder = v3CertGen.build(sigGen);
return new JcaX509CertificateConverter().setProvider("BC").getCertificate(certificateHolder);
}
/**

View file

@ -20,6 +20,7 @@ import static google.registry.networking.handler.SslInitializerTestUtils.setUpSs
import static google.registry.networking.handler.SslInitializerTestUtils.signKeyPair;
import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableList;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
@ -86,7 +87,7 @@ public class SslServerInitializerTest {
requireClientCert,
sslProvider,
Suppliers.ofInstance(privateKey),
Suppliers.ofInstance(certificates));
Suppliers.ofInstance(ImmutableList.copyOf(certificates)));
}
private ChannelHandler getServerHandler(PrivateKey privateKey, X509Certificate... certificates) {
@ -125,7 +126,7 @@ public class SslServerInitializerTest {
true,
sslProvider,
Suppliers.ofInstance(ssc.key()),
Suppliers.ofInstance(new X509Certificate[] {ssc.cert()}));
Suppliers.ofInstance(ImmutableList.of(ssc.cert())));
EmbeddedChannel channel = new EmbeddedChannel();
ChannelPipeline pipeline = channel.pipeline();
pipeline.addLast(sslServerInitializer);