Refactor to be more in line with a standard Gradle project structure

This commit is contained in:
Gus Brodman 2019-05-21 14:12:47 -04:00
parent 98f87bcc03
commit 38cfc9f693
3141 changed files with 99 additions and 100 deletions

View file

@ -0,0 +1,81 @@
# Description:
# This package contains the code for the binary that proxies TCP traffic from
# the GCE/GKE to AppEngine.
load("@io_bazel_rules_docker//container:container.bzl", "container_image", "container_push")
package(
default_visibility = ["//java/google/registry:registry_project"],
)
licenses(["notice"]) # Apache 2.0
java_library(
name = "proxy",
srcs = glob(["**/*.java"]),
resources = glob([
"resources/*",
"config/*.yaml",
]),
deps = [
"//java/google/registry/util",
"@com_beust_jcommander",
"@com_google_api_client",
"@com_google_apis_google_api_services_cloudkms",
"@com_google_apis_google_api_services_monitoring",
"@com_google_apis_google_api_services_storage",
"@com_google_auto_value",
"@com_google_code_findbugs_jsr305",
"@com_google_dagger",
"@com_google_flogger",
"@com_google_flogger_system_backend",
"@com_google_gson",
"@com_google_guava",
"@com_google_monitoring_client_metrics",
"@com_google_monitoring_client_stackdriver",
"@io_netty_buffer",
"@io_netty_codec",
"@io_netty_codec_http",
"@io_netty_common",
"@io_netty_handler",
"@io_netty_transport",
"@javax_inject",
"@joda_time",
"@org_bouncycastle_bcpkix_jdk15on",
],
)
java_binary(
name = "proxy_server",
main_class = "google.registry.proxy.ProxyServer",
runtime_deps = [
":proxy",
"@io_netty_tcnative",
],
)
container_image(
name = "proxy_image",
base = "@java_base//image",
entrypoint = [
"java",
"-jar",
"proxy_server_deploy.jar",
],
files = [":proxy_server_deploy.jar"],
ports = [
"30000",
"30001",
"30002",
"30010",
"30011",
],
)
container_push(
name = "proxy_push",
format = "Docker",
image = ":proxy_image",
registry = "gcr.io",
repository = "GCP_PROJECT/IMAGE_NAME",
)

View file

@ -0,0 +1,243 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Suppliers.memoizeWithExpiration;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.concurrent.TimeUnit.SECONDS;
import com.google.common.collect.ImmutableList;
import dagger.Lazy;
import dagger.Module;
import dagger.Provides;
import google.registry.proxy.ProxyConfig.Environment;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.security.PrivateKey;
import java.security.Security;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.function.Function;
import java.util.function.Supplier;
import javax.inject.Named;
import javax.inject.Provider;
import javax.inject.Qualifier;
import javax.inject.Singleton;
import org.bouncycastle.cert.X509CertificateHolder;
import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.bouncycastle.openssl.PEMException;
import org.bouncycastle.openssl.PEMKeyPair;
import org.bouncycastle.openssl.PEMParser;
import org.bouncycastle.openssl.jcajce.JcaPEMKeyConverter;
/**
* Dagger module that provides bindings needed to inject server certificate chain and private key.
*
* <p>The production certificates and private key are stored in a .pem file that is encrypted by
* Cloud KMS. The .pem file can be generated by concatenating the .crt certificate files on the
* chain and the .key private file.
*
* <p>The production certificates in the .pem file must be stored in order, where the next
* certificate's subject is the previous certificate's issuer.
*
* <p>When running the proxy locally or in test, a self signed certificate is used.
*
* @see <a href="https://cloud.google.com/kms/">Cloud Key Management Service</a>
*/
@Module
public class CertificateModule {
/** Dagger qualifier to provide bindings related to the certificates that the server provides. */
@Qualifier
private @interface ServerCertificates {}
/** Dagger qualifier to provide bindings when running locally. */
@Qualifier
private @interface Local {}
/**
* Dagger qualifier to provide bindings when running in production.
*
* <p>The "production" here means that the proxy runs on GKE, as apposed to on a local machine. It
* does not necessary mean the production environment.
*/
@Qualifier
@interface Prod {}
static {
Security.addProvider(new BouncyCastleProvider());
}
/**
* Select specific type from a given {@link ImmutableList} and convert them using the converter.
*
* @param objects the {@link ImmutableList} to filter from.
* @param clazz the class to filter.
* @param converter the converter function to act on the items in the filtered list.
*/
private static <T, E> ImmutableList<E> filterAndConvert(
ImmutableList<Object> objects, Class<T> clazz, Function<T, E> converter) {
return objects
.stream()
.filter(clazz::isInstance)
.map(clazz::cast)
.map(converter)
.collect(toImmutableList());
}
@Singleton
@Provides
static Supplier<PrivateKey> providePrivateKeySupplier(
@ServerCertificates Provider<PrivateKey> privateKeyProvider, ProxyConfig config) {
return memoizeWithExpiration(
privateKeyProvider::get, config.serverCertificateCacheSeconds, SECONDS);
}
@Singleton
@Provides
static Supplier<X509Certificate[]> provideCertificatesSupplier(
@ServerCertificates Provider<X509Certificate[]> certificatesProvider, ProxyConfig config) {
return memoizeWithExpiration(
certificatesProvider::get, config.serverCertificateCacheSeconds, SECONDS);
}
@Provides
@ServerCertificates
static X509Certificate[] provideCertificates(
Environment env,
@Local Lazy<X509Certificate[]> localCertificates,
@Prod Lazy<X509Certificate[]> prodCertificates) {
return (env == Environment.LOCAL) ? localCertificates.get() : prodCertificates.get();
}
@Provides
@ServerCertificates
static PrivateKey providePrivateKey(
Environment env,
@Local Lazy<PrivateKey> localPrivateKey,
@Prod Lazy<PrivateKey> prodPrivateKey) {
return (env == Environment.LOCAL) ? localPrivateKey.get() : prodPrivateKey.get();
}
@Singleton
@Provides
static SelfSignedCertificate provideSelfSignedCertificate() {
try {
return new SelfSignedCertificate();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Singleton
@Provides
@Local
static PrivateKey provideLocalPrivateKey(SelfSignedCertificate ssc) {
return ssc.key();
}
@Singleton
@Provides
@Local
static X509Certificate[] provideLocalCertificates(SelfSignedCertificate ssc) {
return new X509Certificate[] {ssc.cert()};
}
@Provides
@Named("pemObjects")
static ImmutableList<Object> providePemObjects(@Named("pemBytes") byte[] pemBytes) {
PEMParser pemParser =
new PEMParser(new InputStreamReader(new ByteArrayInputStream(pemBytes), UTF_8));
ImmutableList.Builder<Object> listBuilder = new ImmutableList.Builder<>();
Object obj;
// PEMParser returns an object (private key, certificate, etc) each time readObject() is called,
// until no more object is to be read from the file.
while (true) {
try {
obj = pemParser.readObject();
if (obj == null) {
break;
} else {
listBuilder.add(obj);
}
} catch (IOException e) {
throw new RuntimeException("Cannot parse PEM file correctly.", e);
}
}
return listBuilder.build();
}
// This binding should not be used directly. Use the supplier binding instead.
@Provides
@Prod
static PrivateKey provideProdPrivateKey(@Named("pemObjects") ImmutableList<Object> pemObjects) {
JcaPEMKeyConverter converter = new JcaPEMKeyConverter().setProvider("BC");
Function<PEMKeyPair, PrivateKey> privateKeyConverter =
pemKeyPair -> {
try {
return converter.getKeyPair(pemKeyPair).getPrivate();
} catch (PEMException e) {
throw new RuntimeException(
String.format("Error converting private key: %s", pemKeyPair), e);
}
};
ImmutableList<PrivateKey> privateKeys =
filterAndConvert(pemObjects, PEMKeyPair.class, privateKeyConverter);
checkState(
privateKeys.size() == 1,
"The pem file must contain exactly one private key, but %s keys are found",
privateKeys.size());
return privateKeys.get(0);
}
// This binding should not be used directly. Use the supplier binding instead.
@Provides
@Prod
static X509Certificate[] provideProdCertificates(
@Named("pemObjects") ImmutableList<Object> pemObject) {
JcaX509CertificateConverter converter = new JcaX509CertificateConverter().setProvider("BC");
Function<X509CertificateHolder, X509Certificate> certificateConverter =
certificateHolder -> {
try {
return converter.getCertificate(certificateHolder);
} catch (CertificateException e) {
throw new RuntimeException(
String.format("Error converting certificate: %s", certificateHolder), e);
}
};
ImmutableList<X509Certificate> certificates =
filterAndConvert(pemObject, X509CertificateHolder.class, certificateConverter);
checkState(certificates.size() != 0, "No certificates found in the pem file");
X509Certificate lastCert = null;
for (X509Certificate cert : certificates) {
if (lastCert != null) {
checkState(
lastCert.getIssuerX500Principal().equals(cert.getSubjectX500Principal()),
"Certificate chain error:\n%s\nis not signed by\n%s",
lastCert,
cert);
}
lastCert = cert;
}
X509Certificate[] certificateArray = new X509Certificate[certificates.size()];
certificates.toArray(certificateArray);
return certificateArray;
}
}

View file

@ -0,0 +1,181 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy;
import static google.registry.util.ResourceUtils.readResourceBytes;
import com.google.common.collect.ImmutableList;
import dagger.Module;
import dagger.Provides;
import dagger.multibindings.IntoSet;
import google.registry.proxy.HttpsRelayProtocolModule.HttpsRelayProtocol;
import google.registry.proxy.Protocol.BackendProtocol;
import google.registry.proxy.Protocol.FrontendProtocol;
import google.registry.proxy.handler.EppServiceHandler;
import google.registry.proxy.handler.ProxyProtocolHandler;
import google.registry.proxy.handler.QuotaHandler.EppQuotaHandler;
import google.registry.proxy.handler.RelayHandler.FullHttpRequestRelayHandler;
import google.registry.proxy.handler.SslServerInitializer;
import google.registry.proxy.metric.FrontendMetrics;
import google.registry.proxy.quota.QuotaConfig;
import google.registry.proxy.quota.QuotaManager;
import google.registry.proxy.quota.TokenStore;
import google.registry.util.Clock;
import io.netty.channel.ChannelHandler;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.codec.LengthFieldPrepender;
import io.netty.handler.ssl.SslProvider;
import io.netty.handler.timeout.ReadTimeoutHandler;
import java.io.IOException;
import java.security.PrivateKey;
import java.security.cert.X509Certificate;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService;
import java.util.function.Supplier;
import javax.inject.Named;
import javax.inject.Provider;
import javax.inject.Qualifier;
import javax.inject.Singleton;
/** A module that provides the {@link FrontendProtocol} used for epp protocol. */
@Module
public class EppProtocolModule {
/** Dagger qualifier to provide epp protocol related handlers and other bindings. */
@Qualifier
public @interface EppProtocol {}
private static final String PROTOCOL_NAME = "epp";
@Singleton
@Provides
@IntoSet
static FrontendProtocol provideProtocol(
ProxyConfig config,
@EppProtocol int eppPort,
@EppProtocol ImmutableList<Provider<? extends ChannelHandler>> handlerProviders,
@HttpsRelayProtocol BackendProtocol.Builder backendProtocolBuilder) {
return Protocol.frontendBuilder()
.name(PROTOCOL_NAME)
.port(eppPort)
.handlerProviders(handlerProviders)
.relayProtocol(backendProtocolBuilder.host(config.epp.relayHost).build())
.build();
}
@Provides
@EppProtocol
static ImmutableList<Provider<? extends ChannelHandler>> provideHandlerProviders(
Provider<ProxyProtocolHandler> proxyProtocolHandlerProvider,
@EppProtocol Provider<SslServerInitializer<NioSocketChannel>> sslServerInitializerProvider,
@EppProtocol Provider<ReadTimeoutHandler> readTimeoutHandlerProvider,
Provider<LengthFieldBasedFrameDecoder> lengthFieldBasedFrameDecoderProvider,
Provider<LengthFieldPrepender> lengthFieldPrependerProvider,
Provider<EppServiceHandler> eppServiceHandlerProvider,
Provider<EppQuotaHandler> eppQuotaHandlerProvider,
Provider<FullHttpRequestRelayHandler> relayHandlerProvider) {
return ImmutableList.of(
proxyProtocolHandlerProvider,
sslServerInitializerProvider,
readTimeoutHandlerProvider,
lengthFieldBasedFrameDecoderProvider,
lengthFieldPrependerProvider,
eppServiceHandlerProvider,
eppQuotaHandlerProvider,
relayHandlerProvider);
}
@Provides
static LengthFieldBasedFrameDecoder provideLengthFieldBasedFrameDecoder(ProxyConfig config) {
return new LengthFieldBasedFrameDecoder(
// Max message length.
config.epp.maxMessageLengthBytes,
// Header field location offset.
0,
// Header field length.
config.epp.headerLengthBytes,
// Adjustment applied to the header field value in order to obtain message length.
-config.epp.headerLengthBytes,
// Initial bytes to strip (i. e. strip the length header).
config.epp.headerLengthBytes);
}
@Singleton
@Provides
static LengthFieldPrepender provideLengthFieldPrepender(ProxyConfig config) {
return new LengthFieldPrepender(
// Header field length.
config.epp.headerLengthBytes,
// Length includes header field length.
true);
}
@Provides
@EppProtocol
static ReadTimeoutHandler provideReadTimeoutHandler(ProxyConfig config) {
return new ReadTimeoutHandler(config.epp.readTimeoutSeconds);
}
@Singleton
@Provides
@Named("hello")
static byte[] provideHelloBytes() {
try {
return readResourceBytes(EppProtocolModule.class, "resources/hello.xml").read();
} catch (IOException e) {
throw new RuntimeException("Cannot read EPP <hello> message file.", e);
}
}
@Provides
static EppServiceHandler provideEppServiceHandler(
@Named("accessToken") Supplier<String> accessTokenSupplier,
@Named("hello") byte[] helloBytes,
FrontendMetrics metrics,
ProxyConfig config) {
return new EppServiceHandler(
config.epp.relayHost,
config.epp.relayPath,
accessTokenSupplier,
helloBytes,
metrics);
}
@Singleton
@Provides
@EppProtocol
static SslServerInitializer<NioSocketChannel> provideSslServerInitializer(
SslProvider sslProvider,
Supplier<PrivateKey> privateKeySupplier,
Supplier<X509Certificate[]> certificatesSupplier) {
return new SslServerInitializer<>(true, sslProvider, privateKeySupplier, certificatesSupplier);
}
@Provides
@EppProtocol
static TokenStore provideTokenStore(
ProxyConfig config, ScheduledExecutorService refreshExecutor, Clock clock) {
return new TokenStore(new QuotaConfig(config.epp.quota, PROTOCOL_NAME), refreshExecutor, clock);
}
@Provides
@Singleton
@EppProtocol
static QuotaManager provideQuotaManager(
@EppProtocol TokenStore tokenStore, ExecutorService executorService) {
return new QuotaManager(tokenStore, executorService);
}
}

View file

@ -0,0 +1,127 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy;
import com.google.common.collect.ImmutableMap;
import com.google.gson.Gson;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.logging.Formatter;
import java.util.logging.Level;
import java.util.logging.LogRecord;
/**
* JUL formatter that formats log messages in a single-line JSON that Stackdriver logging can parse.
*
* <p>There is no clear documentation on how to achieve this or on the format of the JSON. This is
* much a trial and error process, plus a lot of searching. To summarize, if the logs are printed to
* {@code STDOUT} or {@code STDERR} in a single-line JSON, with the content in the {@code message}
* field and the log level in the {@code severity} field, it will be picked up by Stackdriver
* logging agent running in GKE containers and logged at correct level..
*
* @see <a
* href="https://medium.com/retailmenot-engineering/formatting-python-logs-for-stackdriver-5a5ddd80761c">
* Formatting Python Logs from Stackdriver</a> <a
* href="https://stackoverflow.com/questions/44164730/gke-stackdriver-java-logback-logging-format">
* GKE & Stackdriver: Java logback logging format?</a>
*/
class GcpJsonFormatter extends Formatter {
/** JSON field that determines the log level. */
private static final String SEVERITY = "severity";
/**
* JSON field that stores the calling class and function when the log occurs.
*
* <p>This field is not used by Stackdriver, but it is useful and can be found when the log
* entries are expanded
*/
private static final String SOURCE = "source";
/** JSON field that contains the content, this will show up as the main entry in a log. */
private static final String MESSAGE = "message";
private static final Gson gson = new Gson();
@Override
public String format(LogRecord record) {
// Add an extra newline before the message. Stackdriver does not show newlines correctly, and
// treats them as whitespace. If you want to see correctly formatted log message, expand the
// log and look for the jsonPayload.message field. This newline makes sure that the entire
// message starts on its own line, so that indentation within the message is correct.
String message = "\n" + record.getMessage();
String severity = severityFor(record.getLevel());
// The rest is mostly lifted from java.util.logging.SimpleFormatter.
String stacktrace = "";
if (record.getThrown() != null) {
StringWriter sw = new StringWriter();
try (PrintWriter pw = new PrintWriter(sw)) {
pw.println();
record.getThrown().printStackTrace(pw);
}
stacktrace = sw.toString();
}
String source;
if (record.getSourceClassName() != null) {
source = record.getSourceClassName();
if (record.getSourceMethodName() != null) {
source += " " + record.getSourceMethodName();
}
} else {
source = record.getLoggerName();
}
return gson.toJson(
ImmutableMap.of(SEVERITY, severity, SOURCE, source, MESSAGE, message + stacktrace))
+ '\n';
}
/**
* Map {@link Level} to a severity string that Stackdriver understands.
*
* @see <a
* href="https://github.com/googleapis/google-cloud-java/blob/master/google-cloud-clients/google-cloud-logging/src/main/java/com/google/cloud/logging/LoggingHandler.java#L325">{@code LoggingHandler}</a>
*/
private static String severityFor(Level level) {
switch (level.intValue()) {
// FINEST
case 300:
return "DEBUG";
// FINER
case 400:
return "DEBUG";
// FINE
case 500:
return "DEBUG";
// CONFIG
case 700:
return "INFO";
// INFO
case 800:
return "INFO";
// WARNING
case 900:
return "WARNING";
// SEVERE
case 1000:
return "ERROR";
default:
return "DEFAULT";
}
}
}

View file

@ -0,0 +1,76 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy;
import com.google.common.collect.ImmutableList;
import dagger.Module;
import dagger.Provides;
import dagger.multibindings.IntoSet;
import google.registry.proxy.Protocol.FrontendProtocol;
import google.registry.proxy.handler.HealthCheckHandler;
import io.netty.channel.ChannelHandler;
import io.netty.handler.codec.FixedLengthFrameDecoder;
import javax.inject.Provider;
import javax.inject.Qualifier;
import javax.inject.Singleton;
/**
* Module that provides a {@link FrontendProtocol} used for GCP load balancer health checking.
*
* <p>The load balancer sends health checking messages to the GCE instances to assess whether they
* are ready to receive traffic. No relay channel needs to be established for this protocol.
*/
@Module
public class HealthCheckProtocolModule {
/** Dagger qualifier to provide health check protocol related handlers and other bindings. */
@Qualifier
@interface HealthCheckProtocol {}
private static final String PROTOCOL_NAME = "health_check";
@Singleton
@Provides
@IntoSet
static FrontendProtocol provideProtocol(
@HealthCheckProtocol int healthCheckPort,
@HealthCheckProtocol ImmutableList<Provider<? extends ChannelHandler>> handlerProviders) {
return Protocol.frontendBuilder()
.name(PROTOCOL_NAME)
.port(healthCheckPort)
.hasBackend(false)
.handlerProviders(handlerProviders)
.build();
}
@Provides
@HealthCheckProtocol
static ImmutableList<Provider<? extends ChannelHandler>> provideHandlerProviders(
Provider<FixedLengthFrameDecoder> fixedLengthFrameDecoderProvider,
Provider<HealthCheckHandler> healthCheckHandlerProvider) {
return ImmutableList.of(fixedLengthFrameDecoderProvider, healthCheckHandlerProvider);
}
@Provides
static FixedLengthFrameDecoder provideFixedLengthFrameDecoder(ProxyConfig config) {
return new FixedLengthFrameDecoder(config.healthCheck.checkRequest.length());
}
@Provides
static HealthCheckHandler provideHealthCheckHandler(ProxyConfig config) {
return new HealthCheckHandler(
config.healthCheck.checkRequest, config.healthCheck.checkResponse);
}
}

View file

@ -0,0 +1,96 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy;
import com.google.common.collect.ImmutableList;
import dagger.Module;
import dagger.Provides;
import google.registry.proxy.Protocol.BackendProtocol;
import google.registry.proxy.handler.BackendMetricsHandler;
import google.registry.proxy.handler.RelayHandler.FullHttpResponseRelayHandler;
import google.registry.proxy.handler.SslClientInitializer;
import io.netty.channel.ChannelHandler;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.http.HttpClientCodec;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.logging.LoggingHandler;
import java.security.cert.X509Certificate;
import javax.annotation.Nullable;
import javax.inject.Provider;
import javax.inject.Qualifier;
/**
* Module that provides a {@link BackendProtocol.Builder} for HTTPS protocol.
*
* <p>Only a builder is provided because the client protocol itself depends on the remote host
* address, which is provided in the server protocol module that relays to this client protocol
* module, e. g. {@link WhoisProtocolModule}.
*/
@Module
public class HttpsRelayProtocolModule {
/** Dagger qualifier to provide https relay protocol related handlers and other bindings. */
@Qualifier
public @interface HttpsRelayProtocol {}
private static final String PROTOCOL_NAME = "https_relay";
@Provides
@HttpsRelayProtocol
static BackendProtocol.Builder provideProtocolBuilder(
ProxyConfig config,
@HttpsRelayProtocol ImmutableList<Provider<? extends ChannelHandler>> handlerProviders) {
return Protocol.backendBuilder()
.name(PROTOCOL_NAME)
.port(config.httpsRelay.port)
.handlerProviders(handlerProviders);
}
@Provides
@HttpsRelayProtocol
static ImmutableList<Provider<? extends ChannelHandler>> provideHandlerProviders(
Provider<SslClientInitializer<NioSocketChannel>> sslClientInitializerProvider,
Provider<HttpClientCodec> httpClientCodecProvider,
Provider<HttpObjectAggregator> httpObjectAggregatorProvider,
Provider<BackendMetricsHandler> backendMetricsHandlerProvider,
Provider<LoggingHandler> loggingHandlerProvider,
Provider<FullHttpResponseRelayHandler> relayHandlerProvider) {
return ImmutableList.of(
sslClientInitializerProvider,
httpClientCodecProvider,
httpObjectAggregatorProvider,
backendMetricsHandlerProvider,
loggingHandlerProvider,
relayHandlerProvider);
}
@Provides
static HttpClientCodec provideHttpClientCodec() {
return new HttpClientCodec();
}
@Provides
static HttpObjectAggregator provideHttpObjectAggregator(ProxyConfig config) {
return new HttpObjectAggregator(config.httpsRelay.maxMessageLengthBytes);
}
@Nullable
@Provides
@HttpsRelayProtocol
public static X509Certificate[] provideTrustedCertificates() {
// null uses the system default trust store.
return null;
}
}

View file

@ -0,0 +1,101 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy;
import com.google.api.client.googleapis.auth.oauth2.GoogleCredential;
import com.google.api.client.googleapis.util.Utils;
import com.google.api.services.monitoring.v3.Monitoring;
import com.google.api.services.monitoring.v3.model.MonitoredResource;
import com.google.common.collect.ImmutableMap;
import com.google.common.flogger.FluentLogger;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import com.google.monitoring.metrics.MetricReporter;
import com.google.monitoring.metrics.MetricWriter;
import com.google.monitoring.metrics.stackdriver.StackdriverWriter;
import dagger.Component;
import dagger.Module;
import dagger.Provides;
import google.registry.proxy.ProxyConfig.Environment;
import google.registry.proxy.metric.MetricParameters;
import javax.inject.Singleton;
/** Module that provides necessary bindings to instantiate a {@link MetricReporter} */
@Module
public class MetricsModule {
private static final FluentLogger logger = FluentLogger.forEnclosingClass();
@Singleton
@Provides
static Monitoring provideMonitoring(GoogleCredential credential, ProxyConfig config) {
return new Monitoring.Builder(
Utils.getDefaultTransport(), Utils.getDefaultJsonFactory(), credential)
.setApplicationName(config.projectId)
.build();
}
@Singleton
@Provides
static MetricWriter provideMetricWriter(
Monitoring monitoringClient, MonitoredResource monitoredResource, ProxyConfig config) {
return new StackdriverWriter(
monitoringClient,
config.projectId,
monitoredResource,
config.metrics.stackdriverMaxQps,
config.metrics.stackdriverMaxPointsPerRequest);
}
@Singleton
@Provides
static MetricReporter provideMetricReporter(MetricWriter metricWriter, ProxyConfig config) {
return new MetricReporter(
metricWriter,
config.metrics.writeIntervalSeconds,
new ThreadFactoryBuilder().setDaemon(true).build());
}
/**
* Provides a {@link MonitoredResource} appropriate for environment tha proxy runs in.
*
* <p>When running locally, the type of the monitored resource is set to {@code global}, otherwise
* it is {@code gke_container}.
*
* @see <a
* href="https://cloud.google.com/monitoring/custom-metrics/creating-metrics#which-resource">
* Choosing a monitored resource type</a>
*/
@Singleton
@Provides
static MonitoredResource provideMonitoredResource(
Environment env, ProxyConfig config, MetricParameters metricParameters) {
MonitoredResource monitoredResource = new MonitoredResource();
if (env == Environment.LOCAL) {
monitoredResource
.setType("global")
.setLabels(ImmutableMap.of("project_id", config.projectId));
} else {
monitoredResource.setType("gke_container").setLabels(metricParameters.makeLabelsMap());
}
logger.atInfo().log("Monitored resource: %s", monitoredResource);
return monitoredResource;
}
@Singleton
@Component(modules = {MetricsModule.class, ProxyModule.class})
interface MetricsComponent {
MetricReporter metricReporter();
}
}

View file

@ -0,0 +1,130 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy;
import com.google.auto.value.AutoValue;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.util.Attribute;
import io.netty.util.AttributeKey;
import javax.annotation.Nullable;
import javax.inject.Provider;
/** Value class that encapsulates parameters of a specific connection. */
public interface Protocol {
/** Key used to retrieve the {@link Protocol} from a {@link Channel}'s {@link Attribute}. */
AttributeKey<Protocol> PROTOCOL_KEY = AttributeKey.valueOf("PROTOCOL_KEY");
/** Protocol name. */
String name();
/**
* Port to bind to (for {@link FrontendProtocol}) or to connect to (for {@link BackendProtocol}).
*/
int port();
/** The {@link ChannelHandler} providers to use for the protocol, in order. */
ImmutableList<Provider<? extends ChannelHandler>> handlerProviders();
/** A builder for {@link FrontendProtocol}, by default there is a backend associated with it. */
static FrontendProtocol.Builder frontendBuilder() {
return new AutoValue_Protocol_FrontendProtocol.Builder().hasBackend(true);
}
static BackendProtocol.Builder backendBuilder() {
return new AutoValue_Protocol_BackendProtocol.Builder();
}
/**
* Generic builder enabling chaining for concrete implementations.
*
* @param <B> builder of the concrete subtype of {@link Protocol}.
* @param <P> type of the concrete subtype of {@link Protocol}.
*/
abstract class Builder<B extends Builder<B, P>, P extends Protocol> {
public abstract B name(String value);
public abstract B port(int port);
public abstract B handlerProviders(ImmutableList<Provider<? extends ChannelHandler>> value);
public abstract P build();
}
/**
* Connection parameters for a connection from the client to the proxy.
*
* <p>This protocol is associated to a {@link NioSocketChannel} established by remote peer
* connecting to the given {@code port} that the proxy is listening on.
*/
@AutoValue
abstract class FrontendProtocol implements Protocol {
/**
* The {@link BackendProtocol} used to establish a relay channel and relay the traffic to. Not
* required for health check protocol or HTTP(S) redirect.
*/
@Nullable
public abstract BackendProtocol relayProtocol();
/**
* Whether this {@code FrontendProtocol} relays to a {@code BackendProtocol}. All proxied
* traffic must be represented by a protocol that has a backend.
*/
public abstract boolean hasBackend();
@AutoValue.Builder
public abstract static class Builder extends Protocol.Builder<Builder, FrontendProtocol> {
public abstract Builder relayProtocol(BackendProtocol value);
public abstract Builder hasBackend(boolean value);
abstract FrontendProtocol autoBuild();
@Override
public FrontendProtocol build() {
FrontendProtocol frontendProtocol = autoBuild();
Preconditions.checkState(
!frontendProtocol.hasBackend() || frontendProtocol.relayProtocol() != null,
"Frontend protocol %s must define a relay protocol.",
frontendProtocol.name());
return frontendProtocol;
}
}
}
/**
* Connection parameters for a connection from the proxy to the GAE app.
*
* <p>This protocol is associated to a {@link NioSocketChannel} established by the proxy
* connecting to a remote peer.
*/
@AutoValue
abstract class BackendProtocol implements Protocol {
/** The hostname that the proxy connects to. */
public abstract String host();
/** Builder of {@link BackendProtocol}. */
@AutoValue.Builder
public abstract static class Builder extends Protocol.Builder<Builder, BackendProtocol> {
public abstract Builder host(String value);
}
}
}

View file

@ -0,0 +1,138 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy;
import static google.registry.util.ResourceUtils.readResourceUtf8;
import static google.registry.util.YamlUtils.getConfigSettings;
import com.google.common.base.Ascii;
import java.util.List;
/** The POJO that YAML config files are deserialized into. */
public class ProxyConfig {
enum Environment {
PRODUCTION,
PRODUCTION_CANARY,
SANDBOX,
SANDBOX_CANARY,
CRASH,
CRASH_CANARY,
ALPHA,
LOCAL,
}
private static final String DEFAULT_CONFIG = "config/default-config.yaml";
private static final String CUSTOM_CONFIG_FORMATTER = "config/proxy-config-%s.yaml";
public String projectId;
public List<String> gcpScopes;
public int accessTokenRefreshBeforeExpirationSeconds;
public int serverCertificateCacheSeconds;
public Gcs gcs;
public Kms kms;
public Epp epp;
public Whois whois;
public HealthCheck healthCheck;
public WebWhois webWhois;
public HttpsRelay httpsRelay;
public Metrics metrics;
/** Configuration options that apply to GCS. */
public static class Gcs {
public String bucket;
public String sslPemFilename;
}
/** Configuration options that apply to Cloud KMS. */
public static class Kms {
public String location;
public String keyRing;
public String cryptoKey;
}
/** Configuration options that apply to EPP protocol. */
public static class Epp {
public int port;
public String relayHost;
public String relayPath;
public int maxMessageLengthBytes;
public int headerLengthBytes;
public int readTimeoutSeconds;
public Quota quota;
}
/** Configuration options that apply to WHOIS protocol. */
public static class Whois {
public int port;
public String relayHost;
public String relayPath;
public int maxMessageLengthBytes;
public int readTimeoutSeconds;
public Quota quota;
}
/** Configuration options that apply to GCP load balancer health check protocol. */
public static class HealthCheck {
public int port;
public String checkRequest;
public String checkResponse;
}
/** Configuration options that apply to web WHOIS redirects. */
public static class WebWhois {
public int httpPort;
public int httpsPort;
public String redirectHost;
}
/** Configuration options that apply to HTTPS relay protocol. */
public static class HttpsRelay {
public int port;
public int maxMessageLengthBytes;
}
/** Configuration options that apply to Stackdriver monitoring metrics. */
public static class Metrics {
public int stackdriverMaxQps;
public int stackdriverMaxPointsPerRequest;
public int writeIntervalSeconds;
}
/** Configuration options that apply to quota management. */
public static class Quota {
/** Quota configuration for a specific set of users. */
public static class QuotaGroup {
public List<String> userId;
public int tokenAmount;
public int refillSeconds;
}
public int refreshSeconds;
public QuotaGroup defaultQuota;
public List<QuotaGroup> customQuota;
}
static ProxyConfig getProxyConfig(Environment env) {
String defaultYaml = readResourceUtf8(ProxyConfig.class, DEFAULT_CONFIG);
String customYaml =
readResourceUtf8(
ProxyConfig.class,
String.format(
CUSTOM_CONFIG_FORMATTER, Ascii.toLowerCase(env.name()).replace("_", "-")));
return getConfigSettings(defaultYaml, customYaml, ProxyConfig.class);
}
}

View file

@ -0,0 +1,355 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy;
import static com.google.common.base.Preconditions.checkArgument;
import static google.registry.proxy.ProxyConfig.getProxyConfig;
import com.beust.jcommander.JCommander;
import com.beust.jcommander.Parameter;
import com.beust.jcommander.ParameterException;
import com.google.api.client.googleapis.auth.oauth2.GoogleCredential;
import com.google.api.client.googleapis.util.Utils;
import com.google.api.services.cloudkms.v1.CloudKMS;
import com.google.api.services.cloudkms.v1.model.DecryptRequest;
import com.google.api.services.storage.Storage;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Maps;
import com.google.common.flogger.LoggerConfig;
import com.google.monitoring.metrics.MetricReporter;
import dagger.Component;
import dagger.Module;
import dagger.Provides;
import google.registry.proxy.EppProtocolModule.EppProtocol;
import google.registry.proxy.HealthCheckProtocolModule.HealthCheckProtocol;
import google.registry.proxy.Protocol.FrontendProtocol;
import google.registry.proxy.ProxyConfig.Environment;
import google.registry.proxy.WebWhoisProtocolsModule.HttpWhoisProtocol;
import google.registry.proxy.WebWhoisProtocolsModule.HttpsWhoisProtocol;
import google.registry.proxy.WhoisProtocolModule.WhoisProtocol;
import google.registry.proxy.handler.ProxyProtocolHandler;
import google.registry.util.Clock;
import google.registry.util.SystemClock;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import io.netty.handler.ssl.OpenSsl;
import io.netty.handler.ssl.SslProvider;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.Arrays;
import java.util.Base64;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.function.Supplier;
import java.util.logging.ConsoleHandler;
import java.util.logging.Handler;
import java.util.logging.Level;
import javax.inject.Named;
import javax.inject.Singleton;
/**
* A module that provides the port-to-protocol map and other configs that are used to bootstrap the
* server.
*/
@Module
public class ProxyModule {
@Parameter(names = "--whois", description = "Port for WHOIS")
private Integer whoisPort;
@Parameter(names = "--epp", description = "Port for EPP")
private Integer eppPort;
@Parameter(names = "--health_check", description = "Port for health check")
private Integer healthCheckPort;
@Parameter(names = "--http_whois", description = "Port for HTTP WHOIS")
private Integer httpWhoisPort;
@Parameter(names = "--https_whois", description = "Port for HTTPS WHOIS")
private Integer httpsWhoisPort;
@Parameter(names = "--env", description = "Environment to run the proxy in")
private Environment env = Environment.LOCAL;
@Parameter(
names = "--log",
description =
"Whether to log activities for debugging. "
+ "This cannot be enabled for production as logs contain PII.")
boolean log;
/**
* Configure logging parameters depending on the {@link Environment}.
*
* <p>If not running locally, set the logging formatter to {@link GcpJsonFormatter} that formats
* the log in a single-line json string printed to {@code STDOUT} or {@code STDERR}, will be
* correctly parsed by Stackdriver logging.
*
* @see <a href="https://cloud.google.com/kubernetes-engine/docs/how-to/logging#best_practices">
* Logging Best Practices</a>
*/
private void configureLogging() {
// Remove all other handlers on the root logger to avoid double logging.
LoggerConfig rootLoggerConfig = LoggerConfig.getConfig("");
Arrays.asList(rootLoggerConfig.getHandlers()).forEach(rootLoggerConfig::removeHandler);
// If running on in a non-local environment, use GCP JSON formatter.
Handler rootHandler = new ConsoleHandler();
rootHandler.setLevel(Level.FINE);
if (env != Environment.LOCAL) {
rootHandler.setFormatter(new GcpJsonFormatter());
}
rootLoggerConfig.addHandler(rootHandler);
if (log) {
// The LoggingHandler records logs at LogLevel.DEBUG (internal Netty log level), which
// corresponds to Level.FINE (JUL log level). It uses a JUL logger with the name
// "io.netty.handler.logging.LoggingHandler" to actually process the logs. This JUL logger is
// set to Level.FINE if the --log parameter is passed, so that it does not filter out logs
// that the LoggingHandler writes. Otherwise the logs are silently ignored because the default
// JUL logger level is Level.INFO.
LoggerConfig.getConfig(LoggingHandler.class).setLevel(Level.FINE);
// Log source IP information if --log parameter is passed. This is considered PII and should
// only be used in non-production environment for debugging purpose.
LoggerConfig.getConfig(ProxyProtocolHandler.class).setLevel(Level.FINE);
}
}
/**
* Parses command line arguments. Show usage if wrong arguments are given.
*
* @param args list of {@code String} arguments
* @return this {@code ProxyModule} object
*/
ProxyModule parse(String[] args) {
JCommander jCommander = new JCommander(this);
jCommander.setProgramName("proxy_server");
try {
jCommander.parse(args);
} catch (ParameterException e) {
jCommander.usage();
throw e;
}
checkArgument(
!log || (env != Environment.PRODUCTION && env != Environment.PRODUCTION_CANARY),
"Logging cannot be enabled for production environment");
configureLogging();
return this;
}
@Provides
@WhoisProtocol
int provideWhoisPort(ProxyConfig config) {
return Optional.ofNullable(whoisPort).orElse(config.whois.port);
}
@Provides
@EppProtocol
int provideEppPort(ProxyConfig config) {
return Optional.ofNullable(eppPort).orElse(config.epp.port);
}
@Provides
@HealthCheckProtocol
int provideHealthCheckPort(ProxyConfig config) {
return Optional.ofNullable(healthCheckPort).orElse(config.healthCheck.port);
}
@Provides
@HttpWhoisProtocol
int provideHttpWhoisProtocol(ProxyConfig config) {
return Optional.ofNullable(httpWhoisPort).orElse(config.webWhois.httpPort);
}
@Provides
@HttpsWhoisProtocol
int provideHttpsWhoisProtocol(ProxyConfig config) {
return Optional.ofNullable(httpsWhoisPort).orElse(config.webWhois.httpsPort);
}
@Provides
ImmutableMap<Integer, FrontendProtocol> providePortToProtocolMap(
Set<FrontendProtocol> protocolSet) {
return Maps.uniqueIndex(protocolSet, Protocol::port);
}
@Provides
Environment provideEnvironment() {
return env;
}
/**
* Provides shared logging handler.
*
* <p>Note that this handler always records logs at {@code LogLevel.DEBUG}, it is up to the JUL
* logger that it contains to decide if logs at this level should actually be captured. The log
* level of the JUL logger is configured in {@link #configureLogging()}.
*/
@Singleton
@Provides
LoggingHandler provideLoggingHandler() {
return new LoggingHandler(LogLevel.DEBUG);
}
@Singleton
@Provides
static GoogleCredential provideCredential(ProxyConfig config) {
try {
GoogleCredential credential = GoogleCredential.getApplicationDefault();
if (credential.createScopedRequired()) {
credential = credential.createScoped(config.gcpScopes);
}
return credential;
} catch (IOException e) {
throw new RuntimeException("Unable to obtain OAuth2 credential.", e);
}
}
/** Access token supplier that auto refreshes 1 minute before expiry. */
@Singleton
@Provides
@Named("accessToken")
static Supplier<String> provideAccessTokenSupplier(
GoogleCredential credential, ProxyConfig config) {
return () -> {
// If we never obtained an access token, the expiration time is null.
if (credential.getExpiresInSeconds() == null
// If we have an access token, make sure to refresh it ahead of time.
|| credential.getExpiresInSeconds() < config.accessTokenRefreshBeforeExpirationSeconds) {
try {
credential.refreshToken();
} catch (IOException e) {
throw new RuntimeException("Cannot refresh access token.", e);
}
}
return credential.getAccessToken();
};
}
@Singleton
@Provides
static CloudKMS provideCloudKms(GoogleCredential credential, ProxyConfig config) {
return new CloudKMS.Builder(
Utils.getDefaultTransport(), Utils.getDefaultJsonFactory(), credential)
.setApplicationName(config.projectId)
.build();
}
@Singleton
@Provides
static Storage provideStorage(GoogleCredential credential, ProxyConfig config) {
return new Storage.Builder(
Utils.getDefaultTransport(), Utils.getDefaultJsonFactory(), credential)
.setApplicationName(config.projectId)
.build();
}
// This binding should not be used directly. Use those provided in CertificateModule instead.
@Provides
@Named("encryptedPemBytes")
static byte[] provideEncryptedPemBytes(Storage storage, ProxyConfig config) {
try {
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
storage
.objects()
.get(config.gcs.bucket, config.gcs.sslPemFilename)
.executeMediaAndDownloadTo(outputStream);
return Base64.getMimeDecoder().decode(outputStream.toByteArray());
} catch (IOException e) {
throw new RuntimeException(
String.format(
"Error reading encrypted PEM file %s from GCS bucket %s",
config.gcs.sslPemFilename, config.gcs.bucket),
e);
}
}
// This binding should not be used directly. Use those provided in CertificateModule instead.
@Provides
@Named("pemBytes")
static byte[] providePemBytes(
CloudKMS cloudKms, @Named("encryptedPemBytes") byte[] encryptedPemBytes, ProxyConfig config) {
String cryptoKeyUrl =
String.format(
"projects/%s/locations/%s/keyRings/%s/cryptoKeys/%s",
config.projectId, config.kms.location, config.kms.keyRing, config.kms.cryptoKey);
try {
DecryptRequest decryptRequest = new DecryptRequest().encodeCiphertext(encryptedPemBytes);
return cloudKms
.projects()
.locations()
.keyRings()
.cryptoKeys()
.decrypt(cryptoKeyUrl, decryptRequest)
.execute()
.decodePlaintext();
} catch (IOException e) {
throw new RuntimeException(
String.format("PEM file decryption failed using CryptoKey: %s", cryptoKeyUrl), e);
}
}
@Provides
static SslProvider provideSslProvider() {
// Prefer OpenSSL.
return OpenSsl.isAvailable() ? SslProvider.OPENSSL : SslProvider.JDK;
}
@Provides
@Singleton
static Clock provideClock() {
return new SystemClock();
}
@Provides
static ExecutorService provideExecutorService() {
return Executors.newWorkStealingPool();
}
@Provides
static ScheduledExecutorService provideScheduledExecutorService() {
return Executors.newSingleThreadScheduledExecutor();
}
@Singleton
@Provides
ProxyConfig provideProxyConfig(Environment env) {
return getProxyConfig(env);
}
/** Root level component that exposes the port-to-protocol map. */
@Singleton
@Component(
modules = {
ProxyModule.class,
CertificateModule.class,
HttpsRelayProtocolModule.class,
WhoisProtocolModule.class,
WebWhoisProtocolsModule.class,
EppProtocolModule.class,
HealthCheckProtocolModule.class,
MetricsModule.class
})
interface ProxyComponent {
ImmutableMap<Integer, FrontendProtocol> portToProtocolMap();
MetricReporter metricReporter();
}
}

View file

@ -0,0 +1,343 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy;
import static google.registry.proxy.Protocol.PROTOCOL_KEY;
import static google.registry.proxy.handler.RelayHandler.RELAY_BUFFER_KEY;
import static google.registry.proxy.handler.RelayHandler.RELAY_CHANNEL_KEY;
import static google.registry.proxy.handler.RelayHandler.writeToRelayChannel;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.flogger.FluentLogger;
import com.google.monitoring.metrics.MetricReporter;
import google.registry.proxy.Protocol.BackendProtocol;
import google.registry.proxy.Protocol.FrontendProtocol;
import google.registry.proxy.ProxyConfig.Environment;
import google.registry.proxy.ProxyModule.ProxyComponent;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.Future;
import io.netty.util.internal.logging.InternalLoggerFactory;
import io.netty.util.internal.logging.JdkLoggerFactory;
import java.util.ArrayDeque;
import java.util.HashMap;
import java.util.Queue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import javax.inject.Provider;
/**
* A multi-protocol proxy server that listens on port(s) specified in {@link
* ProxyModule.ProxyComponent#portToProtocolMap()} }.
*/
public class ProxyServer implements Runnable {
private static final FluentLogger logger = FluentLogger.forEnclosingClass();
/** Maximum length of the queue of incoming connections. */
private static final int MAX_SOCKET_BACKLOG = 128;
private final ImmutableMap<Integer, FrontendProtocol> portToProtocolMap;
private final HashMap<Integer, Channel> portToChannelMap = new HashMap<>();
private final EventLoopGroup eventGroup = new NioEventLoopGroup();
ProxyServer(ProxyComponent proxyComponent) {
this.portToProtocolMap = proxyComponent.portToProtocolMap();
}
/**
* A {@link ChannelInitializer} for connections from a client of a certain protocol.
*
* <p>The {@link #initChannel} method does the following:
*
* <ol>
* <li>Determine the {@link FrontendProtocol} of the inbound {@link Channel} from its parent
* {@link Channel}, i. e. the {@link Channel} that binds to local port and listens.
* <li>Add handlers for the {@link FrontendProtocol} to the inbound {@link Channel}.
* <li>Establish an outbound {@link Channel} that serves as the relay channel of the inbound
* {@link Channel}, as specified by {@link FrontendProtocol#relayProtocol}.
* <li>After the outbound {@link Channel} connects successfully, enable {@link
* ChannelOption#AUTO_READ} on the inbound {@link Channel} to start reading.
* </ol>
*/
private static class ServerChannelInitializer extends ChannelInitializer<NioSocketChannel> {
@Override
protected void initChannel(NioSocketChannel inboundChannel) throws Exception {
// Add inbound channel handlers.
FrontendProtocol inboundProtocol =
(FrontendProtocol) inboundChannel.parent().attr(PROTOCOL_KEY).get();
inboundChannel.attr(PROTOCOL_KEY).set(inboundProtocol);
inboundChannel.attr(RELAY_BUFFER_KEY).set(new ArrayDeque<>());
addHandlers(inboundChannel.pipeline(), inboundProtocol.handlerProviders());
if (!inboundProtocol.hasBackend()) {
// If the frontend has no backend to relay to (health check, web WHOIS redirect, etc), start
// reading immediately.
inboundChannel.config().setAutoRead(true);
} else {
logger.atInfo().log(
"Connection established: %s %s", inboundProtocol.name(), inboundChannel);
// Connect to the relay (outbound) channel specified by the BackendProtocol.
BackendProtocol outboundProtocol = inboundProtocol.relayProtocol();
Bootstrap bootstrap =
new Bootstrap()
// Use the same thread to connect to the relay channel, therefore avoiding
// synchronization handling due to interactions between the two channels
.group(inboundChannel.eventLoop())
.channel(NioSocketChannel.class)
.handler(
new ChannelInitializer<NioSocketChannel>() {
@Override
protected void initChannel(NioSocketChannel outboundChannel)
throws Exception {
addHandlers(
outboundChannel.pipeline(), outboundProtocol.handlerProviders());
}
})
.option(ChannelOption.SO_KEEPALIVE, true)
// Outbound channel relays to inbound channel.
.attr(RELAY_CHANNEL_KEY, inboundChannel)
.attr(PROTOCOL_KEY, outboundProtocol);
connectOutboundChannel(bootstrap, inboundProtocol, outboundProtocol, inboundChannel);
// If the inbound connection is closed, close its outbound relay connection as well. There
// is no way to recover from an inbound connection termination, as the connection can only
// be initiated by the client.
ChannelFuture unusedChannelFuture =
inboundChannel
.closeFuture()
.addListener(
(future) -> {
logger.atInfo().log(
"Connection terminated: %s %s", inboundProtocol.name(), inboundChannel);
// Check if there's a relay connection. In case that the outbound connection
// is not successful, this attribute is not set.
Channel outboundChannel = inboundChannel.attr(RELAY_CHANNEL_KEY).get();
if (outboundChannel != null) {
ChannelFuture unusedChannelFuture2 = outboundChannel.close();
}
// If the frontend channel is closed and there are messages remaining in the
// buffer, we should make sure that they are released (if the messages are
// reference counted).
inboundChannel
.attr(RELAY_BUFFER_KEY)
.get()
.forEach(
msg -> {
logger.atWarning().log(
"Unfinished relay for connection %s\nHASH: %s",
inboundChannel, msg.hashCode());
ReferenceCountUtil.release(msg);
});
});
}
}
/**
* Establishes an outbound relay channel and sets the relevant metadata on both channels.
*
* <p>This method also adds a listener that is called when the established outbound connection
* is closed. The outbound connection to GAE is *not* guaranteed to persist. In case that the
* outbound connection closes but the inbound connection is still active, the listener calls
* this function again to re-establish another outbound connection. The metadata is also reset
* so that the inbound channel knows to relay to the new outbound channel.
*/
private static void connectOutboundChannel(
Bootstrap bootstrap,
FrontendProtocol inboundProtocol,
BackendProtocol outboundProtocol,
NioSocketChannel inboundChannel) {
ChannelFuture outboundChannelFuture =
bootstrap.connect(outboundProtocol.host(), outboundProtocol.port());
outboundChannelFuture.addListener(
(ChannelFuture future) -> {
if (future.isSuccess()) {
// Outbound connection is successful, now we can set the metadata to couple these two
// connections together.
Channel outboundChannel = future.channel();
// Inbound channel relays to outbound channel.
inboundChannel.attr(RELAY_CHANNEL_KEY).set(outboundChannel);
// Outbound channel established successfully, inbound channel can start reading.
// This setter also calls channel.read() to request read operation.
inboundChannel.config().setAutoRead(true);
logger.atInfo().log(
"Relay established: %s <-> %s\nFRONTEND: %s\nBACKEND: %s",
inboundProtocol.name(), outboundProtocol.name(), inboundChannel, outboundChannel);
// Now that we have a functional relay channel to the backend, if there's any
// buffered requests, send them off to the relay channel. We need to obtain a copy
// of the messages and clear the queue first, because if the relay is not successful,
// the message will be written back to the queue, causing an infinite loop.
Queue<Object> relayBuffer = inboundChannel.attr(RELAY_BUFFER_KEY).get();
Object[] messages = relayBuffer.toArray();
relayBuffer.clear();
for (Object msg : messages) {
logger.atInfo().log(
"Relay retried: %s <-> %s\nFRONTEND: %s\nBACKEND: %s\nHASH: %s",
inboundProtocol.name(),
outboundProtocol.name(),
inboundChannel,
outboundChannel,
msg.hashCode());
writeToRelayChannel(inboundChannel, outboundChannel, msg, true);
}
// When this outbound connection is closed, try reconnecting if the inbound connection
// is still active.
ChannelFuture unusedChannelFuture =
outboundChannel
.closeFuture()
.addListener(
(ChannelFuture future2) -> {
if (inboundChannel.isActive()) {
logger.atInfo().log(
"Relay interrupted: %s <-> %s\nFRONTEND: %s\nBACKEND: %s",
inboundProtocol.name(),
outboundProtocol.name(),
inboundChannel,
outboundChannel);
connectOutboundChannel(
bootstrap, inboundProtocol, outboundProtocol, inboundChannel);
} else {
logger.atInfo().log(
"Relay terminated: %s <-> %s\nFRONTEND: %s\nBACKEND: %s",
inboundProtocol.name(),
outboundProtocol.name(),
inboundChannel,
outboundChannel);
}
});
} else {
// We cannot connect to GAE for unknown reasons, no relay can be done so drop the
// inbound connection as well.
logger.atSevere().withCause(future.cause()).log(
"Cannot connect to relay channel for %s channel: %s.",
inboundProtocol.name(), inboundChannel);
ChannelFuture unusedFuture = inboundChannel.close();
}
});
}
private static void addHandlers(
ChannelPipeline channelPipeline,
ImmutableList<Provider<? extends ChannelHandler>> handlerProviders) {
for (Provider<? extends ChannelHandler> handlerProvider : handlerProviders) {
channelPipeline.addLast(handlerProvider.get());
}
}
}
@Override
public void run() {
try {
ServerBootstrap serverBootstrap =
new ServerBootstrap()
.group(eventGroup)
.channel(NioServerSocketChannel.class)
.childHandler(new ServerChannelInitializer())
.option(ChannelOption.SO_BACKLOG, MAX_SOCKET_BACKLOG)
.childOption(ChannelOption.SO_KEEPALIVE, true)
// Do not read before relay channel is established.
.childOption(ChannelOption.AUTO_READ, false);
// Bind to each port specified in portToHandlersMap.
portToProtocolMap.forEach(
(port, protocol) -> {
try {
// Wait for binding to be established for each listening port.
ChannelFuture serverChannelFuture = serverBootstrap.bind(port).sync();
if (serverChannelFuture.isSuccess()) {
logger.atInfo().log(
"Start listening on port %s for %s protocol.", port, protocol.name());
Channel serverChannel = serverChannelFuture.channel();
serverChannel.attr(PROTOCOL_KEY).set(protocol);
portToChannelMap.put(port, serverChannel);
}
} catch (InterruptedException e) {
logger.atSevere().withCause(e).log(
"Cannot listen on port %d for %s protocol.", port, protocol.name());
}
});
// Wait for all listening ports to close.
portToChannelMap.forEach(
(port, channel) -> {
try {
// Block until all server channels are closed.
ChannelFuture unusedFuture = channel.closeFuture().sync();
logger.atInfo().log(
"Stop listening on port %d for %s protocol.",
port, channel.attr(PROTOCOL_KEY).get().name());
} catch (InterruptedException e) {
logger.atSevere().withCause(e).log(
"Listening on port %d for %s protocol interrupted.",
port, channel.attr(PROTOCOL_KEY).get().name());
}
});
} finally {
logger.atInfo().log("Shutting down server...");
Future<?> unusedFuture = eventGroup.shutdownGracefully();
}
}
public static void main(String[] args) throws Exception {
// Use JDK logger for Netty's LoggingHandler,
// which is what Flogger uses under the hood.
InternalLoggerFactory.setDefaultFactory(JdkLoggerFactory.INSTANCE);
// Configure the components, this needs to run first so that the logging format is properly
// configured for each environment.
ProxyModule proxyModule = new ProxyModule().parse(args);
ProxyComponent proxyComponent =
DaggerProxyModule_ProxyComponent.builder().proxyModule(proxyModule).build();
// Do not write metrics when running locally.
if (proxyModule.provideEnvironment() != Environment.LOCAL) {
MetricReporter metricReporter = proxyComponent.metricReporter();
try {
metricReporter.startAsync().awaitRunning(10, TimeUnit.SECONDS);
logger.atInfo().log("Started up MetricReporter");
} catch (TimeoutException timeoutException) {
logger.atSevere().withCause(timeoutException).log(
"Failed to initialize MetricReporter: %s", timeoutException);
}
Runtime.getRuntime()
.addShutdownHook(
new Thread(
() -> {
try {
metricReporter.stopAsync().awaitTerminated(10, TimeUnit.SECONDS);
logger.atInfo().log("Shut down MetricReporter");
} catch (TimeoutException timeoutException) {
logger.atWarning().withCause(timeoutException).log(
"Failed to stop MetricReporter: %s", timeoutException);
}
}));
}
// Start the proxy.
new ProxyServer(proxyComponent).run();
}
}

View file

@ -0,0 +1,139 @@
// Copyright 2018 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy;
import com.google.common.collect.ImmutableList;
import dagger.Module;
import dagger.Provides;
import dagger.multibindings.IntoSet;
import google.registry.proxy.Protocol.FrontendProtocol;
import google.registry.proxy.handler.SslServerInitializer;
import google.registry.proxy.handler.WebWhoisRedirectHandler;
import io.netty.channel.ChannelHandler;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.codec.http.HttpServerExpectContinueHandler;
import io.netty.handler.ssl.SslProvider;
import java.security.PrivateKey;
import java.security.cert.X509Certificate;
import java.util.function.Supplier;
import javax.inject.Provider;
import javax.inject.Qualifier;
import javax.inject.Singleton;
/** A module that provides the {@link FrontendProtocol}s to redirect HTTP(S) web WHOIS requests. */
@Module
public class WebWhoisProtocolsModule {
/** Dagger qualifier to provide HTTP whois protocol related handlers and other bindings. */
@Qualifier
@interface HttpWhoisProtocol {}
/** Dagger qualifier to provide HTTPS whois protocol related handlers and other bindings. */
@Qualifier
@interface HttpsWhoisProtocol {}
private static final String HTTP_PROTOCOL_NAME = "whois_http";
private static final String HTTPS_PROTOCOL_NAME = "whois_https";
@Singleton
@Provides
@IntoSet
static FrontendProtocol provideHttpWhoisProtocol(
@HttpWhoisProtocol int httpWhoisPort,
@HttpWhoisProtocol ImmutableList<Provider<? extends ChannelHandler>> handlerProviders) {
return google.registry.proxy.Protocol.frontendBuilder()
.name(HTTP_PROTOCOL_NAME)
.port(httpWhoisPort)
.hasBackend(false)
.handlerProviders(handlerProviders)
.build();
}
@Singleton
@Provides
@IntoSet
static FrontendProtocol provideHttpsWhoisProtocol(
@HttpsWhoisProtocol int httpsWhoisPort,
@HttpsWhoisProtocol ImmutableList<Provider<? extends ChannelHandler>> handlerProviders) {
return google.registry.proxy.Protocol.frontendBuilder()
.name(HTTPS_PROTOCOL_NAME)
.port(httpsWhoisPort)
.hasBackend(false)
.handlerProviders(handlerProviders)
.build();
}
@Provides
@HttpWhoisProtocol
static ImmutableList<Provider<? extends ChannelHandler>> providerHttpWhoisHandlerProviders(
Provider<HttpServerCodec> httpServerCodecProvider,
Provider<HttpServerExpectContinueHandler> httpServerExpectContinueHandlerProvider,
@HttpWhoisProtocol Provider<WebWhoisRedirectHandler> webWhoisRedirectHandlerProvides) {
return ImmutableList.of(
httpServerCodecProvider,
httpServerExpectContinueHandlerProvider,
webWhoisRedirectHandlerProvides);
}
@Provides
@HttpsWhoisProtocol
static ImmutableList<Provider<? extends ChannelHandler>> providerHttpsWhoisHandlerProviders(
@HttpsWhoisProtocol
Provider<SslServerInitializer<NioSocketChannel>> sslServerInitializerProvider,
Provider<HttpServerCodec> httpServerCodecProvider,
Provider<HttpServerExpectContinueHandler> httpServerExpectContinueHandlerProvider,
@HttpsWhoisProtocol Provider<WebWhoisRedirectHandler> webWhoisRedirectHandlerProvides) {
return ImmutableList.of(
sslServerInitializerProvider,
httpServerCodecProvider,
httpServerExpectContinueHandlerProvider,
webWhoisRedirectHandlerProvides);
}
@Provides
static HttpServerCodec provideHttpServerCodec() {
return new HttpServerCodec();
}
@Provides
@HttpWhoisProtocol
static WebWhoisRedirectHandler provideHttpRedirectHandler(
google.registry.proxy.ProxyConfig config) {
return new WebWhoisRedirectHandler(false, config.webWhois.redirectHost);
}
@Provides
@HttpsWhoisProtocol
static WebWhoisRedirectHandler provideHttpsRedirectHandler(
google.registry.proxy.ProxyConfig config) {
return new WebWhoisRedirectHandler(true, config.webWhois.redirectHost);
}
@Provides
static HttpServerExpectContinueHandler provideHttpServerExpectContinueHandler() {
return new HttpServerExpectContinueHandler();
}
@Singleton
@Provides
@HttpsWhoisProtocol
static SslServerInitializer<NioSocketChannel> provideSslServerInitializer(
SslProvider sslProvider,
Supplier<PrivateKey> privateKeySupplier,
Supplier<X509Certificate[]> certificatesSupplier) {
return new SslServerInitializer<>(false, sslProvider, privateKeySupplier, certificatesSupplier);
}
}

View file

@ -0,0 +1,123 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy;
import com.google.common.collect.ImmutableList;
import dagger.Module;
import dagger.Provides;
import dagger.multibindings.IntoSet;
import google.registry.proxy.HttpsRelayProtocolModule.HttpsRelayProtocol;
import google.registry.proxy.Protocol.BackendProtocol;
import google.registry.proxy.Protocol.FrontendProtocol;
import google.registry.proxy.handler.ProxyProtocolHandler;
import google.registry.proxy.handler.QuotaHandler.WhoisQuotaHandler;
import google.registry.proxy.handler.RelayHandler.FullHttpRequestRelayHandler;
import google.registry.proxy.handler.WhoisServiceHandler;
import google.registry.proxy.metric.FrontendMetrics;
import google.registry.proxy.quota.QuotaConfig;
import google.registry.proxy.quota.QuotaManager;
import google.registry.proxy.quota.TokenStore;
import google.registry.util.Clock;
import io.netty.channel.ChannelHandler;
import io.netty.handler.codec.LineBasedFrameDecoder;
import io.netty.handler.timeout.ReadTimeoutHandler;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService;
import java.util.function.Supplier;
import javax.inject.Named;
import javax.inject.Provider;
import javax.inject.Qualifier;
import javax.inject.Singleton;
/** A module that provides the {@link FrontendProtocol} used for whois protocol. */
@Module
public class WhoisProtocolModule {
/** Dagger qualifier to provide whois protocol related handlers and other bindings. */
@Qualifier
public @interface WhoisProtocol {}
private static final String PROTOCOL_NAME = "whois";
@Singleton
@Provides
@IntoSet
static FrontendProtocol provideProtocol(
ProxyConfig config,
@WhoisProtocol int whoisPort,
@WhoisProtocol ImmutableList<Provider<? extends ChannelHandler>> handlerProviders,
@HttpsRelayProtocol BackendProtocol.Builder backendProtocolBuilder) {
return Protocol.frontendBuilder()
.name(PROTOCOL_NAME)
.port(whoisPort)
.handlerProviders(handlerProviders)
.relayProtocol(backendProtocolBuilder.host(config.whois.relayHost).build())
.build();
}
@Provides
@WhoisProtocol
static ImmutableList<Provider<? extends ChannelHandler>> provideHandlerProviders(
Provider<ProxyProtocolHandler> proxyProtocolHandlerProvider,
@WhoisProtocol Provider<ReadTimeoutHandler> readTimeoutHandlerProvider,
Provider<LineBasedFrameDecoder> lineBasedFrameDecoderProvider,
Provider<WhoisServiceHandler> whoisServiceHandlerProvider,
Provider<WhoisQuotaHandler> whoisQuotaHandlerProvider,
Provider<FullHttpRequestRelayHandler> relayHandlerProvider) {
return ImmutableList.of(
proxyProtocolHandlerProvider,
readTimeoutHandlerProvider,
lineBasedFrameDecoderProvider,
whoisServiceHandlerProvider,
whoisQuotaHandlerProvider,
relayHandlerProvider);
}
@Provides
static WhoisServiceHandler provideWhoisServiceHandler(
ProxyConfig config,
@Named("accessToken") Supplier<String> accessTokenSupplier,
FrontendMetrics metrics) {
return new WhoisServiceHandler(
config.whois.relayHost, config.whois.relayPath, accessTokenSupplier, metrics);
}
@Provides
static LineBasedFrameDecoder provideLineBasedFrameDecoder(ProxyConfig config) {
return new LineBasedFrameDecoder(config.whois.maxMessageLengthBytes);
}
@Provides
@WhoisProtocol
static ReadTimeoutHandler provideReadTimeoutHandler(ProxyConfig config) {
return new ReadTimeoutHandler(config.whois.readTimeoutSeconds);
}
@Provides
@WhoisProtocol
static TokenStore provideTokenStore(
ProxyConfig config, ScheduledExecutorService refreshExecutor, Clock clock) {
return new TokenStore(
new QuotaConfig(config.whois.quota, PROTOCOL_NAME), refreshExecutor, clock);
}
@Provides
@Singleton
@WhoisProtocol
static QuotaManager provideQuotaManager(
@WhoisProtocol TokenStore tokenStore, ExecutorService executorService) {
return new QuotaManager(tokenStore, executorService);
}
}

View file

@ -0,0 +1,215 @@
# This is the default configuration file for the proxy. Do not make changes to
# it unless you are writing new features that requires you to. To customize an
# individual deployment or environment, create a proxy-config.yaml file in the
# same directory overriding only the values you wish to change. You may need
# to override some of these values to configure and enable some services used in
# production environments.
# GCP project ID
projectId: your-gcp-project-id
# OAuth scope that the GoogleCredential will be constructed with. This list
# should include all service scopes that the proxy depends on.
gcpScopes:
# The default OAuth scope granted to GCE instances. Local development instance
# needs this scope to mimic running on GCE. Currently it is used to access
# Cloud KMS and Stackdriver Monitoring APIs.
- https://www.googleapis.com/auth/cloud-platform
# The OAuth scope required to be included in the access token for the GAE app
# to authenticate.
- https://www.googleapis.com/auth/userinfo.email
# Refresh the access token 5 minutes before it expires.
#
# Depending on how the credential is obtained, its renewal behavior is
# different. A credential backed by a private key (like the ADC obtained
# locally) will get a different token when #refreshToken() is called. On GCE,
# the credential is just a wrapper around tokens sent from the metadata server,
# which is valid from 3599 seconds to 1699 seconds (this is no documentation on
# this, I got this number by logging in a GCE VM, calling curl on the metatdata
# server every minute, and check the expiration time of the response). Calling
# refreshToken() does *not* get a new token. The token is only refreshed by
# metadata server itself (every 3599 - 1699 = 1900 seconds).
#
# We refresh the token 5 minutes before it expires, which should work in both
# cases. This is better than caching the token for a pre-defined period, because
# even right after #refreshToken() is called on the client side, tokens obtained
# from GCE metadata server may not be valid for the entirety of 3599 seconds.
accessTokenRefreshBeforeExpirationSeconds: 300
# Server certificate is cached for 30 minutes.
#
# Encrypted server server certificate and private keys are stored on GCS. They
# are cached and shared for all connections for 30 minutes. We not not cache
# the certificate indefinitely because if we upload a new one to GCS, all
# existing instances need to be killed if they cache the old one indefinitely.
serverCertificateCacheSeconds: 1800
gcs:
# GCS bucket that stores the encrypted PEM file.
bucket: your-gcs-bucket
# Name of the encrypted PEM file.
sslPemFilename: your-pem-filename
# Strings used to construct the KMS crypto key URL.
# See: https://cloud.google.com/kms/docs/reference/rest/v1/projects.locations.keyRings.cryptoKeys
kms:
# Location where your key ring is stored (global, us-east1, etc).
location: your-kms-location
# Name of the KeyRing that contains the CryptoKey file.
keyRing: your-kms-keyRing
# Name of the CryptoKey used to encrypt the PEM file.
cryptoKey: your-kms-cryptoKey
epp:
port: 30002
relayHost: registry-project-id.appspot.com
relayPath: /_dr/epp
# Maximum input message length in bytes.
#
# The first 4 bytes in a message is the total length of message, in bytes.
#
# We accept a message up to 1 GB, which should be plentiful, if not over the
# top. In fact we should probably limit this to a more reasonable number, as a
# 1 GB message will likely cause the proxy to go out of memory.
#
# See also: RFC 5734 4 Data Unit Format
# (https://tools.ietf.org/html/rfc5734#section-4).
maxMessageLengthBytes: 1073741824
# Length of the header field in bytes.
#
# Note that value of the header field is the total length (in bytes) of the
# message, including the header itself, the length of the epp xml instance is
# therefore 4 bytes shorter than this value.
headerLengthBytes: 4
# Time after which an idle connection will be closed.
#
# The RFC gives registry discretionary power to set a timeout period. 1 hr
# should be reasonable enough for any registrar to login and submit their
# request.
readTimeoutSeconds: 3600
# Quota configuration for EPP
quota:
# Token database refresh period. Set to 0 to disable refresh.
#
# After the set time period, inactive userIds will be deleted.
refreshSeconds: 0
# Default quota for any userId not matched in customQuota.
defaultQuota:
# List of identifiers, e. g. IP address, certificate hash.
#
# userId for defaultQuota should always be an empty list. Any value
# in the list will be discarded.
#
# There should be no duplicate userIds, either within this list, or
# across quota groups within customQuota. Any duplication will result
# in an error when constructing QuotaConfig.
userId: []
# Number of tokens allotted to the matched user. Set to -1 to allow
# infinite quota.
tokenAmount: 100
# Token refill period. Set to 0 to disable refill.
#
# After the set time period, the token for the user will be
# reset to tokenAmount.
refillSeconds: 0
# List of custom quotas for specific userId. Use the same schema as
# defaultQuota for list entries.
customQuota: []
whois:
port: 30001
relayHost: registry-project-id.appspot.com
relayPath: /_dr/whois
# Maximum input message length in bytes.
#
# Domain name cannot be longer than 256 characters. 512-character message
# length should be safe for most cases, including registrar queries.
#
# See also: RFC 1035 2.3.4 Size limits
# (http://www.freesoft.org/CIE/RFC/1035/9.htm).
maxMessageLengthBytes: 512
# Whois protocol is transient, the client should not establish a long lasting
# idle connection.
readTimeoutSeconds: 60
# Quota configuration for WHOIS
quota:
# Token database refresh period. Set to 0 to disable refresh.
#
# After the set time period, inactive token buckets will be deleted.
refreshSeconds: 3600
# Default quota for any userId not matched in customQuota.
defaultQuota:
# List of identifiers, e. g. IP address, certificate hash.
#
# userId for defaultQuota should always be an empty list.
userId: []
# Number of tokens allotted to the matched user. Set to -1 to allow
# infinite quota.
tokenAmount: 100
# Token refill period. Set to 0 to disable refill.
#
# After the set time period, the token for the given user will be
# reset to tokenAmount.
refillSeconds: 600
# List of custom quotas for specific userId. Use the same schema as
# defaultQuota for list entries.
customQuota: []
healthCheck:
port: 30000
# Health checker request message, defined in GCP load balancer backend.
checkRequest: HEALTH_CHECK_REQUEST
# Health checker response message, defined in GCP load balancer backend.
checkResponse: HEALTH_CHECK_RESPONSE
httpsRelay:
port: 443
# Maximum size of an HTTP message in bytes.
maxMessageLengthBytes: 524288
webWhois:
httpPort: 30010
httpsPort: 30011
# The 302 redirect destination of HTTPS web WHOIS GET requests.
# HTTP web WHOIS GET requests will be 301 redirected to HTTPS first.
redirectHost: whois.yourdomain.tld
metrics:
# Max queries per second for the Google Cloud Monitoring V3 (aka Stackdriver)
# API. The limit can be adjusted by contacting Cloud Support.
stackdriverMaxQps: 30
# Max number of points that can be sent to Stackdriver in a single
# TimeSeries.Create API call.
stackdriverMaxPointsPerRequest: 200
# How often metrics are written.
writeIntervalSeconds: 60

View file

@ -0,0 +1 @@
# Add environment-specific proxy configuration here.

View file

@ -0,0 +1 @@
# Add environment-specific proxy configuration here.

View file

@ -0,0 +1 @@
# Add environment-specific proxy configuration here.

View file

@ -0,0 +1 @@
# Add environment-specific proxy configuration here.

View file

@ -0,0 +1 @@
# Add environment-specific proxy configuration here.

View file

@ -0,0 +1 @@
# Add environment-specific proxy configuration here.

View file

@ -0,0 +1 @@
# Add environment-specific proxy configuration here.

View file

@ -0,0 +1 @@
# Add environment-specific proxy configuration here.

View file

@ -0,0 +1,134 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy.handler;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
import static google.registry.proxy.Protocol.PROTOCOL_KEY;
import static google.registry.proxy.handler.EppServiceHandler.CLIENT_CERTIFICATE_HASH_KEY;
import static google.registry.proxy.handler.RelayHandler.RELAY_CHANNEL_KEY;
import google.registry.proxy.handler.RelayHandler.FullHttpResponseRelayHandler;
import google.registry.proxy.metric.BackendMetrics;
import google.registry.util.Clock;
import io.netty.channel.Channel;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import java.util.ArrayDeque;
import java.util.Optional;
import java.util.Queue;
import javax.inject.Inject;
import org.joda.time.DateTime;
/**
* Handler that records metrics a backend channel.
*
* <p>This handler is added right before {@link FullHttpResponseRelayHandler} in the backend
* protocol handler provider method. {@link FullHttpRequest} outbound messages encounter this first
* before being handed over to HTTP related handler. {@link FullHttpResponse} inbound messages are
* first constructed (from plain bytes) by preceding handlers and then logged in this handler.
*/
public class BackendMetricsHandler extends ChannelDuplexHandler {
private final Clock clock;
private final BackendMetrics metrics;
private String relayedProtocolName;
private String clientCertHash;
private Channel relayedChannel;
/**
* A queue that saves the time at which a request is sent to the GAE app.
*
* <p>This queue is used to calculate HTTP request-response latency. HTTP 1.1 specification allows
* for pipelining, in which a client can sent multiple requests without waiting for each
* responses. Therefore a queue is needed to record all the requests that are sent but have not
* yet received a response.
*
* <p>A server must send its response in the same order it receives requests. This invariance
* guarantees that the request time at the head of the queue always corresponds to the response
* received in {@link #channelRead}.
*
* @see <a href="https://www.w3.org/Protocols/rfc2616/rfc2616-sec8.html">RFC 2616 8.1.2.2
* Pipelining</a>
*/
private final Queue<DateTime> requestSentTimeQueue = new ArrayDeque<>();
@Inject
BackendMetricsHandler(Clock clock, BackendMetrics metrics) {
this.clock = clock;
this.metrics = metrics;
}
@Override
public void channelRegistered(ChannelHandlerContext ctx) throws Exception {
// Backend channel is always established after a frontend channel is connected, so this call
// should always return a non-null relay channel.
relayedChannel = ctx.channel().attr(RELAY_CHANNEL_KEY).get();
checkNotNull(relayedChannel, "No frontend channel found.");
relayedProtocolName = relayedChannel.attr(PROTOCOL_KEY).get().name();
super.channelRegistered(ctx);
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
checkArgument(msg instanceof FullHttpResponse, "Incoming response must be FullHttpResponse.");
checkState(!requestSentTimeQueue.isEmpty(), "Response received before request is sent.");
metrics.responseReceived(
relayedProtocolName,
clientCertHash,
(FullHttpResponse) msg,
clock.nowUtc().getMillis() - requestSentTimeQueue.remove().getMillis());
super.channelRead(ctx, msg);
}
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
throws Exception {
checkArgument(msg instanceof FullHttpRequest, "Outgoing request must be FullHttpRequest.");
// For WHOIS, client certificate hash is always set to "none".
// For EPP, the client hash attribute is set upon handshake completion, before the first HELLO
// is sent to the server. Therefore the first call to write() with HELLO payload has access to
// the hash in its channel attribute.
if (clientCertHash == null) {
clientCertHash =
Optional.ofNullable(relayedChannel.attr(CLIENT_CERTIFICATE_HASH_KEY).get())
.orElse("none");
}
FullHttpRequest request = (FullHttpRequest) msg;
// Record request size now because the content would have read by the time the listener is
// called and the readable bytes would be zero by then.
int bytes = request.content().readableBytes();
// Record sent time before write finishes allows us to take network latency into account.
DateTime sentTime = clock.nowUtc();
ChannelFuture unusedFuture =
ctx.write(msg, promise)
.addListener(
future -> {
if (future.isSuccess()) {
// Only instrument request metrics when the request is actually sent to GAE.
metrics.requestSent(relayedProtocolName, clientCertHash, bytes);
requestSentTimeQueue.add(sentTime);
}
});
}
}

View file

@ -0,0 +1,148 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy.handler;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static google.registry.proxy.handler.ProxyProtocolHandler.REMOTE_ADDRESS_KEY;
import static google.registry.proxy.handler.SslServerInitializer.CLIENT_CERTIFICATE_PROMISE_KEY;
import static google.registry.util.X509Utils.getCertificateHash;
import com.google.common.flogger.FluentLogger;
import google.registry.proxy.metric.FrontendMetrics;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.ssl.SslHandshakeCompletionEvent;
import io.netty.util.AttributeKey;
import io.netty.util.concurrent.Promise;
import java.security.cert.X509Certificate;
import java.util.function.Supplier;
/** Handler that processes EPP protocol logic. */
public class EppServiceHandler extends HttpsRelayServiceHandler {
private static final FluentLogger logger = FluentLogger.forEnclosingClass();
/**
* Attribute key to the client certificate hash whose value is set when the certificate promise is
* fulfilled.
*/
public static final AttributeKey<String> CLIENT_CERTIFICATE_HASH_KEY =
AttributeKey.valueOf("CLIENT_CERTIFICATE_HASH_KEY");
/** Name of the HTTP header that stores the client certificate hash. */
public static final String SSL_CLIENT_CERTIFICATE_HASH_FIELD = "X-SSL-Certificate";
/** Name of the HTTP header that stores the client IP address. */
public static final String FORWARDED_FOR_FIELD = "X-Forwarded-For";
/** Name of the HTTP header that indicates if the EPP session should be closed. */
public static final String EPP_SESSION_FIELD = "Epp-Session";
public static final String EPP_CONTENT_TYPE = "application/epp+xml";
private final byte[] helloBytes;
private String sslClientCertificateHash;
private String clientAddress;
public EppServiceHandler(
String relayHost,
String relayPath,
Supplier<String> accessTokenSupplier,
byte[] helloBytes,
FrontendMetrics metrics) {
super(relayHost, relayPath, accessTokenSupplier, metrics);
this.helloBytes = helloBytes;
}
/**
* Write <hello> to the server after SSL handshake completion to request <greeting>
*
* <p>When handling EPP over TCP, the server should issue a <greeting> to the client when a
* connection is established. Nomulus app however does not automatically sends the <greeting> upon
* connection. The proxy therefore first sends a <hello> to registry to request a <greeting>
* response.
*
* <p>The <hello> request is only sent after SSL handshake is completed between the client and the
* proxy so that the client certificate hash is available, which is needed to communicate with the
* server. Because {@link SslHandshakeCompletionEvent} is triggered before any calls to {@link
* #channelRead} are scheduled by the event loop executor, the <hello> request is guaranteed to be
* the first message sent to the server.
*
* @see <a href="https://tools.ietf.org/html/rfc5734">RFC 5732 EPP Transport over TCP</a>
* @see <a href="https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt">The Proxy
* Protocol</a>
*/
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
Promise<X509Certificate> unusedPromise =
ctx.channel()
.attr(CLIENT_CERTIFICATE_PROMISE_KEY)
.get()
.addListener(
(Promise<X509Certificate> promise) -> {
if (promise.isSuccess()) {
sslClientCertificateHash = getCertificateHash(promise.get());
// Set the client cert hash key attribute for both this channel,
// used for collecting metrics on specific clients.
ctx.channel().attr(CLIENT_CERTIFICATE_HASH_KEY).set(sslClientCertificateHash);
clientAddress = ctx.channel().attr(REMOTE_ADDRESS_KEY).get();
metrics.registerActiveConnection(
"epp", sslClientCertificateHash, ctx.channel());
channelRead(ctx, Unpooled.wrappedBuffer(helloBytes));
} else {
logger.atWarning().withCause(promise.cause()).log(
"Cannot finish handshake for channel %s, remote IP %s",
ctx.channel(), ctx.channel().attr(REMOTE_ADDRESS_KEY).get());
ChannelFuture unusedFuture = ctx.close();
}
});
super.channelActive(ctx);
}
@Override
protected FullHttpRequest decodeFullHttpRequest(ByteBuf byteBuf) {
checkNotNull(clientAddress, "Cannot obtain client address.");
checkNotNull(sslClientCertificateHash, "Cannot obtain client certificate hash.");
FullHttpRequest request = super.decodeFullHttpRequest(byteBuf);
request
.headers()
.set(SSL_CLIENT_CERTIFICATE_HASH_FIELD, sslClientCertificateHash)
.set(FORWARDED_FOR_FIELD, clientAddress)
.set(HttpHeaderNames.CONTENT_TYPE, EPP_CONTENT_TYPE)
.set(HttpHeaderNames.ACCEPT, EPP_CONTENT_TYPE);
return request;
}
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
throws Exception {
checkArgument(msg instanceof HttpResponse);
HttpResponse response = (HttpResponse) msg;
String sessionAliveValue = response.headers().get(EPP_SESSION_FIELD);
if (sessionAliveValue != null && sessionAliveValue.equals("close")) {
promise.addListener(ChannelFutureListener.CLOSE);
}
super.write(ctx, msg, promise);
}
}

View file

@ -0,0 +1,43 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy.handler;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import java.nio.charset.StandardCharsets;
/** A handler that responds to GCP load balancer health check message */
public class HealthCheckHandler extends ChannelInboundHandlerAdapter {
private final ByteBuf checkRequest;
private final ByteBuf checkResponse;
public HealthCheckHandler(String checkRequest, String checkResponse) {
this.checkRequest = Unpooled.wrappedBuffer(checkRequest.getBytes(StandardCharsets.US_ASCII));
this.checkResponse = Unpooled.wrappedBuffer(checkResponse.getBytes(StandardCharsets.US_ASCII));
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
ByteBuf buf = (ByteBuf) msg;
if (buf.equals(checkRequest)) {
ChannelFuture unusedFuture = ctx.writeAndFlush(checkResponse);
}
buf.release();
}
}

View file

@ -0,0 +1,213 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy.handler;
import static java.nio.charset.StandardCharsets.UTF_8;
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableSet;
import com.google.common.flogger.FluentLogger;
import google.registry.proxy.metric.FrontendMetrics;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.ByteToMessageCodec;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.codec.http.cookie.ClientCookieDecoder;
import io.netty.handler.codec.http.cookie.ClientCookieEncoder;
import io.netty.handler.codec.http.cookie.Cookie;
import io.netty.handler.timeout.ReadTimeoutException;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
import javax.net.ssl.SSLHandshakeException;
/**
* Handler that relays a single (framed) ByteBuf message to an HTTPS server.
*
* <p>This handler reads in a {@link ByteBuf}, converts it to an {@link FullHttpRequest}, and passes
* it to the {@code channelRead} method of the next inbound handler the channel pipeline, which is
* usually a {@link RelayHandler<FullHttpRequest>}. The relay handler writes the request to the
* relay channel, which is connected to an HTTPS endpoint. After the relay channel receives a {@link
* FullHttpResponse} back, its own relay handler writes the response back to this channel, which is
* the relay channel of the relay channel. This handler then handles write request by encoding the
* {@link FullHttpResponse} to a plain {@link ByteBuf}, and pass it down to the {@code write} method
* of the next outbound handler in the channel pipeline, which eventually writes the response bytes
* to the remote peer of this channel.
*
* <p>This handler is session aware and will store all the session cookies that the are contained in
* the HTTP response headers, which are added back to headers of subsequent HTTP requests.
*/
public abstract class HttpsRelayServiceHandler extends ByteToMessageCodec<FullHttpResponse> {
private static final FluentLogger logger = FluentLogger.forEnclosingClass();
protected static final ImmutableSet<Class<? extends Exception>> NON_FATAL_INBOUND_EXCEPTIONS =
ImmutableSet.of(ReadTimeoutException.class, SSLHandshakeException.class);
protected static final ImmutableSet<Class<? extends Exception>> NON_FATAL_OUTBOUND_EXCEPTIONS =
ImmutableSet.of(NonOkHttpResponseException.class);
private final Map<String, Cookie> cookieStore = new LinkedHashMap<>();
private final String relayHost;
private final String relayPath;
private final Supplier<String> accessTokenSupplier;
protected final FrontendMetrics metrics;
HttpsRelayServiceHandler(
String relayHost,
String relayPath,
Supplier<String> accessTokenSupplier,
FrontendMetrics metrics) {
this.relayHost = relayHost;
this.relayPath = relayPath;
this.accessTokenSupplier = accessTokenSupplier;
this.metrics = metrics;
}
/**
* Construct the {@link FullHttpRequest}.
*
* <p>This default method creates a bare-bone {@link FullHttpRequest} that may need to be
* modified, e. g. adding headers specific for each protocol.
*
* @param byteBuf inbound message.
*/
protected FullHttpRequest decodeFullHttpRequest(ByteBuf byteBuf) {
FullHttpRequest request =
new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, relayPath);
request
.headers()
.set(HttpHeaderNames.USER_AGENT, "Proxy")
.set(HttpHeaderNames.HOST, relayHost)
.set(HttpHeaderNames.AUTHORIZATION, "Bearer " + accessTokenSupplier.get())
.setInt(HttpHeaderNames.CONTENT_LENGTH, byteBuf.readableBytes());
request.content().writeBytes(byteBuf);
return request;
}
/**
* Load session cookies in the cookie store and write them in to the HTTP request.
*
* <p>Multiple cookies are folded into one {@code Cookie} header per RFC 6265.
*
* @see <a href="https://tools.ietf.org/html/rfc6265#section-5.4">RFC 6265 5.4.The Cookie
* Header</a>
*/
private void loadCookies(FullHttpRequest request) {
if (!cookieStore.isEmpty()) {
request
.headers()
.set(HttpHeaderNames.COOKIE, ClientCookieEncoder.STRICT.encode(cookieStore.values()));
}
}
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf byteBuf, List<Object> out)
throws Exception {
FullHttpRequest request = decodeFullHttpRequest(byteBuf);
loadCookies(request);
out.add(request);
}
/**
* Construct the {@link ByteBuf}
*
* <p>This default method puts all the response payload into the {@link ByteBuf}.
*
* @param fullHttpResponse outbound http response.
*/
ByteBuf encodeFullHttpResponse(FullHttpResponse fullHttpResponse) {
return fullHttpResponse.content();
}
/**
* Save session cookies from the HTTP response header to the cookie store.
*
* <p>Multiple cookies are </b>not</b> folded in to one {@code Set-Cookie} header per RFC 6265.
*
* @see <a href="https://tools.ietf.org/html/rfc6265#section-3">RFC 6265 3.Overview</a>
*/
private void saveCookies(FullHttpResponse response) {
for (String cookieString : response.headers().getAll(HttpHeaderNames.SET_COOKIE)) {
Cookie cookie = ClientCookieDecoder.STRICT.decode(cookieString);
cookieStore.put(cookie.name(), cookie);
}
}
@Override
protected void encode(ChannelHandlerContext ctx, FullHttpResponse response, ByteBuf byteBuf)
throws Exception {
if (!response.status().equals(HttpResponseStatus.OK)) {
throw new NonOkHttpResponseException(response, ctx.channel());
}
saveCookies(response);
byteBuf.writeBytes(encodeFullHttpResponse(response));
}
/** Terminates connection upon inbound exception. */
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
if (NON_FATAL_INBOUND_EXCEPTIONS.contains(Throwables.getRootCause(cause).getClass())) {
logger.atWarning().withCause(cause).log(
"Inbound exception caught for channel %s", ctx.channel());
} else {
logger.atSevere().withCause(cause).log(
"Inbound exception caught for channel %s", ctx.channel());
}
ChannelFuture unusedFuture = ctx.close();
}
/** Terminates connection upon outbound exception. */
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
throws Exception {
promise.addListener(
(ChannelFuture channelFuture) -> {
if (!channelFuture.isSuccess()) {
Throwable cause = channelFuture.cause();
if (NON_FATAL_OUTBOUND_EXCEPTIONS.contains(Throwables.getRootCause(cause).getClass())) {
logger.atWarning().withCause(channelFuture.cause()).log(
"Outbound exception caught for channel %s", channelFuture.channel());
} else {
logger.atSevere().withCause(channelFuture.cause()).log(
"Outbound exception caught for channel %s", channelFuture.channel());
}
ChannelFuture unusedFuture = channelFuture.channel().close();
}
});
super.write(ctx, msg, promise);
}
/** Exception thrown when the response status from GAE is not 200. */
public static class NonOkHttpResponseException extends Exception {
NonOkHttpResponseException(FullHttpResponse response, Channel channel) {
super(
String.format(
"Cannot relay HTTP response status \"%s\" in channel %s:\n%s",
response.status(), channel, response.content().toString(UTF_8)));
}
}
}

View file

@ -0,0 +1,199 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy.handler;
import static com.google.common.base.Preconditions.checkState;
import static java.nio.charset.StandardCharsets.US_ASCII;
import com.google.common.flogger.FluentLogger;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.util.AttributeKey;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.List;
import javax.inject.Inject;
/**
* Handler that processes possible existence of a PROXY protocol v1 header.
*
* <p>When an EPP client connects to the registry (through the proxy), the registry performs two
* validations to ensure that only known registrars are allowed. First it checks the sha265 hash of
* the client SSL certificate and match it to the hash stored in datastore for the registrar. It
* then checks if the connection is from an whitelisted IP address that belongs to that registrar.
*
* <p>The proxy receives client connects via the GCP load balancer, which results in the loss of
* original client IP from the channel. Luckily, the load balancer supports the PROXY protocol v1,
* which adds a header with source IP information, among other things, to the TCP request at the
* start of the connection.
*
* <p>This handler determines if a connection is proxied (PROXY protocol v1 header present) and
* correctly sets the source IP address to the channel's attribute regardless of whether it is
* proxied. After that it removes itself from the channel pipeline because the proxy header is only
* present at the beginning of the connection.
*
* <p>This handler must be the very first handler in a protocol, even before SSL handlers, because
* PROXY protocol header comes as the very first thing, even before SSL handshake request.
*
* @see <a href="https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt">The PROXY protocol</a>
*/
public class ProxyProtocolHandler extends ByteToMessageDecoder {
/** Key used to retrieve origin IP address from a channel's attribute. */
public static final AttributeKey<String> REMOTE_ADDRESS_KEY =
AttributeKey.valueOf("REMOTE_ADDRESS_KEY");
private static final FluentLogger logger = FluentLogger.forEnclosingClass();
// The proxy header must start with this prefix.
// Sample header: "PROXY TCP4 255.255.255.255 255.255.255.255 65535 65535\r\n".
private static final byte[] HEADER_PREFIX = "PROXY".getBytes(US_ASCII);
private boolean finished = false;
private String proxyHeader = null;
@Inject
ProxyProtocolHandler() {}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
super.channelRead(ctx, msg);
if (finished) {
String remoteIP;
if (proxyHeader != null) {
logger.atFine().log("PROXIED CONNECTION: %s", ctx.channel());
logger.atFine().log("PROXY HEADER for channel %s: %s", ctx.channel(), proxyHeader);
String[] headerArray = proxyHeader.split(" ", -1);
if (headerArray.length == 6) {
remoteIP = headerArray[2];
logger.atFine().log(
"Header parsed, using %s as remote IP for channel %s", remoteIP, ctx.channel());
// If the header is "PROXY UNKNOWN"
// (see https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt), likely when the
// remote connection to the external load balancer is through special means, make it
// 0.0.0.0 so that it can be treated accordingly by the relevant quota configs.
} else if (headerArray.length == 2 && headerArray[1].equals("UNKNOWN")) {
logger.atFine().log(
"Header parsed, source IP unknown, using 0.0.0.0 as remote IP for channel %s",
ctx.channel());
remoteIP = "0.0.0.0";
} else {
logger.atFine().log(
"Cannot parse the header, using source IP as remote IP for channel %s",
ctx.channel());
remoteIP = getSourceIP(ctx);
}
} else {
logger.atFine().log(
"No header present, using source IP directly for channel %s", ctx.channel());
remoteIP = getSourceIP(ctx);
}
if (remoteIP != null) {
ctx.channel().attr(REMOTE_ADDRESS_KEY).set(remoteIP);
} else {
logger.atWarning().log("Not able to obtain remote IP for channel %s", ctx.channel());
}
// ByteToMessageDecoder automatically flushes unread bytes in the ByteBuf to the next handler
// when itself is being removed.
ctx.pipeline().remove(this);
}
}
private static String getSourceIP(ChannelHandlerContext ctx) {
SocketAddress remoteAddress = ctx.channel().remoteAddress();
return (remoteAddress instanceof InetSocketAddress)
? ((InetSocketAddress) remoteAddress).getAddress().getHostAddress()
: null;
}
/**
* Attempts to decode an internally accumulated buffer and find the proxy protocol header.
*
* <p>When the connection is not proxied (i. e. the initial bytes are not "PROXY"), simply set
* {@link #finished} to true and allow the handler to be removed. Otherwise the handler waits
* until there's enough bytes to parse the header, save the parsed header to {@link #proxyHeader},
* and then mark {@link #finished}.
*
* @param in internally accumulated buffer, newly arrived bytes are appended to it.
* @param out objects passed to the next handler, in this case nothing is ever passed because the
* header itself is processed and written to the attribute of the proxy, and the handler is
* then removed from the pipeline.
*/
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
// Wait until there are more bytes available than the header's length before processing.
if (in.readableBytes() >= HEADER_PREFIX.length) {
if (containsHeader(in)) {
// The inbound message contains the header, it must be a proxied connection. Note that
// currently proxied connection is only used for EPP protocol, which requires the connection
// to be SSL enabled. So the beginning of the inbound message upon connection can only be
// either the proxy header (when proxied), or SSL handshake request (when not proxied),
// which does not start with "PROXY". Therefore it is safe to assume that if the beginning
// of the message contains "PROXY", it must be proxied, and must contain \r\n.
int eol = findEndOfLine(in);
// If eol is not found, that is because that we do not yet have enough inbound message, do
// nothing and wait for more bytes to be readable. eol will eventually be positive because
// of the reasoning above: The connection starts with "PROXY", so it must be a proxied
// connection and contain \r\n.
if (eol >= 0) {
// ByteBuf.readBytes is called so that the header is processed and not passed to handlers
// further in the pipeline.
byte[] headerBytes = new byte[eol];
in.readBytes(headerBytes);
proxyHeader = new String(headerBytes, US_ASCII);
// Skip \r\n.
in.skipBytes(2);
// Proxy header processed, mark finished so that this handler is removed.
finished = true;
}
} else {
// The inbound message does not contain a proxy header, mark finished so that this handler
// is removed. Note that no inbound bytes are actually processed by this handler because we
// did not call ByteBuf.readBytes(), but ByteBuf.getByte(), which does not change reader
// index of the ByteBuf. So any inbound byte is then passed to the next handler to process.
finished = true;
}
}
}
/**
* Returns the index in the buffer of the end of line found. Returns -1 if no end of line was
* found in the buffer.
*/
private static int findEndOfLine(final ByteBuf buffer) {
final int n = buffer.writerIndex();
for (int i = buffer.readerIndex(); i < n; i++) {
final byte b = buffer.getByte(i);
if (b == '\r' && i < n - 1 && buffer.getByte(i + 1) == '\n') {
return i; // \r\n
}
}
return -1; // Not found.
}
/** Checks if the given buffer contains the proxy header prefix. */
private boolean containsHeader(ByteBuf buffer) {
// The readable bytes is always more or equal to the size of the header prefix because this
// method is only called when this condition is true.
checkState(buffer.readableBytes() >= HEADER_PREFIX.length);
for (int i = 0; i < HEADER_PREFIX.length; ++i) {
if (buffer.getByte(buffer.readerIndex() + i) != HEADER_PREFIX[i]) {
return false;
}
}
return true;
}
}

View file

@ -0,0 +1,167 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy.handler;
import static com.google.common.base.Preconditions.checkNotNull;
import static google.registry.proxy.Protocol.PROTOCOL_KEY;
import static google.registry.proxy.handler.EppServiceHandler.CLIENT_CERTIFICATE_HASH_KEY;
import static google.registry.proxy.handler.ProxyProtocolHandler.REMOTE_ADDRESS_KEY;
import google.registry.proxy.EppProtocolModule.EppProtocol;
import google.registry.proxy.WhoisProtocolModule.WhoisProtocol;
import google.registry.proxy.metric.FrontendMetrics;
import google.registry.proxy.quota.QuotaManager;
import google.registry.proxy.quota.QuotaManager.QuotaRebate;
import google.registry.proxy.quota.QuotaManager.QuotaRequest;
import google.registry.proxy.quota.QuotaManager.QuotaResponse;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import java.util.concurrent.Future;
import javax.inject.Inject;
/**
* Handler that checks quota fulfillment and terminates connection if necessary.
*
* <p>This handler attempts to acquire quota during the first {@link #channelRead} operation, not
* when connection is established. The reason is that the {@code userId} used for acquiring quota is
* not always available when the connection is just open.
*/
public abstract class QuotaHandler extends ChannelInboundHandlerAdapter {
protected final QuotaManager quotaManager;
protected QuotaResponse quotaResponse;
protected final FrontendMetrics metrics;
protected QuotaHandler(QuotaManager quotaManager, FrontendMetrics metrics) {
this.quotaManager = quotaManager;
this.metrics = metrics;
}
abstract String getUserId(ChannelHandlerContext ctx);
/** Whether the user id is PII ans should not be logged. IP addresses are considered PII. */
abstract boolean isUserIdPii();
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (quotaResponse == null) {
String userId = getUserId(ctx);
checkNotNull(userId, "Cannot obtain User ID");
quotaResponse = quotaManager.acquireQuota(QuotaRequest.create(userId));
if (!quotaResponse.success()) {
String protocolName = ctx.channel().attr(PROTOCOL_KEY).get().name();
metrics.registerQuotaRejection(protocolName, isUserIdPii() ? "none" : userId);
throw new OverQuotaException(protocolName, isUserIdPii() ? "none" : userId);
}
}
ctx.fireChannelRead(msg);
}
/**
* Actions to take when the connection terminates.
*
* <p>Depending on the quota type, the handler either returns the tokens, or does nothing.
*/
@Override
public abstract void channelInactive(ChannelHandlerContext ctx);
static class OverQuotaException extends Exception {
OverQuotaException(String protocol, String userId) {
super(String.format("Quota exceeded for: PROTOCOL: %s, USER ID: %s", protocol, userId));
}
}
/** Quota Handler for WHOIS protocol. */
public static class WhoisQuotaHandler extends QuotaHandler {
@Inject
WhoisQuotaHandler(@WhoisProtocol QuotaManager quotaManager, FrontendMetrics metrics) {
super(quotaManager, metrics);
}
/**
* Reads user ID from channel attribute {@code REMOTE_ADDRESS_KEY}.
*
* <p>This attribute is set by {@link ProxyProtocolHandler} when the first frame of message is
* read.
*/
@Override
String getUserId(ChannelHandlerContext ctx) {
return ctx.channel().attr(REMOTE_ADDRESS_KEY).get();
}
@Override
boolean isUserIdPii() {
return true;
}
/**
* Do nothing when connection terminates.
*
* <p>WHOIS protocol is configured with a QPS type quota, there is no need to return the tokens
* back to the quota store because the quota store will auto-refill tokens based on the QPS.
*/
@Override
public void channelInactive(ChannelHandlerContext ctx) {
ctx.fireChannelInactive();
}
}
/** Quota Handler for EPP protocol. */
public static class EppQuotaHandler extends QuotaHandler {
@Inject
EppQuotaHandler(@EppProtocol QuotaManager quotaManager, FrontendMetrics metrics) {
super(quotaManager, metrics);
}
/**
* Reads user ID from channel attribute {@code CLIENT_CERTIFICATE_HASH_KEY}.
*
* <p>This attribute is set by {@link EppServiceHandler} when SSH handshake completes
* successfully. That handler subsequently simulates reading of an EPP HELLO request, in order
* to solicit an EPP GREETING response from the server. The {@link #channelRead} method of this
* handler is called afterward because it is the next handler in the channel pipeline,
* guaranteeing that the {@code CLIENT_CERTIFICATE_HASH_KEY} is always non-null.
*/
@Override
String getUserId(ChannelHandlerContext ctx) {
return ctx.channel().attr(CLIENT_CERTIFICATE_HASH_KEY).get();
}
@Override
boolean isUserIdPii() {
return false;
}
/**
* Returns the leased token (if available) back to the token store upon connection termination.
*
* <p>A connection with concurrent quota needs to do this in order to maintain its quota number
* invariance.
*/
@Override
public void channelInactive(ChannelHandlerContext ctx) {
// If no reads occurred before the connection is inactive (for example when the handshake
// is not successful), no quota is leased and therefore no return is needed.
// Note that the quota response can be a failure, in which case no token was leased to us from
// the token store. Consequently no return is necessary.
if (quotaResponse != null && quotaResponse.success()) {
Future<?> unusedFuture = quotaManager.releaseQuota(QuotaRebate.create(quotaResponse));
}
ctx.fireChannelInactive();
}
}
}

View file

@ -0,0 +1,163 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy.handler;
import static google.registry.proxy.Protocol.PROTOCOL_KEY;
import com.google.common.flogger.FluentLogger;
import google.registry.proxy.handler.QuotaHandler.OverQuotaException;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.util.Attribute;
import io.netty.util.AttributeKey;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.ReferenceCounted;
import java.util.Deque;
import java.util.Queue;
import javax.inject.Inject;
/**
* Receives inbound massage of type {@code I}, and writes it to the {@code relayChannel} stored in
* the inbound channel's attribute.
*/
public class RelayHandler<I> extends SimpleChannelInboundHandler<I> {
/**
* A queue that saves messages that failed to be relayed.
*
* <p>This queue is null for channels that should not retry on failure, i. e. backend channels.
*
* <p>This queue does not need to be synchronised because it is only accessed by the I/O thread of
* the channel, or its relay channel. Since both channels use the same EventLoop, their I/O
* activities are handled by the same thread.
*/
public static final AttributeKey<Deque<Object>> RELAY_BUFFER_KEY =
AttributeKey.valueOf("RELAY_BUFFER_KEY");
private static final FluentLogger logger = FluentLogger.forEnclosingClass();
/** Key used to retrieve the relay channel from a {@link Channel}'s {@link Attribute}. */
public static final AttributeKey<Channel> RELAY_CHANNEL_KEY =
AttributeKey.valueOf("RELAY_CHANNEL");
public RelayHandler(Class<? extends I> clazz) {
super(clazz, false);
}
/** Read message of type {@code I}, write it as-is into the relay channel. */
@Override
protected void channelRead0(ChannelHandlerContext ctx, I msg) throws Exception {
Channel channel = ctx.channel();
Channel relayChannel = channel.attr(RELAY_CHANNEL_KEY).get();
if (relayChannel == null) {
logger.atSevere().log("Relay channel not specified for channel: %s", channel);
ChannelFuture unusedFuture = channel.close();
} else {
writeToRelayChannel(channel, relayChannel, msg, false);
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
if (cause instanceof OverQuotaException) {
logger.atWarning().withCause(cause).log(
"Channel %s closed due to quota exceeded.", ctx.channel());
} else {
logger.atWarning().withCause(cause).log(
"Channel %s closed due to unexpected exception.", ctx.channel());
}
ChannelFuture unusedFuture = ctx.close();
}
public static void writeToRelayChannel(
Channel channel, Channel relayChannel, Object msg, boolean retry) {
// If the message is reference counted, its internal buffer that holds the data will be freed by
// Netty when the reference count reduce to zero. When this message is written to the relay
// channel, regardless of whether it is successful or not, its reference count will be reduced
// to zero and its buffer will be freed. After the buffer is freed, the message cannot be used
// anymore, even if in Java's eye the object still exist, its content is gone. We increment a
// count here so that the message can be retried, in case the relay is not successful.
if (msg instanceof ReferenceCounted) {
((ReferenceCounted) msg).retain();
}
ChannelFuture unusedFuture =
relayChannel
.writeAndFlush(msg)
.addListener(
future -> {
if (!future.isSuccess()) {
logger.atWarning().withCause(future.cause()).log(
"Relay failed: %s --> %s\nINBOUND: %s\nOUTBOUND: %s\nHASH: %s",
channel.attr(PROTOCOL_KEY).get().name(),
relayChannel.attr(PROTOCOL_KEY).get().name(),
channel,
relayChannel,
msg.hashCode());
// If we cannot write to the relay channel and the originating channel has
// a relay buffer (i. e. we tried to relay the frontend to the backend), store
// the message in the buffer for retry later. The relay channel (backend) should
// be killed (if it is not already dead, usually the relay is unsuccessful
// because the connection is closed), and a new backend channel will re-connect
// as long as the frontend channel is open. Otherwise, we are relaying from the
// backend to the frontend, and this relay failure cannot be recovered from: we
// should just kill the relay (frontend) channel, which in turn will kill the
// backend channel.
Queue<Object> relayBuffer = channel.attr(RELAY_BUFFER_KEY).get();
if (relayBuffer != null) {
channel.attr(RELAY_BUFFER_KEY).get().add(msg);
} else {
// We are not going to retry, decrement a counter to allow the message to be
// freed by Netty, if the message is reference counted.
ReferenceCountUtil.release(msg);
}
ChannelFuture unusedFuture2 = relayChannel.close();
} else {
if (retry) {
logger.atInfo().log(
"Relay retry succeeded: %s --> %s\nINBOUND: %s\nOUTBOUND: %s\nHASH: %s",
channel.attr(PROTOCOL_KEY).get().name(),
relayChannel.attr(PROTOCOL_KEY).get().name(),
channel,
relayChannel,
msg.hashCode());
}
// If the write is successful, we know that no retry is needed. This function
// will decrement the reference count if the message is reference counted,
// allowing Netty to free the message's buffer.
ReferenceCountUtil.release(msg);
}
});
}
/** Specialized {@link RelayHandler} that takes a {@link FullHttpRequest} as inbound payload. */
public static class FullHttpRequestRelayHandler extends RelayHandler<FullHttpRequest> {
@Inject
public FullHttpRequestRelayHandler() {
super(FullHttpRequest.class);
}
}
/** Specialized {@link RelayHandler} that takes a {@link FullHttpResponse} as inbound payload. */
public static class FullHttpResponseRelayHandler extends RelayHandler<FullHttpResponse> {
@Inject
public FullHttpResponseRelayHandler() {
super(FullHttpResponse.class);
}
}
}

View file

@ -0,0 +1,84 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy.handler;
import static com.google.common.base.Preconditions.checkNotNull;
import static google.registry.proxy.Protocol.PROTOCOL_KEY;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.flogger.FluentLogger;
import google.registry.proxy.Protocol.BackendProtocol;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler.Sharable;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.SslProvider;
import java.security.cert.X509Certificate;
import javax.inject.Inject;
import javax.inject.Singleton;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLParameters;
/**
* Adds a client side SSL handler to the channel pipeline.
*
* <p>This <b>must</b> be the first handler provided for any handler provider list, if it is
* provided. The type parameter {@code C} is needed so that unit tests can construct this handler
* that works with {@link EmbeddedChannel};
*/
@Singleton
@Sharable
public class SslClientInitializer<C extends Channel> extends ChannelInitializer<C> {
private static final FluentLogger logger = FluentLogger.forEnclosingClass();
private final SslProvider sslProvider;
private final X509Certificate[] trustedCertificates;
@Inject
public SslClientInitializer(SslProvider sslProvider) {
// null uses the system default trust store.
this(sslProvider, null);
}
@VisibleForTesting
SslClientInitializer(SslProvider sslProvider, X509Certificate[] trustCertificates) {
logger.atInfo().log("Client SSL Provider: %s", sslProvider);
this.sslProvider = sslProvider;
this.trustedCertificates = trustCertificates;
}
@Override
protected void initChannel(C channel) throws Exception {
BackendProtocol protocol = (BackendProtocol) channel.attr(PROTOCOL_KEY).get();
checkNotNull(protocol, "Protocol is not set for channel: %s", channel);
SslHandler sslHandler =
SslContextBuilder.forClient()
.sslProvider(sslProvider)
.trustManager(trustedCertificates)
.build()
.newHandler(channel.alloc(), protocol.host(), protocol.port());
// Enable hostname verification.
SSLEngine sslEngine = sslHandler.engine();
SSLParameters sslParameters = sslEngine.getSSLParameters();
sslParameters.setEndpointIdentificationAlgorithm("HTTPS");
sslEngine.setSSLParameters(sslParameters);
channel.pipeline().addLast(sslHandler);
}
}

View file

@ -0,0 +1,106 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy.handler;
import com.google.common.flogger.FluentLogger;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler.Sharable;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.ssl.ClientAuth;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.SslProvider;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.netty.util.AttributeKey;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.Promise;
import java.security.PrivateKey;
import java.security.cert.X509Certificate;
import java.util.function.Supplier;
/**
* Adds a server side SSL handler to the channel pipeline.
*
* <p>This <b>should</b> be the first handler provided for any handler provider list, if it is
* provided. Unless you wish to first process the PROXY header with {@link ProxyProtocolHandler},
* which should come before this handler. The type parameter {@code C} is needed so that unit tests
* can construct this handler that works with {@link EmbeddedChannel};
*
* <p>The ssl handler added requires client authentication, but it uses an {@link
* InsecureTrustManagerFactory}, which accepts any ssl certificate presented by the client, as long
* as the client uses the corresponding private key to establish SSL handshake. The client
* certificate hash will be passed along to GAE as an HTTP header for verification (not handled by
* this handler).
*/
@Sharable
public class SslServerInitializer<C extends Channel> extends ChannelInitializer<C> {
/**
* Attribute key to the client certificate promise whose value is set when SSL handshake completes
* successfully.
*/
public static final AttributeKey<Promise<X509Certificate>> CLIENT_CERTIFICATE_PROMISE_KEY =
AttributeKey.valueOf("CLIENT_CERTIFICATE_PROMISE_KEY");
private static final FluentLogger logger = FluentLogger.forEnclosingClass();
private final boolean requireClientCert;
private final SslProvider sslProvider;
private final Supplier<PrivateKey> privateKeySupplier;
private final Supplier<X509Certificate[]> certificatesSupplier;
public SslServerInitializer(
boolean requireClientCert,
SslProvider sslProvider,
Supplier<PrivateKey> privateKeySupplier,
Supplier<X509Certificate[]> certificatesSupplier) {
logger.atInfo().log("Server SSL Provider: %s", sslProvider);
this.requireClientCert = requireClientCert;
this.sslProvider = sslProvider;
this.privateKeySupplier = privateKeySupplier;
this.certificatesSupplier = certificatesSupplier;
}
@Override
protected void initChannel(C channel) throws Exception {
SslHandler sslHandler =
SslContextBuilder.forServer(privateKeySupplier.get(), certificatesSupplier.get())
.sslProvider(sslProvider)
.trustManager(InsecureTrustManagerFactory.INSTANCE)
.clientAuth(requireClientCert ? ClientAuth.REQUIRE : ClientAuth.NONE)
.build()
.newHandler(channel.alloc());
if (requireClientCert) {
Promise<X509Certificate> clientCertificatePromise = channel.eventLoop().newPromise();
Future<Channel> unusedFuture =
sslHandler
.handshakeFuture()
.addListener(
future -> {
if (future.isSuccess()) {
Promise<X509Certificate> unusedPromise =
clientCertificatePromise.setSuccess(
(X509Certificate)
sslHandler.engine().getSession().getPeerCertificates()[0]);
} else {
Promise<X509Certificate> unusedPromise =
clientCertificatePromise.setFailure(future.cause());
}
});
channel.attr(CLIENT_CERTIFICATE_PROMISE_KEY).set(clientCertificatePromise);
}
channel.pipeline().addLast(sslHandler);
}
}

View file

@ -0,0 +1,145 @@
// Copyright 2018 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy.handler;
import static io.netty.handler.codec.http.HttpHeaderNames.CONNECTION;
import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH;
import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_TYPE;
import static io.netty.handler.codec.http.HttpHeaderNames.HOST;
import static io.netty.handler.codec.http.HttpHeaderNames.LOCATION;
import static io.netty.handler.codec.http.HttpHeaderValues.KEEP_ALIVE;
import static io.netty.handler.codec.http.HttpHeaderValues.TEXT_PLAIN;
import static io.netty.handler.codec.http.HttpMethod.GET;
import static io.netty.handler.codec.http.HttpMethod.HEAD;
import static io.netty.handler.codec.http.HttpResponseStatus.BAD_REQUEST;
import static io.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN;
import static io.netty.handler.codec.http.HttpResponseStatus.FOUND;
import static io.netty.handler.codec.http.HttpResponseStatus.METHOD_NOT_ALLOWED;
import static io.netty.handler.codec.http.HttpResponseStatus.MOVED_PERMANENTLY;
import static io.netty.handler.codec.http.HttpResponseStatus.OK;
import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1;
import com.google.common.base.Splitter;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.common.flogger.FluentLogger;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpUtil;
import java.time.Duration;
/**
* Handler that redirects web WHOIS requests to a canonical website.
*
* <p>ICANN requires that port 43 and web-based WHOIS are both available on whois.nic.TLD. Since we
* expose a single IPv4/IPv6 anycast external IP address for the proxy, we need the load balancer to
* router port 80/443 traffic to the proxy to support web WHOIS.
*
* <p>HTTP (port 80) traffic is simply upgraded to HTTPS (port 443) on the same host, while HTTPS
* requests are redirected to the {@code redirectHost}, which is the canonical website that provide
* the web WHOIS service.
*
* @see <a
* href="https://newgtlds.icann.org/sites/default/files/agreements/agreement-approved-31jul17-en.html">
* REGISTRY AGREEMENT</a>
*/
public class WebWhoisRedirectHandler extends SimpleChannelInboundHandler<HttpRequest> {
private static final FluentLogger logger = FluentLogger.forEnclosingClass();
/**
* HTTP health check sent by GCP HTTP load balancer is set to use this host name.
*
* <p>Status 200 must be returned in order for a health check to be considered successful.
*
* @see <a
* href="https://cloud.google.com/load-balancing/docs/health-check-concepts#http_https_and_http2_health_checks">
* HTTP, HTTPS, and HTTP/2 health checks</a>
*/
private static final String HEALTH_CHECK_HOST = "health-check.invalid";
private static final String HSTS_HEADER_NAME = "Strict-Transport-Security";
private static final Duration HSTS_MAX_AGE = Duration.ofDays(365);
private static final ImmutableList<HttpMethod> ALLOWED_METHODS = ImmutableList.of(GET, HEAD);
private final boolean isHttps;
private final String redirectHost;
public WebWhoisRedirectHandler(boolean isHttps, String redirectHost) {
this.isHttps = isHttps;
this.redirectHost = redirectHost;
}
@Override
protected void channelRead0(ChannelHandlerContext ctx, HttpRequest msg) {
FullHttpResponse response;
if (!ALLOWED_METHODS.contains(msg.method())) {
response = new DefaultFullHttpResponse(HTTP_1_1, METHOD_NOT_ALLOWED);
} else if (Strings.isNullOrEmpty(msg.headers().get(HOST))) {
response = new DefaultFullHttpResponse(HTTP_1_1, BAD_REQUEST);
} else {
// All HTTP/1.1 request must contain a Host header with the format "host:[port]".
// See https://tools.ietf.org/html/rfc2616#section-14.23
String host = Splitter.on(':').split(msg.headers().get(HOST)).iterator().next();
if (host.equals(HEALTH_CHECK_HOST)) {
// The health check request should always be sent to the HTTP port.
response =
isHttps
? new DefaultFullHttpResponse(HTTP_1_1, FORBIDDEN)
: new DefaultFullHttpResponse(HTTP_1_1, OK);
} else {
// HTTP -> HTTPS is a 301 redirect, whereas HTTPS -> web WHOIS site is 302 redirect.
response = new DefaultFullHttpResponse(HTTP_1_1, isHttps ? FOUND : MOVED_PERMANENTLY);
String redirectUrl = String.format("https://%s/", isHttps ? redirectHost : host);
response.headers().set(LOCATION, redirectUrl);
// Add HSTS header to HTTPS response.
if (isHttps) {
response
.headers()
.set(HSTS_HEADER_NAME, String.format("max-age=%d", HSTS_MAX_AGE.getSeconds()));
}
}
}
// Common headers that need to be set on any response.
response
.headers()
.set(CONTENT_TYPE, TEXT_PLAIN)
.setInt(CONTENT_LENGTH, response.content().readableBytes());
// Close the connection if keep-alive is not set in the request.
if (!HttpUtil.isKeepAlive(msg)) {
ChannelFuture unusedFuture =
ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE);
} else {
response.headers().set(CONNECTION, KEEP_ALIVE);
ChannelFuture unusedFuture = ctx.writeAndFlush(response);
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
logger.atWarning().withCause(cause).log(
(isHttps ? "HTTPS" : "HTTP") + " WHOIS inbound exception caught for channel %s",
ctx.channel());
ChannelFuture unusedFuture = ctx.close();
}
}

View file

@ -0,0 +1,66 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy.handler;
import static com.google.common.base.Preconditions.checkArgument;
import google.registry.proxy.metric.FrontendMetrics;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.HttpResponse;
import java.util.function.Supplier;
/** Handler that processes WHOIS protocol logic. */
public final class WhoisServiceHandler extends HttpsRelayServiceHandler {
public WhoisServiceHandler(
String relayHost,
String relayPath,
Supplier<String> accessTokenSupplier,
FrontendMetrics metrics) {
super(relayHost, relayPath, accessTokenSupplier, metrics);
}
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
metrics.registerActiveConnection("whois", "none", ctx.channel());
super.channelActive(ctx);
}
@Override
protected FullHttpRequest decodeFullHttpRequest(ByteBuf byteBuf) {
FullHttpRequest request = super.decodeFullHttpRequest(byteBuf);
request
.headers()
.set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.TEXT_PLAIN)
.set(HttpHeaderNames.ACCEPT, HttpHeaderValues.TEXT_PLAIN);
return request;
}
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
throws Exception {
// Close connection after a response is received, per RFC-3912
// https://tools.ietf.org/html/rfc3912
checkArgument(msg instanceof HttpResponse);
promise.addListener(ChannelFutureListener.CLOSE);
super.write(ctx, msg, promise);
}
}

View file

@ -0,0 +1,54 @@
apiVersion: apps/v1
kind: Deployment
metadata:
namespace: default
name: proxy-deployment
labels:
app: proxy
spec:
replicas: 3
selector:
matchLabels:
app: proxy
template:
metadata:
labels:
app: proxy
spec:
containers:
- name: proxy
image: gcr.io/GCP_PROJECT/proxy
ports:
- containerPort: 30000
name: health-check
- containerPort: 30001
name: whois
- containerPort: 30002
name: epp
- containerPort: 30010
name: http-whois
- containerPort: 30011
name: https-whois
readinessProbe:
tcpSocket:
port: health-check
initialDelaySeconds: 5
periodSeconds: 10
livenessProbe:
tcpSocket:
port: health-check
initialDelaySeconds: 15
periodSeconds: 20
imagePullPolicy: Always
args: ["--env", "alpha", "--log"]
env:
- name: POD_ID
valueFrom:
fieldRef:
fieldPath: metadata.name
- name: NAMESPACE_ID
valueFrom:
fieldRef:
fieldPath: metadata.namespace
- name: CONTAINER_NAME
value: proxy

View file

@ -0,0 +1,54 @@
apiVersion: apps/v1
kind: Deployment
metadata:
namespace: default
name: proxy-deployment-canary
labels:
app: proxy-canary
spec:
replicas: 3
selector:
matchLabels:
app: proxy-canary
template:
metadata:
labels:
app: proxy-canary
spec:
containers:
- name: proxy-canary
image: gcr.io/GCP_PROJECT/proxy
ports:
- containerPort: 30000
name: health-check
- containerPort: 30001
name: whois
- containerPort: 30002
name: epp
- containerPort: 30010
name: http-whois
- containerPort: 30011
name: https-whois
readinessProbe:
tcpSocket:
port: health-check
initialDelaySeconds: 5
periodSeconds: 10
livenessProbe:
tcpSocket:
port: health-check
initialDelaySeconds: 15
periodSeconds: 20
imagePullPolicy: Always
args: ["--env", "crash_canary", "--log"]
env:
- name: POD_ID
valueFrom:
fieldRef:
fieldPath: metadata.name
- name: NAMESPACE_ID
valueFrom:
fieldRef:
fieldPath: metadata.namespace
- name: CONTAINER_NAME
value: proxy-canary

View file

@ -0,0 +1,54 @@
apiVersion: apps/v1
kind: Deployment
metadata:
namespace: default
name: proxy-deployment
labels:
app: proxy
spec:
replicas: 3
selector:
matchLabels:
app: proxy
template:
metadata:
labels:
app: proxy
spec:
containers:
- name: proxy
image: gcr.io/GCP_PROJECT/proxy
ports:
- containerPort: 30000
name: health-check
- containerPort: 30001
name: whois
- containerPort: 30002
name: epp
- containerPort: 30010
name: http-whois
- containerPort: 30011
name: https-whois
readinessProbe:
tcpSocket:
port: health-check
initialDelaySeconds: 5
periodSeconds: 10
livenessProbe:
tcpSocket:
port: health-check
initialDelaySeconds: 15
periodSeconds: 20
imagePullPolicy: Always
args: ["--env", "crash", "--log"]
env:
- name: POD_ID
valueFrom:
fieldRef:
fieldPath: metadata.name
- name: NAMESPACE_ID
valueFrom:
fieldRef:
fieldPath: metadata.namespace
- name: CONTAINER_NAME
value: proxy

View file

@ -0,0 +1,55 @@
apiVersion: apps/v1
kind: Deployment
metadata:
namespace: default
name: proxy-deployment-canary
labels:
app: proxy-canary
spec:
replicas: 3
selector:
matchLabels:
app: proxy-canary
template:
metadata:
labels:
app: proxy-canary
spec:
containers:
- name: proxy-canary
image: gcr.io/GCP_PROJECT/proxy
ports:
- containerPort: 30000
name: health-check
- containerPort: 30001
name: whois
- containerPort: 30002
name: epp
- containerPort: 30010
name: http-whois
- containerPort: 30011
name: https-whois
readinessProbe:
tcpSocket:
port: health-check
initialDelaySeconds: 5
periodSeconds: 10
livenessProbe:
tcpSocket:
port: health-check
initialDelaySeconds: 15
periodSeconds: 20
imagePullPolicy: Always
args: ["--env", "production_canary"]
env:
- name: POD_ID
valueFrom:
fieldRef:
fieldPath: metadata.name
- name: NAMESPACE_ID
valueFrom:
fieldRef:
fieldPath: metadata.namespace
- name: CONTAINER_NAME
value: proxy-canary

View file

@ -0,0 +1,55 @@
apiVersion: apps/v1
kind: Deployment
metadata:
namespace: default
name: proxy-deployment
labels:
app: proxy
spec:
replicas: 3
selector:
matchLabels:
app: proxy
template:
metadata:
labels:
app: proxy
spec:
containers:
- name: proxy
image: gcr.io/GCP_PROJECT/proxy
ports:
- containerPort: 30000
name: health-check
- containerPort: 30001
name: whois
- containerPort: 30002
name: epp
- containerPort: 30010
name: http-whois
- containerPort: 30011
name: https-whois
readinessProbe:
tcpSocket:
port: health-check
initialDelaySeconds: 5
periodSeconds: 10
livenessProbe:
tcpSocket:
port: health-check
initialDelaySeconds: 15
periodSeconds: 20
imagePullPolicy: Always
args: ["--env", "production"]
env:
- name: POD_ID
valueFrom:
fieldRef:
fieldPath: metadata.name
- name: NAMESPACE_ID
valueFrom:
fieldRef:
fieldPath: metadata.namespace
- name: CONTAINER_NAME
value: proxy

View file

@ -0,0 +1,55 @@
apiVersion: apps/v1
kind: Deployment
metadata:
namespace: default
name: proxy-deployment-canary
labels:
app: proxy-canary
spec:
replicas: 3
selector:
matchLabels:
app: proxy-canary
template:
metadata:
labels:
app: proxy-canary
spec:
containers:
- name: proxy-canary
image: gcr.io/GCP_PROJECT/proxy
ports:
- containerPort: 30000
name: health-check
- containerPort: 30001
name: whois
- containerPort: 30002
name: epp
- containerPort: 30010
name: http-whois
- containerPort: 30011
name: https-whois
readinessProbe:
tcpSocket:
port: health-check
initialDelaySeconds: 5
periodSeconds: 10
livenessProbe:
tcpSocket:
port: health-check
initialDelaySeconds: 15
periodSeconds: 20
imagePullPolicy: Always
args: ["--env", "sandbox_canary", "--log"]
env:
- name: POD_ID
valueFrom:
fieldRef:
fieldPath: metadata.name
- name: NAMESPACE_ID
valueFrom:
fieldRef:
fieldPath: metadata.namespace
- name: CONTAINER_NAME
value: proxy-canary

View file

@ -0,0 +1,55 @@
apiVersion: apps/v1
kind: Deployment
metadata:
namespace: default
name: proxy-deployment
labels:
app: proxy
spec:
replicas: 3
selector:
matchLabels:
app: proxy
template:
metadata:
labels:
app: proxy
spec:
containers:
- name: proxy
image: gcr.io/GCP_PROJECT/proxy
ports:
- containerPort: 30000
name: health-check
- containerPort: 30001
name: whois
- containerPort: 30002
name: epp
- containerPort: 30010
name: http-whois
- containerPort: 30011
name: https-whois
readinessProbe:
tcpSocket:
port: health-check
initialDelaySeconds: 5
periodSeconds: 10
livenessProbe:
tcpSocket:
port: health-check
initialDelaySeconds: 15
periodSeconds: 20
imagePullPolicy: Always
args: ["--env", "sandbox", "--log"]
env:
- name: POD_ID
valueFrom:
fieldRef:
fieldPath: metadata.name
- name: NAMESPACE_ID
valueFrom:
fieldRef:
fieldPath: metadata.namespace
- name: CONTAINER_NAME
value: proxy

View file

@ -0,0 +1,50 @@
kind: Service
apiVersion: v1
metadata:
namespace: default
name: proxy-service-canary
spec:
selector:
app: proxy-canary
ports:
- protocol: TCP
port: 30000
nodePort: 31000
targetPort: health-check
name: health-check
- protocol: TCP
port: 30001
nodePort: 31001
targetPort: whois
name: whois
- protocol: TCP
port: 30002
nodePort: 31002
targetPort: epp
name: epp
- protocol: TCP
port: 30010
nodePort: 31010
targetPort: http-whois
name: http-whois
- protocol: TCP
port: 30011
nodePort: 31011
targetPort: https-whois
name: https-whois
type: NodePort
---
apiVersion: autoscaling/v2beta1
kind: HorizontalPodAutoscaler
metadata:
namespace: default
name: proxy-autoscale-canary
labels:
app: proxy-canary
spec:
scaleTargetRef:
apiVersion: extensions/v1beta1
kind: Deployment
name: proxy-deployment-canary
maxReplicas: 10
minReplicas: 1

View file

@ -0,0 +1,50 @@
kind: Service
apiVersion: v1
metadata:
namespace: default
name: proxy-service
spec:
selector:
app: proxy
ports:
- protocol: TCP
port: 30000
nodePort: 30000
targetPort: health-check
name: health-check
- protocol: TCP
port: 30001
nodePort: 30001
targetPort: whois
name: whois
- protocol: TCP
port: 30002
nodePort: 30002
targetPort: epp
name: epp
- protocol: TCP
port: 30010
nodePort: 30010
targetPort: http-whois
name: http-whois
- protocol: TCP
port: 30011
nodePort: 30011
targetPort: https-whois
name: https-whois
type: NodePort
---
apiVersion: autoscaling/v2beta1
kind: HorizontalPodAutoscaler
metadata:
namespace: default
name: proxy-autoscale
labels:
app: proxy
spec:
scaleTargetRef:
apiVersion: extensions/v1beta1
kind: Deployment
name: proxy-deployment
maxReplicas: 10
minReplicas: 1

View file

@ -0,0 +1,126 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy.metric;
import com.google.common.collect.ImmutableSet;
import com.google.monitoring.metrics.CustomFitter;
import com.google.monitoring.metrics.EventMetric;
import com.google.monitoring.metrics.ExponentialFitter;
import com.google.monitoring.metrics.FibonacciFitter;
import com.google.monitoring.metrics.IncrementableMetric;
import com.google.monitoring.metrics.LabelDescriptor;
import com.google.monitoring.metrics.MetricRegistryImpl;
import google.registry.util.NonFinalForTesting;
import io.netty.handler.codec.http.FullHttpResponse;
import javax.inject.Inject;
import javax.inject.Singleton;
/** Backend metrics instrumentation. */
@Singleton
public class BackendMetrics {
// Maximum request size is defined in the config file, this is not realistic and we'd be out of
// memory when the size approach 1 GB.
private static final CustomFitter DEFAULT_SIZE_FITTER = FibonacciFitter.create(1073741824);
// Maximum 1 hour latency, this is not specified by the spec, but given we have a one hour idle
// timeout, it seems reasonable that maximum latency is set to 1 hour as well. If we are
// approaching anywhere near 1 hour latency, we'd be way out of SLO anyway.
private static final ExponentialFitter DEFAULT_LATENCY_FITTER =
ExponentialFitter.create(22, 2, 1.0);
private static final ImmutableSet<LabelDescriptor> LABELS =
ImmutableSet.of(
LabelDescriptor.create("protocol", "Name of the protocol."),
LabelDescriptor.create(
"client_cert_hash", "SHA256 hash of the client certificate, if available."));
static final IncrementableMetric requestsCounter =
MetricRegistryImpl.getDefault()
.newIncrementableMetric(
"/proxy/backend/requests",
"Total number of requests send to the backend.",
"Requests",
LABELS);
static final IncrementableMetric responsesCounter =
MetricRegistryImpl.getDefault()
.newIncrementableMetric(
"/proxy/backend/responses",
"Total number of responses received by the backend.",
"Responses",
ImmutableSet.<LabelDescriptor>builder()
.addAll(LABELS)
.add(LabelDescriptor.create("status", "HTTP status code."))
.build());
static final EventMetric requestBytes =
MetricRegistryImpl.getDefault()
.newEventMetric(
"/proxy/backend/request_bytes",
"Size of the backend requests sent.",
"Request Bytes",
LABELS,
DEFAULT_SIZE_FITTER);
static final EventMetric responseBytes =
MetricRegistryImpl.getDefault()
.newEventMetric(
"/proxy/backend/response_bytes",
"Size of the backend responses received.",
"Response Bytes",
LABELS,
DEFAULT_SIZE_FITTER);
static final EventMetric latencyMs =
MetricRegistryImpl.getDefault()
.newEventMetric(
"/proxy/backend/latency_ms",
"Round-trip time between a request sent and its corresponding response received.",
"Latency Milliseconds",
LABELS,
DEFAULT_LATENCY_FITTER);
@Inject
BackendMetrics() {}
/**
* Resets all backend metrics.
*
* <p>This should only used in tests to clear out states. No production code should call this
* function.
*/
void resetMetric() {
requestBytes.reset();
requestsCounter.reset();
responseBytes.reset();
responsesCounter.reset();
latencyMs.reset();
}
@NonFinalForTesting
public void requestSent(String protocol, String certHash, int bytes) {
requestsCounter.increment(protocol, certHash);
requestBytes.record(bytes, protocol, certHash);
}
@NonFinalForTesting
public void responseReceived(
String protocol, String certHash, FullHttpResponse response, long latency) {
latencyMs.record(latency, protocol, certHash);
responseBytes.record(response.content().readableBytes(), protocol, certHash);
responsesCounter.increment(protocol, certHash, response.status().toString());
}
}

View file

@ -0,0 +1,125 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy.metric;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.monitoring.metrics.IncrementableMetric;
import com.google.monitoring.metrics.LabelDescriptor;
import com.google.monitoring.metrics.Metric;
import com.google.monitoring.metrics.MetricRegistryImpl;
import google.registry.util.NonFinalForTesting;
import io.netty.channel.Channel;
import io.netty.channel.group.ChannelGroup;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.util.concurrent.GlobalEventExecutor;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import javax.inject.Inject;
import javax.inject.Singleton;
/** Frontend metrics instrumentation. */
@Singleton
public class FrontendMetrics {
/**
* Labels to register front metrics with.
*
* <p>The client certificate hash value is only used for EPP metrics. For WHOIS metrics, it will
* always be {@code "none"}. In order to get the actual registrar name, one can use the {@code
* nomulus} tool:
*
* <pre>
* nomulus -e production list_registrars -f clientCertificateHash | grep $HASH
* </pre>
*/
private static final ImmutableSet<LabelDescriptor> LABELS =
ImmutableSet.of(
LabelDescriptor.create("protocol", "Name of the protocol."),
LabelDescriptor.create(
"client_cert_hash", "SHA256 hash of the client certificate, if available."));
private static final ConcurrentMap<ImmutableList<String>, ChannelGroup> activeConnections =
new ConcurrentHashMap<>();
static final Metric<Long> activeConnectionsGauge =
MetricRegistryImpl.getDefault()
.newGauge(
"/proxy/frontend/active_connections",
"Number of active connections from clients to the proxy.",
"Active Connections",
LABELS,
() ->
activeConnections
.entrySet()
.stream()
.collect(
ImmutableMap.toImmutableMap(
Map.Entry::getKey, entry -> (long) entry.getValue().size())),
Long.class);
static final IncrementableMetric totalConnectionsCounter =
MetricRegistryImpl.getDefault()
.newIncrementableMetric(
"/proxy/frontend/total_connections",
"Total number connections ever made from clients to the proxy.",
"Total Connections",
LABELS);
static final IncrementableMetric quotaRejectionsCounter =
MetricRegistryImpl.getDefault()
.newIncrementableMetric(
"/proxy/frontend/quota_rejections",
"Total number rejected quota request made by proxy for each connection.",
"Quota Rejections",
LABELS);
@Inject
public FrontendMetrics() {}
/**
* Resets all frontend metrics.
*
* <p>This should only be used in tests to reset states. Production code should not call this
* method.
*/
@VisibleForTesting
void resetMetrics() {
totalConnectionsCounter.reset();
activeConnections.clear();
}
@NonFinalForTesting
public void registerActiveConnection(String protocol, String certHash, Channel channel) {
totalConnectionsCounter.increment(protocol, certHash);
ImmutableList<String> labels = ImmutableList.of(protocol, certHash);
ChannelGroup channelGroup;
if (activeConnections.containsKey(labels)) {
channelGroup = activeConnections.get(labels);
} else {
channelGroup = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE);
activeConnections.put(labels, channelGroup);
}
channelGroup.add(channel);
}
@NonFinalForTesting
public void registerQuotaRejection(String protocol, String certHash) {
quotaRejectionsCounter.increment(protocol, certHash);
}
}

View file

@ -0,0 +1,143 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy.metric;
import static java.nio.charset.StandardCharsets.UTF_8;
import com.google.api.services.monitoring.v3.model.MonitoredResource;
import com.google.common.collect.ImmutableMap;
import com.google.common.flogger.FluentLogger;
import com.google.common.io.CharStreams;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.HttpURLConnection;
import java.net.URL;
import java.util.Map;
import java.util.function.Function;
import javax.inject.Inject;
/**
* Utility class to obtain labels for monitored resource of type {@code gke_container}.
*
* <p>Custom metrics collected by the proxy need to be associated with a {@link MonitoredResource}.
* When running on GKE, the type is {@code gke_container}. The labels for this type are used to
* group related metrics together, and to avoid out-of-order metrics writes. This class provides a
* map of the labels where the values are either read from environment variables (pod and container
* related labels) or queried from GCE metadata server (cluster and instance related labels).
*
* @see <a
* href="https://cloud.google.com/monitoring/custom-metrics/creating-metrics#which-resource">
* Creating Custom Metrics - Choosing a monitored resource type</a>
* @see <a href="https://cloud.google.com/monitoring/api/resources#tag_gke_container">Monitored
* Resource Types - gke_container</a>
* @see <a href="https://cloud.google.com/compute/docs/storing-retrieving-metadata#querying">Storing
* and Retrieving Instance Metadata - Getting metadata</a>
* @see <a
* href="https://kubernetes.io/docs/tasks/inject-data-application/environment-variable-expose-pod-information/">
* Expose Pod Information to Containers Through Environment Variables </a>
*/
public class MetricParameters {
// Environment variable names, defined in the GKE deployment pod spec.
static final String NAMESPACE_ID_ENV = "NAMESPACE_ID";
static final String POD_ID_ENV = "POD_ID";
static final String CONTAINER_NAME_ENV = "CONTAINER_NAME";
// GCE metadata server URLs to retrieve instance related information.
private static final String GCE_METADATA_URL_BASE = "http://metadata.google.internal/";
static final String PROJECT_ID_PATH = "computeMetadata/v1/project/project-id";
static final String CLUSTER_NAME_PATH = "computeMetadata/v1/instance/attributes/cluster-name";
static final String INSTANCE_ID_PATH = "computeMetadata/v1/instance/id";
static final String ZONE_PATH = "computeMetadata/v1/instance/zone";
private static final FluentLogger logger = FluentLogger.forEnclosingClass();
private final Map<String, String> envVarMap;
private final Function<String, HttpURLConnection> connectionFactory;
MetricParameters(
Map<String, String> envVarMap, Function<String, HttpURLConnection> connectionFactory) {
this.envVarMap = envVarMap;
this.connectionFactory = connectionFactory;
}
@Inject
MetricParameters() {
this(ImmutableMap.copyOf(System.getenv()), MetricParameters::gceConnectionFactory);
}
private static HttpURLConnection gceConnectionFactory(String path) {
String url = GCE_METADATA_URL_BASE + path;
try {
HttpURLConnection connection = (HttpURLConnection) new URL(url).openConnection();
connection.setRequestMethod("GET");
// The metadata server requires this header to be set when querying from a GCE instance.
connection.setRequestProperty("Metadata-Flavor", "Google");
connection.setDoOutput(true);
return connection;
} catch (IOException e) {
throw new RuntimeException(String.format("Incorrect GCE metadata server URL: %s", url), e);
}
}
private String readEnvVar(String envVar) {
return envVarMap.getOrDefault(envVar, "");
}
private String readGceMetadata(String path) {
String value = "";
HttpURLConnection connection = connectionFactory.apply(path);
try {
connection.connect();
int responseCode = connection.getResponseCode();
if (responseCode < 200 || responseCode > 299) {
logger.atWarning().log(
"Got an error response: %d\n%s",
responseCode,
CharStreams.toString(new InputStreamReader(connection.getErrorStream(), UTF_8)));
} else {
value = CharStreams.toString(new InputStreamReader(connection.getInputStream(), UTF_8));
}
} catch (IOException e) {
logger.atWarning().withCause(e).log("Cannot obtain GCE metadata from path %s", path);
}
return value;
}
public ImmutableMap<String, String> makeLabelsMap() {
// The zone metadata is in the form of "projects/<PROJECT_NUMERICAL_ID>/zones/<ZONE_NAME>".
// We only need the last part after the slash.
String fullZone = readGceMetadata(ZONE_PATH);
String zone;
String[] fullZoneArray = fullZone.split("/", -1);
if (fullZoneArray.length < 4) {
logger.atWarning().log("Zone %s is valid.", fullZone);
// This will make the metric report throw, but it happens in a different thread and will not
// kill the whole application.
zone = "";
} else {
zone = fullZoneArray[3];
}
return new ImmutableMap.Builder<String, String>()
.put("project_id", readGceMetadata(PROJECT_ID_PATH))
.put("cluster_name", readGceMetadata(CLUSTER_NAME_PATH))
.put("namespace_id", readEnvVar(NAMESPACE_ID_ENV))
.put("instance_id", readGceMetadata(INSTANCE_ID_PATH))
.put("pod_id", readEnvVar(POD_ID_ENV))
.put("container_name", readEnvVar(CONTAINER_NAME_ENV))
.put("zone", zone)
.build();
}
}

View file

@ -0,0 +1,87 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy.quota;
import static com.google.common.base.Preconditions.checkState;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;
import google.registry.proxy.ProxyConfig.Quota;
import google.registry.proxy.ProxyConfig.Quota.QuotaGroup;
import org.joda.time.Duration;
/** Value class that stores the quota configuration for a protocol. */
public class QuotaConfig {
/** A special value of token amount that indicates unlimited tokens. */
public static final int SENTINEL_UNLIMITED_TOKENS = -1;
private final String protocolName;
private final int refreshSeconds;
private final QuotaGroup defaultQuota;
private final ImmutableMap<String, QuotaGroup> customQuotaMap;
/**
* Constructs a {@link QuotaConfig} from a {@link Quota}.
*
* <p>Each {@link QuotaGroup} is keyed to all the {@code userId}s it contains. This allows for
* fast lookup with a {@code userId}.
*/
public QuotaConfig(Quota quota, String protocolName) {
this.protocolName = protocolName;
refreshSeconds = quota.refreshSeconds;
defaultQuota = quota.defaultQuota;
ImmutableMap.Builder<String, QuotaGroup> mapBuilder = new ImmutableMap.Builder<>();
quota.customQuota.forEach(
quotaGroup -> quotaGroup.userId.forEach(userId -> mapBuilder.put(userId, quotaGroup)));
customQuotaMap = mapBuilder.build();
}
@VisibleForTesting
QuotaGroup findQuotaGroup(String userId) {
return customQuotaMap.getOrDefault(userId, defaultQuota);
}
/**
* Returns if the given user ID is provisioned with unlimited tokens.
*
* <p>This is configured by setting {@code tokenAmount} to {@code -1} in the config file.
*/
boolean hasUnlimitedTokens(String userId) {
return findQuotaGroup(userId).tokenAmount == SENTINEL_UNLIMITED_TOKENS;
}
/** Returns the token amount for the given {@code userId}. */
int getTokenAmount(String userId) {
checkState(
!hasUnlimitedTokens(userId), "User ID %s is provisioned with unlimited tokens", userId);
return findQuotaGroup(userId).tokenAmount;
}
/** Returns the refill period for the given {@code userId}. */
Duration getRefillPeriod(String userId) {
return Duration.standardSeconds(findQuotaGroup(userId).refillSeconds);
}
/** Returns the refresh period for this quota config. */
Duration getRefreshPeriod() {
return Duration.standardSeconds(refreshSeconds);
}
/** Returns the name of the protocol for which this quota config is made. */
String getProtocolName() {
return protocolName;
}
}

View file

@ -0,0 +1,102 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy.quota;
import com.google.auto.value.AutoValue;
import google.registry.proxy.quota.TokenStore.TimestampedInteger;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import javax.annotation.concurrent.ThreadSafe;
import org.joda.time.DateTime;
/**
* A thread-safe quota manager that schedules background refresh if necessary.
*
* <p>This class abstracts away details about the {@link TokenStore}. It:
*
* <ul>
* <li>Translates a {@link QuotaRequest} to taking one token from the store, blocks the caller,
* and responds with a {@link QuotaResponse}.
* <li>Translates a {@link QuotaRebate} to putting the token to the store asynchronously, and
* immediately returns.
* <li>Periodically refreshes the token records asynchronously to purge stale recodes.
* </ul>
*
* <p>There should be one {@link QuotaManager} per protocol.
*/
@ThreadSafe
public class QuotaManager {
/** Value class representing a quota request. */
@AutoValue
public abstract static class QuotaRequest {
public static QuotaRequest create(String userId) {
return new AutoValue_QuotaManager_QuotaRequest(userId);
}
abstract String userId();
}
/** Value class representing a quota response. */
@AutoValue
public abstract static class QuotaResponse {
public static QuotaResponse create(
boolean success, String userId, DateTime grantedTokenRefillTime) {
return new AutoValue_QuotaManager_QuotaResponse(success, userId, grantedTokenRefillTime);
}
public abstract boolean success();
abstract String userId();
abstract DateTime grantedTokenRefillTime();
}
/** Value class representing a quota rebate. */
@AutoValue
public abstract static class QuotaRebate {
public static QuotaRebate create(QuotaResponse response) {
return new AutoValue_QuotaManager_QuotaRebate(
response.userId(), response.grantedTokenRefillTime());
}
abstract String userId();
abstract DateTime grantedTokenRefillTime();
}
private final TokenStore tokenStore;
private final ExecutorService backgroundExecutor;
public QuotaManager(TokenStore tokenStore, ExecutorService backgroundExecutor) {
this.tokenStore = tokenStore;
this.backgroundExecutor = backgroundExecutor;
tokenStore.scheduleRefresh();
}
/** Attempts to acquire requested quota, synchronously. */
public QuotaResponse acquireQuota(QuotaRequest request) {
TimestampedInteger tokens = tokenStore.take(request.userId());
return QuotaResponse.create(tokens.value() != 0, request.userId(), tokens.timestamp());
}
/** Returns granted quota to the token store, asynchronously. */
public Future<?> releaseQuota(QuotaRebate rebate) {
return backgroundExecutor.submit(
() -> tokenStore.put(rebate.userId(), rebate.grantedTokenRefillTime()));
}
}

View file

@ -0,0 +1,222 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy.quota;
import static google.registry.proxy.quota.QuotaConfig.SENTINEL_UNLIMITED_TOKENS;
import static java.lang.StrictMath.max;
import static java.lang.StrictMath.min;
import com.google.auto.value.AutoValue;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.flogger.FluentLogger;
import google.registry.util.Clock;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Future;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import javax.annotation.concurrent.ThreadSafe;
import org.joda.time.DateTime;
import org.joda.time.Duration;
/**
* A thread-safe token store that supports concurrent {@link #take}, {@link #put}, and {@link
* #refresh} operations.
*
* <p>The tokens represent quota allocated to each user, which needs to be leased to the user upon
* connection and optionally returned to the store upon termination. Failure to acquire tokens
* results in quota fulfillment failure, leading to automatic connection termination. For details on
* tokens, see {@code config/default-config.yaml}.
*
* <p>The store also lazily refills tokens for a {@code userId} when a {@link #take} or a {@link
* #put} takes place. It also exposes a {@link #refresh} method that goes through each entry in the
* store and purges stale entries, in order to prevent the token store from growing too large.
*
* <p>There should be one token store for each protocol.
*/
@ThreadSafe
public class TokenStore {
/** Value class representing a timestamped integer. */
@AutoValue
abstract static class TimestampedInteger {
static TimestampedInteger create(int value, DateTime timestamp) {
return new AutoValue_TokenStore_TimestampedInteger(value, timestamp);
}
abstract int value();
abstract DateTime timestamp();
}
/**
* A wrapper to get around Java lambda's closure limitation.
*
* <p>Use the class to modify the value of a local variable captured by an lambda.
*/
private static class Wrapper<T> {
T value;
}
private static final FluentLogger logger = FluentLogger.forEnclosingClass();
/** A map of {@code userId} to available tokens, timestamped at last refill time. */
private final ConcurrentHashMap<String, TimestampedInteger> tokensMap = new ConcurrentHashMap<>();
private final QuotaConfig config;
private final ScheduledExecutorService refreshExecutor;
private final Clock clock;
public TokenStore(QuotaConfig config, ScheduledExecutorService refreshExecutor, Clock clock) {
this.config = config;
this.refreshExecutor = refreshExecutor;
this.clock = clock;
}
/**
* Attempts to take one token from the token store.
*
* <p>This method first check if the user already has an existing entry in the tokens map, and if
* that entry has been last refilled before the refill period. In either case it will reset the
* token amount to the allotted to the user.
*
* <p>The request can be partially fulfilled or all-or-nothing, meaning if there are fewer tokens
* available than requested, we can grant all available ones, or grant nothing, depending on the
* {@code partialGrant} parameter.
*
* @param userId the identifier of the user requesting the token.
* @return the number of token granted, timestamped at refill time of the pool of tokens from
* which the granted one is taken.
*/
TimestampedInteger take(String userId) {
Wrapper<TimestampedInteger> grantedToken = new Wrapper<>();
tokensMap.compute(
userId,
(user, availableTokens) -> {
DateTime now = clock.nowUtc();
int currentTokenCount;
DateTime refillTime;
// Checks if the user is provisioned with unlimited tokens.
if (config.hasUnlimitedTokens(user)) {
grantedToken.value = TimestampedInteger.create(1, now);
return TimestampedInteger.create(SENTINEL_UNLIMITED_TOKENS, now);
}
// Checks if the entry exists.
if (availableTokens == null
// Or if refill is enabled and the entry needs to be refilled.
|| (!config.getRefillPeriod(user).isEqual(Duration.ZERO)
&& !new Duration(availableTokens.timestamp(), now)
.isShorterThan(config.getRefillPeriod(user)))) {
currentTokenCount = config.getTokenAmount(user);
refillTime = now;
} else {
currentTokenCount = availableTokens.value();
refillTime = availableTokens.timestamp();
}
int newTokenCount = max(0, currentTokenCount - 1);
grantedToken.value =
TimestampedInteger.create(currentTokenCount - newTokenCount, refillTime);
return TimestampedInteger.create(newTokenCount, refillTime);
});
return grantedToken.value;
}
/**
* Attempts to return the granted token to the token store.
*
* <p>The method first check if a refill is needed, and do it accordingly. It then checks if the
* returned token are from the current pool (i. e. has the same refill timestamp as the current
* pool), and returns the token, capped at the allotted amount for the {@code userId}.
*
* @param userId the identifier of the user returning the token.
* @param returnedTokenRefillTime The refill time of the pool of tokens from which the returned
* one is taken from.
*/
void put(String userId, DateTime returnedTokenRefillTime) {
tokensMap.computeIfPresent(
userId,
(user, availableTokens) -> {
DateTime now = clock.nowUtc();
int currentTokenCount = availableTokens.value();
DateTime refillTime = availableTokens.timestamp();
int newTokenCount;
// Check if quota is unlimited.
if (!config.hasUnlimitedTokens(userId)) {
// Check if refill is enabled and a refill is needed.
if (!config.getRefillPeriod(user).isEqual(Duration.ZERO)
&& !new Duration(availableTokens.timestamp(), now)
.isShorterThan(config.getRefillPeriod(user))) {
currentTokenCount = config.getTokenAmount(user);
refillTime = now;
}
// If the returned token comes from the current pool, add it back, otherwise discard it.
newTokenCount =
returnedTokenRefillTime.equals(refillTime)
? min(currentTokenCount + 1, config.getTokenAmount(userId))
: currentTokenCount;
} else {
newTokenCount = SENTINEL_UNLIMITED_TOKENS;
}
return TimestampedInteger.create(newTokenCount, refillTime);
});
}
/**
* Refreshes the token store and deletes any entry that has not been refilled for longer than the
* refresh period.
*
* <p>Strictly speaking it should delete the entries that have not been updated (put, taken,
* refill) for longer than the refresh period. But the last update time is not recorded. Typically
* the refill period is much shorter than the refresh period, so the last refill time should serve
* as a good proxy for last update time as the actual update time cannot be one refill period
* later from the refill time, otherwise another refill would have been performed.
*/
void refresh() {
tokensMap.forEach(
(user, availableTokens) -> {
if (!new Duration(availableTokens.timestamp(), clock.nowUtc())
.isShorterThan(config.getRefreshPeriod())) {
tokensMap.remove(user);
}
});
}
/** Schedules token store refresh if enabled. */
void scheduleRefresh() {
// Only schedule refresh if the refresh period is not zero.
if (!config.getRefreshPeriod().isEqual(Duration.ZERO)) {
Future<?> unusedFuture =
refreshExecutor.scheduleWithFixedDelay(
() -> {
refresh();
logger.atInfo().log("Refreshing quota for protocol %s", config.getProtocolName());
},
config.getRefreshPeriod().getStandardSeconds(),
config.getRefreshPeriod().getStandardSeconds(),
TimeUnit.SECONDS);
}
}
/**
* Helper method to retrieve the timestamped token value for a {@code userId} for testing.
*
* <p>This non-mutating method is exposed solely for testing, so that the {@link #tokensMap} can
* stay private and not be altered unintentionally.
*/
@VisibleForTesting
TimestampedInteger getTokenForTests(String userId) {
return tokensMap.get(userId);
}
}

View file

@ -0,0 +1,4 @@
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<epp xmlns="urn:ietf:params:xml:ns:epp-1.0">
<hello/>
</epp>

View file

@ -0,0 +1,31 @@
terraform {
backend "gcs" {
# The name of the GCS bucket that stores the terraform.tfstate file.
bucket = "YOUR_GCS_BUCKET"
prefix = "terraform/state"
}
}
module "proxy" {
source = "../../modules"
proxy_project_name = "YOUR_PROXY_PROJECT"
gcr_project_name = "YOUR_GCR_PROJECT"
proxy_domain_name = "YOUR_PROXY_DOMAIN"
proxy_certificate_bucket = "YOU_CERTIFICATE_BUCKET"
}
output "proxy_service_account" {
value = "${module.proxy.proxy_service_account}"
}
output "proxy_name_servers" {
value = "${module.proxy.proxy_name_servers}"
}
output "proxy_instance_groups" {
value = "${module.proxy.proxy_instance_groups}"
}
output "proxy_ip_addresses" {
value = "${module.proxy.proxy_ip_addresses}"
}

View file

@ -0,0 +1,4 @@
provider "google" {
version = ">= 1.13.0"
project = "${var.proxy_project_name}"
}

View file

@ -0,0 +1,10 @@
resource "google_storage_bucket" "proxy_certificate" {
name = "${var.proxy_certificate_bucket}"
storage_class = "MULTI_REGIONAL"
}
resource "google_storage_bucket_iam_member" "member" {
bucket = "${google_storage_bucket.proxy_certificate.name}"
role = "roles/storage.objectViewer"
member = "serviceAccount:${google_service_account.proxy_service_account.email}"
}

View file

@ -0,0 +1,25 @@
module "proxy_gke_americas" {
source = "./gke"
proxy_cluster_region = "americas"
proxy_service_account_email = "${google_service_account.proxy_service_account.email}"
}
module "proxy_gke_emea" {
source = "./gke"
proxy_cluster_region = "emea"
proxy_service_account_email = "${google_service_account.proxy_service_account.email}"
}
module "proxy_gke_apac" {
source = "./gke"
proxy_cluster_region = "apac"
proxy_service_account_email = "${google_service_account.proxy_service_account.email}"
}
locals {
"proxy_instance_groups" = {
americas = "${module.proxy_gke_americas.proxy_instance_group}"
emea = "${module.proxy_gke_emea.proxy_instance_group}"
apac = "${module.proxy_gke_apac.proxy_instance_group}"
}
}

View file

@ -0,0 +1,40 @@
locals {
proxy_cluster_zone = "${lookup(var.proxy_cluster_zones, var.proxy_cluster_region)}"
}
resource "google_container_cluster" "proxy_cluster" {
name = "proxy-cluster-${var.proxy_cluster_region}"
zone = "${local.proxy_cluster_zone}"
timeouts {
update = "30m"
}
node_pool {
name = "proxy-node-pool"
initial_node_count = 1
node_config {
tags = [
"proxy-cluster",
]
service_account = "${var.proxy_service_account_email}"
oauth_scopes = [
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
]
}
autoscaling {
max_node_count = 5
min_node_count = 1
}
management {
auto_repair = "true"
auto_upgrade = "true"
}
}
}

View file

@ -0,0 +1,3 @@
output "proxy_instance_group" {
value = "${google_container_cluster.proxy_cluster.instance_group_urls[0]}"
}

View file

@ -0,0 +1,13 @@
variable "proxy_service_account_email" {}
variable "proxy_cluster_region" {}
variable "proxy_cluster_zones" {
type = "map"
default = {
americas = "us-east4-a"
emea = "europe-west4-b"
apac = "asia-northeast1-c"
}
}

View file

@ -0,0 +1,20 @@
resource "google_service_account" "proxy_service_account" {
account_id = "proxy-service-account"
display_name = "Nomulus proxy service account"
}
resource "google_project_iam_member" "gcr_storage_viewer" {
project = "${var.gcr_project_name}"
role = "roles/storage.objectViewer"
member = "serviceAccount:${google_service_account.proxy_service_account.email}"
}
resource "google_project_iam_member" "metric_writer" {
role = "roles/monitoring.metricWriter"
member = "serviceAccount:${google_service_account.proxy_service_account.email}"
}
resource "google_project_iam_member" "log_writer" {
role = "roles/logging.logWriter"
member = "serviceAccount:${google_service_account.proxy_service_account.email}"
}

View file

@ -0,0 +1,15 @@
resource "google_kms_key_ring" "proxy_key_ring" {
name = "${var.proxy_key_ring}"
location = "global"
}
resource "google_kms_crypto_key" "proxy_key" {
name = "${var.proxy_key}"
key_ring = "${google_kms_key_ring.proxy_key_ring.id}"
}
resource "google_kms_crypto_key_iam_member" "ssl_key_decrypter" {
crypto_key_id = "${google_kms_crypto_key.proxy_key.id}"
role = "roles/cloudkms.cryptoKeyDecrypter"
member = "serviceAccount:${google_service_account.proxy_service_account.email}"
}

View file

@ -0,0 +1,21 @@
resource "google_dns_managed_zone" "proxy_domain" {
name = "proxy-domain"
dns_name = "${var.proxy_domain_name}."
}
module "proxy_networking" {
source = "./networking"
proxy_instance_groups = "${local.proxy_instance_groups}"
proxy_ports = "${var.proxy_ports}"
proxy_domain = "${google_dns_managed_zone.proxy_domain.name}"
proxy_domain_name = "${google_dns_managed_zone.proxy_domain.dns_name}"
}
module "proxy_networking_canary" {
source = "./networking"
proxy_instance_groups = "${local.proxy_instance_groups}"
suffix = "-canary"
proxy_ports = "${var.proxy_ports_canary}"
proxy_domain = "${google_dns_managed_zone.proxy_domain.name}"
proxy_domain_name = "${google_dns_managed_zone.proxy_domain.dns_name}"
}

View file

@ -0,0 +1,31 @@
resource "google_dns_record_set" "proxy_epp_a_record" {
name = "epp${var.suffix}.${var.proxy_domain_name}"
type = "A"
ttl = 300
managed_zone = "${var.proxy_domain}"
rrdatas = ["${google_compute_global_address.proxy_ipv4_address.address}"]
}
resource "google_dns_record_set" "proxy_epp_aaaa_record" {
name = "epp${var.suffix}.${var.proxy_domain_name}"
type = "AAAA"
ttl = 300
managed_zone = "${var.proxy_domain}"
rrdatas = ["${google_compute_global_address.proxy_ipv6_address.address}"]
}
resource "google_dns_record_set" "proxy_whois_a_record" {
name = "whois${var.suffix}.${var.proxy_domain_name}"
type = "A"
ttl = 300
managed_zone = "${var.proxy_domain}"
rrdatas = ["${google_compute_global_address.proxy_ipv4_address.address}"]
}
resource "google_dns_record_set" "proxy_whois_aaaa_record" {
name = "whois${var.suffix}.${var.proxy_domain_name}"
type = "AAAA"
ttl = 300
managed_zone = "${var.proxy_domain}"
rrdatas = ["${google_compute_global_address.proxy_ipv6_address.address}"]
}

View file

@ -0,0 +1,230 @@
resource "google_compute_global_address" "proxy_ipv4_address" {
name = "proxy-ipv4-address${var.suffix}"
ip_version = "IPV4"
}
resource "google_compute_global_address" "proxy_ipv6_address" {
name = "proxy-ipv6-address${var.suffix}"
ip_version = "IPV6"
}
resource "google_compute_firewall" "proxy_firewall" {
name = "proxy-firewall${var.suffix}"
network = "default"
allow {
protocol = "tcp"
ports = [
"${var.proxy_ports["epp"]}",
"${var.proxy_ports["whois"]}",
"${var.proxy_ports["health_check"]}",
"${var.proxy_ports["http-whois"]}",
"${var.proxy_ports["https-whois"]}",
]
}
source_ranges = [
"130.211.0.0/22",
"35.191.0.0/16",
]
target_tags = [
"proxy-cluster",
]
}
resource "google_compute_health_check" "proxy_health_check" {
name = "proxy-health-check${var.suffix}"
tcp_health_check {
port = "${var.proxy_ports["health_check"]}"
request = "HEALTH_CHECK_REQUEST"
response = "HEALTH_CHECK_RESPONSE"
}
}
resource "google_compute_health_check" "proxy_http_health_check" {
name = "proxy-http-health-check${var.suffix}"
http_health_check {
host = "health-check.invalid"
port = "${var.proxy_ports["http-whois"]}"
request_path = "/"
}
}
resource "google_compute_url_map" "proxy_url_map" {
name = "proxy-url-map${var.suffix}"
default_service = "${google_compute_backend_service.http_whois_backend_service.self_link}"
}
resource "google_compute_backend_service" "epp_backend_service" {
name = "epp-backend-service${var.suffix}"
protocol = "TCP"
timeout_sec = 3600
port_name = "epp${var.suffix}"
backend {
group = "${var.proxy_instance_groups["americas"]}"
}
backend {
group = "${var.proxy_instance_groups["emea"]}"
}
backend {
group = "${var.proxy_instance_groups["apac"]}"
}
health_checks = [
"${google_compute_health_check.proxy_health_check.self_link}",
]
}
resource "google_compute_backend_service" "whois_backend_service" {
name = "whois-backend-service${var.suffix}"
protocol = "TCP"
timeout_sec = 60
port_name = "whois${var.suffix}"
backend {
group = "${var.proxy_instance_groups["americas"]}"
}
backend {
group = "${var.proxy_instance_groups["emea"]}"
}
backend {
group = "${var.proxy_instance_groups["apac"]}"
}
health_checks = [
"${google_compute_health_check.proxy_health_check.self_link}",
]
}
resource "google_compute_backend_service" "https_whois_backend_service" {
name = "https-whois-backend-service${var.suffix}"
protocol = "TCP"
timeout_sec = 60
port_name = "https-whois${var.suffix}"
backend {
group = "${var.proxy_instance_groups["americas"]}"
}
backend {
group = "${var.proxy_instance_groups["emea"]}"
}
backend {
group = "${var.proxy_instance_groups["apac"]}"
}
health_checks = [
"${google_compute_health_check.proxy_health_check.self_link}",
]
}
resource "google_compute_backend_service" "http_whois_backend_service" {
name = "http-whois-backend-service${var.suffix}"
protocol = "HTTP"
timeout_sec = 60
port_name = "http-whois${var.suffix}"
backend {
group = "${var.proxy_instance_groups["americas"]}"
}
backend {
group = "${var.proxy_instance_groups["emea"]}"
}
backend {
group = "${var.proxy_instance_groups["apac"]}"
}
health_checks = [
"${google_compute_health_check.proxy_http_health_check.self_link}",
]
}
resource "google_compute_target_tcp_proxy" "epp_tcp_proxy" {
name = "epp-tcp-proxy${var.suffix}"
proxy_header = "PROXY_V1"
backend_service = "${google_compute_backend_service.epp_backend_service.self_link}"
}
resource "google_compute_target_tcp_proxy" "whois_tcp_proxy" {
name = "whois-tcp-proxy${var.suffix}"
proxy_header = "PROXY_V1"
backend_service = "${google_compute_backend_service.whois_backend_service.self_link}"
}
resource "google_compute_target_tcp_proxy" "https_whois_tcp_proxy" {
name = "https-whois-tcp-proxy${var.suffix}"
backend_service = "${google_compute_backend_service.https_whois_backend_service.self_link}"
}
resource "google_compute_target_http_proxy" "http_whois_http_proxy" {
name = "http-whois-tcp-proxy${var.suffix}"
url_map = "${google_compute_url_map.proxy_url_map.self_link}"
}
resource "google_compute_global_forwarding_rule" "epp_ipv4_forwarding_rule" {
name = "epp-ipv4-forwarding-rule${var.suffix}"
ip_address = "${google_compute_global_address.proxy_ipv4_address.address}"
target = "${google_compute_target_tcp_proxy.epp_tcp_proxy.self_link}"
port_range = "700"
}
resource "google_compute_global_forwarding_rule" "epp_ipv6_forwarding_rule" {
name = "epp-ipv6-forwarding-rule${var.suffix}"
ip_address = "${google_compute_global_address.proxy_ipv6_address.address}"
target = "${google_compute_target_tcp_proxy.epp_tcp_proxy.self_link}"
port_range = "700"
}
resource "google_compute_global_forwarding_rule" "whois_ipv4_forwarding_rule" {
name = "whois-ipv4-forwarding-rule${var.suffix}"
ip_address = "${google_compute_global_address.proxy_ipv4_address.address}"
target = "${google_compute_target_tcp_proxy.whois_tcp_proxy.self_link}"
port_range = "43"
}
resource "google_compute_global_forwarding_rule" "whois_ipv6_forwarding_rule" {
name = "whois-ipv6-forwarding-rule${var.suffix}"
ip_address = "${google_compute_global_address.proxy_ipv6_address.address}"
target = "${google_compute_target_tcp_proxy.whois_tcp_proxy.self_link}"
port_range = "43"
}
resource "google_compute_global_forwarding_rule" "https_whois_ipv4_forwarding_rule" {
name = "https-whois-ipv4-forwarding-rule${var.suffix}"
ip_address = "${google_compute_global_address.proxy_ipv4_address.address}"
target = "${google_compute_target_tcp_proxy.https_whois_tcp_proxy.self_link}"
port_range = "443"
}
resource "google_compute_global_forwarding_rule" "https_whois_ipv6_forwarding_rule" {
name = "https-whois-ipv6-forwarding-rule${var.suffix}"
ip_address = "${google_compute_global_address.proxy_ipv6_address.address}"
target = "${google_compute_target_tcp_proxy.https_whois_tcp_proxy.self_link}"
port_range = "443"
}
resource "google_compute_global_forwarding_rule" "http_whois_ipv4_forwarding_rule" {
name = "http-whois-ipv4-forwarding-rule${var.suffix}"
ip_address = "${google_compute_global_address.proxy_ipv4_address.address}"
target = "${google_compute_target_http_proxy.http_whois_http_proxy.self_link}"
port_range = "80"
}
resource "google_compute_global_forwarding_rule" "http_whois_ipv6_forwarding_rule" {
name = "http-whois-ipv6-forwarding-rule${var.suffix}"
ip_address = "${google_compute_global_address.proxy_ipv6_address.address}"
target = "${google_compute_target_http_proxy.http_whois_http_proxy.self_link}"
port_range = "80"
}

View file

@ -0,0 +1,7 @@
output "proxy_ipv4_address" {
value = "${google_compute_global_address.proxy_ipv4_address.address}"
}
output "proxy_ipv6_address" {
value = "${google_compute_global_address.proxy_ipv6_address.address}"
}

View file

@ -0,0 +1,20 @@
# Instance groups that the load balancer forwards traffic to.
variable "proxy_instance_groups" {
type = "map"
}
# Suffix (such as "-canary") added to the resource names.
variable "suffix" {
default = ""
}
# Node ports exposed by the proxy.
variable "proxy_ports" {
type = "map"
}
# DNS zone for the proxy domain.
variable "proxy_domain" {}
# domain name of the zone.
variable "proxy_domain_name" {}

View file

@ -0,0 +1,23 @@
output "proxy_name_servers" {
value = "${google_dns_managed_zone.proxy_domain.name_servers}"
}
output "proxy_instance_groups" {
value = "${local.proxy_instance_groups}"
}
output "proxy_service_account" {
value = {
email = "${google_service_account.proxy_service_account.email}"
client_id = "${google_service_account.proxy_service_account.unique_id}"
}
}
output "proxy_ip_addresses" {
value = {
ipv4 = "${module.proxy_networking.proxy_ipv4_address}"
ipv6 = "${module.proxy_networking.proxy_ipv6_address}"
ipv4_canary = "${module.proxy_networking_canary.proxy_ipv4_address}"
ipv6_canary = "${module.proxy_networking_canary.proxy_ipv6_address}"
}
}

View file

@ -0,0 +1,47 @@
# GCP project in which the proxy runs.
variable "proxy_project_name" {}
# GCP project from which the proxy image is pulled.
variable "gcr_project_name" {}
# The base domain name of the proxy, without the whois. or epp. part.
variable "proxy_domain_name" {}
# The GCS bucket that stores the encrypted SSL certificate.
variable "proxy_certificate_bucket" {}
# Cloud KMS keyring name
variable "proxy_key_ring" {
default = "proxy-key-ring"
}
# Cloud KMS key name
variable "proxy_key" {
default = "proxy-key"
}
# Node ports exposed by the proxy.
variable "proxy_ports" {
type = "map"
default = {
health_check = 30000
whois = 30001
epp = 30002
http-whois = 30010
https-whois = 30011
}
}
# Node ports exposed by the canary proxy.
variable "proxy_ports_canary" {
type = "map"
default = {
health_check = 31000
whois = 31001
epp = 31002
http-whois = 31010
https-whois = 31011
}
}

View file

@ -0,0 +1,30 @@
#!/bin/bash
# Copyright 2018 The Nomulus Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Terraform currently cannot set named ports on the instance groups underlying
# the gke instances it creates. Here we output the instance group URL, extract
# the project, zone and instance group names, and then call gcloud to add the
# named ports.
PROD_PORTS="whois:30001,epp:30002,http-whois:30010,https-whois:30011"
CANARY_PORTS="whois-canary:31001,epp-canary:31002,"\
"http-whois-canary:31010,https-whois-canary:31011"
while read line
do
gcloud compute instance-groups set-named-ports --named-ports \
"${PROD_PORTS}","${CANARY_PORTS}" "$line"
done < <(terraform output proxy_instance_groups | awk '{print $3}' | \
awk -F '/' '{print "--project", $7, "--zone", $9, $11}')

View file

@ -0,0 +1,53 @@
package(
default_testonly = 1,
default_visibility = ["//java/google/registry:registry_project"],
)
licenses(["notice"]) # Apache 2.0
load("//java/com/google/testing/builddefs:GenTestRules.bzl", "GenTestRules")
java_library(
name = "proxy",
srcs = glob(["**/*.java"]),
resources = glob([
"testdata/*.xml",
"quota/testdata/*.yaml",
]),
runtime_deps = [
"@io_netty_tcnative_boringssl_static",
],
deps = [
"//java/google/registry/proxy",
"//java/google/registry/util",
"//javatests/google/registry/testing",
"@com_beust_jcommander",
"@com_google_dagger",
"@com_google_guava",
"@com_google_monitoring_client_contrib",
"@com_google_monitoring_client_metrics",
"@com_google_truth",
"@com_google_truth_extensions_truth_java8_extension",
"@io_netty_buffer",
"@io_netty_codec",
"@io_netty_codec_http",
"@io_netty_common",
"@io_netty_handler",
"@io_netty_transport",
"@javax_inject",
"@joda_time",
"@junit",
"@org_bouncycastle_bcpkix_jdk15on",
"@org_mockito_core",
"@org_yaml_snakeyaml",
],
)
GenTestRules(
name = "GeneratedTestRules",
test_files = glob(
["**/*Test.java"],
exclude = ["ProtocolModuleTest.java"],
),
deps = [":proxy"],
)

View file

@ -0,0 +1,161 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy;
import static com.google.common.truth.Truth.assertThat;
import static google.registry.proxy.handler.SslInitializerTestUtils.getKeyPair;
import static google.registry.proxy.handler.SslInitializerTestUtils.signKeyPair;
import static google.registry.testing.JUnitBackports.assertThrows;
import static java.nio.charset.StandardCharsets.UTF_8;
import dagger.Component;
import dagger.Module;
import dagger.Provides;
import google.registry.proxy.CertificateModule.Prod;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import java.io.ByteArrayOutputStream;
import java.io.OutputStreamWriter;
import java.security.KeyPair;
import java.security.PrivateKey;
import java.security.cert.Certificate;
import java.security.cert.X509Certificate;
import javax.inject.Named;
import javax.inject.Singleton;
import org.bouncycastle.openssl.PEMWriter;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link CertificateModule}. */
@RunWith(JUnit4.class)
public class CertificateModuleTest {
private SelfSignedCertificate ssc;
private PrivateKey key;
private Certificate cert;
private TestComponent component;
private static byte[] getPemBytes(Object... objects) throws Exception {
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
try (PEMWriter pemWriter =
new PEMWriter(new OutputStreamWriter(byteArrayOutputStream, UTF_8))) {
for (Object object : objects) {
pemWriter.writeObject(object);
}
}
return byteArrayOutputStream.toByteArray();
}
/** Create a component with bindings to the given bytes[] as the contents from a PEM file. */
private TestComponent createComponent(byte[] pemBytes) {
return DaggerCertificateModuleTest_TestComponent.builder()
.pemBytesModule(new PemBytesModule(pemBytes))
.build();
}
@Before
public void setUp() throws Exception {
ssc = new SelfSignedCertificate();
KeyPair keyPair = getKeyPair();
key = keyPair.getPrivate();
cert = signKeyPair(ssc, keyPair, "example.tld");
}
@Test
public void testSuccess() throws Exception {
byte[] pemBytes = getPemBytes(cert, ssc.cert(), key);
component = createComponent(pemBytes);
assertThat(component.privateKey()).isEqualTo(key);
assertThat(component.certificates()).asList().containsExactly(cert, ssc.cert()).inOrder();
}
@Test
public void testSuccess_certificateChainNotContinuous() throws Exception {
byte[] pemBytes = getPemBytes(cert, key, ssc.cert());
component = createComponent(pemBytes);
assertThat(component.privateKey()).isEqualTo(key);
assertThat(component.certificates()).asList().containsExactly(cert, ssc.cert()).inOrder();
}
@Test
public void testFailure_noPrivateKey() throws Exception {
byte[] pemBytes = getPemBytes(cert, ssc.cert());
component = createComponent(pemBytes);
IllegalStateException thrown =
assertThrows(IllegalStateException.class, () -> component.privateKey());
assertThat(thrown).hasMessageThat().contains("0 keys are found");
}
@Test
public void testFailure_twoPrivateKeys() throws Exception {
byte[] pemBytes = getPemBytes(cert, ssc.cert(), key, ssc.key());
component = createComponent(pemBytes);
IllegalStateException thrown =
assertThrows(IllegalStateException.class, () -> component.privateKey());
assertThat(thrown).hasMessageThat().contains("2 keys are found");
}
@Test
public void testFailure_certificatesOutOfOrder() throws Exception {
byte[] pemBytes = getPemBytes(ssc.cert(), cert, key);
component = createComponent(pemBytes);
IllegalStateException thrown =
assertThrows(IllegalStateException.class, () -> component.certificates());
assertThat(thrown).hasMessageThat().contains("is not signed by");
}
@Test
public void testFailure_noCertificates() throws Exception {
byte[] pemBytes = getPemBytes(key);
component = createComponent(pemBytes);
IllegalStateException thrown =
assertThrows(IllegalStateException.class, () -> component.certificates());
assertThat(thrown).hasMessageThat().contains("No certificates");
}
@Module
static class PemBytesModule {
private final byte[] pemBytes;
PemBytesModule(byte[] pemBytes) {
this.pemBytes = pemBytes;
}
@Provides
@Named("pemBytes")
byte[] providePemBytes() {
return pemBytes;
}
}
/**
* Test component that exposes prod certificate and key.
*
* <p>Local certificate and key are not tested because they are directly extracted from a
* self-signed certificate. Here we want to test that we can correctly parse and create
* certificate and keys from a .pem file.
*/
@Singleton
@Component(modules = {CertificateModule.class, PemBytesModule.class})
interface TestComponent {
@Prod
PrivateKey privateKey();
@Prod
X509Certificate[] certificates();
}
}

View file

@ -0,0 +1,280 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy;
import static com.google.common.truth.Truth.assertThat;
import static google.registry.proxy.handler.ProxyProtocolHandler.REMOTE_ADDRESS_KEY;
import static google.registry.proxy.handler.SslServerInitializer.CLIENT_CERTIFICATE_PROMISE_KEY;
import static google.registry.testing.JUnitBackports.assertThrows;
import static google.registry.util.ResourceUtils.readResourceBytes;
import static google.registry.util.X509Utils.getCertificateHash;
import static java.nio.charset.StandardCharsets.UTF_8;
import com.google.common.base.Throwables;
import google.registry.proxy.handler.HttpsRelayServiceHandler.NonOkHttpResponseException;
import google.registry.testing.FakeClock;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.EncoderException;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.cookie.Cookie;
import io.netty.handler.codec.http.cookie.DefaultCookie;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import io.netty.util.concurrent.Promise;
import java.security.cert.X509Certificate;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** End-to-end tests for {@link EppProtocolModule}. */
@RunWith(JUnit4.class)
public class EppProtocolModuleTest extends ProtocolModuleTest {
private static final int HEADER_LENGTH = 4;
private static final String CLIENT_ADDRESS = "epp.client.tld";
private static final byte[] HELLO_BYTES =
("<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n"
+ "<epp xmlns=\"urn:ietf:params:xml:ns:epp-1.0\">\n"
+ " <hello/>\n"
+ "</epp>\n")
.getBytes(UTF_8);
private X509Certificate certificate;
public EppProtocolModuleTest() {
super(TestComponent::eppHandlers);
}
/** Verifies that the epp message content is represented by the buffers. */
private static void assertBufferRepresentsContent(ByteBuf buffer, byte[] expectedContents) {
// First make sure that buffer length is expected content length plus header length.
assertThat(buffer.readableBytes()).isEqualTo(expectedContents.length + HEADER_LENGTH);
// Then check if the header value is indeed expected content length plus header length.
assertThat(buffer.readInt()).isEqualTo(expectedContents.length + HEADER_LENGTH);
// Finally check the buffer contains the expected contents.
byte[] actualContents = new byte[expectedContents.length];
buffer.readBytes(actualContents);
assertThat(actualContents).isEqualTo(expectedContents);
}
/**
* Read all available outbound frames and make a composite {@link ByteBuf} consisting all of them.
*
* <p>This is needed because {@link io.netty.handler.codec.LengthFieldPrepender} does not
* necessary output only one {@link ByteBuf} from one input message. We need to reassemble the
* frames together in order to obtain the processed message (prepended with length header).
*/
private static ByteBuf getAllOutboundFrames(EmbeddedChannel channel) {
ByteBuf combinedBuffer = Unpooled.buffer();
ByteBuf buffer;
while ((buffer = channel.readOutbound()) != null) {
combinedBuffer.writeBytes(buffer);
}
return combinedBuffer;
}
/** Get a {@link ByteBuf} that represents the raw epp request with the given content. */
private ByteBuf getByteBufFromContent(byte[] content) {
ByteBuf buffer = Unpooled.buffer();
buffer.writeInt(content.length + HEADER_LENGTH);
buffer.writeBytes(content);
return buffer;
}
private FullHttpRequest makeEppHttpRequest(byte[] content, Cookie... cookies) {
return TestUtils.makeEppHttpRequest(
new String(content, UTF_8),
PROXY_CONFIG.epp.relayHost,
PROXY_CONFIG.epp.relayPath,
TestModule.provideFakeAccessToken().get(),
getCertificateHash(certificate),
CLIENT_ADDRESS,
cookies);
}
private FullHttpResponse makeEppHttpResponse(byte[] content, Cookie... cookies) {
return makeEppHttpResponse(content, HttpResponseStatus.OK, cookies);
}
private FullHttpResponse makeEppHttpResponse(
byte[] content, HttpResponseStatus status, Cookie... cookies) {
return TestUtils.makeEppHttpResponse(new String(content, UTF_8), status, cookies);
}
@Override
@Before
public void setUp() throws Exception {
testComponent = makeTestComponent(new FakeClock());
certificate = new SelfSignedCertificate().cert();
initializeChannel(
ch -> {
ch.attr(REMOTE_ADDRESS_KEY).set(CLIENT_ADDRESS);
ch.attr(CLIENT_CERTIFICATE_PROMISE_KEY).set(ch.eventLoop().newPromise());
addAllTestableHandlers(ch);
});
Promise<X509Certificate> unusedPromise =
channel.attr(CLIENT_CERTIFICATE_PROMISE_KEY).get().setSuccess(certificate);
}
@Test
public void testSuccess_singleFrameInboundMessage() throws Exception {
// First inbound message is hello.
assertThat((FullHttpRequest) channel.readInbound()).isEqualTo(makeEppHttpRequest(HELLO_BYTES));
byte[] inputBytes = readResourceBytes(getClass(), "testdata/login.xml").read();
// Verify inbound message is as expected.
assertThat(channel.writeInbound(getByteBufFromContent(inputBytes))).isTrue();
assertThat((FullHttpRequest) channel.readInbound()).isEqualTo(makeEppHttpRequest(inputBytes));
// Nothing more to read.
assertThat((Object) channel.readInbound()).isNull();
assertThat(channel.isActive()).isTrue();
}
@Test
public void testSuccess_SingleFrame_MultipleInboundMessages() throws Exception {
// First inbound message is hello.
channel.readInbound();
byte[] inputBytes1 = readResourceBytes(getClass(), "testdata/login.xml").read();
byte[] inputBytes2 = readResourceBytes(getClass(), "testdata/logout.xml").read();
// Verify inbound messages are as expected.
assertThat(
channel.writeInbound(
Unpooled.wrappedBuffer(
getByteBufFromContent(inputBytes1), getByteBufFromContent(inputBytes2))))
.isTrue();
assertThat((FullHttpRequest) channel.readInbound()).isEqualTo(makeEppHttpRequest(inputBytes1));
assertThat((FullHttpRequest) channel.readInbound()).isEqualTo(makeEppHttpRequest(inputBytes2));
// Nothing more to read.
assertThat((Object) channel.readInbound()).isNull();
assertThat(channel.isActive()).isTrue();
}
@Test
public void testSuccess_MultipleFrames_MultipleInboundMessages() throws Exception {
// First inbound message is hello.
channel.readInbound();
byte[] inputBytes1 = readResourceBytes(getClass(), "testdata/login.xml").read();
byte[] inputBytes2 = readResourceBytes(getClass(), "testdata/logout.xml").read();
ByteBuf inputBuffer =
Unpooled.wrappedBuffer(
getByteBufFromContent(inputBytes1), getByteBufFromContent(inputBytes2));
// The first frame does not contain the entire first message because it is missing 4 byte of
// header length.
assertThat(channel.writeInbound(inputBuffer.readBytes(inputBytes1.length))).isFalse();
// The second frame contains the first message, and part of the second message.
assertThat(channel.writeInbound(inputBuffer.readBytes(inputBytes2.length))).isTrue();
assertThat((FullHttpRequest) channel.readInbound()).isEqualTo(makeEppHttpRequest(inputBytes1));
// The third frame contains the rest of the second message.
assertThat(channel.writeInbound(inputBuffer)).isTrue();
assertThat((FullHttpRequest) channel.readInbound()).isEqualTo(makeEppHttpRequest(inputBytes2));
// Nothing more to read.
assertThat((Object) channel.readInbound()).isNull();
assertThat(channel.isActive()).isTrue();
}
@Test
public void testSuccess_simpleOutboundMessage() throws Exception {
// First inbound message is hello.
channel.readInbound();
byte[] outputBytes = readResourceBytes(getClass(), "testdata/login_response.xml").read();
// Verify outbound message is as expected.
assertThat(channel.writeOutbound(makeEppHttpResponse(outputBytes))).isTrue();
assertBufferRepresentsContent(getAllOutboundFrames(channel), outputBytes);
// Nothing more to write.
assertThat((Object) channel.readOutbound()).isNull();
assertThat(channel.isActive()).isTrue();
}
@Test
public void testFailure_nonOkOutboundMessage() throws Exception {
// First inbound message is hello.
channel.readInbound();
byte[] outputBytes = readResourceBytes(getClass(), "testdata/login_response.xml").read();
// Verify outbound message is not written to the peer as the response is not OK.
EncoderException thrown =
assertThrows(
EncoderException.class,
() ->
channel.writeOutbound(
makeEppHttpResponse(outputBytes, HttpResponseStatus.UNAUTHORIZED)));
assertThat(Throwables.getRootCause(thrown)).isInstanceOf(NonOkHttpResponseException.class);
assertThat(thrown).hasMessageThat().contains("401 Unauthorized");
assertThat((Object) channel.readOutbound()).isNull();
// Channel is closed.
assertThat(channel.isActive()).isFalse();
}
@Test
public void testSuccess_setAndReadCookies() throws Exception {
// First inbound message is hello.
channel.readInbound();
byte[] outputBytes1 = readResourceBytes(getClass(), "testdata/login_response.xml").read();
Cookie cookie1 = new DefaultCookie("name1", "value1");
Cookie cookie2 = new DefaultCookie("name2", "value2");
// Verify outbound message is as expected.
assertThat(channel.writeOutbound(makeEppHttpResponse(outputBytes1, cookie1, cookie2))).isTrue();
assertBufferRepresentsContent(getAllOutboundFrames(channel), outputBytes1);
// Verify inbound message contains cookies.
byte[] inputBytes1 = readResourceBytes(getClass(), "testdata/logout.xml").read();
assertThat(channel.writeInbound(getByteBufFromContent(inputBytes1))).isTrue();
assertThat((FullHttpRequest) channel.readInbound())
.isEqualTo(makeEppHttpRequest(inputBytes1, cookie1, cookie2));
// Second outbound message change cookies.
byte[] outputBytes2 = readResourceBytes(getClass(), "testdata/logout_response.xml").read();
Cookie cookie3 = new DefaultCookie("name3", "value3");
cookie2 = new DefaultCookie("name2", "newValue2");
// Verify outbound message is as expected.
assertThat(channel.writeOutbound(makeEppHttpResponse(outputBytes2, cookie2, cookie3))).isTrue();
assertBufferRepresentsContent(getAllOutboundFrames(channel), outputBytes2);
// Verify inbound message contains updated cookies.
byte[] inputBytes2 = readResourceBytes(getClass(), "testdata/login.xml").read();
assertThat(channel.writeInbound(getByteBufFromContent(inputBytes2))).isTrue();
assertThat((FullHttpRequest) channel.readInbound())
.isEqualTo(makeEppHttpRequest(inputBytes2, cookie1, cookie2, cookie3));
// Nothing more to write or read.
assertThat((Object) channel.readOutbound()).isNull();
assertThat((Object) channel.readInbound()).isNull();
assertThat(channel.isActive()).isTrue();
}
}

View file

@ -0,0 +1,105 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy;
import static com.google.common.truth.Truth.assertThat;
import com.google.common.base.Joiner;
import java.util.logging.Level;
import java.util.logging.LogRecord;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link google.registry.proxy.GcpJsonFormatter}. */
@RunWith(JUnit4.class)
public class GcpJsonFormatterTest {
private static final String LOGGER_NAME = "example.company.app.logger";
private static final String SOURCE_CLASS_NAME = "example.company.app.component.Doer";
private static final String SOURCE_METHOD_NAME = "doStuff";
private static final String MESSAGE = "Something I have to say";
private final GcpJsonFormatter formatter = new GcpJsonFormatter();
private final LogRecord logRecord = new LogRecord(Level.WARNING, MESSAGE);
private static String makeJson(String severity, String source, String message) {
return "{"
+ Joiner.on(",")
.join(
makeJsonField("severity", severity),
makeJsonField("source", source),
makeJsonField("message", "\\n" + message))
+ "}\n";
}
private static String makeJsonField(String name, String content) {
return Joiner.on(":").join(addQuoteAndReplaceNewline(name), addQuoteAndReplaceNewline(content));
}
private static String addQuoteAndReplaceNewline(String content) {
// This quadruple escaping is hurting my eyes.
return "\"" + content.replaceAll("\n", "\\\\n") + "\"";
}
@Before
public void setUp() {
logRecord.setLoggerName(LOGGER_NAME);
}
@Test
public void testSuccess() {
String actual = formatter.format(logRecord);
String expected = makeJson("WARNING", LOGGER_NAME, MESSAGE);
assertThat(actual).isEqualTo(expected);
}
@Test
public void testSuccess_sourceClassAndMethod() {
logRecord.setSourceClassName(SOURCE_CLASS_NAME);
logRecord.setSourceMethodName(SOURCE_METHOD_NAME);
String actual = formatter.format(logRecord);
String expected = makeJson("WARNING", SOURCE_CLASS_NAME + " " + SOURCE_METHOD_NAME, MESSAGE);
assertThat(actual).isEqualTo(expected);
}
@Test
public void testSuccess_multilineMessage() {
String multilineMessage = "First line message\nSecond line message\n";
logRecord.setMessage(multilineMessage);
String actual = formatter.format(logRecord);
String expected = makeJson("WARNING", LOGGER_NAME, multilineMessage);
assertThat(actual).isEqualTo(expected);
}
@Test
public void testSuccess_withCause() {
Throwable throwable = new Throwable("Some reason");
StackTraceElement[] stacktrace = {
new StackTraceElement("class1", "method1", "file1", 5),
new StackTraceElement("class2", "method2", "file2", 10),
};
String stacktraceString =
"java.lang.Throwable: Some reason\\n"
+ "\\tat class1.method1(file1:5)\\n"
+ "\\tat class2.method2(file2:10)\\n";
throwable.setStackTrace(stacktrace);
logRecord.setThrown(throwable);
String actual = formatter.format(logRecord);
String expected = makeJson("WARNING", LOGGER_NAME, MESSAGE + "\\n" + stacktraceString);
assertThat(actual).isEqualTo(expected);
}
}

View file

@ -0,0 +1,88 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy;
import static com.google.common.truth.Truth.assertThat;
import static java.nio.charset.StandardCharsets.US_ASCII;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** End-to-end tests for {@link HealthCheckProtocolModule}. */
@RunWith(JUnit4.class)
public class HealthCheckProtocolModuleTest extends ProtocolModuleTest {
public HealthCheckProtocolModuleTest() {
super(TestComponent::healthCheckHandlers);
}
@Test
public void testSuccess_expectedInboundMessage() {
// no inbound message passed along.
assertThat(
channel.writeInbound(
Unpooled.wrappedBuffer(PROXY_CONFIG.healthCheck.checkRequest.getBytes(US_ASCII))))
.isFalse();
ByteBuf outputBuffer = channel.readOutbound();
// response written to channel.
assertThat(outputBuffer.toString(US_ASCII)).isEqualTo(PROXY_CONFIG.healthCheck.checkResponse);
assertThat(channel.isActive()).isTrue();
// nothing more to write.
assertThat((Object) channel.readOutbound()).isNull();
}
@Test
public void testSuccess_InboundMessageTooShort() {
String shortRequest = "HEALTH_CHECK";
// no inbound message passed along.
assertThat(channel.writeInbound(Unpooled.wrappedBuffer(shortRequest.getBytes(US_ASCII))))
.isFalse();
// nothing to write.
assertThat(channel.isActive()).isTrue();
assertThat((Object) channel.readOutbound()).isNull();
}
@Test
public void testSuccess_InboundMessageTooLong() {
String longRequest = "HEALTH_CHECK_REQUEST HELLO";
// no inbound message passed along.
assertThat(channel.writeInbound(Unpooled.wrappedBuffer(longRequest.getBytes(US_ASCII))))
.isFalse();
ByteBuf outputBuffer = channel.readOutbound();
// The fixed length frame decoder will decode the first inbound message as "HEALTH_CHECK_
// REQUEST", which is what this handler expects. So it will respond with the pre-defined
// response message. This is an acceptable false-positive because the GCP health checker will
// only send the pre-defined request message. As long as the health check can receive the
// request it expects, we do not care if the protocol also respond to other requests.
assertThat(outputBuffer.toString(US_ASCII)).isEqualTo(PROXY_CONFIG.healthCheck.checkResponse);
assertThat(channel.isActive()).isTrue();
// nothing more to write.
assertThat((Object) channel.readOutbound()).isNull();
}
@Test
public void testSuccess_InboundMessageNotMatch() {
String invalidRequest = "HEALTH_CHECK_REQUESX";
// no inbound message passed along.
assertThat(channel.writeInbound(Unpooled.wrappedBuffer(invalidRequest.getBytes(US_ASCII))))
.isFalse();
// nothing to write.
assertThat(channel.isActive()).isTrue();
assertThat((Object) channel.readOutbound()).isNull();
}
}

View file

@ -0,0 +1,103 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy;
import static com.google.common.truth.Truth.assertThat;
import static google.registry.proxy.TestUtils.assertHttpRequestEquivalent;
import static google.registry.proxy.TestUtils.assertHttpResponseEquivalent;
import static google.registry.proxy.TestUtils.makeHttpPostRequest;
import static google.registry.proxy.TestUtils.makeHttpResponse;
import io.netty.buffer.ByteBuf;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpServerCodec;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/**
* End-to-end tests for {@link HttpsRelayProtocolModule}.
*
* <p>This protocol defines a connection in which the proxy behaves as a standard http client (sans
* the relay operation which is excluded in end-to-end testing). Because non user-defined handlers
* are used, the tests here focus on verifying that the request written to the network socket by the
* client is reconstructed faithfully by a server, and vice versa, that the response the client
* decoded from incoming bytes is equivalent to the response sent by the server.
*
* <p>These tests only ensure that the client represented by this protocol is compatible with a
* server implementation provided by Netty itself. They test the self-consistency of various Netty
* handlers that deal with HTTP protocol, but not whether the handlers converts between bytes and
* HTTP messages correctly, which is presumed correct.
*/
@RunWith(JUnit4.class)
public class HttpsRelayProtocolModuleTest extends ProtocolModuleTest {
private static final String HOST = "test.tld";
private static final String PATH = "/path/to/test";
private static final String CONTENT = "content to test\nnext line\n";
private final EmbeddedChannel serverChannel =
new EmbeddedChannel(new HttpServerCodec(), new HttpObjectAggregator(512 * 1024));
public HttpsRelayProtocolModuleTest() {
super(TestComponent::httpsRelayHandlers);
}
/**
* Tests that the client converts given {@link FullHttpRequest} to bytes, which is sent to the
* server and reconstructed to a {@link FullHttpRequest} that is equivalent to the original. Then
* test that the server converts given {@link FullHttpResponse} to bytes, which is sent to the
* client and reconstructed to a {@link FullHttpResponse} that is equivalent to the original.
*
* <p>The request and response equivalences are tested in the same method because the client codec
* tries to pair the response it receives with the request it sends. Receiving a response without
* sending a request first will cause the {@link HttpObjectAggregator} to fail to aggregate
* properly.
*/
private void requestAndRespondWithStatus(HttpResponseStatus status) {
ByteBuf buffer;
FullHttpRequest requestSent = makeHttpPostRequest(CONTENT, HOST, PATH);
// Need to send a copy as the content read index will advance after the request is written to
// the outbound of client channel, making comparison with requestReceived fail.
assertThat(channel.writeOutbound(requestSent.copy())).isTrue();
buffer = channel.readOutbound();
assertThat(serverChannel.writeInbound(buffer)).isTrue();
FullHttpRequest requestReceived = serverChannel.readInbound();
// Verify that the request received is the same as the request sent.
assertHttpRequestEquivalent(requestSent, requestReceived);
FullHttpResponse responseSent = makeHttpResponse(CONTENT, status);
assertThat(serverChannel.writeOutbound(responseSent.copy())).isTrue();
buffer = serverChannel.readOutbound();
assertThat(channel.writeInbound(buffer)).isTrue();
FullHttpResponse responseReceived = channel.readInbound();
// Verify that the request received is the same as the request sent.
assertHttpResponseEquivalent(responseSent, responseReceived);
}
@Test
public void testSuccess_OkResponse() {
requestAndRespondWithStatus(HttpResponseStatus.OK);
}
@Test
public void testSuccess_NonOkResponse() {
requestAndRespondWithStatus(HttpResponseStatus.BAD_REQUEST);
}
}

View file

@ -0,0 +1,291 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static google.registry.proxy.ProxyConfig.Environment.LOCAL;
import static google.registry.proxy.ProxyConfig.getProxyConfig;
import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.util.concurrent.MoreExecutors;
import dagger.Component;
import dagger.Module;
import dagger.Provides;
import google.registry.proxy.EppProtocolModule.EppProtocol;
import google.registry.proxy.HealthCheckProtocolModule.HealthCheckProtocol;
import google.registry.proxy.HttpsRelayProtocolModule.HttpsRelayProtocol;
import google.registry.proxy.ProxyConfig.Environment;
import google.registry.proxy.WebWhoisProtocolsModule.HttpWhoisProtocol;
import google.registry.proxy.WhoisProtocolModule.WhoisProtocol;
import google.registry.proxy.handler.BackendMetricsHandler;
import google.registry.proxy.handler.ProxyProtocolHandler;
import google.registry.proxy.handler.QuotaHandler.EppQuotaHandler;
import google.registry.proxy.handler.QuotaHandler.WhoisQuotaHandler;
import google.registry.proxy.handler.RelayHandler.FullHttpRequestRelayHandler;
import google.registry.proxy.handler.RelayHandler.FullHttpResponseRelayHandler;
import google.registry.proxy.handler.SslClientInitializer;
import google.registry.proxy.handler.SslServerInitializer;
import google.registry.proxy.handler.WebWhoisRedirectHandler;
import google.registry.testing.FakeClock;
import google.registry.util.Clock;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.logging.LoggingHandler;
import io.netty.handler.ssl.SslProvider;
import io.netty.handler.timeout.ReadTimeoutHandler;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import javax.inject.Named;
import javax.inject.Provider;
import javax.inject.Singleton;
import org.junit.Before;
/**
* Base class for end-to-end tests of a {@link Protocol}.
*
* <p>The end-to-end tests ensures that the business logic that a {@link Protocol} defines are
* correctly performed by various handlers attached to its pipeline. Non-business essential handlers
* should be excluded.
*
* <p>Subclass should implement an no-arg constructor that calls constructors of this class,
* providing the method reference of the {@link TestComponent} method to call to obtain the list of
* {@link ChannelHandler} providers for the {@link Protocol} to test, and optionally a set of {@link
* ChannelHandler} classes to exclude from testing.
*/
public abstract class ProtocolModuleTest {
protected static final ProxyConfig PROXY_CONFIG = getProxyConfig(LOCAL);
protected TestComponent testComponent;
/**
* Default list of handler classes that are not of interest in end-to-end testing of the {@link
* Protocol}.
*/
private static final ImmutableSet<Class<? extends ChannelHandler>> DEFAULT_EXCLUDED_HANDLERS =
ImmutableSet.of(
// The PROXY protocol is only used when the proxy is behind the GCP load balancer. It is
// not part of any business logic.
ProxyProtocolHandler.class,
// SSL is part of the business logic for some protocol (EPP for example), but its
// impact is isolated. Including it makes tests much more complicated. It should be tested
// separately in its own unit tests.
SslClientInitializer.class,
SslServerInitializer.class,
// These two handlers provide essential functionalities for the proxy to operate, but they
// do not directly implement the business logic of a well-defined protocol. They should be
// tested separately in their respective unit tests.
FullHttpRequestRelayHandler.class,
FullHttpResponseRelayHandler.class,
// This handler is tested in its own unit tests. It is installed in web whois redirect
// protocols. The end-to-end tests for the rest of the handlers in its pipeline need to
// be able to emit incoming requests out of the channel for assertions. Therefore this
// handler is removed from the pipeline.
WebWhoisRedirectHandler.class,
// The rest are not part of business logic and do not need to be tested, obviously.
LoggingHandler.class,
// Metrics instrumentation is tested separately.
BackendMetricsHandler.class,
// Quota management is tested separately.
WhoisQuotaHandler.class,
EppQuotaHandler.class,
ReadTimeoutHandler.class);
protected EmbeddedChannel channel;
/**
* Method reference to the component method that exposes the list of handler providers for the
* specific {@link Protocol} in interest.
*/
protected final Function<TestComponent, ImmutableList<Provider<? extends ChannelHandler>>>
handlerProvidersMethod;
protected final ImmutableSet<Class<? extends ChannelHandler>> excludedHandlers;
protected ProtocolModuleTest(
Function<TestComponent, ImmutableList<Provider<? extends ChannelHandler>>>
handlerProvidersMethod,
ImmutableSet<Class<? extends ChannelHandler>> excludedHandlers) {
this.handlerProvidersMethod = handlerProvidersMethod;
this.excludedHandlers = excludedHandlers;
}
protected ProtocolModuleTest(
Function<TestComponent, ImmutableList<Provider<? extends ChannelHandler>>>
handlerProvidersMethod) {
this(handlerProvidersMethod, DEFAULT_EXCLUDED_HANDLERS);
}
/** Excludes handler providers that are not of interested for testing. */
private ImmutableList<Provider<? extends ChannelHandler>> excludeHandlerProvidersForTesting(
ImmutableList<Provider<? extends ChannelHandler>> handlerProviders) {
return handlerProviders
.stream()
.filter(handlerProvider -> !excludedHandlers.contains(handlerProvider.get().getClass()))
.collect(toImmutableList());
}
protected void initializeChannel(Consumer<Channel> initializer) {
channel =
new EmbeddedChannel(
new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
initializer.accept(ch);
}
});
}
/** Adds handlers to the channel pipeline, excluding any one in {@link #excludedHandlers}. */
void addAllTestableHandlers(Channel ch) {
for (Provider<? extends ChannelHandler> handlerProvider :
excludeHandlerProvidersForTesting(handlerProvidersMethod.apply(testComponent))) {
ch.pipeline().addLast(handlerProvider.get());
}
}
static TestComponent makeTestComponent(FakeClock fakeClock) {
return DaggerProtocolModuleTest_TestComponent.builder()
.testModule(new TestModule(new FakeClock()))
.build();
}
@Before
public void setUp() throws Exception {
testComponent = makeTestComponent(new FakeClock());
initializeChannel(this::addAllTestableHandlers);
}
/**
* Component used to obtain the list of {@link ChannelHandler} providers for each {@link
* Protocol}.
*/
@Singleton
@Component(
modules = {
TestModule.class,
CertificateModule.class,
WhoisProtocolModule.class,
WebWhoisProtocolsModule.class,
EppProtocolModule.class,
HealthCheckProtocolModule.class,
HttpsRelayProtocolModule.class
})
interface TestComponent {
@WhoisProtocol
ImmutableList<Provider<? extends ChannelHandler>> whoisHandlers();
@EppProtocol
ImmutableList<Provider<? extends ChannelHandler>> eppHandlers();
@HealthCheckProtocol
ImmutableList<Provider<? extends ChannelHandler>> healthCheckHandlers();
@HttpsRelayProtocol
ImmutableList<Provider<? extends ChannelHandler>> httpsRelayHandlers();
@HttpWhoisProtocol
ImmutableList<Provider<? extends ChannelHandler>> httpWhoisHandlers();
}
/**
* Module that provides bindings used in tests.
*
* <p>Most of the binding provided in this module should be either a fake, or a {@link
* ChannelHandler} that is excluded, and annotated with {@code @Singleton}. This module acts as a
* replacement for {@link ProxyModule} used in production component. Providing a handler that is
* part of the business logic of a {@link Protocol} from this module is a sign that the binding
* should be provided in the respective {@code ProtocolModule} instead.
*/
@Module
static class TestModule {
/**
* A fake clock that is explicitly provided. Users can construct a module with a controller
* clock.
*/
private final FakeClock fakeClock;
TestModule(FakeClock fakeClock) {
this.fakeClock = fakeClock;
}
@Singleton
@Provides
static ProxyConfig provideProxyConfig() {
return getProxyConfig(LOCAL);
}
@Singleton
@Provides
static SslProvider provideSslProvider() {
return SslProvider.JDK;
}
@Singleton
@Provides
@Named("accessToken")
static Supplier<String> provideFakeAccessToken() {
return Suppliers.ofInstance("fake.test.token");
}
@Singleton
@Provides
static LoggingHandler provideLoggingHandler() {
return new LoggingHandler();
}
@Singleton
@Provides
Clock provideFakeClock() {
return fakeClock;
}
@Singleton
@Provides
static ExecutorService provideExecutorService() {
return MoreExecutors.newDirectExecutorService();
}
@Singleton
@Provides
static ScheduledExecutorService provideScheduledExecutorService() {
return Executors.newSingleThreadScheduledExecutor();
}
@Singleton
@Provides
static Environment provideEnvironment() {
return Environment.LOCAL;
}
// This method is only here to satisfy Dagger binding, but is never used. In test environment,
// it is the self-signed certificate and its key that end up being used.
@Singleton
@Provides
@Named("pemBytes")
static byte[] providePemBytes() {
return new byte[0];
}
}
}

View file

@ -0,0 +1,133 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy;
import static com.google.common.truth.Truth.assertThat;
import static google.registry.proxy.ProxyConfig.Environment.LOCAL;
import static google.registry.proxy.ProxyConfig.getProxyConfig;
import static google.registry.testing.JUnitBackports.assertThrows;
import com.beust.jcommander.ParameterException;
import google.registry.proxy.ProxyConfig.Environment;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link ProxyModule}. */
@RunWith(JUnit4.class)
public class ProxyModuleTest {
private static final ProxyConfig PROXY_CONFIG = getProxyConfig(LOCAL);
private final ProxyModule proxyModule = new ProxyModule();
@Test
public void testSuccess_parseArgs_defaultArgs() {
String[] args = {};
proxyModule.parse(args);
assertThat(proxyModule.provideWhoisPort(PROXY_CONFIG)).isEqualTo(PROXY_CONFIG.whois.port);
assertThat(proxyModule.provideEppPort(PROXY_CONFIG)).isEqualTo(PROXY_CONFIG.epp.port);
assertThat(proxyModule.provideHealthCheckPort(PROXY_CONFIG))
.isEqualTo(PROXY_CONFIG.healthCheck.port);
assertThat(proxyModule.provideHttpWhoisProtocol(PROXY_CONFIG))
.isEqualTo(PROXY_CONFIG.webWhois.httpPort);
assertThat(proxyModule.provideHttpsWhoisProtocol(PROXY_CONFIG))
.isEqualTo(PROXY_CONFIG.webWhois.httpsPort);
assertThat(proxyModule.provideEnvironment()).isEqualTo(LOCAL);
assertThat(proxyModule.log).isFalse();
}
@Test
public void testFailure_parseArgs_loggingInProduction() {
String[] args = {"--env", "production", "--log"};
IllegalArgumentException e =
assertThrows(
IllegalArgumentException.class,
() -> {
proxyModule.parse(args);
});
assertThat(e)
.hasMessageThat()
.isEqualTo("Logging cannot be enabled for production environment");
}
@Test
public void testFailure_parseArgs_wrongArguments() {
String[] args = {"--wrong_flag", "some_value"};
ParameterException thrown =
assertThrows(ParameterException.class, () -> proxyModule.parse(args));
assertThat(thrown).hasMessageThat().contains("--wrong_flag");
}
@Test
public void testSuccess_parseArgs_log() {
String[] args = {"--log"};
proxyModule.parse(args);
assertThat(proxyModule.log).isTrue();
}
@Test
public void testSuccess_parseArgs_customWhoisPort() {
String[] args = {"--whois", "12345"};
proxyModule.parse(args);
assertThat(proxyModule.provideWhoisPort(PROXY_CONFIG)).isEqualTo(12345);
}
@Test
public void testSuccess_parseArgs_customEppPort() {
String[] args = {"--epp", "22222"};
proxyModule.parse(args);
assertThat(proxyModule.provideEppPort(PROXY_CONFIG)).isEqualTo(22222);
}
@Test
public void testSuccess_parseArgs_customHealthCheckPort() {
String[] args = {"--health_check", "23456"};
proxyModule.parse(args);
assertThat(proxyModule.provideHealthCheckPort(PROXY_CONFIG)).isEqualTo(23456);
}
@Test
public void testSuccess_parseArgs_customhttpWhoisPort() {
String[] args = {"--http_whois", "12121"};
proxyModule.parse(args);
assertThat(proxyModule.provideHttpWhoisProtocol(PROXY_CONFIG)).isEqualTo(12121);
}
@Test
public void testSuccess_parseArgs_customhttpsWhoisPort() {
String[] args = {"--https_whois", "21212"};
proxyModule.parse(args);
assertThat(proxyModule.provideHttpsWhoisProtocol(PROXY_CONFIG)).isEqualTo(21212);
}
@Test
public void testSuccess_parseArgs_customEnvironment() {
String[] args = {"--env", "ALpHa"};
proxyModule.parse(args);
assertThat(proxyModule.provideEnvironment()).isEqualTo(Environment.ALPHA);
}
@Test
public void testFailure_parseArgs_wrongEnvironment() {
ParameterException e =
assertThrows(
ParameterException.class,
() -> {
String[] args = {"--env", "beta"};
proxyModule.parse(args);
});
assertThat(e).hasMessageThat().contains("Invalid value for --env parameter");
}
}

View file

@ -0,0 +1,147 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy;
import static com.google.common.truth.Truth.assertThat;
import static java.nio.charset.StandardCharsets.US_ASCII;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.FullHttpMessage;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpMessage;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.codec.http.cookie.ClientCookieEncoder;
import io.netty.handler.codec.http.cookie.Cookie;
import io.netty.handler.codec.http.cookie.ServerCookieEncoder;
/** Utility class for various helper methods used in testing. */
public class TestUtils {
public static FullHttpRequest makeHttpPostRequest(String content, String host, String path) {
ByteBuf buf = Unpooled.wrappedBuffer(content.getBytes(US_ASCII));
FullHttpRequest request =
new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, path, buf);
request
.headers()
.set("user-agent", "Proxy")
.set("host", host)
.setInt("content-length", buf.readableBytes());
return request;
}
public static FullHttpRequest makeHttpGetRequest(String host, String path) {
FullHttpRequest request =
new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, path);
request.headers().set("host", host).setInt("content-length", 0);
return request;
}
public static FullHttpResponse makeHttpResponse(String content, HttpResponseStatus status) {
ByteBuf buf = Unpooled.wrappedBuffer(content.getBytes(US_ASCII));
FullHttpResponse response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status, buf);
response.headers().setInt("content-length", buf.readableBytes());
return response;
}
public static FullHttpResponse makeHttpResponse(HttpResponseStatus status) {
FullHttpResponse response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status);
response.headers().setInt("content-length", 0);
return response;
}
public static FullHttpRequest makeWhoisHttpRequest(
String content, String host, String path, String accessToken) {
FullHttpRequest request = makeHttpPostRequest(content, host, path);
request
.headers()
.set("authorization", "Bearer " + accessToken)
.set("content-type", "text/plain")
.set("accept", "text/plain");
return request;
}
public static FullHttpRequest makeEppHttpRequest(
String content,
String host,
String path,
String accessToken,
String sslClientCertificateHash,
String clientAddress,
Cookie... cookies) {
FullHttpRequest request = makeHttpPostRequest(content, host, path);
request
.headers()
.set("authorization", "Bearer " + accessToken)
.set("content-type", "application/epp+xml")
.set("accept", "application/epp+xml")
.set("X-SSL-Certificate", sslClientCertificateHash)
.set("X-Forwarded-For", clientAddress);
if (cookies.length != 0) {
request.headers().set("cookie", ClientCookieEncoder.STRICT.encode(cookies));
}
return request;
}
public static FullHttpResponse makeWhoisHttpResponse(String content, HttpResponseStatus status) {
FullHttpResponse response = makeHttpResponse(content, status);
response.headers().set("content-type", "text/plain");
return response;
}
public static FullHttpResponse makeEppHttpResponse(
String content, HttpResponseStatus status, Cookie... cookies) {
FullHttpResponse response = makeHttpResponse(content, status);
response.headers().set("content-type", "application/epp+xml");
for (Cookie cookie : cookies) {
response.headers().add("set-cookie", ServerCookieEncoder.STRICT.encode(cookie));
}
return response;
}
/**
* Compares two {@link FullHttpMessage} for equivalency.
*
* <p>This method is needed because an HTTP message decoded and aggregated from inbound {@link
* ByteBuf} is of a different class than the one written to the outbound {@link ByteBuf}, and The
* {@link ByteBuf} implementations that hold the content of the HTTP messages are different, even
* though the actual content, headers, etc are the same.
*
* <p>This method is not type-safe, msg1 & msg2 can be a request and a response, respectively. Do
* not use this method directly.
*/
private static void assertHttpMessageEquivalent(HttpMessage msg1, HttpMessage msg2) {
assertThat(msg1.protocolVersion()).isEqualTo(msg2.protocolVersion());
assertThat(msg1.headers()).isEqualTo(msg2.headers());
if (msg1 instanceof FullHttpRequest && msg2 instanceof FullHttpRequest) {
assertThat(((FullHttpRequest) msg1).content()).isEqualTo(((FullHttpRequest) msg2).content());
}
}
public static void assertHttpResponseEquivalent(FullHttpResponse res1, FullHttpResponse res2) {
assertThat(res1.status()).isEqualTo(res2.status());
assertHttpMessageEquivalent(res1, res2);
}
public static void assertHttpRequestEquivalent(HttpRequest req1, HttpRequest req2) {
assertHttpMessageEquivalent(req1, req2);
}
}

View file

@ -0,0 +1,109 @@
// Copyright 2018 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy;
import static com.google.common.truth.Truth.assertThat;
import static google.registry.proxy.TestUtils.assertHttpRequestEquivalent;
import static google.registry.proxy.TestUtils.assertHttpResponseEquivalent;
import static google.registry.proxy.TestUtils.makeHttpGetRequest;
import static google.registry.proxy.TestUtils.makeHttpResponse;
import io.netty.buffer.ByteBuf;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.DefaultHttpRequest;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpClientCodec;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpResponseStatus;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/**
* End-to-end tests for {@link WebWhoisProtocolsModule}.
*
* <p>This protocol defines a connection in which the proxy behaves as a standard http server (sans
* the redirect operation which is excluded in end-to-end testing). Because non user-defined
* handlers are used, the tests here focus on verifying that the request written to the network
* socket by a client is reconstructed faithfully by the server, and vice versa, that the response a
* client decoded from incoming bytes is equivalent to the response sent by the server.
*
* <p>These tests only ensure that the server represented by this protocol is compatible with a
* client implementation provided by Netty itself. They test the self-consistency of various Netty
* handlers that deal with HTTP protocol, but not whether the handlers converts between bytes and
* HTTP messages correctly, which is presumed correct.
*
* <p>Only the HTTP redirect protocol is tested as both protocols share the same handlers except for
* those that are excluded ({@code SslServerInitializer}, {@code WebWhoisRedirectHandler}).
*/
@RunWith(JUnit4.class)
public class WebWhoisProtocolsModuleTest extends ProtocolModuleTest {
private static final String HOST = "test.tld";
private static final String PATH = "/path/to/test";
private final EmbeddedChannel clientChannel =
new EmbeddedChannel(new HttpClientCodec(), new HttpObjectAggregator(512 * 1024));
public WebWhoisProtocolsModuleTest() {
super(TestComponent::httpWhoisHandlers);
}
/**
* Tests that the client converts given {@link FullHttpRequest} to bytes, which is sent to the
* server and reconstructed to a {@link FullHttpRequest} that is equivalent to the original. Then
* test that the server converts given {@link FullHttpResponse} to bytes, which is sent to the
* client and reconstructed to a {@link FullHttpResponse} that is equivalent to the original.
*
* <p>The request and response equivalences are tested in the same method because the client codec
* tries to pair the response it receives with the request it sends. Receiving a response without
* sending a request first will cause the {@link HttpObjectAggregator} to fail to aggregate
* properly.
*/
private void requestAndRespondWithStatus(HttpResponseStatus status) {
ByteBuf buffer;
FullHttpRequest requestSent = makeHttpGetRequest(HOST, PATH);
// Need to send a copy as the content read index will advance after the request is written to
// the outbound of client channel, making comparison with requestReceived fail.
assertThat(clientChannel.writeOutbound(requestSent.copy())).isTrue();
buffer = clientChannel.readOutbound();
assertThat(channel.writeInbound(buffer)).isTrue();
// We only have a DefaultHttpRequest, not a FullHttpRequest because there is no HTTP aggregator
// in the server's pipeline. But it is fine as we are not interested in the content (payload) of
// the request, just its headers, which are contained in the DefaultHttpRequest.
DefaultHttpRequest requestReceived = channel.readInbound();
// Verify that the request received is the same as the request sent.
assertHttpRequestEquivalent(requestSent, requestReceived);
FullHttpResponse responseSent = makeHttpResponse(status);
assertThat(channel.writeOutbound(responseSent.copy())).isTrue();
buffer = channel.readOutbound();
assertThat(clientChannel.writeInbound(buffer)).isTrue();
FullHttpResponse responseReceived = clientChannel.readInbound();
// Verify that the request received is the same as the request sent.
assertHttpResponseEquivalent(responseSent, responseReceived);
}
@Test
public void testSuccess_OkResponse() {
requestAndRespondWithStatus(HttpResponseStatus.OK);
}
@Test
public void testSuccess_NonOkResponse() {
requestAndRespondWithStatus(HttpResponseStatus.BAD_REQUEST);
}
}

View file

@ -0,0 +1,164 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy;
import static com.google.common.truth.Truth.assertThat;
import static google.registry.proxy.TestUtils.makeWhoisHttpRequest;
import static google.registry.proxy.TestUtils.makeWhoisHttpResponse;
import static google.registry.testing.JUnitBackports.assertThrows;
import static java.nio.charset.StandardCharsets.US_ASCII;
import static java.util.stream.Collectors.joining;
import com.google.common.base.Throwables;
import google.registry.proxy.handler.HttpsRelayServiceHandler.NonOkHttpResponseException;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.handler.codec.EncoderException;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import java.nio.channels.ClosedChannelException;
import java.util.stream.Stream;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** End-to-end tests for {@link WhoisProtocolModule}. */
@RunWith(JUnit4.class)
public class WhoisProtocolModuleTest extends ProtocolModuleTest {
public WhoisProtocolModuleTest() {
super(TestComponent::whoisHandlers);
}
@Test
public void testSuccess_singleFrameInboundMessage() {
String inputString = "test.tld\r\n";
// Inbound message processed and passed along.
assertThat(channel.writeInbound(Unpooled.wrappedBuffer(inputString.getBytes(US_ASCII))))
.isTrue();
FullHttpRequest actualRequest = channel.readInbound();
FullHttpRequest expectedRequest =
makeWhoisHttpRequest(
"test.tld",
PROXY_CONFIG.whois.relayHost,
PROXY_CONFIG.whois.relayPath,
TestModule.provideFakeAccessToken().get());
assertThat(actualRequest).isEqualTo(expectedRequest);
assertThat(channel.isActive()).isTrue();
// Nothing more to read.
assertThat((Object) channel.readInbound()).isNull();
}
@Test
public void testSuccess_noNewlineInboundMessage() {
String inputString = "test.tld";
// No newline encountered, no message formed.
assertThat(channel.writeInbound(Unpooled.wrappedBuffer(inputString.getBytes(US_ASCII))))
.isFalse();
assertThat(channel.isActive()).isTrue();
}
@Test
public void testSuccess_multiFrameInboundMessage() {
String frame1 = "test";
String frame2 = "1.tld";
String frame3 = "\r\nte";
String frame4 = "st2.tld\r";
String frame5 = "\ntest3.tld";
// No newline yet.
assertThat(channel.writeInbound(Unpooled.wrappedBuffer(frame1.getBytes(US_ASCII)))).isFalse();
// Still no newline yet.
assertThat(channel.writeInbound(Unpooled.wrappedBuffer(frame2.getBytes(US_ASCII)))).isFalse();
// First newline encountered.
assertThat(channel.writeInbound(Unpooled.wrappedBuffer(frame3.getBytes(US_ASCII)))).isTrue();
FullHttpRequest actualRequest1 = channel.readInbound();
FullHttpRequest expectedRequest1 =
makeWhoisHttpRequest(
"test1.tld",
PROXY_CONFIG.whois.relayHost,
PROXY_CONFIG.whois.relayPath,
TestModule.provideFakeAccessToken().get());
assertThat(actualRequest1).isEqualTo(expectedRequest1);
// No more message at this point.
assertThat((Object) channel.readInbound()).isNull();
// More inbound bytes, but no newline.
assertThat(channel.writeInbound(Unpooled.wrappedBuffer(frame4.getBytes(US_ASCII)))).isFalse();
// Second message read.
assertThat(channel.writeInbound(Unpooled.wrappedBuffer(frame5.getBytes(US_ASCII)))).isTrue();
FullHttpRequest actualRequest2 = channel.readInbound();
FullHttpRequest expectedRequest2 =
makeWhoisHttpRequest(
"test2.tld",
PROXY_CONFIG.whois.relayHost,
PROXY_CONFIG.whois.relayPath,
TestModule.provideFakeAccessToken().get());
assertThat(actualRequest2).isEqualTo(expectedRequest2);
// The third message is not complete yet.
assertThat(channel.isActive()).isTrue();
assertThat((Object) channel.readInbound()).isNull();
}
@Test
public void testSuccess_inboundMessageTooLong() {
String inputString = Stream.generate(() -> "x").limit(513).collect(joining()) + "\r\n";
// Nothing gets propagated further.
assertThat(channel.writeInbound(Unpooled.wrappedBuffer(inputString.getBytes(US_ASCII))))
.isFalse();
// Connection is closed due to inbound message overflow.
assertThat(channel.isActive()).isFalse();
}
@Test
public void testSuccess_parseSingleOutboundHttpResponse() {
String outputString = "line1\r\nline2\r\n";
FullHttpResponse response = makeWhoisHttpResponse(outputString, HttpResponseStatus.OK);
// Http response parsed and passed along.
assertThat(channel.writeOutbound(response)).isTrue();
ByteBuf outputBuffer = channel.readOutbound();
assertThat(outputBuffer.toString(US_ASCII)).isEqualTo(outputString);
assertThat(channel.isActive()).isFalse();
// Nothing more to write.
assertThat((Object) channel.readOutbound()).isNull();
}
@Test
public void testFailure_parseOnlyFirstFromMultipleOutboundHttpResponse() {
String outputString1 = "line1\r\nline2\r\n";
String outputString2 = "line3\r\nline4\r\nline5\r\n";
FullHttpResponse response1 = makeWhoisHttpResponse(outputString1, HttpResponseStatus.OK);
FullHttpResponse response2 = makeWhoisHttpResponse(outputString2, HttpResponseStatus.OK);
assertThrows(ClosedChannelException.class, () -> channel.writeOutbound(response1, response2));
// First Http response parsed
ByteBuf outputBuffer1 = channel.readOutbound();
assertThat(outputBuffer1.toString(US_ASCII)).isEqualTo(outputString1);
// Second Http response not parsed because the connection is closed.
assertThat(channel.isActive()).isFalse();
assertThat((Object) channel.readOutbound()).isNull();
}
@Test
public void testFailure_outboundResponseStatusNotOK() {
String outputString = "line1\r\nline2\r\n";
FullHttpResponse response = makeWhoisHttpResponse(outputString, HttpResponseStatus.BAD_REQUEST);
EncoderException thrown =
assertThrows(EncoderException.class, () -> channel.writeOutbound(response));
assertThat(Throwables.getRootCause(thrown)).isInstanceOf(NonOkHttpResponseException.class);
assertThat(thrown).hasMessageThat().contains("400 Bad Request");
assertThat((Object) channel.readOutbound()).isNull();
assertThat(channel.isActive()).isFalse();
}
}

View file

@ -0,0 +1,233 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy.handler;
import static com.google.common.truth.Truth.assertThat;
import static google.registry.proxy.Protocol.PROTOCOL_KEY;
import static google.registry.proxy.TestUtils.assertHttpRequestEquivalent;
import static google.registry.proxy.TestUtils.assertHttpResponseEquivalent;
import static google.registry.proxy.TestUtils.makeHttpPostRequest;
import static google.registry.proxy.TestUtils.makeHttpResponse;
import static google.registry.proxy.handler.EppServiceHandler.CLIENT_CERTIFICATE_HASH_KEY;
import static google.registry.proxy.handler.RelayHandler.RELAY_CHANNEL_KEY;
import static google.registry.testing.JUnitBackports.assertThrows;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import com.google.common.collect.ImmutableList;
import google.registry.proxy.Protocol;
import google.registry.proxy.Protocol.BackendProtocol;
import google.registry.proxy.Protocol.FrontendProtocol;
import google.registry.proxy.metric.BackendMetrics;
import google.registry.testing.FakeClock;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import org.joda.time.DateTime;
import org.joda.time.Duration;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link BackendMetricsHandler}. */
@RunWith(JUnit4.class)
public class BackendMetricsHandlerTest {
private static final String HOST = "host.tld";
private static final String CLIENT_CERT_HASH = "blah12345";
private static final String RELAYED_PROTOCOL_NAME = "frontend protocol";
private final FakeClock fakeClock = new FakeClock();
private final BackendMetrics metrics = mock(BackendMetrics.class);
private final BackendMetricsHandler handler = new BackendMetricsHandler(fakeClock, metrics);
private final BackendProtocol backendProtocol =
Protocol.backendBuilder()
.name("backend protocol")
.host(HOST)
.port(1)
.handlerProviders(ImmutableList.of())
.build();
private final FrontendProtocol frontendProtocol =
Protocol.frontendBuilder()
.name(RELAYED_PROTOCOL_NAME)
.port(2)
.relayProtocol(backendProtocol)
.handlerProviders(ImmutableList.of())
.build();
private EmbeddedChannel channel;
@Before
public void setUp() {
EmbeddedChannel frontendChannel = new EmbeddedChannel();
frontendChannel.attr(PROTOCOL_KEY).set(frontendProtocol);
frontendChannel.attr(CLIENT_CERTIFICATE_HASH_KEY).set(CLIENT_CERT_HASH);
channel =
new EmbeddedChannel(
new ChannelInitializer<EmbeddedChannel>() {
@Override
protected void initChannel(EmbeddedChannel ch) throws Exception {
ch.attr(PROTOCOL_KEY).set(backendProtocol);
ch.attr(RELAY_CHANNEL_KEY).set(frontendChannel);
ch.pipeline().addLast(handler);
}
});
}
@Test
public void testFailure_outbound_wrongType() {
Object request = new Object();
IllegalArgumentException e =
assertThrows(IllegalArgumentException.class, () -> channel.writeOutbound(request));
assertThat(e).hasMessageThat().isEqualTo("Outgoing request must be FullHttpRequest.");
}
@Test
public void testFailure_inbound_wrongType() {
Object response = new Object();
IllegalArgumentException e =
assertThrows(IllegalArgumentException.class, () -> channel.writeInbound(response));
assertThat(e).hasMessageThat().isEqualTo("Incoming response must be FullHttpResponse.");
}
@Test
public void testSuccess_oneRequest() {
FullHttpRequest request = makeHttpPostRequest("some content", HOST, "/");
// outbound message passed to the next handler.
assertThat(channel.writeOutbound(request)).isTrue();
assertHttpRequestEquivalent(request, channel.readOutbound());
verify(metrics)
.requestSent(RELAYED_PROTOCOL_NAME, CLIENT_CERT_HASH, request.content().readableBytes());
verifyNoMoreInteractions(metrics);
}
@Test
public void testSuccess_oneRequest_oneResponse() {
FullHttpRequest request = makeHttpPostRequest("some request", HOST, "/");
FullHttpResponse response = makeHttpResponse("some response", HttpResponseStatus.OK);
// outbound message passed to the next handler.
assertThat(channel.writeOutbound(request)).isTrue();
assertHttpRequestEquivalent(request, channel.readOutbound());
fakeClock.advanceOneMilli();
// inbound message passed to the next handler.
assertThat(channel.writeInbound(response)).isTrue();
assertHttpResponseEquivalent(response, channel.readInbound());
verify(metrics)
.requestSent(RELAYED_PROTOCOL_NAME, CLIENT_CERT_HASH, request.content().readableBytes());
verify(metrics).responseReceived(RELAYED_PROTOCOL_NAME, CLIENT_CERT_HASH, response, 1);
verifyNoMoreInteractions(metrics);
}
@Test
public void testSuccess_badResponse() {
FullHttpRequest request = makeHttpPostRequest("some request", HOST, "/");
FullHttpResponse response =
makeHttpResponse("some bad response", HttpResponseStatus.BAD_REQUEST);
// outbound message passed to the next handler.
assertThat(channel.writeOutbound(request)).isTrue();
assertHttpRequestEquivalent(request, channel.readOutbound());
fakeClock.advanceOneMilli();
// inbound message passed to the next handler.
// Even though the response status is not OK, the metrics handler only logs it and pass it
// along to the next handler, which handles it.
assertThat(channel.writeInbound(response)).isTrue();
assertHttpResponseEquivalent(response, channel.readInbound());
verify(metrics)
.requestSent(RELAYED_PROTOCOL_NAME, CLIENT_CERT_HASH, request.content().readableBytes());
verify(metrics).responseReceived(RELAYED_PROTOCOL_NAME, CLIENT_CERT_HASH, response, 1);
verifyNoMoreInteractions(metrics);
}
@Test
public void testFailure_responseBeforeRequest() {
FullHttpResponse response = makeHttpResponse("phantom response", HttpResponseStatus.OK);
IllegalStateException e =
assertThrows(IllegalStateException.class, () -> channel.writeInbound(response));
assertThat(e).hasMessageThat().isEqualTo("Response received before request is sent.");
}
@Test
public void testSuccess_pipelinedResponses() {
FullHttpRequest request1 = makeHttpPostRequest("request 1", HOST, "/");
FullHttpResponse response1 = makeHttpResponse("response 1", HttpResponseStatus.OK);
FullHttpRequest request2 = makeHttpPostRequest("request 22", HOST, "/");
FullHttpResponse response2 = makeHttpResponse("response 22", HttpResponseStatus.OK);
FullHttpRequest request3 = makeHttpPostRequest("request 333", HOST, "/");
FullHttpResponse response3 = makeHttpResponse("response 333", HttpResponseStatus.OK);
// First request, time = 0
assertThat(channel.writeOutbound(request1)).isTrue();
assertHttpRequestEquivalent(request1, channel.readOutbound());
DateTime sentTime1 = fakeClock.nowUtc();
fakeClock.advanceBy(Duration.millis(5));
// Second request, time = 5
assertThat(channel.writeOutbound(request2)).isTrue();
assertHttpRequestEquivalent(request2, channel.readOutbound());
DateTime sentTime2 = fakeClock.nowUtc();
fakeClock.advanceBy(Duration.millis(7));
// First response, time = 12, latency = 12 - 0 = 12
assertThat(channel.writeInbound(response1)).isTrue();
assertHttpResponseEquivalent(response1, channel.readInbound());
DateTime receivedTime1 = fakeClock.nowUtc();
fakeClock.advanceBy(Duration.millis(11));
// Third request, time = 23
assertThat(channel.writeOutbound(request3)).isTrue();
assertHttpRequestEquivalent(request3, channel.readOutbound());
DateTime sentTime3 = fakeClock.nowUtc();
fakeClock.advanceBy(Duration.millis(2));
// Second response, time = 25, latency = 25 - 5 = 20
assertThat(channel.writeInbound(response2)).isTrue();
assertHttpResponseEquivalent(response2, channel.readInbound());
DateTime receivedTime2 = fakeClock.nowUtc();
fakeClock.advanceBy(Duration.millis(4));
// Third response, time = 29, latency = 29 - 23 = 6
assertThat(channel.writeInbound(response3)).isTrue();
assertHttpResponseEquivalent(response3, channel.readInbound());
DateTime receivedTime3 = fakeClock.nowUtc();
long latency1 = new Duration(sentTime1, receivedTime1).getMillis();
long latency2 = new Duration(sentTime2, receivedTime2).getMillis();
long latency3 = new Duration(sentTime3, receivedTime3).getMillis();
verify(metrics)
.requestSent(RELAYED_PROTOCOL_NAME, CLIENT_CERT_HASH, request1.content().readableBytes());
verify(metrics)
.requestSent(RELAYED_PROTOCOL_NAME, CLIENT_CERT_HASH, request2.content().readableBytes());
verify(metrics)
.requestSent(RELAYED_PROTOCOL_NAME, CLIENT_CERT_HASH, request3.content().readableBytes());
verify(metrics).responseReceived(RELAYED_PROTOCOL_NAME, CLIENT_CERT_HASH, response1, latency1);
verify(metrics).responseReceived(RELAYED_PROTOCOL_NAME, CLIENT_CERT_HASH, response2, latency2);
verify(metrics).responseReceived(RELAYED_PROTOCOL_NAME, CLIENT_CERT_HASH, response3, latency3);
verifyNoMoreInteractions(metrics);
}
}

View file

@ -0,0 +1,174 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy.handler;
import static com.google.common.truth.Truth.assertThat;
import static google.registry.proxy.Protocol.PROTOCOL_KEY;
import static google.registry.proxy.handler.EppServiceHandler.CLIENT_CERTIFICATE_HASH_KEY;
import static google.registry.testing.JUnitBackports.assertThrows;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableList;
import google.registry.proxy.Protocol;
import google.registry.proxy.handler.QuotaHandler.EppQuotaHandler;
import google.registry.proxy.handler.QuotaHandler.OverQuotaException;
import google.registry.proxy.metric.FrontendMetrics;
import google.registry.proxy.quota.QuotaManager;
import google.registry.proxy.quota.QuotaManager.QuotaRebate;
import google.registry.proxy.quota.QuotaManager.QuotaRequest;
import google.registry.proxy.quota.QuotaManager.QuotaResponse;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.embedded.EmbeddedChannel;
import org.joda.time.DateTime;
import org.joda.time.DateTimeZone;
import org.joda.time.Duration;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link EppQuotaHandler} */
@RunWith(JUnit4.class)
public class EppQuotaHandlerTest {
private final QuotaManager quotaManager = mock(QuotaManager.class);
private final FrontendMetrics metrics = mock(FrontendMetrics.class);
private final EppQuotaHandler handler = new EppQuotaHandler(quotaManager, metrics);
private final EmbeddedChannel channel = new EmbeddedChannel(handler);
private final String clientCertHash = "blah/123!";
private final DateTime now = DateTime.now(DateTimeZone.UTC);
private final Object message = new Object();
private void setProtocol(Channel channel) {
channel
.attr(PROTOCOL_KEY)
.set(
Protocol.frontendBuilder()
.name("epp")
.port(12345)
.handlerProviders(ImmutableList.of())
.relayProtocol(
Protocol.backendBuilder()
.name("backend")
.host("host.tld")
.port(1234)
.handlerProviders(ImmutableList.of())
.build())
.build());
}
@Before
public void setUp() {
channel.attr(CLIENT_CERTIFICATE_HASH_KEY).set(clientCertHash);
setProtocol(channel);
}
@Test
public void testSuccess_quotaGrantedAndReturned() {
when(quotaManager.acquireQuota(QuotaRequest.create(clientCertHash)))
.thenReturn(QuotaResponse.create(true, clientCertHash, now));
// First read, acquire quota.
assertThat(channel.writeInbound(message)).isTrue();
assertThat((Object) channel.readInbound()).isEqualTo(message);
assertThat(channel.isActive()).isTrue();
verify(quotaManager).acquireQuota(QuotaRequest.create(clientCertHash));
// Second read, should not acquire quota again.
Object newMessage = new Object();
assertThat(channel.writeInbound(newMessage)).isTrue();
assertThat((Object) channel.readInbound()).isEqualTo(newMessage);
verifyNoMoreInteractions(quotaManager);
// Channel closed, release quota.
ChannelFuture unusedFuture = channel.close();
verify(quotaManager)
.releaseQuota(QuotaRebate.create(QuotaResponse.create(true, clientCertHash, now)));
verifyNoMoreInteractions(quotaManager);
}
@Test
public void testFailure_quotaNotGranted() {
when(quotaManager.acquireQuota(QuotaRequest.create(clientCertHash)))
.thenReturn(QuotaResponse.create(false, clientCertHash, now));
OverQuotaException e =
assertThrows(OverQuotaException.class, () -> channel.writeInbound(message));
ChannelFuture unusedFuture = channel.close();
assertThat(e).hasMessageThat().contains(clientCertHash);
verify(quotaManager).acquireQuota(QuotaRequest.create(clientCertHash));
// Make sure that quotaManager.releaseQuota() is not called when the channel closes.
verifyNoMoreInteractions(quotaManager);
verify(metrics).registerQuotaRejection("epp", clientCertHash);
verifyNoMoreInteractions(metrics);
}
@Test
public void testSuccess_twoChannels_twoUserIds() {
// Set up another user.
final EppQuotaHandler otherHandler = new EppQuotaHandler(quotaManager, metrics);
final EmbeddedChannel otherChannel = new EmbeddedChannel(otherHandler);
final String otherClientCertHash = "hola@9x";
otherChannel.attr(CLIENT_CERTIFICATE_HASH_KEY).set(otherClientCertHash);
setProtocol(otherChannel);
final DateTime later = now.plus(Duration.standardSeconds(1));
when(quotaManager.acquireQuota(QuotaRequest.create(clientCertHash)))
.thenReturn(QuotaResponse.create(true, clientCertHash, now));
when(quotaManager.acquireQuota(QuotaRequest.create(otherClientCertHash)))
.thenReturn(QuotaResponse.create(false, otherClientCertHash, later));
// Allows the first user.
assertThat(channel.writeInbound(message)).isTrue();
assertThat((Object) channel.readInbound()).isEqualTo(message);
assertThat(channel.isActive()).isTrue();
// Blocks the second user.
OverQuotaException e =
assertThrows(OverQuotaException.class, () -> otherChannel.writeInbound(message));
assertThat(e).hasMessageThat().contains(otherClientCertHash);
verify(metrics).registerQuotaRejection("epp", otherClientCertHash);
verifyNoMoreInteractions(metrics);
}
@Test
public void testSuccess_twoChannels_sameUserIds() {
// Set up another channel for the same user.
final EppQuotaHandler otherHandler = new EppQuotaHandler(quotaManager, metrics);
final EmbeddedChannel otherChannel = new EmbeddedChannel(otherHandler);
otherChannel.attr(CLIENT_CERTIFICATE_HASH_KEY).set(clientCertHash);
setProtocol(otherChannel);
final DateTime later = now.plus(Duration.standardSeconds(1));
when(quotaManager.acquireQuota(QuotaRequest.create(clientCertHash)))
.thenReturn(QuotaResponse.create(true, clientCertHash, now))
.thenReturn(QuotaResponse.create(false, clientCertHash, later));
// Allows the first channel.
assertThat(channel.writeInbound(message)).isTrue();
assertThat((Object) channel.readInbound()).isEqualTo(message);
assertThat(channel.isActive()).isTrue();
// Blocks the second channel.
OverQuotaException e =
assertThrows(OverQuotaException.class, () -> otherChannel.writeInbound(message));
assertThat(e).hasMessageThat().contains(clientCertHash);
verify(metrics).registerQuotaRejection("epp", clientCertHash);
verifyNoMoreInteractions(metrics);
}
}

View file

@ -0,0 +1,329 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy.handler;
import static com.google.common.truth.Truth.assertThat;
import static google.registry.proxy.TestUtils.assertHttpRequestEquivalent;
import static google.registry.proxy.TestUtils.makeEppHttpResponse;
import static google.registry.proxy.handler.ProxyProtocolHandler.REMOTE_ADDRESS_KEY;
import static google.registry.proxy.handler.SslServerInitializer.CLIENT_CERTIFICATE_PROMISE_KEY;
import static google.registry.testing.JUnitBackports.assertThrows;
import static google.registry.util.X509Utils.getCertificateHash;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import com.google.common.base.Throwables;
import google.registry.proxy.TestUtils;
import google.registry.proxy.handler.HttpsRelayServiceHandler.NonOkHttpResponseException;
import google.registry.proxy.metric.FrontendMetrics;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.DefaultChannelId;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.EncoderException;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.cookie.Cookie;
import io.netty.handler.codec.http.cookie.DefaultCookie;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import io.netty.util.concurrent.Promise;
import java.security.cert.X509Certificate;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link EppServiceHandler}. */
@RunWith(JUnit4.class)
public class EppServiceHandlerTest {
private static final String HELLO =
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n"
+ "<epp xmlns=\"urn:ietf:params:xml:ns:epp-1.0\">\n"
+ " <hello/>\n"
+ "</epp>\n";
private static final String RELAY_HOST = "registry.example.tld";
private static final String RELAY_PATH = "/epp";
private static final String ACCESS_TOKEN = "this.access.token";
private static final String CLIENT_ADDRESS = "epp.client.tld";
private static final String PROTOCOL = "epp";
private X509Certificate clientCertificate;
private final FrontendMetrics metrics = mock(FrontendMetrics.class);
private final EppServiceHandler eppServiceHandler =
new EppServiceHandler(
RELAY_HOST,
RELAY_PATH,
() -> ACCESS_TOKEN,
HELLO.getBytes(UTF_8),
metrics);
private EmbeddedChannel channel;
private void setHandshakeSuccess(EmbeddedChannel channel, X509Certificate certificate)
throws Exception {
Promise<X509Certificate> unusedPromise =
channel.attr(CLIENT_CERTIFICATE_PROMISE_KEY).get().setSuccess(certificate);
}
private void setHandshakeSuccess() throws Exception {
setHandshakeSuccess(channel, clientCertificate);
}
private void setHandshakeFailure(EmbeddedChannel channel) throws Exception {
Promise<X509Certificate> unusedPromise =
channel
.attr(CLIENT_CERTIFICATE_PROMISE_KEY)
.get()
.setFailure(new Exception("Handshake Failure"));
}
private void setHandshakeFailure() throws Exception {
setHandshakeFailure(channel);
}
private FullHttpRequest makeEppHttpRequest(String content, Cookie... cookies) {
return TestUtils.makeEppHttpRequest(
content,
RELAY_HOST,
RELAY_PATH,
ACCESS_TOKEN,
getCertificateHash(clientCertificate),
CLIENT_ADDRESS,
cookies);
}
@Before
public void setUp() throws Exception {
clientCertificate = new SelfSignedCertificate().cert();
channel = setUpNewChannel(eppServiceHandler);
}
private EmbeddedChannel setUpNewChannel(EppServiceHandler handler) throws Exception {
return new EmbeddedChannel(
DefaultChannelId.newInstance(),
new ChannelInitializer<EmbeddedChannel>() {
@Override
protected void initChannel(EmbeddedChannel ch) throws Exception {
ch.attr(REMOTE_ADDRESS_KEY).set(CLIENT_ADDRESS);
ch.attr(CLIENT_CERTIFICATE_PROMISE_KEY).set(ch.eventLoop().newPromise());
ch.pipeline().addLast(handler);
}
});
}
@Test
public void testSuccess_connectionMetrics_oneConnection() throws Exception {
setHandshakeSuccess();
String certHash = getCertificateHash(clientCertificate);
assertThat(channel.isActive()).isTrue();
verify(metrics).registerActiveConnection(PROTOCOL, certHash, channel);
verifyNoMoreInteractions(metrics);
}
@Test
public void testSuccess_connectionMetrics_twoConnections_sameClient() throws Exception {
setHandshakeSuccess();
String certHash = getCertificateHash(clientCertificate);
assertThat(channel.isActive()).isTrue();
// Setup the second channel.
EppServiceHandler eppServiceHandler2 =
new EppServiceHandler(
RELAY_HOST,
RELAY_PATH,
() -> ACCESS_TOKEN,
HELLO.getBytes(UTF_8),
metrics);
EmbeddedChannel channel2 = setUpNewChannel(eppServiceHandler2);
setHandshakeSuccess(channel2, clientCertificate);
assertThat(channel2.isActive()).isTrue();
verify(metrics).registerActiveConnection(PROTOCOL, certHash, channel);
verify(metrics).registerActiveConnection(PROTOCOL, certHash, channel2);
verifyNoMoreInteractions(metrics);
}
@Test
public void testSuccess_connectionMetrics_twoConnections_differentClients() throws Exception {
setHandshakeSuccess();
String certHash = getCertificateHash(clientCertificate);
assertThat(channel.isActive()).isTrue();
// Setup the second channel.
EppServiceHandler eppServiceHandler2 =
new EppServiceHandler(
RELAY_HOST,
RELAY_PATH,
() -> ACCESS_TOKEN,
HELLO.getBytes(UTF_8),
metrics);
EmbeddedChannel channel2 = setUpNewChannel(eppServiceHandler2);
X509Certificate clientCertificate2 = new SelfSignedCertificate().cert();
setHandshakeSuccess(channel2, clientCertificate2);
String certHash2 = getCertificateHash(clientCertificate2);
assertThat(channel2.isActive()).isTrue();
verify(metrics).registerActiveConnection(PROTOCOL, certHash, channel);
verify(metrics).registerActiveConnection(PROTOCOL, certHash2, channel2);
verifyNoMoreInteractions(metrics);
}
@Test
public void testSuccess_sendHelloUponHandshakeSuccess() throws Exception {
// Nothing to pass to the next handler.
assertThat((Object) channel.readInbound()).isNull();
setHandshakeSuccess();
// hello bytes should be passed to the next handler.
FullHttpRequest helloRequest = channel.readInbound();
assertThat(helloRequest).isEqualTo(makeEppHttpRequest(HELLO));
// Nothing further to pass to the next handler.
assertThat((Object) channel.readInbound()).isNull();
assertThat(channel.isActive()).isTrue();
}
@Test
public void testSuccess_disconnectUponHandshakeFailure() throws Exception {
// Nothing to pass to the next handler.
assertThat((Object) channel.readInbound()).isNull();
setHandshakeFailure();
assertThat(channel.isActive()).isFalse();
}
@Test
public void testSuccess_sendRequestToNextHandler() throws Exception {
setHandshakeSuccess();
// First inbound message is hello.
channel.readInbound();
String content = "<epp>stuff</epp>";
channel.writeInbound(Unpooled.wrappedBuffer(content.getBytes(UTF_8)));
FullHttpRequest request = channel.readInbound();
assertThat(request).isEqualTo(makeEppHttpRequest(content));
// Nothing further to pass to the next handler.
assertThat((Object) channel.readInbound()).isNull();
assertThat(channel.isActive()).isTrue();
}
@Test
public void testSuccess_sendResponseToNextHandler() throws Exception {
setHandshakeSuccess();
String content = "<epp>stuff</epp>";
channel.writeOutbound(makeEppHttpResponse(content, HttpResponseStatus.OK));
ByteBuf response = channel.readOutbound();
assertThat(response).isEqualTo(Unpooled.wrappedBuffer(content.getBytes(UTF_8)));
// Nothing further to pass to the next handler.
assertThat((Object) channel.readOutbound()).isNull();
assertThat(channel.isActive()).isTrue();
}
@Test
public void testSuccess_sendResponseToNextHandler_andDisconnect() throws Exception {
setHandshakeSuccess();
String content = "<epp>stuff</epp>";
HttpResponse response = makeEppHttpResponse(content, HttpResponseStatus.OK);
response.headers().set("Epp-Session", "close");
channel.writeOutbound(response);
ByteBuf expectedResponse = channel.readOutbound();
assertThat(Unpooled.wrappedBuffer(content.getBytes(UTF_8))).isEqualTo(expectedResponse);
// Nothing further to pass to the next handler.
assertThat((Object) channel.readOutbound()).isNull();
// Channel is disconnected.
assertThat(channel.isActive()).isFalse();
}
@Test
public void testFailure_disconnectOnNonOKResponseStatus() throws Exception {
setHandshakeSuccess();
String content = "<epp>stuff</epp>";
EncoderException thrown =
assertThrows(
EncoderException.class,
() ->
channel.writeOutbound(
makeEppHttpResponse(content, HttpResponseStatus.BAD_REQUEST)));
assertThat(Throwables.getRootCause(thrown)).isInstanceOf(NonOkHttpResponseException.class);
assertThat(thrown).hasMessageThat().contains(HttpResponseStatus.BAD_REQUEST.toString());
assertThat((Object) channel.readOutbound()).isNull();
assertThat(channel.isActive()).isFalse();
}
@Test
public void testSuccess_setCookies() throws Exception {
setHandshakeSuccess();
// First inbound message is hello.
channel.readInbound();
String responseContent = "<epp>response</epp>";
Cookie cookie1 = new DefaultCookie("name1", "value1");
Cookie cookie2 = new DefaultCookie("name2", "value2");
channel.writeOutbound(
makeEppHttpResponse(responseContent, HttpResponseStatus.OK, cookie1, cookie2));
ByteBuf response = channel.readOutbound();
assertThat(response).isEqualTo(Unpooled.wrappedBuffer(responseContent.getBytes(UTF_8)));
String requestContent = "<epp>request</epp>";
channel.writeInbound(Unpooled.wrappedBuffer(requestContent.getBytes(UTF_8)));
FullHttpRequest request = channel.readInbound();
assertHttpRequestEquivalent(request, makeEppHttpRequest(requestContent, cookie1, cookie2));
// Nothing further to pass to the next handler.
assertThat((Object) channel.readInbound()).isNull();
assertThat((Object) channel.readOutbound()).isNull();
assertThat(channel.isActive()).isTrue();
}
@Test
public void testSuccess_updateCookies() throws Exception {
setHandshakeSuccess();
// First inbound message is hello.
channel.readInbound();
String responseContent1 = "<epp>response1</epp>";
Cookie cookie1 = new DefaultCookie("name1", "value1");
Cookie cookie2 = new DefaultCookie("name2", "value2");
// First response written.
channel.writeOutbound(
makeEppHttpResponse(responseContent1, HttpResponseStatus.OK, cookie1, cookie2));
channel.readOutbound();
String requestContent1 = "<epp>request1</epp>";
// First request written.
channel.writeInbound(Unpooled.wrappedBuffer(requestContent1.getBytes(UTF_8)));
FullHttpRequest request1 = channel.readInbound();
assertHttpRequestEquivalent(request1, makeEppHttpRequest(requestContent1, cookie1, cookie2));
String responseContent2 = "<epp>response2</epp>";
Cookie cookie3 = new DefaultCookie("name3", "value3");
Cookie newCookie2 = new DefaultCookie("name2", "newValue");
// Second response written.
channel.writeOutbound(
makeEppHttpResponse(responseContent2, HttpResponseStatus.OK, cookie3, newCookie2));
channel.readOutbound();
String requestContent2 = "<epp>request2</epp>";
// Second request written.
channel.writeInbound(Unpooled.wrappedBuffer(requestContent2.getBytes(UTF_8)));
FullHttpRequest request2 = channel.readInbound();
// Cookies in second request should be updated.
assertHttpRequestEquivalent(
request2, makeEppHttpRequest(requestContent2, cookie1, newCookie2, cookie3));
// Nothing further to pass to the next handler.
assertThat((Object) channel.readInbound()).isNull();
assertThat((Object) channel.readOutbound()).isNull();
assertThat(channel.isActive()).isTrue();
}
}

View file

@ -0,0 +1,58 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy.handler;
import static com.google.common.truth.Truth.assertThat;
import static java.nio.charset.StandardCharsets.US_ASCII;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedChannel;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link HealthCheckHandler}. */
@RunWith(JUnit4.class)
public class HealthCheckHandlerTest {
private static final String CHECK_REQ = "REQUEST";
private static final String CHECK_RES = "RESPONSE";
private final HealthCheckHandler healthCheckHandler =
new HealthCheckHandler(CHECK_REQ, CHECK_RES);
private final EmbeddedChannel channel = new EmbeddedChannel(healthCheckHandler);
@Test
public void testSuccess_ResponseSent() {
ByteBuf input = Unpooled.wrappedBuffer(CHECK_REQ.getBytes(US_ASCII));
// No inbound message passed to the next handler.
assertThat(channel.writeInbound(input)).isFalse();
ByteBuf output = channel.readOutbound();
assertThat(channel.isActive()).isTrue();
assertThat(output.toString(US_ASCII)).isEqualTo(CHECK_RES);
}
@Test
public void testSuccess_IgnoreUnrecognizedRequest() {
String unrecognizedInput = "1234567";
ByteBuf input = Unpooled.wrappedBuffer(unrecognizedInput.getBytes(US_ASCII));
// No inbound message passed to the next handler.
assertThat(channel.writeInbound(input)).isFalse();
// No response is sent.
assertThat(channel.isActive()).isTrue();
assertThat((Object) channel.readOutbound()).isNull();
}
}

View file

@ -0,0 +1,223 @@
// Copyright 2018 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy.handler;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.truth.Truth.assertThat;
import static google.registry.proxy.Protocol.PROTOCOL_KEY;
import static google.registry.testing.JUnitBackports.assertThrows;
import static java.nio.charset.StandardCharsets.US_ASCII;
import static java.nio.charset.StandardCharsets.UTF_8;
import com.google.common.base.Throwables;
import com.google.common.truth.ThrowableSubject;
import google.registry.proxy.Protocol.BackendProtocol;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalServerChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.util.ReferenceCountUtil;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import org.junit.rules.ExternalResource;
/**
* Helper for setting up and testing client / server connection with netty.
*
* <p>Used in {@link SslClientInitializerTest} and {@link SslServerInitializerTest}.
*/
final class NettyRule extends ExternalResource {
// All I/O operations are done inside the single thread within this event loop group, which is
// different from the main test thread. Therefore synchronizations are required to make sure that
// certain I/O activities are finished when assertions are performed.
private final EventLoopGroup eventLoopGroup = new NioEventLoopGroup(1);
// Handler attached to server's channel to record the request received.
private EchoHandler echoHandler;
// Handler attached to client's channel to record the response received.
private DumpHandler dumpHandler;
private Channel channel;
/** Sets up a server channel bound to the given local address. */
void setUpServer(LocalAddress localAddress, ChannelHandler handler) {
checkState(echoHandler == null, "Can't call setUpServer twice");
echoHandler = new EchoHandler();
ChannelInitializer<LocalChannel> serverInitializer =
new ChannelInitializer<LocalChannel>() {
@Override
protected void initChannel(LocalChannel ch) {
// Add the given handler
ch.pipeline().addLast(handler);
// Add the "echoHandler" last to log the incoming message and send it back
ch.pipeline().addLast(echoHandler);
}
};
ServerBootstrap sb =
new ServerBootstrap()
.group(eventLoopGroup)
.channel(LocalServerChannel.class)
.childHandler(serverInitializer);
ChannelFuture unusedFuture = sb.bind(localAddress).syncUninterruptibly();
}
/** Sets up a client channel connecting to the give local address. */
void setUpClient(
LocalAddress localAddress,
BackendProtocol protocol,
ChannelHandler handler) {
checkState(echoHandler != null, "Must call setUpServer before setUpClient");
checkState(dumpHandler == null, "Can't call setUpClient twice");
dumpHandler = new DumpHandler();
ChannelInitializer<LocalChannel> clientInitializer =
new ChannelInitializer<LocalChannel>() {
@Override
protected void initChannel(LocalChannel ch) throws Exception {
// Add the given handler
ch.pipeline().addLast(handler);
// Add the "dumpHandler" last to log the incoming message
ch.pipeline().addLast(dumpHandler);
}
};
Bootstrap b =
new Bootstrap()
.group(eventLoopGroup)
.channel(LocalChannel.class)
.handler(clientInitializer)
.attr(PROTOCOL_KEY, protocol);
channel = b.connect(localAddress).syncUninterruptibly().channel();
}
void checkReady() {
checkState(channel != null, "Must call setUpClient to finish NettyRule setup");
}
/**
* Test that a message can go through, both inbound and outbound.
*
* <p>The client writes the message to the server, which echos it back and saves the string in its
* promise. The client receives the echo and saves it in its promise. All these activities happens
* in the I/O thread, and this call itself returns immediately.
*/
void assertThatMessagesWork() throws Exception {
checkReady();
assertThat(channel.isActive()).isTrue();
writeToChannelAndFlush(channel, "Hello, world!");
assertThat(echoHandler.getRequestFuture().get()).isEqualTo("Hello, world!");
assertThat(dumpHandler.getResponseFuture().get()).isEqualTo("Hello, world!");
}
Channel getChannel() {
checkReady();
return channel;
}
ThrowableSubject assertThatServerRootCause() {
checkReady();
return assertThat(
Throwables.getRootCause(
assertThrows(ExecutionException.class, () -> echoHandler.getRequestFuture().get())));
}
ThrowableSubject assertThatClientRootCause() {
checkReady();
return assertThat(
Throwables.getRootCause(
assertThrows(ExecutionException.class, () -> dumpHandler.getResponseFuture().get())));
}
/**
* A handler that echoes back its inbound message. The message is also saved in a promise for
* inspection later.
*/
private static class EchoHandler extends ChannelInboundHandlerAdapter {
private final CompletableFuture<String> requestFuture = new CompletableFuture<>();
Future<String> getRequestFuture() {
return requestFuture;
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
// In the test we only send messages of type ByteBuf.
assertThat(msg).isInstanceOf(ByteBuf.class);
String request = ((ByteBuf) msg).toString(UTF_8);
// After the message is written back to the client, fulfill the promise.
ChannelFuture unusedFuture =
ctx.writeAndFlush(msg).addListener(f -> requestFuture.complete(request));
}
/** Saves any inbound error as the cause of the promise failure. */
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
ChannelFuture unusedFuture =
ctx.channel().closeFuture().addListener(f -> requestFuture.completeExceptionally(cause));
}
}
/** A handler that dumps its inbound message to a promise that can be inspected later. */
private static class DumpHandler extends ChannelInboundHandlerAdapter {
private final CompletableFuture<String> responseFuture = new CompletableFuture<>();
Future<String> getResponseFuture() {
return responseFuture;
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
// In the test we only send messages of type ByteBuf.
assertThat(msg).isInstanceOf(ByteBuf.class);
String response = ((ByteBuf) msg).toString(UTF_8);
// There is no more use of this message, we should release its reference count so that it
// can be more effectively garbage collected by Netty.
ReferenceCountUtil.release(msg);
// Save the string in the promise and make it as complete.
responseFuture.complete(response);
}
/** Saves any inbound error into the failure cause of the promise. */
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
ctx.channel().closeFuture().addListener(f -> responseFuture.completeExceptionally(cause));
}
}
@Override
protected void after() {
Future<?> unusedFuture = eventLoopGroup.shutdownGracefully();
}
private static void writeToChannelAndFlush(Channel channel, String data) {
ChannelFuture unusedFuture =
channel.writeAndFlush(Unpooled.wrappedBuffer(data.getBytes(US_ASCII)));
}
}

View file

@ -0,0 +1,131 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy.handler;
import static com.google.common.truth.Truth.assertThat;
import static google.registry.proxy.handler.ProxyProtocolHandler.REMOTE_ADDRESS_KEY;
import static java.nio.charset.StandardCharsets.UTF_8;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedChannel;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link ProxyProtocolHandler}. */
@RunWith(JUnit4.class)
public class ProxyProtocolHandlerTest {
private static final String HEADER_TEMPLATE = "PROXY TCP%d %s %s %s %s\r\n";
private final ProxyProtocolHandler handler = new ProxyProtocolHandler();
private final EmbeddedChannel channel = new EmbeddedChannel(handler);
private String header;
@Test
public void testSuccess_proxyHeaderPresent_singleFrame() {
header = String.format(HEADER_TEMPLATE, 4, "172.0.0.1", "255.255.255.255", "234", "123");
String message = "some message";
// Header processed, rest of the message passed along.
assertThat(channel.writeInbound(Unpooled.wrappedBuffer((header + message).getBytes(UTF_8))))
.isTrue();
assertThat(((ByteBuf) channel.readInbound()).toString(UTF_8)).isEqualTo(message);
assertThat(channel.attr(REMOTE_ADDRESS_KEY).get()).isEqualTo("172.0.0.1");
assertThat(channel.pipeline().get(ProxyProtocolHandler.class)).isNull();
assertThat(channel.isActive()).isTrue();
}
@Test
public void testSuccess_proxyHeaderUnknownSource_singleFrame() {
header = "PROXY UNKNOWN\r\n";
String message = "some message";
// Header processed, rest of the message passed along.
assertThat(channel.writeInbound(Unpooled.wrappedBuffer((header + message).getBytes(UTF_8))))
.isTrue();
assertThat(((ByteBuf) channel.readInbound()).toString(UTF_8)).isEqualTo(message);
assertThat(channel.attr(REMOTE_ADDRESS_KEY).get()).isEqualTo("0.0.0.0");
assertThat(channel.pipeline().get(ProxyProtocolHandler.class)).isNull();
assertThat(channel.isActive()).isTrue();
}
@Test
public void testSuccess_proxyHeaderPresent_multipleFrames() {
header = String.format(HEADER_TEMPLATE, 4, "172.0.0.1", "255.255.255.255", "234", "123");
String frame1 = header.substring(0, 4);
String frame2 = header.substring(4, 7);
String frame3 = header.substring(7, 15);
String frame4 = header.substring(15, header.length() - 1);
String frame5 = header.substring(header.length() - 1) + "some message";
// Have not had enough bytes to determine the presence of a header, no message passed along.
assertThat(channel.writeInbound(Unpooled.wrappedBuffer(frame1.getBytes(UTF_8)))).isFalse();
// Have not had enough bytes to determine the end a header, no message passed along.
assertThat(channel.writeInbound(Unpooled.wrappedBuffer(frame2.getBytes(UTF_8)))).isFalse();
assertThat(channel.writeInbound(Unpooled.wrappedBuffer(frame3.getBytes(UTF_8)))).isFalse();
assertThat(channel.writeInbound(Unpooled.wrappedBuffer(frame4.getBytes(UTF_8)))).isFalse();
// Now there are enough bytes to construct a header.
assertThat(channel.writeInbound(Unpooled.wrappedBuffer(frame5.getBytes(UTF_8)))).isTrue();
assertThat(((ByteBuf) channel.readInbound()).toString(UTF_8)).isEqualTo("some message");
assertThat(channel.attr(REMOTE_ADDRESS_KEY).get()).isEqualTo("172.0.0.1");
assertThat(channel.pipeline().get(ProxyProtocolHandler.class)).isNull();
assertThat(channel.isActive()).isTrue();
}
@Test
public void testSuccess_proxyHeaderPresent_singleFrame_ipv6() {
header =
String.format(HEADER_TEMPLATE, 6, "2001:db8:0:1:1:1:1:1", "0:0:0:0:0:0:0:1", "234", "123");
String message = "some message";
// Header processed, rest of the message passed along.
assertThat(channel.writeInbound(Unpooled.wrappedBuffer((header + message).getBytes(UTF_8))))
.isTrue();
assertThat(((ByteBuf) channel.readInbound()).toString(UTF_8)).isEqualTo(message);
assertThat(channel.attr(REMOTE_ADDRESS_KEY).get()).isEqualTo("2001:db8:0:1:1:1:1:1");
assertThat(channel.pipeline().get(ProxyProtocolHandler.class)).isNull();
assertThat(channel.isActive()).isTrue();
}
@Test
public void testSuccess_proxyHeaderNotPresent_singleFrame() {
String message = "some message";
// No header present, rest of the message passed along.
assertThat(channel.writeInbound(Unpooled.wrappedBuffer(message.getBytes(UTF_8)))).isTrue();
assertThat(((ByteBuf) channel.readInbound()).toString(UTF_8)).isEqualTo(message);
assertThat(channel.attr(REMOTE_ADDRESS_KEY).get()).isNull();
assertThat(channel.pipeline().get(ProxyProtocolHandler.class)).isNull();
assertThat(channel.isActive()).isTrue();
}
@Test
public void testSuccess_proxyHeaderNotPresent_multipleFrames() {
String frame1 = "som";
String frame2 = "e mess";
String frame3 = "age\nis not";
String frame4 = "meant to be good.\n";
// Have not had enough bytes to determine the presence of a header, no message passed along.
assertThat(channel.writeInbound(Unpooled.wrappedBuffer(frame1.getBytes(UTF_8)))).isFalse();
// Now we have more than five bytes to determine if it starts with "PROXY"
assertThat(channel.writeInbound(Unpooled.wrappedBuffer(frame2.getBytes(UTF_8)))).isTrue();
assertThat(((ByteBuf) channel.readInbound()).toString(UTF_8)).isEqualTo(frame1 + frame2);
assertThat(channel.writeInbound(Unpooled.wrappedBuffer(frame3.getBytes(UTF_8)))).isTrue();
assertThat(((ByteBuf) channel.readInbound()).toString(UTF_8)).isEqualTo(frame3);
assertThat(channel.writeInbound(Unpooled.wrappedBuffer(frame4.getBytes(UTF_8)))).isTrue();
assertThat(((ByteBuf) channel.readInbound()).toString(UTF_8)).isEqualTo(frame4);
assertThat(channel.attr(REMOTE_ADDRESS_KEY).get()).isNull();
assertThat(channel.pipeline().get(ProxyProtocolHandler.class)).isNull();
assertThat(channel.isActive()).isTrue();
}
}

View file

@ -0,0 +1,125 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy.handler;
import static com.google.common.truth.Truth.assertThat;
import static google.registry.proxy.Protocol.PROTOCOL_KEY;
import static google.registry.proxy.handler.RelayHandler.RELAY_BUFFER_KEY;
import static google.registry.proxy.handler.RelayHandler.RELAY_CHANNEL_KEY;
import com.google.common.collect.ImmutableList;
import google.registry.proxy.Protocol;
import google.registry.proxy.Protocol.BackendProtocol;
import google.registry.proxy.Protocol.FrontendProtocol;
import io.netty.channel.embedded.EmbeddedChannel;
import java.util.ArrayDeque;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link RelayHandler}. */
@RunWith(JUnit4.class)
public class RelayHandlerTest {
private static final class ExpectedType {}
private static final class OtherType {}
private final RelayHandler<ExpectedType> relayHandler = new RelayHandler<>(ExpectedType.class);
private final EmbeddedChannel inboundChannel = new EmbeddedChannel(relayHandler);
private final EmbeddedChannel outboundChannel = new EmbeddedChannel();
private final FrontendProtocol frontendProtocol =
Protocol.frontendBuilder()
.port(0)
.name("FRONTEND")
.handlerProviders(ImmutableList.of())
.relayProtocol(
Protocol.backendBuilder()
.host("host.invalid")
.port(0)
.name("BACKEND")
.handlerProviders(ImmutableList.of())
.build())
.build();
private final BackendProtocol backendProtocol = frontendProtocol.relayProtocol();
@Before
public void setUp() {
inboundChannel.attr(RELAY_CHANNEL_KEY).set(outboundChannel);
inboundChannel.attr(RELAY_BUFFER_KEY).set(new ArrayDeque<>());
inboundChannel.attr(PROTOCOL_KEY).set(frontendProtocol);
outboundChannel.attr(PROTOCOL_KEY).set(backendProtocol);
}
@Test
public void testSuccess_relayInboundMessageOfExpectedType() {
ExpectedType inboundMessage = new ExpectedType();
// Relay handler intercepted the message, no further inbound message.
assertThat(inboundChannel.writeInbound(inboundMessage)).isFalse();
// Message wrote to outbound channel as-is.
ExpectedType relayedMessage = outboundChannel.readOutbound();
assertThat(relayedMessage).isEqualTo(inboundMessage);
}
@Test
public void testSuccess_ignoreInboundMessageOfOtherType() {
OtherType inboundMessage = new OtherType();
// Relay handler ignores inbound message of other types, the inbound message is passed along.
assertThat(inboundChannel.writeInbound(inboundMessage)).isTrue();
// Nothing is written into the outbound channel.
ExpectedType relayedMessage = outboundChannel.readOutbound();
assertThat(relayedMessage).isNull();
}
@Test
public void testSuccess_frontClosed() {
inboundChannel.attr(RELAY_BUFFER_KEY).set(null);
inboundChannel.attr(PROTOCOL_KEY).set(backendProtocol);
outboundChannel.attr(PROTOCOL_KEY).set(frontendProtocol);
ExpectedType inboundMessage = new ExpectedType();
// Outbound channel (frontend) is closed.
outboundChannel.finish();
assertThat(inboundChannel.writeInbound(inboundMessage)).isFalse();
ExpectedType relayedMessage = outboundChannel.readOutbound();
assertThat(relayedMessage).isNull();
// Inbound channel (backend) should stay open.
assertThat(inboundChannel.isActive()).isTrue();
assertThat(inboundChannel.attr(RELAY_BUFFER_KEY).get()).isNull();
}
@Test
public void testSuccess_backendClosed_enqueueBuffer() {
ExpectedType inboundMessage = new ExpectedType();
// Outbound channel (backend) is closed.
outboundChannel.finish();
assertThat(inboundChannel.writeInbound(inboundMessage)).isFalse();
ExpectedType relayedMessage = outboundChannel.readOutbound();
assertThat(relayedMessage).isNull();
// Inbound channel (frontend) should stay open.
assertThat(inboundChannel.isActive()).isTrue();
assertThat(inboundChannel.attr(RELAY_BUFFER_KEY).get()).containsExactly(inboundMessage);
}
@Test
public void testSuccess_channelRead_relayNotSet() {
ExpectedType inboundMessage = new ExpectedType();
inboundChannel.attr(RELAY_CHANNEL_KEY).set(null);
// Nothing to read.
assertThat(inboundChannel.writeInbound(inboundMessage)).isFalse();
// Inbound channel is closed.
assertThat(inboundChannel.isActive()).isFalse();
}
}

View file

@ -0,0 +1,206 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy.handler;
import static com.google.common.truth.Truth.assertThat;
import static google.registry.proxy.Protocol.PROTOCOL_KEY;
import static google.registry.proxy.handler.SslInitializerTestUtils.getKeyPair;
import static google.registry.proxy.handler.SslInitializerTestUtils.setUpSslChannel;
import static google.registry.proxy.handler.SslInitializerTestUtils.signKeyPair;
import com.google.common.collect.ImmutableList;
import google.registry.proxy.Protocol;
import google.registry.proxy.Protocol.BackendProtocol;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.handler.ssl.OpenSsl;
import io.netty.handler.ssl.SniHandler;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.SslProvider;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import java.security.KeyPair;
import java.security.PrivateKey;
import java.security.cert.CertPathBuilderException;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import javax.net.ssl.SSLException;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameter;
import org.junit.runners.Parameterized.Parameters;
/**
* Unit tests for {@link SslClientInitializer}.
*
* <p>To validate that the handler accepts & rejects connections as expected, a test server and a
* test client are spun up, and both connect to the {@link LocalAddress} within the JVM. This avoids
* the overhead of routing traffic through the network layer, even if it were to go through
* loopback. It also alleviates the need to pick a free port to use.
*
* <p>The local addresses used in each test method must to be different, otherwise tests run in
* parallel may interfere with each other.
*/
@RunWith(Parameterized.class)
public class SslClientInitializerTest {
/** Fake host to test if the SSL engine gets the correct peer host. */
private static final String SSL_HOST = "www.example.tld";
/** Fake port to test if the SSL engine gets the correct peer port. */
private static final int SSL_PORT = 12345;
@Rule
public NettyRule nettyRule = new NettyRule();
@Parameter(0)
public SslProvider sslProvider;
// We do our best effort to test all available SSL providers.
@Parameters(name = "{0}")
public static SslProvider[] data() {
return OpenSsl.isAvailable()
? new SslProvider[] {SslProvider.JDK, SslProvider.OPENSSL}
: new SslProvider[] {SslProvider.JDK};
}
/** Saves the SNI hostname received by the server, if sent by the client. */
private String sniHostReceived;
/** Fake protocol saved in channel attribute. */
private static final BackendProtocol PROTOCOL =
Protocol.backendBuilder()
.name("ssl")
.host(SSL_HOST)
.port(SSL_PORT)
.handlerProviders(ImmutableList.of())
.build();
private ChannelHandler getServerHandler(PrivateKey privateKey, X509Certificate certificate)
throws Exception {
SslContext sslContext = SslContextBuilder.forServer(privateKey, certificate).build();
return new SniHandler(
hostname -> {
sniHostReceived = hostname;
return sslContext;
});
}
@Test
public void testSuccess_swappedInitializerWithSslHandler() throws Exception {
SslClientInitializer<EmbeddedChannel> sslClientInitializer =
new SslClientInitializer<>(sslProvider);
EmbeddedChannel channel = new EmbeddedChannel();
channel.attr(PROTOCOL_KEY).set(PROTOCOL);
ChannelPipeline pipeline = channel.pipeline();
pipeline.addLast(sslClientInitializer);
ChannelHandler firstHandler = pipeline.first();
assertThat(firstHandler.getClass()).isEqualTo(SslHandler.class);
SslHandler sslHandler = (SslHandler) firstHandler;
assertThat(sslHandler.engine().getPeerHost()).isEqualTo(SSL_HOST);
assertThat(sslHandler.engine().getPeerPort()).isEqualTo(SSL_PORT);
assertThat(channel.isActive()).isTrue();
}
@Test
public void testSuccess_protocolAttributeNotSet() {
SslClientInitializer<EmbeddedChannel> sslClientInitializer =
new SslClientInitializer<>(sslProvider);
EmbeddedChannel channel = new EmbeddedChannel();
ChannelPipeline pipeline = channel.pipeline();
pipeline.addLast(sslClientInitializer);
// Channel initializer swallows error thrown, and closes the connection.
assertThat(channel.isActive()).isFalse();
}
@Test
public void testFailure_defaultTrustManager_rejectSelfSignedCert() throws Exception {
SelfSignedCertificate ssc = new SelfSignedCertificate(SSL_HOST);
LocalAddress localAddress =
new LocalAddress("DEFAULT_TRUST_MANAGER_REJECT_SELF_SIGNED_CERT_" + sslProvider);
nettyRule.setUpServer(localAddress, getServerHandler(ssc.key(), ssc.cert()));
SslClientInitializer<LocalChannel> sslClientInitializer =
new SslClientInitializer<>(sslProvider);
nettyRule.setUpClient(localAddress, PROTOCOL, sslClientInitializer);
// The connection is now terminated, both the client side and the server side should get
// exceptions.
nettyRule.assertThatClientRootCause().isInstanceOf(CertPathBuilderException.class);
nettyRule.assertThatServerRootCause().isInstanceOf(SSLException.class);
assertThat(nettyRule.getChannel().isActive()).isFalse();
}
@Test
public void testSuccess_customTrustManager_acceptCertSignedByTrustedCa() throws Exception {
LocalAddress localAddress =
new LocalAddress("CUSTOM_TRUST_MANAGER_ACCEPT_CERT_SIGNED_BY_TRUSTED_CA_" + sslProvider);
// Generate a new key pair.
KeyPair keyPair = getKeyPair();
// Generate a self signed certificate, and use it to sign the key pair.
SelfSignedCertificate ssc = new SelfSignedCertificate();
X509Certificate cert = signKeyPair(ssc, keyPair, SSL_HOST);
// Set up the server to use the signed cert and private key to perform handshake;
PrivateKey privateKey = keyPair.getPrivate();
nettyRule.setUpServer(localAddress, getServerHandler(privateKey, cert));
// Set up the client to trust the self signed cert used to sign the cert that server provides.
SslClientInitializer<LocalChannel> sslClientInitializer =
new SslClientInitializer<>(sslProvider, new X509Certificate[] {ssc.cert()});
nettyRule.setUpClient(localAddress, PROTOCOL, sslClientInitializer);
setUpSslChannel(nettyRule.getChannel(), cert);
nettyRule.assertThatMessagesWork();
// Verify that the SNI extension is sent during handshake.
assertThat(sniHostReceived).isEqualTo(SSL_HOST);
}
@Test
public void testFailure_customTrustManager_wrongHostnameInCertificate() throws Exception {
LocalAddress localAddress =
new LocalAddress("CUSTOM_TRUST_MANAGER_WRONG_HOSTNAME_" + sslProvider);
// Generate a new key pair.
KeyPair keyPair = getKeyPair();
// Generate a self signed certificate, and use it to sign the key pair.
SelfSignedCertificate ssc = new SelfSignedCertificate();
X509Certificate cert = signKeyPair(ssc, keyPair, "wrong.com");
// Set up the server to use the signed cert and private key to perform handshake;
PrivateKey privateKey = keyPair.getPrivate();
nettyRule.setUpServer(localAddress, getServerHandler(privateKey, cert));
// Set up the client to trust the self signed cert used to sign the cert that server provides.
SslClientInitializer<LocalChannel> sslClientInitializer =
new SslClientInitializer<>(sslProvider, new X509Certificate[] {ssc.cert()});
nettyRule.setUpClient(localAddress, PROTOCOL, sslClientInitializer);
// When the client rejects the server cert due to wrong hostname, both the client and server
// should throw exceptions.
nettyRule.assertThatClientRootCause().isInstanceOf(CertificateException.class);
nettyRule.assertThatClientRootCause().hasMessageThat().contains(SSL_HOST);
nettyRule.assertThatServerRootCause().isInstanceOf(SSLException.class);
assertThat(nettyRule.getChannel().isActive()).isFalse();
}
}

View file

@ -0,0 +1,95 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy.handler;
import static com.google.common.truth.Truth.assertThat;
import io.netty.channel.Channel;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import java.math.BigInteger;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.SecureRandom;
import java.security.Security;
import java.security.cert.X509Certificate;
import java.time.Duration;
import java.time.Instant;
import java.util.Date;
import javax.net.ssl.SSLSession;
import javax.security.auth.x500.X500Principal;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.bouncycastle.x509.X509V3CertificateGenerator;
/**
* Utility class that provides methods used by {@link SslClientInitializerTest} and {@link
* SslServerInitializerTest}.
*/
public class SslInitializerTestUtils {
static {
Security.addProvider(new BouncyCastleProvider());
}
public static KeyPair getKeyPair() throws Exception {
KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA", "BC");
keyPairGenerator.initialize(2048, new SecureRandom());
return keyPairGenerator.generateKeyPair();
}
/**
* Signs the given key pair with the given self signed certificate.
*
* @return signed public key (of the key pair) certificate
*/
public static X509Certificate signKeyPair(
SelfSignedCertificate ssc, KeyPair keyPair, String hostname) throws Exception {
X509V3CertificateGenerator certGen = new X509V3CertificateGenerator();
X500Principal dnName = new X500Principal("CN=" + hostname);
certGen.setSerialNumber(BigInteger.valueOf(System.currentTimeMillis()));
certGen.setSubjectDN(dnName);
certGen.setIssuerDN(ssc.cert().getSubjectX500Principal());
certGen.setNotBefore(Date.from(Instant.now().minus(Duration.ofDays(1))));
certGen.setNotAfter(Date.from(Instant.now().plus(Duration.ofDays(1))));
certGen.setPublicKey(keyPair.getPublic());
certGen.setSignatureAlgorithm("SHA256WithRSAEncryption");
return certGen.generate(ssc.key(), "BC");
}
/**
* Verifies tha the SSL channel is established as expected, and also sends a message to the server
* and verifies if it is echoed back correctly.
*
* @param certs The certificate that the server should provide.
* @return The SSL session in current channel, can be used for further validation.
*/
static SSLSession setUpSslChannel(
Channel channel,
X509Certificate... certs)
throws Exception {
SslHandler sslHandler = channel.pipeline().get(SslHandler.class);
// Wait till the handshake is complete.
sslHandler.handshakeFuture().get();
assertThat(channel.isActive()).isTrue();
assertThat(sslHandler.handshakeFuture().isSuccess()).isTrue();
assertThat(sslHandler.engine().getSession().isValid()).isTrue();
assertThat(sslHandler.engine().getSession().getPeerCertificates())
.asList()
.containsExactlyElementsIn(certs);
// Returns the SSL session for further assertion.
return sslHandler.engine().getSession();
}
}

View file

@ -0,0 +1,270 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy.handler;
import static com.google.common.truth.Truth.assertThat;
import static google.registry.proxy.handler.SslInitializerTestUtils.getKeyPair;
import static google.registry.proxy.handler.SslInitializerTestUtils.setUpSslChannel;
import static google.registry.proxy.handler.SslInitializerTestUtils.signKeyPair;
import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableList;
import google.registry.proxy.Protocol;
import google.registry.proxy.Protocol.BackendProtocol;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.handler.ssl.OpenSsl;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.SslProvider;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import java.security.KeyPair;
import java.security.PrivateKey;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLHandshakeException;
import javax.net.ssl.SSLParameters;
import javax.net.ssl.SSLSession;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameter;
import org.junit.runners.Parameterized.Parameters;
/**
* Unit tests for {@link SslServerInitializer}.
*
* <p>To validate that the handler accepts & rejects connections as expected, a test server and a
* test client are spun up, and both connect to the {@link LocalAddress} within the JVM. This avoids
* the overhead of routing traffic through the network layer, even if it were to go through
* loopback. It also alleviates the need to pick a free port to use.
*
* <p>The local addresses used in each test method must to be different, otherwise tests run in
* parallel may interfere with each other.
*/
@RunWith(Parameterized.class)
public class SslServerInitializerTest {
/** Fake host to test if the SSL engine gets the correct peer host. */
private static final String SSL_HOST = "www.example.tld";
/** Fake port to test if the SSL engine gets the correct peer port. */
private static final int SSL_PORT = 12345;
/** Fake protocol saved in channel attribute. */
private static final BackendProtocol PROTOCOL =
Protocol.backendBuilder()
.name("ssl")
.host(SSL_HOST)
.port(SSL_PORT)
.handlerProviders(ImmutableList.of())
.build();
@Rule
public NettyRule nettyRule = new NettyRule();
@Parameter(0)
public SslProvider sslProvider;
// We do our best effort to test all available SSL providers.
@Parameters(name = "{0}")
public static SslProvider[] data() {
return OpenSsl.isAvailable()
? new SslProvider[] {SslProvider.OPENSSL, SslProvider.JDK}
: new SslProvider[] {SslProvider.JDK};
}
private ChannelHandler getServerHandler(
boolean requireClientCert, PrivateKey privateKey, X509Certificate... certificates) {
return new SslServerInitializer<LocalChannel>(
requireClientCert,
sslProvider,
Suppliers.ofInstance(privateKey),
Suppliers.ofInstance(certificates));
}
private ChannelHandler getServerHandler(PrivateKey privateKey, X509Certificate... certificates) {
return getServerHandler(true, privateKey, certificates);
}
private ChannelHandler getClientHandler(
X509Certificate trustedCertificate,
PrivateKey privateKey,
X509Certificate certificate) {
return new ChannelInitializer<LocalChannel>() {
@Override
protected void initChannel(LocalChannel ch) throws Exception {
SslContextBuilder sslContextBuilder =
SslContextBuilder.forClient().trustManager(trustedCertificate).sslProvider(sslProvider);
if (privateKey != null && certificate != null) {
sslContextBuilder.keyManager(privateKey, certificate);
}
SslHandler sslHandler = sslContextBuilder.build().newHandler(ch.alloc(), SSL_HOST, SSL_PORT);
// Enable hostname verification.
SSLEngine sslEngine = sslHandler.engine();
SSLParameters sslParameters = sslEngine.getSSLParameters();
sslParameters.setEndpointIdentificationAlgorithm("HTTPS");
sslEngine.setSSLParameters(sslParameters);
ch.pipeline().addLast(sslHandler);
}
};
}
@Test
public void testSuccess_swappedInitializerWithSslHandler() throws Exception {
SelfSignedCertificate ssc = new SelfSignedCertificate(SSL_HOST);
SslServerInitializer<EmbeddedChannel> sslServerInitializer =
new SslServerInitializer<>(
true,
sslProvider,
Suppliers.ofInstance(ssc.key()),
Suppliers.ofInstance(new X509Certificate[] {ssc.cert()}));
EmbeddedChannel channel = new EmbeddedChannel();
ChannelPipeline pipeline = channel.pipeline();
pipeline.addLast(sslServerInitializer);
ChannelHandler firstHandler = pipeline.first();
assertThat(firstHandler.getClass()).isEqualTo(SslHandler.class);
SslHandler sslHandler = (SslHandler) firstHandler;
assertThat(sslHandler.engine().getNeedClientAuth()).isTrue();
assertThat(channel.isActive()).isTrue();
}
@Test
public void testSuccess_trustAnyClientCert() throws Exception {
SelfSignedCertificate serverSsc = new SelfSignedCertificate(SSL_HOST);
LocalAddress localAddress = new LocalAddress("TRUST_ANY_CLIENT_CERT_" + sslProvider);
nettyRule.setUpServer(localAddress, getServerHandler(serverSsc.key(), serverSsc.cert()));
SelfSignedCertificate clientSsc = new SelfSignedCertificate();
nettyRule.setUpClient(
localAddress,
PROTOCOL,
getClientHandler(serverSsc.cert(), clientSsc.key(), clientSsc.cert()));
SSLSession sslSession = setUpSslChannel(nettyRule.getChannel(), serverSsc.cert());
nettyRule.assertThatMessagesWork();
// Verify that the SSL session gets the client cert. Note that this SslSession is for the client
// channel, therefore its local certificates are the remote certificates of the SslSession for
// the server channel, and vice versa.
assertThat(sslSession.getLocalCertificates()).asList().containsExactly(clientSsc.cert());
assertThat(sslSession.getPeerCertificates()).asList().containsExactly(serverSsc.cert());
}
@Test
public void testSuccess_doesNotRequireClientCert() throws Exception {
SelfSignedCertificate serverSsc = new SelfSignedCertificate(SSL_HOST);
LocalAddress localAddress = new LocalAddress("DOES_NOT_REQUIRE_CLIENT_CERT_" + sslProvider);
nettyRule.setUpServer(
localAddress,
getServerHandler(false, serverSsc.key(), serverSsc.cert()));
nettyRule.setUpClient(
localAddress, PROTOCOL, getClientHandler(serverSsc.cert(), null, null));
SSLSession sslSession = setUpSslChannel(nettyRule.getChannel(), serverSsc.cert());
nettyRule.assertThatMessagesWork();
// Verify that the SSL session does not contain any client cert. Note that this SslSession is
// for the client channel, therefore its local certificates are the remote certificates of the
// SslSession for the server channel, and vice versa.
assertThat(sslSession.getLocalCertificates()).isNull();
assertThat(sslSession.getPeerCertificates()).asList().containsExactly(serverSsc.cert());
}
@Test
public void testSuccess_CertSignedByOtherCA() throws Exception {
// The self-signed cert of the CA.
SelfSignedCertificate caSsc = new SelfSignedCertificate();
KeyPair keyPair = getKeyPair();
X509Certificate serverCert = signKeyPair(caSsc, keyPair, SSL_HOST);
LocalAddress localAddress = new LocalAddress("CERT_SIGNED_BY_OTHER_CA_" + sslProvider);
nettyRule.setUpServer(
localAddress,
getServerHandler(
keyPair.getPrivate(),
// Serving both the server cert, and the CA cert
serverCert,
caSsc.cert()));
SelfSignedCertificate clientSsc = new SelfSignedCertificate();
nettyRule.setUpClient(
localAddress,
PROTOCOL,
getClientHandler(
// Client trusts the CA cert
caSsc.cert(), clientSsc.key(), clientSsc.cert()));
SSLSession sslSession = setUpSslChannel(nettyRule.getChannel(), serverCert, caSsc.cert());
nettyRule.assertThatMessagesWork();
assertThat(sslSession.getLocalCertificates()).asList().containsExactly(clientSsc.cert());
assertThat(sslSession.getPeerCertificates())
.asList()
.containsExactly(serverCert, caSsc.cert())
.inOrder();
}
@Test
public void testFailure_requireClientCertificate() throws Exception {
SelfSignedCertificate serverSsc = new SelfSignedCertificate(SSL_HOST);
LocalAddress localAddress = new LocalAddress("REQUIRE_CLIENT_CERT_" + sslProvider);
nettyRule.setUpServer(localAddress, getServerHandler(serverSsc.key(), serverSsc.cert()));
nettyRule.setUpClient(
localAddress,
PROTOCOL,
getClientHandler(
serverSsc.cert(),
// No client cert/private key used.
null,
null));
// When the server rejects the client during handshake due to lack of client certificate, both
// should throw exceptions.
nettyRule.assertThatServerRootCause().isInstanceOf(SSLHandshakeException.class);
nettyRule.assertThatClientRootCause().isInstanceOf(SSLException.class);
assertThat(nettyRule.getChannel().isActive()).isFalse();
}
@Test
public void testFailure_wrongHostnameInCertificate() throws Exception {
SelfSignedCertificate serverSsc = new SelfSignedCertificate("wrong.com");
LocalAddress localAddress = new LocalAddress("WRONG_HOSTNAME_" + sslProvider);
nettyRule.setUpServer(localAddress, getServerHandler(serverSsc.key(), serverSsc.cert()));
SelfSignedCertificate clientSsc = new SelfSignedCertificate();
nettyRule.setUpClient(
localAddress,
PROTOCOL,
getClientHandler(serverSsc.cert(), clientSsc.key(), clientSsc.cert()));
// When the client rejects the server cert due to wrong hostname, both the server and the client
// throw exceptions.
nettyRule.assertThatClientRootCause().isInstanceOf(CertificateException.class);
nettyRule.assertThatClientRootCause().hasMessageThat().contains(SSL_HOST);
nettyRule.assertThatServerRootCause().isInstanceOf(SSLException.class);
assertThat(nettyRule.getChannel().isActive()).isFalse();
}
}

View file

@ -0,0 +1,233 @@
// Copyright 2018 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy.handler;
import static com.google.common.truth.Truth.assertThat;
import static google.registry.proxy.TestUtils.assertHttpResponseEquivalent;
import static google.registry.proxy.TestUtils.makeHttpGetRequest;
import static google.registry.proxy.TestUtils.makeHttpPostRequest;
import static google.registry.proxy.TestUtils.makeHttpResponse;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link WebWhoisRedirectHandler}. */
@RunWith(JUnit4.class)
public class WebWhoisRedirectHandlerTest {
private static final String REDIRECT_HOST = "www.example.com";
private static final String TARGET_HOST = "whois.nic.tld";
private EmbeddedChannel channel;
private FullHttpRequest request;
private FullHttpResponse response;
private void setupChannel(boolean isHttps) {
channel = new EmbeddedChannel(new WebWhoisRedirectHandler(isHttps, REDIRECT_HOST));
}
private static FullHttpResponse makeRedirectResponse(
HttpResponseStatus status, String location, boolean keepAlive, boolean hsts) {
FullHttpResponse response = makeHttpResponse("", status);
response.headers().set("content-type", "text/plain").set("content-length", "0");
if (location != null) {
response.headers().set("location", location);
}
if (keepAlive) {
response.headers().set("connection", "keep-alive");
}
if (hsts) {
response.headers().set("Strict-Transport-Security", "max-age=31536000");
}
return response;
}
// HTTP redirect tests.
@Test
public void testSuccess_http_methodNotAllowed() {
setupChannel(false);
request = makeHttpPostRequest("", TARGET_HOST, "/");
// No inbound message passed to the next handler.
assertThat(channel.writeInbound(request)).isFalse();
response = channel.readOutbound();
assertHttpResponseEquivalent(
response, makeRedirectResponse(HttpResponseStatus.METHOD_NOT_ALLOWED, null, true, false));
assertThat(channel.isActive()).isTrue();
}
@Test
public void testSuccess_http_badHost() {
setupChannel(false);
request = makeHttpGetRequest("", "/");
// No inbound message passed to the next handler.
assertThat(channel.writeInbound(request)).isFalse();
response = channel.readOutbound();
assertHttpResponseEquivalent(
response, makeRedirectResponse(HttpResponseStatus.BAD_REQUEST, null, true, false));
assertThat(channel.isActive()).isTrue();
}
@Test
public void testSuccess_http_noHost() {
setupChannel(false);
request = makeHttpGetRequest("", "/");
request.headers().remove("host");
// No inbound message passed to the next handler.
assertThat(channel.writeInbound(request)).isFalse();
response = channel.readOutbound();
assertHttpResponseEquivalent(
response, makeRedirectResponse(HttpResponseStatus.BAD_REQUEST, null, true, false));
assertThat(channel.isActive()).isTrue();
}
@Test
public void testSuccess_http_healthCheck() {
setupChannel(false);
request = makeHttpPostRequest("", TARGET_HOST, "/");
// No inbound message passed to the next handler.
assertThat(channel.writeInbound(request)).isFalse();
response = channel.readOutbound();
assertHttpResponseEquivalent(
response, makeRedirectResponse(HttpResponseStatus.METHOD_NOT_ALLOWED, null, true, false));
assertThat(channel.isActive()).isTrue();
}
@Test
public void testSuccess_http_redirectToHttps() {
setupChannel(false);
request = makeHttpGetRequest(TARGET_HOST, "/");
// No inbound message passed to the next handler.
assertThat(channel.writeInbound(request)).isFalse();
response = channel.readOutbound();
assertHttpResponseEquivalent(
response,
makeRedirectResponse(
HttpResponseStatus.MOVED_PERMANENTLY, "https://whois.nic.tld/", true, false));
assertThat(channel.isActive()).isTrue();
}
@Test
public void testSuccess_http_redirectToHttps_hostAndPort() {
setupChannel(false);
request = makeHttpGetRequest(TARGET_HOST + ":80", "/");
// No inbound message passed to the next handler.
assertThat(channel.writeInbound(request)).isFalse();
response = channel.readOutbound();
assertHttpResponseEquivalent(
response,
makeRedirectResponse(
HttpResponseStatus.MOVED_PERMANENTLY, "https://whois.nic.tld/", true, false));
assertThat(channel.isActive()).isTrue();
}
@Test
public void testSuccess_http_redirectToHttps_noKeepAlive() {
setupChannel(false);
request = makeHttpGetRequest(TARGET_HOST, "/");
request.headers().set("connection", "close");
// No inbound message passed to the next handler.
assertThat(channel.writeInbound(request)).isFalse();
response = channel.readOutbound();
assertHttpResponseEquivalent(
response,
makeRedirectResponse(
HttpResponseStatus.MOVED_PERMANENTLY, "https://whois.nic.tld/", false, false));
assertThat(channel.isActive()).isFalse();
}
// HTTPS redirect tests.
@Test
public void testSuccess_https_methodNotAllowed() {
setupChannel(true);
request = makeHttpPostRequest("", TARGET_HOST, "/");
// No inbound message passed to the next handler.
assertThat(channel.writeInbound(request)).isFalse();
response = channel.readOutbound();
assertHttpResponseEquivalent(
response, makeRedirectResponse(HttpResponseStatus.METHOD_NOT_ALLOWED, null, true, false));
assertThat(channel.isActive()).isTrue();
}
@Test
public void testSuccess_https_badHost() {
setupChannel(true);
request = makeHttpGetRequest("", "/");
// No inbound message passed to the next handler.
assertThat(channel.writeInbound(request)).isFalse();
response = channel.readOutbound();
assertHttpResponseEquivalent(
response, makeRedirectResponse(HttpResponseStatus.BAD_REQUEST, null, true, false));
assertThat(channel.isActive()).isTrue();
}
@Test
public void testSuccess_https_noHost() {
setupChannel(true);
request = makeHttpGetRequest("", "/");
request.headers().remove("host");
// No inbound message passed to the next handler.
assertThat(channel.writeInbound(request)).isFalse();
response = channel.readOutbound();
assertHttpResponseEquivalent(
response, makeRedirectResponse(HttpResponseStatus.BAD_REQUEST, null, true, false));
assertThat(channel.isActive()).isTrue();
}
@Test
public void testSuccess_https_healthCheck() {
setupChannel(true);
request = makeHttpGetRequest("health-check.invalid", "/");
// No inbound message passed to the next handler.
assertThat(channel.writeInbound(request)).isFalse();
response = channel.readOutbound();
assertHttpResponseEquivalent(
response, makeRedirectResponse(HttpResponseStatus.FORBIDDEN, null, true, false));
assertThat(channel.isActive()).isTrue();
}
@Test
public void testSuccess_https_redirectToDestination() {
setupChannel(true);
request = makeHttpGetRequest(TARGET_HOST, "/");
// No inbound message passed to the next handler.
assertThat(channel.writeInbound(request)).isFalse();
response = channel.readOutbound();
assertHttpResponseEquivalent(
response,
makeRedirectResponse(HttpResponseStatus.FOUND, "https://www.example.com/", true, true));
assertThat(channel.isActive()).isTrue();
}
@Test
public void testSuccess_https_redirectToDestination_noKeepAlive() {
setupChannel(true);
request = makeHttpGetRequest(TARGET_HOST, "/");
request.headers().set("connection", "close");
// No inbound message passed to the next handler.
assertThat(channel.writeInbound(request)).isFalse();
response = channel.readOutbound();
assertHttpResponseEquivalent(
response,
makeRedirectResponse(HttpResponseStatus.FOUND, "https://www.example.com/", false, true));
assertThat(channel.isActive()).isFalse();
}
}

View file

@ -0,0 +1,179 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy.handler;
import static com.google.common.truth.Truth.assertThat;
import static google.registry.proxy.Protocol.PROTOCOL_KEY;
import static google.registry.proxy.handler.ProxyProtocolHandler.REMOTE_ADDRESS_KEY;
import static google.registry.testing.JUnitBackports.assertThrows;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableList;
import google.registry.proxy.Protocol;
import google.registry.proxy.handler.QuotaHandler.OverQuotaException;
import google.registry.proxy.handler.QuotaHandler.WhoisQuotaHandler;
import google.registry.proxy.metric.FrontendMetrics;
import google.registry.proxy.quota.QuotaManager;
import google.registry.proxy.quota.QuotaManager.QuotaRequest;
import google.registry.proxy.quota.QuotaManager.QuotaResponse;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.embedded.EmbeddedChannel;
import org.joda.time.DateTime;
import org.joda.time.DateTimeZone;
import org.joda.time.Duration;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link WhoisQuotaHandler} */
@RunWith(JUnit4.class)
public class WhoisQuotaHandlerTest {
private final QuotaManager quotaManager = mock(QuotaManager.class);
private final FrontendMetrics metrics = mock(FrontendMetrics.class);
private final WhoisQuotaHandler handler = new WhoisQuotaHandler(quotaManager, metrics);
private final EmbeddedChannel channel = new EmbeddedChannel(handler);
private final DateTime now = DateTime.now(DateTimeZone.UTC);
private final String remoteAddress = "127.0.0.1";
private final Object message = new Object();
private void setProtocol(Channel channel) {
channel
.attr(PROTOCOL_KEY)
.set(
Protocol.frontendBuilder()
.name("whois")
.port(12345)
.handlerProviders(ImmutableList.of())
.relayProtocol(
Protocol.backendBuilder()
.name("backend")
.host("host.tld")
.port(1234)
.handlerProviders(ImmutableList.of())
.build())
.build());
}
@Before
public void setUp() {
channel.attr(REMOTE_ADDRESS_KEY).set(remoteAddress);
setProtocol(channel);
}
@Test
public void testSuccess_quotaGranted() {
when(quotaManager.acquireQuota(QuotaRequest.create(remoteAddress)))
.thenReturn(QuotaResponse.create(true, remoteAddress, now));
// First read, acquire quota.
assertThat(channel.writeInbound(message)).isTrue();
assertThat((Object) channel.readInbound()).isEqualTo(message);
assertThat(channel.isActive()).isTrue();
verify(quotaManager).acquireQuota(QuotaRequest.create(remoteAddress));
// Second read, should not acquire quota again.
assertThat(channel.writeInbound(message)).isTrue();
assertThat((Object) channel.readInbound()).isEqualTo(message);
// Channel closed, release quota.
ChannelFuture unusedFuture = channel.close();
verifyNoMoreInteractions(quotaManager);
}
@Test
public void testFailure_quotaNotGranted() {
when(quotaManager.acquireQuota(QuotaRequest.create(remoteAddress)))
.thenReturn(QuotaResponse.create(false, remoteAddress, now));
OverQuotaException e =
assertThrows(OverQuotaException.class, () -> channel.writeInbound(message));
assertThat(e).hasMessageThat().contains("none");
verify(metrics).registerQuotaRejection("whois", "none");
verifyNoMoreInteractions(metrics);
}
@Test
public void testSuccess_twoChannels_twoUserIds() {
// Set up another user.
final WhoisQuotaHandler otherHandler = new WhoisQuotaHandler(quotaManager, metrics);
final EmbeddedChannel otherChannel = new EmbeddedChannel(otherHandler);
final String otherRemoteAddress = "192.168.0.1";
otherChannel.attr(REMOTE_ADDRESS_KEY).set(otherRemoteAddress);
setProtocol(otherChannel);
final DateTime later = now.plus(Duration.standardSeconds(1));
when(quotaManager.acquireQuota(QuotaRequest.create(remoteAddress)))
.thenReturn(QuotaResponse.create(true, remoteAddress, now));
when(quotaManager.acquireQuota(QuotaRequest.create(otherRemoteAddress)))
.thenReturn(QuotaResponse.create(false, otherRemoteAddress, later));
// Allows the first user.
assertThat(channel.writeInbound(message)).isTrue();
assertThat((Object) channel.readInbound()).isEqualTo(message);
assertThat(channel.isActive()).isTrue();
// Blocks the second user.
OverQuotaException e =
assertThrows(OverQuotaException.class, () -> otherChannel.writeInbound(message));
assertThat(e).hasMessageThat().contains("none");
verify(metrics).registerQuotaRejection("whois", "none");
verifyNoMoreInteractions(metrics);
}
@Test
public void testSuccess_oneUser_rateLimited() {
// Set up another channel for the same user.
final WhoisQuotaHandler otherHandler = new WhoisQuotaHandler(quotaManager, metrics);
final EmbeddedChannel otherChannel = new EmbeddedChannel(otherHandler);
otherChannel.attr(REMOTE_ADDRESS_KEY).set(remoteAddress);
setProtocol(otherChannel);
final DateTime later = now.plus(Duration.standardSeconds(1));
// Set up the third channel for the same user
final WhoisQuotaHandler thirdHandler = new WhoisQuotaHandler(quotaManager, metrics);
final EmbeddedChannel thirdChannel = new EmbeddedChannel(thirdHandler);
thirdChannel.attr(REMOTE_ADDRESS_KEY).set(remoteAddress);
final DateTime evenLater = now.plus(Duration.standardSeconds(60));
when(quotaManager.acquireQuota(QuotaRequest.create(remoteAddress)))
.thenReturn(QuotaResponse.create(true, remoteAddress, now))
// Throttles the second connection.
.thenReturn(QuotaResponse.create(false, remoteAddress, later))
// Allows the third connection because token refilled.
.thenReturn(QuotaResponse.create(true, remoteAddress, evenLater));
// Allows the first channel.
assertThat(channel.writeInbound(message)).isTrue();
assertThat((Object) channel.readInbound()).isEqualTo(message);
assertThat(channel.isActive()).isTrue();
// Blocks the second channel.
OverQuotaException e =
assertThrows(OverQuotaException.class, () -> otherChannel.writeInbound(message));
assertThat(e).hasMessageThat().contains("none");
verify(metrics).registerQuotaRejection("whois", "none");
// Allows the third channel.
assertThat(thirdChannel.writeInbound(message)).isTrue();
assertThat((Object) thirdChannel.readInbound()).isEqualTo(message);
assertThat(thirdChannel.isActive()).isTrue();
verifyNoMoreInteractions(metrics);
}
}

View file

@ -0,0 +1,129 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy.handler;
import static com.google.common.truth.Truth.assertThat;
import static google.registry.proxy.TestUtils.makeWhoisHttpRequest;
import static google.registry.proxy.TestUtils.makeWhoisHttpResponse;
import static google.registry.testing.JUnitBackports.assertThrows;
import static java.nio.charset.StandardCharsets.US_ASCII;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import com.google.common.base.Throwables;
import google.registry.proxy.handler.HttpsRelayServiceHandler.NonOkHttpResponseException;
import google.registry.proxy.metric.FrontendMetrics;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.DefaultChannelId;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.EncoderException;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link WhoisServiceHandler}. */
@RunWith(JUnit4.class)
public class WhoisServiceHandlerTest {
private static final String RELAY_HOST = "www.example.tld";
private static final String RELAY_PATH = "/test";
private static final String QUERY_CONTENT = "test.tld";
private static final String ACCESS_TOKEN = "this.access.token";
private static final String PROTOCOL = "whois";
private static final String CLIENT_HASH = "none";
private final FrontendMetrics metrics = mock(FrontendMetrics.class);
private final WhoisServiceHandler whoisServiceHandler =
new WhoisServiceHandler(RELAY_HOST, RELAY_PATH, () -> ACCESS_TOKEN, metrics);
private EmbeddedChannel channel;
@Before
public void setUp() {
// Need to reset metrics for each test method, since they are static fields on the class and
// shared between each run.
channel = new EmbeddedChannel(whoisServiceHandler);
}
@Test
public void testSuccess_connectionMetrics_oneChannel() {
assertThat(channel.isActive()).isTrue();
verify(metrics).registerActiveConnection(PROTOCOL, CLIENT_HASH, channel);
verifyNoMoreInteractions(metrics);
}
@Test
public void testSuccess_ConnectionMetrics_twoConnections() {
assertThat(channel.isActive()).isTrue();
verify(metrics).registerActiveConnection(PROTOCOL, CLIENT_HASH, channel);
// Setup second channel.
WhoisServiceHandler whoisServiceHandler2 =
new WhoisServiceHandler(RELAY_HOST, RELAY_PATH, () -> ACCESS_TOKEN, metrics);
EmbeddedChannel channel2 =
// We need a new channel id so that it has a different hash code.
// This only is needed for EmbeddedChannel because it has a dummy hash code implementation.
new EmbeddedChannel(DefaultChannelId.newInstance(), whoisServiceHandler2);
assertThat(channel2.isActive()).isTrue();
verify(metrics).registerActiveConnection(PROTOCOL, CLIENT_HASH, channel2);
verifyNoMoreInteractions(metrics);
}
@Test
public void testSuccess_fireInboundHttpRequest() {
ByteBuf inputBuffer = Unpooled.wrappedBuffer(QUERY_CONTENT.getBytes(US_ASCII));
FullHttpRequest expectedRequest =
makeWhoisHttpRequest(QUERY_CONTENT, RELAY_HOST, RELAY_PATH, ACCESS_TOKEN);
// Input data passed to next handler
assertThat(channel.writeInbound(inputBuffer)).isTrue();
FullHttpRequest inputRequest = channel.readInbound();
assertThat(inputRequest).isEqualTo(expectedRequest);
// The channel is still open, and nothing else is to be read from it.
assertThat((Object) channel.readInbound()).isNull();
assertThat(channel.isActive()).isTrue();
}
@Test
public void testSuccess_parseOutboundHttpResponse() {
String outputString = "line1\r\nline2\r\n";
FullHttpResponse outputResponse = makeWhoisHttpResponse(outputString, HttpResponseStatus.OK);
// output data passed to next handler
assertThat(channel.writeOutbound(outputResponse)).isTrue();
ByteBuf parsedBuffer = channel.readOutbound();
assertThat(parsedBuffer.toString(US_ASCII)).isEqualTo(outputString);
// The channel is still open, and nothing else is to be written to it.
assertThat((Object) channel.readOutbound()).isNull();
assertThat(channel.isActive()).isFalse();
}
@Test
public void testFailure_OutboundHttpResponseNotOK() {
String outputString = "line1\r\nline2\r\n";
FullHttpResponse outputResponse =
makeWhoisHttpResponse(outputString, HttpResponseStatus.BAD_REQUEST);
EncoderException thrown =
assertThrows(EncoderException.class, () -> channel.writeOutbound(outputResponse));
assertThat(Throwables.getRootCause(thrown)).isInstanceOf(NonOkHttpResponseException.class);
assertThat(thrown).hasMessageThat().contains("400 Bad Request");
assertThat((Object) channel.readOutbound()).isNull();
assertThat(channel.isActive()).isFalse();
}
}

View file

@ -0,0 +1,173 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy.metric;
import static com.google.monitoring.metrics.contrib.DistributionMetricSubject.assertThat;
import static com.google.monitoring.metrics.contrib.LongMetricSubject.assertThat;
import static google.registry.proxy.TestUtils.makeHttpPostRequest;
import static google.registry.proxy.TestUtils.makeHttpResponse;
import com.google.common.collect.ImmutableSet;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link BackendMetrics}. */
@RunWith(JUnit4.class)
public class BackendMetricsTest {
private final String host = "host.tld";
private final String certHash = "blah12345";
private final String protocol = "frontend protocol";
private final BackendMetrics metrics = new BackendMetrics();
@Before
public void setUp() {
metrics.resetMetric();
}
@Test
public void testSuccess_oneRequest() {
String content = "some content";
FullHttpRequest request = makeHttpPostRequest(content, host, "/");
metrics.requestSent(protocol, certHash, request.content().readableBytes());
assertThat(BackendMetrics.requestsCounter)
.hasValueForLabels(1, protocol, certHash)
.and()
.hasNoOtherValues();
assertThat(BackendMetrics.requestBytes)
.hasDataSetForLabels(ImmutableSet.of(content.length()), protocol, certHash)
.and()
.hasNoOtherValues();
assertThat(BackendMetrics.responsesCounter).hasNoOtherValues();
assertThat(BackendMetrics.responseBytes).hasNoOtherValues();
assertThat(BackendMetrics.latencyMs).hasNoOtherValues();
}
@Test
public void testSuccess_multipleRequests() {
String content1 = "some content";
String content2 = "some other content";
FullHttpRequest request1 = makeHttpPostRequest(content1, host, "/");
FullHttpRequest request2 = makeHttpPostRequest(content2, host, "/");
metrics.requestSent(protocol, certHash, request1.content().readableBytes());
metrics.requestSent(protocol, certHash, request2.content().readableBytes());
assertThat(BackendMetrics.requestsCounter)
.hasValueForLabels(2, protocol, certHash)
.and()
.hasNoOtherValues();
assertThat(BackendMetrics.requestBytes)
.hasDataSetForLabels(
ImmutableSet.of(content1.length(), content2.length()), protocol, certHash)
.and()
.hasNoOtherValues();
assertThat(BackendMetrics.responsesCounter).hasNoOtherValues();
assertThat(BackendMetrics.responseBytes).hasNoOtherValues();
assertThat(BackendMetrics.latencyMs).hasNoOtherValues();
}
@Test
public void testSuccess_oneResponse() {
String content = "some response";
FullHttpResponse response = makeHttpResponse(content, HttpResponseStatus.OK);
metrics.responseReceived(protocol, certHash, response, 5);
assertThat(BackendMetrics.requestsCounter).hasNoOtherValues();
assertThat(BackendMetrics.requestBytes).hasNoOtherValues();
assertThat(BackendMetrics.responsesCounter)
.hasValueForLabels(1, protocol, certHash, "200 OK")
.and()
.hasNoOtherValues();
assertThat(BackendMetrics.responseBytes)
.hasDataSetForLabels(ImmutableSet.of(content.length()), protocol, certHash)
.and()
.hasNoOtherValues();
assertThat(BackendMetrics.latencyMs)
.hasDataSetForLabels(ImmutableSet.of(5), protocol, certHash)
.and()
.hasNoOtherValues();
}
@Test
public void testSuccess_multipleResponses() {
String content1 = "some response";
String content2 = "other response";
String content3 = "a very bad response";
FullHttpResponse response1 = makeHttpResponse(content1, HttpResponseStatus.OK);
FullHttpResponse response2 = makeHttpResponse(content2, HttpResponseStatus.OK);
FullHttpResponse response3 = makeHttpResponse(content3, HttpResponseStatus.BAD_REQUEST);
metrics.responseReceived(protocol, certHash, response1, 5);
metrics.responseReceived(protocol, certHash, response2, 8);
metrics.responseReceived(protocol, certHash, response3, 2);
assertThat(BackendMetrics.requestsCounter).hasNoOtherValues();
assertThat(BackendMetrics.requestBytes).hasNoOtherValues();
assertThat(BackendMetrics.responsesCounter)
.hasValueForLabels(2, protocol, certHash, "200 OK")
.and()
.hasValueForLabels(1, protocol, certHash, "400 Bad Request")
.and()
.hasNoOtherValues();
assertThat(BackendMetrics.responseBytes)
.hasDataSetForLabels(
ImmutableSet.of(content1.length(), content2.length(), content3.length()),
protocol,
certHash)
.and()
.hasNoOtherValues();
assertThat(BackendMetrics.latencyMs)
.hasDataSetForLabels(ImmutableSet.of(5, 8, 2), protocol, certHash)
.and()
.hasNoOtherValues();
}
@Test
public void testSuccess_oneRequest_oneResponse() {
String requestContent = "some request";
String responseContent = "the only response";
FullHttpRequest request = makeHttpPostRequest(requestContent, host, "/");
FullHttpResponse response = makeHttpResponse(responseContent, HttpResponseStatus.OK);
metrics.requestSent(protocol, certHash, request.content().readableBytes());
metrics.responseReceived(protocol, certHash, response, 10);
assertThat(BackendMetrics.requestsCounter)
.hasValueForLabels(1, protocol, certHash)
.and()
.hasNoOtherValues();
assertThat(BackendMetrics.responsesCounter)
.hasValueForLabels(1, protocol, certHash, "200 OK")
.and()
.hasNoOtherValues();
assertThat(BackendMetrics.requestBytes)
.hasDataSetForLabels(ImmutableSet.of(requestContent.length()), protocol, certHash)
.and()
.hasNoOtherValues();
assertThat(BackendMetrics.responseBytes)
.hasDataSetForLabels(ImmutableSet.of(responseContent.length()), protocol, certHash)
.and()
.hasNoOtherValues();
assertThat(BackendMetrics.latencyMs)
.hasDataSetForLabels(ImmutableSet.of(10), protocol, certHash)
.and()
.hasNoOtherValues();
}
}

View file

@ -0,0 +1,187 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy.metric;
import static com.google.common.truth.Truth.assertThat;
import static com.google.monitoring.metrics.contrib.LongMetricSubject.assertThat;
import io.netty.channel.ChannelFuture;
import io.netty.channel.DefaultChannelId;
import io.netty.channel.embedded.EmbeddedChannel;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link FrontendMetrics}. */
@RunWith(JUnit4.class)
public class FrontendMetricsTest {
private static final String PROTOCOL = "some protocol";
private static final String CERT_HASH = "abc_blah_1134zdf";
private final FrontendMetrics metrics = new FrontendMetrics();
@Before
public void setUp() {
metrics.resetMetrics();
}
@Test
public void testSuccess_oneConnection() {
EmbeddedChannel channel = new EmbeddedChannel();
metrics.registerActiveConnection(PROTOCOL, CERT_HASH, channel);
assertThat(channel.isActive()).isTrue();
assertThat(FrontendMetrics.activeConnectionsGauge)
.hasValueForLabels(1, PROTOCOL, CERT_HASH)
.and()
.hasNoOtherValues();
assertThat(FrontendMetrics.totalConnectionsCounter)
.hasValueForLabels(1, PROTOCOL, CERT_HASH)
.and()
.hasNoOtherValues();
ChannelFuture unusedFuture = channel.close();
assertThat(channel.isActive()).isFalse();
assertThat(FrontendMetrics.activeConnectionsGauge).hasNoOtherValues();
assertThat(FrontendMetrics.totalConnectionsCounter)
.hasValueForLabels(1, PROTOCOL, CERT_HASH)
.and()
.hasNoOtherValues();
}
@Test
public void testSuccess_twoConnections_sameClient() {
EmbeddedChannel channel1 = new EmbeddedChannel();
EmbeddedChannel channel2 = new EmbeddedChannel(DefaultChannelId.newInstance());
metrics.registerActiveConnection(PROTOCOL, CERT_HASH, channel1);
assertThat(channel1.isActive()).isTrue();
assertThat(FrontendMetrics.activeConnectionsGauge)
.hasValueForLabels(1, PROTOCOL, CERT_HASH)
.and()
.hasNoOtherValues();
assertThat(FrontendMetrics.totalConnectionsCounter)
.hasValueForLabels(1, PROTOCOL, CERT_HASH)
.and()
.hasNoOtherValues();
metrics.registerActiveConnection(PROTOCOL, CERT_HASH, channel2);
assertThat(channel2.isActive()).isTrue();
assertThat(FrontendMetrics.activeConnectionsGauge)
.hasValueForLabels(2, PROTOCOL, CERT_HASH)
.and()
.hasNoOtherValues();
assertThat(FrontendMetrics.totalConnectionsCounter)
.hasValueForLabels(2, PROTOCOL, CERT_HASH)
.and()
.hasNoOtherValues();
@SuppressWarnings("unused")
ChannelFuture unusedFuture1 = channel1.close();
assertThat(channel1.isActive()).isFalse();
assertThat(FrontendMetrics.activeConnectionsGauge)
.hasValueForLabels(1, PROTOCOL, CERT_HASH)
.and()
.hasNoOtherValues();
assertThat(FrontendMetrics.totalConnectionsCounter)
.hasValueForLabels(2, PROTOCOL, CERT_HASH)
.and()
.hasNoOtherValues();
@SuppressWarnings("unused")
ChannelFuture unusedFuture2 = channel2.close();
assertThat(channel2.isActive()).isFalse();
assertThat(FrontendMetrics.activeConnectionsGauge).hasNoOtherValues();
assertThat(FrontendMetrics.totalConnectionsCounter)
.hasValueForLabels(2, PROTOCOL, CERT_HASH)
.and()
.hasNoOtherValues();
}
@Test
public void testSuccess_twoConnections_differentClients() {
EmbeddedChannel channel1 = new EmbeddedChannel();
EmbeddedChannel channel2 = new EmbeddedChannel(DefaultChannelId.newInstance());
String certHash2 = "blahblah_lol_234";
metrics.registerActiveConnection(PROTOCOL, CERT_HASH, channel1);
assertThat(channel1.isActive()).isTrue();
assertThat(FrontendMetrics.activeConnectionsGauge)
.hasValueForLabels(1, PROTOCOL, CERT_HASH)
.and()
.hasNoOtherValues();
assertThat(FrontendMetrics.totalConnectionsCounter)
.hasValueForLabels(1, PROTOCOL, CERT_HASH)
.and()
.hasNoOtherValues();
metrics.registerActiveConnection(PROTOCOL, certHash2, channel2);
assertThat(channel2.isActive()).isTrue();
assertThat(FrontendMetrics.activeConnectionsGauge)
.hasValueForLabels(1, PROTOCOL, CERT_HASH)
.and()
.hasValueForLabels(1, PROTOCOL, certHash2)
.and()
.hasNoOtherValues();
assertThat(FrontendMetrics.totalConnectionsCounter)
.hasValueForLabels(1, PROTOCOL, CERT_HASH)
.and()
.hasValueForLabels(1, PROTOCOL, certHash2)
.and()
.hasNoOtherValues();
ChannelFuture unusedFuture = channel1.close();
assertThat(channel1.isActive()).isFalse();
assertThat(FrontendMetrics.activeConnectionsGauge)
.hasValueForLabels(1, PROTOCOL, certHash2)
.and()
.hasNoOtherValues();
assertThat(FrontendMetrics.totalConnectionsCounter)
.hasValueForLabels(1, PROTOCOL, CERT_HASH)
.and()
.hasValueForLabels(1, PROTOCOL, certHash2)
.and()
.hasNoOtherValues();
unusedFuture = channel2.close();
assertThat(channel2.isActive()).isFalse();
assertThat(FrontendMetrics.activeConnectionsGauge).hasNoOtherValues();
assertThat(FrontendMetrics.totalConnectionsCounter)
.hasValueForLabels(1, PROTOCOL, CERT_HASH)
.and()
.hasValueForLabels(1, PROTOCOL, certHash2)
.and()
.hasNoOtherValues();
}
@Test
public void testSuccess_registerQuotaRejections() {
String otherCertHash = "foobar1234X";
String remoteAddress = "127.0.0.1";
String otherProtocol = "other protocol";
metrics.registerQuotaRejection(PROTOCOL, CERT_HASH);
metrics.registerQuotaRejection(PROTOCOL, otherCertHash);
metrics.registerQuotaRejection(PROTOCOL, otherCertHash);
metrics.registerQuotaRejection(otherProtocol, remoteAddress);
assertThat(FrontendMetrics.quotaRejectionsCounter)
.hasValueForLabels(1, PROTOCOL, CERT_HASH)
.and()
.hasValueForLabels(2, PROTOCOL, otherCertHash)
.and()
.hasValueForLabels(1, otherProtocol, remoteAddress)
.and()
.hasNoOtherValues();
}
}

View file

@ -0,0 +1,136 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy.metric;
import static com.google.common.truth.Truth.assertThat;
import static google.registry.proxy.metric.MetricParameters.CLUSTER_NAME_PATH;
import static google.registry.proxy.metric.MetricParameters.CONTAINER_NAME_ENV;
import static google.registry.proxy.metric.MetricParameters.INSTANCE_ID_PATH;
import static google.registry.proxy.metric.MetricParameters.NAMESPACE_ID_ENV;
import static google.registry.proxy.metric.MetricParameters.POD_ID_ENV;
import static google.registry.proxy.metric.MetricParameters.PROJECT_ID_PATH;
import static google.registry.proxy.metric.MetricParameters.ZONE_PATH;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableMap;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.HttpURLConnection;
import java.util.HashMap;
import java.util.Map.Entry;
import java.util.function.Function;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link MetricParameters}. */
@RunWith(JUnit4.class)
public class MetricParametersTest {
private static final HashMap<String, String> RESULTS = new HashMap<>();
private final HttpURLConnection projectIdConnection = mock(HttpURLConnection.class);
private final HttpURLConnection clusterNameConnection = mock(HttpURLConnection.class);
private final HttpURLConnection instanceIdConnection = mock(HttpURLConnection.class);
private final HttpURLConnection zoneConnection = mock(HttpURLConnection.class);
private final ImmutableMap<String, HttpURLConnection> mockConnections =
ImmutableMap.of(
PROJECT_ID_PATH,
projectIdConnection,
CLUSTER_NAME_PATH,
clusterNameConnection,
INSTANCE_ID_PATH,
instanceIdConnection,
ZONE_PATH,
zoneConnection);
private final HashMap<String, String> fakeEnvVarMap = new HashMap<>();
private final Function<String, HttpURLConnection> fakeConnectionFactory = mockConnections::get;
private final MetricParameters metricParameters =
new MetricParameters(fakeEnvVarMap, fakeConnectionFactory);
private static InputStream makeInputStreamFromString(String input) {
return new ByteArrayInputStream(input.getBytes(UTF_8));
}
@Before
public void setUp() throws Exception {
fakeEnvVarMap.put(NAMESPACE_ID_ENV, "some-namespace");
fakeEnvVarMap.put(POD_ID_ENV, "some-pod");
fakeEnvVarMap.put(CONTAINER_NAME_ENV, "some-container");
when(projectIdConnection.getInputStream())
.thenReturn(makeInputStreamFromString("some-project"));
when(clusterNameConnection.getInputStream())
.thenReturn(makeInputStreamFromString("some-cluster"));
when(instanceIdConnection.getInputStream())
.thenReturn(makeInputStreamFromString("some-instance"));
when(zoneConnection.getInputStream())
.thenReturn(makeInputStreamFromString("projects/some-project/zones/some-zone"));
for (Entry<String, HttpURLConnection> entry : mockConnections.entrySet()) {
when(entry.getValue().getResponseCode()).thenReturn(200);
}
RESULTS.put("project_id", "some-project");
RESULTS.put("cluster_name", "some-cluster");
RESULTS.put("namespace_id", "some-namespace");
RESULTS.put("instance_id", "some-instance");
RESULTS.put("pod_id", "some-pod");
RESULTS.put("container_name", "some-container");
RESULTS.put("zone", "some-zone");
}
@Test
public void testSuccess() {
assertThat(metricParameters.makeLabelsMap()).isEqualTo(ImmutableMap.copyOf(RESULTS));
}
@Test
public void testSuccess_missingEnvVar() {
fakeEnvVarMap.remove(POD_ID_ENV);
RESULTS.put("pod_id", "");
assertThat(metricParameters.makeLabelsMap()).isEqualTo(ImmutableMap.copyOf(RESULTS));
}
@Test
public void testSuccess_malformedZone() throws Exception {
when(zoneConnection.getInputStream()).thenReturn(makeInputStreamFromString("some-zone"));
RESULTS.put("zone", "");
assertThat(metricParameters.makeLabelsMap()).isEqualTo(ImmutableMap.copyOf(RESULTS));
}
@Test
public void testSuccess_errorResponseCode() throws Exception {
when(projectIdConnection.getResponseCode()).thenReturn(404);
when(projectIdConnection.getErrorStream())
.thenReturn(makeInputStreamFromString("some error message"));
RESULTS.put("project_id", "");
assertThat(metricParameters.makeLabelsMap()).isEqualTo(ImmutableMap.copyOf(RESULTS));
}
@Test
public void testSuccess_connectionError() throws Exception {
InputStream fakeInputStream = mock(InputStream.class);
when(projectIdConnection.getInputStream()).thenReturn(fakeInputStream);
when(fakeInputStream.read(any(byte[].class), anyInt(), anyInt()))
.thenThrow(new IOException("some exception"));
RESULTS.put("project_id", "");
assertThat(metricParameters.makeLabelsMap()).isEqualTo(ImmutableMap.copyOf(RESULTS));
}
}

View file

@ -0,0 +1,92 @@
// Copyright 2018 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy.quota;
import static com.google.common.truth.Truth.assertThat;
import static google.registry.testing.JUnitBackports.assertThrows;
import static google.registry.util.ResourceUtils.readResourceUtf8;
import google.registry.proxy.ProxyConfig.Quota;
import org.joda.time.Duration;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.yaml.snakeyaml.Yaml;
/** Unit Tests for {@link QuotaConfig} */
@RunWith(JUnit4.class)
public class QuotaConfigTest {
private QuotaConfig quotaConfig;
private static QuotaConfig loadQuotaConfig(String filename) {
return new QuotaConfig(
new Yaml()
.loadAs(readResourceUtf8(QuotaConfigTest.class, "testdata/" + filename), Quota.class),
"theProtocol");
}
private void validateQuota(String userId, int tokenAmount, int refillSeconds) {
assertThat(quotaConfig.hasUnlimitedTokens(userId)).isFalse();
assertThat(quotaConfig.getTokenAmount(userId)).isEqualTo(tokenAmount);
assertThat(quotaConfig.getRefillPeriod(userId))
.isEqualTo(Duration.standardSeconds(refillSeconds));
assertThat(quotaConfig.getProtocolName()).isEqualTo("theProtocol");
}
@Test
public void testSuccess_regularConfig() {
quotaConfig = loadQuotaConfig("quota_config_regular.yaml");
assertThat(quotaConfig.getRefreshPeriod()).isEqualTo(Duration.standardHours(1));
validateQuota("abc", 10, 60);
validateQuota("987lol", 500, 10);
validateQuota("no_match", 100, 60);
}
@Test
public void testSuccess_onlyDefault() {
quotaConfig = loadQuotaConfig("quota_config_default.yaml");
assertThat(quotaConfig.getRefreshPeriod()).isEqualTo(Duration.standardHours(1));
validateQuota("abc", 100, 60);
validateQuota("987lol", 100, 60);
validateQuota("no_match", 100, 60);
}
@Test
public void testSuccess_noRefresh_noRefill() {
quotaConfig = loadQuotaConfig("quota_config_no_refresh_no_refill.yaml");
assertThat(quotaConfig.getRefreshPeriod()).isEqualTo(Duration.ZERO);
assertThat(quotaConfig.getRefillPeriod("no_match")).isEqualTo(Duration.ZERO);
}
@Test
public void testFailure_getTokenAmount_throwsOnUnlimitedTokens() {
quotaConfig = loadQuotaConfig("quota_config_unlimited_tokens.yaml");
assertThat(quotaConfig.hasUnlimitedTokens("some_user")).isTrue();
IllegalStateException e =
assertThrows(IllegalStateException.class, () -> quotaConfig.getTokenAmount("some_user"));
assertThat(e)
.hasMessageThat()
.contains("User ID some_user is provisioned with unlimited tokens");
}
@Test
public void testFailure_duplicateUserId() {
IllegalArgumentException e =
assertThrows(
IllegalArgumentException.class, () -> loadQuotaConfig("quota_config_duplicate.yaml"));
assertThat(e).hasMessageThat().contains("Multiple entries with same key");
}
}

View file

@ -0,0 +1,82 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy.quota;
import static com.google.common.truth.Truth.assertThat;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import com.google.common.util.concurrent.MoreExecutors;
import google.registry.proxy.quota.QuotaManager.QuotaRebate;
import google.registry.proxy.quota.QuotaManager.QuotaRequest;
import google.registry.proxy.quota.QuotaManager.QuotaResponse;
import google.registry.proxy.quota.TokenStore.TimestampedInteger;
import google.registry.testing.FakeClock;
import java.util.concurrent.Future;
import org.joda.time.DateTime;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link QuotaManager}. */
@RunWith(JUnit4.class)
public class QuotaManagerTest {
private static final String USER_ID = "theUser";
private final TokenStore tokenStore = mock(TokenStore.class);
private final FakeClock clock = new FakeClock();
private QuotaManager quotaManager =
new QuotaManager(tokenStore, MoreExecutors.newDirectExecutorService());
private QuotaRequest request;
private QuotaResponse response;
@Test
public void testSuccess_requestApproved() {
when(tokenStore.take(anyString())).thenReturn(TimestampedInteger.create(1, clock.nowUtc()));
request = QuotaRequest.create(USER_ID);
response = quotaManager.acquireQuota(request);
assertThat(response.success()).isTrue();
assertThat(response.userId()).isEqualTo(USER_ID);
assertThat(response.grantedTokenRefillTime()).isEqualTo(clock.nowUtc());
}
@Test
public void testSuccess_requestDenied() {
when(tokenStore.take(anyString())).thenReturn(TimestampedInteger.create(0, clock.nowUtc()));
request = QuotaRequest.create(USER_ID);
response = quotaManager.acquireQuota(request);
assertThat(response.success()).isFalse();
assertThat(response.userId()).isEqualTo(USER_ID);
assertThat(response.grantedTokenRefillTime()).isEqualTo(clock.nowUtc());
}
@Test
public void testSuccess_rebate() throws Exception {
DateTime grantedTokenRefillTime = clock.nowUtc();
response = QuotaResponse.create(true, USER_ID, grantedTokenRefillTime);
QuotaRebate rebate = QuotaRebate.create(response);
Future<?> unusedFuture = quotaManager.releaseQuota(rebate);
verify(tokenStore).scheduleRefresh();
verify(tokenStore).put(USER_ID, grantedTokenRefillTime);
verifyNoMoreInteractions(tokenStore);
}
}

View file

@ -0,0 +1,315 @@
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.proxy.quota;
import static com.google.common.truth.Truth.assertThat;
import static google.registry.proxy.quota.QuotaConfig.SENTINEL_UNLIMITED_TOKENS;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import google.registry.proxy.quota.TokenStore.TimestampedInteger;
import google.registry.testing.FakeClock;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import org.joda.time.DateTime;
import org.joda.time.Duration;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
/** Unit tests for {@link TokenStore}. */
@RunWith(JUnit4.class)
public class TokenStoreTest {
private final QuotaConfig quotaConfig = mock(QuotaConfig.class);
private final FakeClock clock = new FakeClock();
private final ScheduledExecutorService refreshExecutor = mock(ScheduledExecutorService.class);
private final TokenStore tokenStore = spy(new TokenStore(quotaConfig, refreshExecutor, clock));
private final String user = "theUser";
private final String otherUser = "theOtherUser";
private DateTime assertTake(int grantAmount, int amountLeft, DateTime timestamp) {
return assertTake(user, grantAmount, amountLeft, timestamp);
}
private DateTime assertTake(String user, int grantAmount, int amountLeft, DateTime timestamp) {
TimestampedInteger grantedToken = tokenStore.take(user);
assertThat(grantedToken).isEqualTo(TimestampedInteger.create(grantAmount, timestamp));
assertThat(tokenStore.getTokenForTests(user))
.isEqualTo(TimestampedInteger.create(amountLeft, timestamp));
return grantedToken.timestamp();
}
private void assertPut(
DateTime returnedTokenRefillTime, int amountAfterReturn, DateTime refillTime) {
assertPut(user, returnedTokenRefillTime, amountAfterReturn, refillTime);
}
private void assertPut(
String user, DateTime returnedTokenRefillTime, int amountAfterReturn, DateTime refillTime) {
tokenStore.put(user, returnedTokenRefillTime);
assertThat(tokenStore.getTokenForTests(user))
.isEqualTo(TimestampedInteger.create(amountAfterReturn, refillTime));
}
private void submitAndWaitForTasks(ExecutorService executor, Runnable... tasks) {
List<Future<?>> futures = new ArrayList<>();
for (Runnable task : tasks) {
futures.add(executor.submit(task));
}
futures.forEach(
f -> {
try {
f.get();
} catch (Exception e) {
throw new RuntimeException(e);
}
});
}
@Before
public void setUp() {
when(quotaConfig.getRefreshPeriod()).thenReturn(Duration.standardSeconds(60));
when(quotaConfig.getRefillPeriod(user)).thenReturn(Duration.standardSeconds(10));
when(quotaConfig.getTokenAmount(user)).thenReturn(3);
when(quotaConfig.getRefillPeriod(otherUser)).thenReturn(Duration.standardSeconds(15));
when(quotaConfig.getTokenAmount(otherUser)).thenReturn(5);
}
@Test
public void testSuccess_take() {
// Take 3 tokens one by one.
DateTime refillTime = clock.nowUtc();
assertTake(1, 2, refillTime);
assertTake(1, 1, refillTime);
clock.advanceBy(Duration.standardSeconds(2));
assertTake(1, 0, refillTime);
// Take 1 token, not enough tokens left.
clock.advanceBy(Duration.standardSeconds(3));
assertTake(0, 0, refillTime);
// Refill period passed. Take 1 token - success.
clock.advanceBy(Duration.standardSeconds(6));
refillTime = clock.nowUtc();
assertTake(1, 2, refillTime);
}
@Test
public void testSuccess_put_entryDoesNotExist() {
tokenStore.put(user, clock.nowUtc());
assertThat(tokenStore.getTokenForTests(user)).isNull();
}
@Test
public void testSuccess_put() {
DateTime refillTime = clock.nowUtc();
// Initialize the entry.
DateTime grantedTokenRefillTime = assertTake(1, 2, refillTime);
// Put into full bucket.
assertPut(grantedTokenRefillTime, 3, refillTime);
assertPut(grantedTokenRefillTime, 3, refillTime);
clock.advanceBy(Duration.standardSeconds(3));
// Take 1 token out, put 1 back in.
assertTake(1, 2, refillTime);
assertPut(refillTime, 3, refillTime);
// Do not put old token back.
grantedTokenRefillTime = assertTake(1, 2, refillTime);
clock.advanceBy(Duration.standardSeconds(11));
refillTime = clock.nowUtc();
assertPut(grantedTokenRefillTime, 3, refillTime);
}
@Test
public void testSuccess_takeAndPut() {
DateTime refillTime = clock.nowUtc();
// Take 1 token.
DateTime grantedTokenRefillTime1 = assertTake(1, 2, refillTime);
// Take 1 token.
DateTime grantedTokenRefillTime2 = assertTake(1, 1, refillTime);
// Return first token.
clock.advanceBy(Duration.standardSeconds(2));
assertPut(grantedTokenRefillTime1, 2, refillTime);
// Refill time passed, second returned token discarded.
clock.advanceBy(Duration.standardSeconds(10));
refillTime = clock.nowUtc();
assertPut(grantedTokenRefillTime2, 3, refillTime);
}
@Test
public void testSuccess_multipleUsers() {
DateTime refillTime1 = clock.nowUtc();
DateTime refillTime2 = clock.nowUtc();
// Take 1 from first user.
DateTime grantedTokenRefillTime1 = assertTake(user, 1, 2, refillTime1);
// Take 1 from second user.
DateTime grantedTokenRefillTime2 = assertTake(otherUser, 1, 4, refillTime2);
assertTake(otherUser, 1, 3, refillTime2);
assertTake(otherUser, 1, 2, refillTime2);
// first user tokens refilled.
clock.advanceBy(Duration.standardSeconds(10));
refillTime1 = clock.nowUtc();
DateTime grantedTokenRefillTime3 = assertTake(user, 1, 2, refillTime1);
DateTime grantedTokenRefillTime4 = assertTake(otherUser, 1, 1, refillTime2);
assertPut(user, grantedTokenRefillTime1, 2, refillTime1);
assertPut(otherUser, grantedTokenRefillTime2, 2, refillTime2);
// second user tokens refilled.
clock.advanceBy(Duration.standardSeconds(5));
refillTime2 = clock.nowUtc();
assertPut(user, grantedTokenRefillTime3, 3, refillTime1);
assertPut(otherUser, grantedTokenRefillTime4, 5, refillTime2);
}
@Test
public void testSuccess_refresh() {
DateTime refillTime1 = clock.nowUtc();
assertTake(user, 1, 2, refillTime1);
clock.advanceBy(Duration.standardSeconds(5));
DateTime refillTime2 = clock.nowUtc();
assertTake(otherUser, 1, 4, refillTime2);
clock.advanceBy(Duration.standardSeconds(55));
// Entry for user is 60s old, entry for otherUser is 55s old.
tokenStore.refresh();
assertThat(tokenStore.getTokenForTests(user)).isNull();
assertThat(tokenStore.getTokenForTests(otherUser))
.isEqualTo(TimestampedInteger.create(4, refillTime2));
}
@Test
public void testSuccess_unlimitedQuota() {
when(quotaConfig.hasUnlimitedTokens(user)).thenReturn(true);
for (int i = 0; i < 10000; ++i) {
assertTake(1, SENTINEL_UNLIMITED_TOKENS, clock.nowUtc());
}
for (int i = 0; i < 10000; ++i) {
assertPut(clock.nowUtc(), SENTINEL_UNLIMITED_TOKENS, clock.nowUtc());
}
}
@Test
public void testSuccess_noRefill() {
when(quotaConfig.getRefillPeriod(user)).thenReturn(Duration.ZERO);
DateTime refillTime = clock.nowUtc();
assertTake(1, 2, refillTime);
assertTake(1, 1, refillTime);
assertTake(1, 0, refillTime);
clock.advanceBy(Duration.standardDays(365));
assertTake(0, 0, refillTime);
}
@Test
public void testSuccess_noRefresh() {
when(quotaConfig.getRefreshPeriod()).thenReturn(Duration.ZERO);
DateTime refillTime = clock.nowUtc();
assertTake(1, 2, refillTime);
clock.advanceBy(Duration.standardDays(365));
assertThat(tokenStore.getTokenForTests(user))
.isEqualTo(TimestampedInteger.create(2, refillTime));
}
@Test
public void testSuccess_concurrency() throws Exception {
ExecutorService executor = Executors.newWorkStealingPool();
final DateTime time1 = clock.nowUtc();
submitAndWaitForTasks(
executor,
() -> tokenStore.take(user),
() -> tokenStore.take(otherUser),
() -> tokenStore.take(user),
() -> tokenStore.take(otherUser));
assertThat(tokenStore.getTokenForTests(user)).isEqualTo(TimestampedInteger.create(1, time1));
assertThat(tokenStore.getTokenForTests(otherUser))
.isEqualTo(TimestampedInteger.create(3, time1));
// No refill.
clock.advanceBy(Duration.standardSeconds(5));
submitAndWaitForTasks(
executor, () -> tokenStore.take(user), () -> tokenStore.put(otherUser, time1));
assertThat(tokenStore.getTokenForTests(user)).isEqualTo(TimestampedInteger.create(0, time1));
assertThat(tokenStore.getTokenForTests(otherUser))
.isEqualTo(TimestampedInteger.create(4, time1));
// First user refill.
clock.advanceBy(Duration.standardSeconds(5));
final DateTime time2 = clock.nowUtc();
submitAndWaitForTasks(
executor,
() -> {
tokenStore.put(user, time1);
tokenStore.take(user);
},
() -> tokenStore.take(otherUser));
assertThat(tokenStore.getTokenForTests(user)).isEqualTo(TimestampedInteger.create(2, time2));
assertThat(tokenStore.getTokenForTests(otherUser))
.isEqualTo(TimestampedInteger.create(3, time1));
// Second user refill.
clock.advanceBy(Duration.standardSeconds(5));
final DateTime time3 = clock.nowUtc();
submitAndWaitForTasks(
executor,
() -> tokenStore.take(user),
() -> {
tokenStore.put(otherUser, time1);
tokenStore.take(otherUser);
});
assertThat(tokenStore.getTokenForTests(user)).isEqualTo(TimestampedInteger.create(1, time2));
assertThat(tokenStore.getTokenForTests(otherUser))
.isEqualTo(TimestampedInteger.create(4, time3));
}
@Test
public void testSuccess_scheduleRefresh() throws Exception {
when(quotaConfig.getRefreshPeriod()).thenReturn(Duration.standardSeconds(5));
tokenStore.scheduleRefresh();
// Verify that a task is scheduled.
ArgumentCaptor<Runnable> argument = ArgumentCaptor.forClass(Runnable.class);
verify(refreshExecutor)
.scheduleWithFixedDelay(
argument.capture(), eq((long) 5), eq((long) 5), eq(TimeUnit.SECONDS));
// Verify that the scheduled task calls TokenStore.refresh().
argument.getValue().run();
verify(tokenStore).refresh();
}
}

View file

@ -0,0 +1,8 @@
refreshSeconds: 3600
defaultQuota:
userId: []
tokenAmount: 100
refillSeconds: 60
customQuota: []

View file

@ -0,0 +1,14 @@
refreshSeconds: 3600
defaultQuota:
userId: []
tokenAmount: 100
refillSeconds: 60
customQuota:
- userId: ["abc", "def"]
tokenAmount: 10
refillSeconds: 60
- userId: ["xyz123", "def", "luckycat"]
tokenAmount: 500
refillSeconds: 10

View file

@ -0,0 +1,8 @@
refreshSeconds: 0
defaultQuota:
userId: []
tokenAmount: 100
refillSeconds: 0
customQuota: []

View file

@ -0,0 +1,14 @@
refreshSeconds: 3600
defaultQuota:
userId: []
tokenAmount: 100
refillSeconds: 60
customQuota:
- userId: ["abc", "def"]
tokenAmount: 10
refillSeconds: 60
- userId: ["xyz123", "987lol", "luckycat"]
tokenAmount: 500
refillSeconds: 10

Some files were not shown because too many files have changed in this diff Show more