diff --git a/java/google/registry/proxy/ProxyServer.java b/java/google/registry/proxy/ProxyServer.java index 81cda3920..417255b77 100644 --- a/java/google/registry/proxy/ProxyServer.java +++ b/java/google/registry/proxy/ProxyServer.java @@ -129,6 +129,8 @@ public class ProxyServer implements Runnable { .closeFuture() .addListener( (future) -> { + logger.atInfo().log( + "Connection terminated: %s %s", inboundProtocol.name(), inboundChannel); // Check if there's a relay connection. In case that the outbound connection // is not successful, this attribute is not set. Channel outboundChannel = inboundChannel.attr(RELAY_CHANNEL_KEY).get(); @@ -177,13 +179,15 @@ public class ProxyServer implements Runnable { Object[] messages = relayBuffer.toArray(); relayBuffer.clear(); for (Object msg : messages) { - writeToRelayChannel(inboundChannel, outboundChannel, msg); + // TODO (jianglai): do not log the message once retry behavior is confirmed. logger.atInfo().log( - "Relay retried: %s <-> %s\nFRONTEND: %s\nBACKEND: %s", + "Relay retried: %s <-> %s\nFRONTEND: %s\nBACKEND: %s\nMESSAGE: %s", inboundProtocol.name(), outboundProtocol.name(), inboundChannel, - outboundChannel); + outboundChannel, + msg); + writeToRelayChannel(inboundChannel, outboundChannel, msg, true); } // When this outbound connection is closed, try reconnecting if the inbound connection // is still active. diff --git a/java/google/registry/proxy/handler/HttpsRelayServiceHandler.java b/java/google/registry/proxy/handler/HttpsRelayServiceHandler.java index 917934294..92f848f42 100644 --- a/java/google/registry/proxy/handler/HttpsRelayServiceHandler.java +++ b/java/google/registry/proxy/handler/HttpsRelayServiceHandler.java @@ -14,13 +14,14 @@ package google.registry.proxy.handler; -import static com.google.common.base.Preconditions.checkArgument; import static java.nio.charset.StandardCharsets.UTF_8; import com.google.common.base.Throwables; +import com.google.common.collect.ImmutableSet; import com.google.common.flogger.FluentLogger; import google.registry.proxy.metric.FrontendMetrics; import io.netty.buffer.ByteBuf; +import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; @@ -58,10 +59,16 @@ import javax.net.ssl.SSLHandshakeException; *

This handler is session aware and will store all the session cookies that the are contained in * the HTTP response headers, which are added back to headers of subsequent HTTP requests. */ -abstract class HttpsRelayServiceHandler extends ByteToMessageCodec { +public abstract class HttpsRelayServiceHandler extends ByteToMessageCodec { private static final FluentLogger logger = FluentLogger.forEnclosingClass(); + protected static final ImmutableSet> NON_FATAL_INBOUND_EXCEPTIONS = + ImmutableSet.of(ReadTimeoutException.class, SSLHandshakeException.class); + + protected static final ImmutableSet> NON_FATAL_OUTBOUND_EXCEPTIONS = + ImmutableSet.of(NonOkHttpResponseException.class); + private final Map cookieStore = new LinkedHashMap<>(); private final String relayHost; private final String relayPath; @@ -153,12 +160,9 @@ abstract class HttpsRelayServiceHandler extends ByteToMessageCodec { if (!channelFuture.isSuccess()) { Throwable cause = channelFuture.cause(); - // If the failure is caused by IllegalArgumentException, we know that it is because we - // got a non 200 response. This is an expected error from the backend and should not be - // logged at severe. - if (Throwables.getRootCause(cause) instanceof IllegalArgumentException) { + if (NON_FATAL_OUTBOUND_EXCEPTIONS.contains(Throwables.getRootCause(cause).getClass())) { logger.atWarning().withCause(channelFuture.cause()).log( "Outbound exception caught for channel %s", channelFuture.channel()); } else { @@ -202,4 +200,14 @@ abstract class HttpsRelayServiceHandler extends ByteToMessageCodec extends SimpleChannelInboundHandler { logger.atSevere().log("Relay channel not specified for channel: %s", channel); ChannelFuture unusedFuture = channel.close(); } else { - writeToRelayChannel(channel, relayChannel, msg); + writeToRelayChannel(channel, relayChannel, msg, false); } } - public static void writeToRelayChannel(Channel channel, Channel relayChannel, Object msg) { + public static void writeToRelayChannel( + Channel channel, Channel relayChannel, Object msg, boolean retry) { ChannelFuture unusedFuture = relayChannel .writeAndFlush(msg) .addListener( future -> { if (!future.isSuccess()) { - logger.atWarning().log( - "Relay failed: %s --> %s\nINBOUND: %s\nOUTBOUND: %s", + // TODO (jianglai): do not log the message once retry behavior is confirmed. + logger.atWarning().withCause(future.cause()).log( + "Relay failed: %s --> %s\nINBOUND: %s\nOUTBOUND: %s\nMESSAGE: %s", channel.attr(PROTOCOL_KEY).get().name(), relayChannel.attr(PROTOCOL_KEY).get().name(), channel, - relayChannel); + relayChannel, + msg); // If we cannot write to the relay channel and the originating channel has // a relay buffer (i. e. we tried to relay the frontend to the backend), store - // the message in the buffer for retry later. Otherwise, we are relaying from - // the backend to the frontend, and this relay failure cannot be recovered - // from, we should just kill the relay (frontend) channel, which in turn will - // kill the backend channel. We should not kill any backend channel while the - // the frontend channel is open, because that will just trigger a reconnect. - // It is fine to just save the message object itself, not a clone of it, - // because if the relay is not successful, its content is not read, therefore - // its buffer is not cleared. + // the message in the buffer for retry later. The relay channel (backend) should + // be killed (if it is not already dead, usually the relay is unsuccessful + // because the connection is closed), and a new backend channel will re-connect + // as long as the frontend channel is open. Otherwise, we are relaying from the + // backend to the frontend, and this relay failure cannot be recovered from: we + // should just kill the relay (frontend) channel, which in turn will kill the + // backend channel. It is fine to just save the message object itself, not a + // clone of it, because if the relay is not successful, its content is not read, + // therefore its buffer is not cleared. Queue relayBuffer = channel.attr(RELAY_BUFFER_KEY).get(); if (relayBuffer != null) { channel.attr(RELAY_BUFFER_KEY).get().add(msg); - } else { - ChannelFuture unusedFuture2 = relayChannel.close(); } + ChannelFuture unusedFuture2 = relayChannel.close(); + } else if (retry) { + // TODO (jianglai): do not log the message once retry behavior is confirmed. + logger.atInfo().log( + "Relay retry succeeded: %s --> %s\nINBOUND: %s\nOUTBOUND: %s\nsMESSAGE: %s", + channel.attr(PROTOCOL_KEY).get().name(), + relayChannel.attr(PROTOCOL_KEY).get().name(), + channel, + relayChannel, + msg); } }); } diff --git a/javatests/google/registry/proxy/CertificateModuleTest.java b/javatests/google/registry/proxy/CertificateModuleTest.java index 841c0c265..98616f57c 100644 --- a/javatests/google/registry/proxy/CertificateModuleTest.java +++ b/javatests/google/registry/proxy/CertificateModuleTest.java @@ -17,8 +17,8 @@ package google.registry.proxy; import static com.google.common.truth.Truth.assertThat; import static google.registry.proxy.handler.SslInitializerTestUtils.getKeyPair; import static google.registry.proxy.handler.SslInitializerTestUtils.signKeyPair; +import static google.registry.testing.JUnitBackports.assertThrows; import static java.nio.charset.StandardCharsets.UTF_8; -import static org.junit.Assert.fail; import dagger.Component; import dagger.Module; @@ -94,48 +94,36 @@ public class CertificateModuleTest { public void testFailure_noPrivateKey() throws Exception { byte[] pemBytes = getPemBytes(cert, ssc.cert()); component = createComponent(pemBytes); - try { - component.privateKey(); - fail("Expect IllegalStateException."); - } catch (IllegalStateException e) { - assertThat(e).hasMessageThat().contains("0 keys are found"); - } + IllegalStateException thrown = + assertThrows(IllegalStateException.class, () -> component.privateKey()); + assertThat(thrown).hasMessageThat().contains("0 keys are found"); } @Test public void testFailure_twoPrivateKeys() throws Exception { byte[] pemBytes = getPemBytes(cert, ssc.cert(), key, ssc.key()); component = createComponent(pemBytes); - try { - component.privateKey(); - fail("Expect IllegalStateException."); - } catch (IllegalStateException e) { - assertThat(e).hasMessageThat().contains("2 keys are found"); - } + IllegalStateException thrown = + assertThrows(IllegalStateException.class, () -> component.privateKey()); + assertThat(thrown).hasMessageThat().contains("2 keys are found"); } @Test public void testFailure_certificatesOutOfOrder() throws Exception { byte[] pemBytes = getPemBytes(ssc.cert(), cert, key); component = createComponent(pemBytes); - try { - component.certificates(); - fail("Expect IllegalStateException."); - } catch (IllegalStateException e) { - assertThat(e).hasMessageThat().contains("is not signed by"); - } + IllegalStateException thrown = + assertThrows(IllegalStateException.class, () -> component.certificates()); + assertThat(thrown).hasMessageThat().contains("is not signed by"); } @Test public void testFailure_noCertificates() throws Exception { byte[] pemBytes = getPemBytes(key); component = createComponent(pemBytes); - try { - component.certificates(); - fail("Expect IllegalStateException."); - } catch (IllegalStateException e) { - assertThat(e).hasMessageThat().contains("No certificates"); - } + IllegalStateException thrown = + assertThrows(IllegalStateException.class, () -> component.certificates()); + assertThat(thrown).hasMessageThat().contains("No certificates"); } @Module diff --git a/javatests/google/registry/proxy/EppProtocolModuleTest.java b/javatests/google/registry/proxy/EppProtocolModuleTest.java index 0a53cf0ed..87fec54dc 100644 --- a/javatests/google/registry/proxy/EppProtocolModuleTest.java +++ b/javatests/google/registry/proxy/EppProtocolModuleTest.java @@ -17,14 +17,18 @@ package google.registry.proxy; import static com.google.common.truth.Truth.assertThat; import static google.registry.proxy.handler.ProxyProtocolHandler.REMOTE_ADDRESS_KEY; import static google.registry.proxy.handler.SslServerInitializer.CLIENT_CERTIFICATE_PROMISE_KEY; +import static google.registry.testing.JUnitBackports.assertThrows; import static google.registry.util.ResourceUtils.readResourceBytes; import static google.registry.util.X509Utils.getCertificateHash; import static java.nio.charset.StandardCharsets.UTF_8; +import com.google.common.base.Throwables; +import google.registry.proxy.handler.HttpsRelayServiceHandler.NonOkHttpResponseException; import google.registry.testing.FakeClock; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.EncoderException; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.HttpResponseStatus; @@ -108,8 +112,12 @@ public class EppProtocolModuleTest extends ProtocolModuleTest { } private FullHttpResponse makeEppHttpResponse(byte[] content, Cookie... cookies) { - return TestUtils.makeEppHttpResponse( - new String(content, UTF_8), HttpResponseStatus.OK, cookies); + return makeEppHttpResponse(content, HttpResponseStatus.OK, cookies); + } + + private FullHttpResponse makeEppHttpResponse( + byte[] content, HttpResponseStatus status, Cookie... cookies) { + return TestUtils.makeEppHttpResponse(new String(content, UTF_8), status, cookies); } @Override @@ -209,6 +217,28 @@ public class EppProtocolModuleTest extends ProtocolModuleTest { assertThat(channel.isActive()).isTrue(); } + @Test + public void testFailure_nonOkOutboundMessage() throws Exception { + // First inbound message is hello. + channel.readInbound(); + + byte[] outputBytes = readResourceBytes(getClass(), "testdata/login_response.xml").read(); + + // Verify outbound message is not written to the peer as the response is not OK. + EncoderException thrown = + assertThrows( + EncoderException.class, + () -> + channel.writeOutbound( + makeEppHttpResponse(outputBytes, HttpResponseStatus.UNAUTHORIZED))); + assertThat(Throwables.getRootCause(thrown)).isInstanceOf(NonOkHttpResponseException.class); + assertThat(thrown).hasMessageThat().contains("401 Unauthorized"); + assertThat((Object) channel.readOutbound()).isNull(); + + // Channel is closed. + assertThat(channel.isActive()).isFalse(); + } + @Test public void testSuccess_setAndReadCookies() throws Exception { // First inbound message is hello. diff --git a/javatests/google/registry/proxy/ProxyModuleTest.java b/javatests/google/registry/proxy/ProxyModuleTest.java index 40916f455..6f88fcb15 100644 --- a/javatests/google/registry/proxy/ProxyModuleTest.java +++ b/javatests/google/registry/proxy/ProxyModuleTest.java @@ -18,7 +18,6 @@ import static com.google.common.truth.Truth.assertThat; import static google.registry.proxy.ProxyConfig.Environment.LOCAL; import static google.registry.proxy.ProxyConfig.getProxyConfig; import static google.registry.testing.JUnitBackports.assertThrows; -import static org.junit.Assert.fail; import com.beust.jcommander.ParameterException; import google.registry.proxy.ProxyConfig.Environment; @@ -66,12 +65,9 @@ public class ProxyModuleTest { @Test public void testFailure_parseArgs_wrongArguments() { String[] args = {"--wrong_flag", "some_value"}; - try { - proxyModule.parse(args); - fail("Expected ParameterException."); - } catch (ParameterException e) { - assertThat(e).hasMessageThat().contains("--wrong_flag"); - } + ParameterException thrown = + assertThrows(ParameterException.class, () -> proxyModule.parse(args)); + assertThat(thrown).hasMessageThat().contains("--wrong_flag"); } @Test diff --git a/javatests/google/registry/proxy/WhoisProtocolModuleTest.java b/javatests/google/registry/proxy/WhoisProtocolModuleTest.java index b90bdbb07..999ad1dfe 100644 --- a/javatests/google/registry/proxy/WhoisProtocolModuleTest.java +++ b/javatests/google/registry/proxy/WhoisProtocolModuleTest.java @@ -20,10 +20,12 @@ import static google.registry.proxy.TestUtils.makeWhoisHttpResponse; import static google.registry.testing.JUnitBackports.assertThrows; import static java.nio.charset.StandardCharsets.US_ASCII; import static java.util.stream.Collectors.joining; -import static org.junit.Assert.fail; +import com.google.common.base.Throwables; +import google.registry.proxy.handler.HttpsRelayServiceHandler.NonOkHttpResponseException; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; +import io.netty.handler.codec.EncoderException; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.HttpResponseStatus; @@ -152,13 +154,11 @@ public class WhoisProtocolModuleTest extends ProtocolModuleTest { public void testFailure_outboundResponseStatusNotOK() { String outputString = "line1\r\nline2\r\n"; FullHttpResponse response = makeWhoisHttpResponse(outputString, HttpResponseStatus.BAD_REQUEST); - try { - channel.writeOutbound(response); - fail("Expected failure due to non-OK HTTP response status"); - } catch (Exception e) { - assertThat(e).hasCauseThat().isInstanceOf(IllegalArgumentException.class); - assertThat(e).hasMessageThat().contains("400 Bad Request"); - } + EncoderException thrown = + assertThrows(EncoderException.class, () -> channel.writeOutbound(response)); + assertThat(Throwables.getRootCause(thrown)).isInstanceOf(NonOkHttpResponseException.class); + assertThat(thrown).hasMessageThat().contains("400 Bad Request"); + assertThat((Object) channel.readOutbound()).isNull(); assertThat(channel.isActive()).isFalse(); } } diff --git a/javatests/google/registry/proxy/handler/EppServiceHandlerTest.java b/javatests/google/registry/proxy/handler/EppServiceHandlerTest.java index f7ae274b5..39a75edc0 100644 --- a/javatests/google/registry/proxy/handler/EppServiceHandlerTest.java +++ b/javatests/google/registry/proxy/handler/EppServiceHandlerTest.java @@ -19,14 +19,16 @@ import static google.registry.proxy.TestUtils.assertHttpRequestEquivalent; import static google.registry.proxy.TestUtils.makeEppHttpResponse; import static google.registry.proxy.handler.ProxyProtocolHandler.REMOTE_ADDRESS_KEY; import static google.registry.proxy.handler.SslServerInitializer.CLIENT_CERTIFICATE_PROMISE_KEY; +import static google.registry.testing.JUnitBackports.assertThrows; import static google.registry.util.X509Utils.getCertificateHash; import static java.nio.charset.StandardCharsets.UTF_8; -import static org.junit.Assert.fail; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; +import com.google.common.base.Throwables; import google.registry.proxy.TestUtils; +import google.registry.proxy.handler.HttpsRelayServiceHandler.NonOkHttpResponseException; import google.registry.proxy.metric.FrontendMetrics; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; @@ -248,7 +250,7 @@ public class EppServiceHandlerTest { response.headers().set("Epp-Session", "close"); channel.writeOutbound(response); ByteBuf expectedResponse = channel.readOutbound(); - assertThat(expectedResponse).isEqualTo(Unpooled.wrappedBuffer(content.getBytes(UTF_8))); + assertThat(Unpooled.wrappedBuffer(content.getBytes(UTF_8))).isEqualTo(expectedResponse); // Nothing further to pass to the next handler. assertThat((Object) channel.readOutbound()).isNull(); // Channel is disconnected. @@ -259,14 +261,16 @@ public class EppServiceHandlerTest { public void testFailure_disconnectOnNonOKResponseStatus() throws Exception { setHandshakeSuccess(); String content = "stuff"; - try { - channel.writeOutbound(makeEppHttpResponse(content, HttpResponseStatus.BAD_REQUEST)); - fail("Expected EncoderException"); - } catch (EncoderException e) { - assertThat(e).hasCauseThat().isInstanceOf(IllegalArgumentException.class); - assertThat(e).hasMessageThat().contains(HttpResponseStatus.BAD_REQUEST.toString()); - assertThat(channel.isActive()).isFalse(); - } + EncoderException thrown = + assertThrows( + EncoderException.class, + () -> + channel.writeOutbound( + makeEppHttpResponse(content, HttpResponseStatus.BAD_REQUEST))); + assertThat(Throwables.getRootCause(thrown)).isInstanceOf(NonOkHttpResponseException.class); + assertThat(thrown).hasMessageThat().contains(HttpResponseStatus.BAD_REQUEST.toString()); + assertThat((Object) channel.readOutbound()).isNull(); + assertThat(channel.isActive()).isFalse(); } @Test diff --git a/javatests/google/registry/proxy/handler/RelayHandlerTest.java b/javatests/google/registry/proxy/handler/RelayHandlerTest.java index a59420bcd..e1735c2c3 100644 --- a/javatests/google/registry/proxy/handler/RelayHandlerTest.java +++ b/javatests/google/registry/proxy/handler/RelayHandlerTest.java @@ -85,7 +85,23 @@ public class RelayHandlerTest { } @Test - public void testSuccess_outboundClosed_enqueueBuffer() { + public void testSuccess_frontClosed() { + inboundChannel.attr(RELAY_BUFFER_KEY).set(null); + inboundChannel.attr(PROTOCOL_KEY).set(backendProtocol); + outboundChannel.attr(PROTOCOL_KEY).set(frontendProtocol); + ExpectedType inboundMessage = new ExpectedType(); + // Outbound channel (frontend) is closed. + outboundChannel.finish(); + assertThat(inboundChannel.writeInbound(inboundMessage)).isFalse(); + ExpectedType relayedMessage = outboundChannel.readOutbound(); + assertThat(relayedMessage).isNull(); + // Inbound channel (backend) should stay open. + assertThat(inboundChannel.isActive()).isTrue(); + assertThat(inboundChannel.attr(RELAY_BUFFER_KEY).get()).isNull(); + } + + @Test + public void testSuccess_backendClosed_enqueueBuffer() { ExpectedType inboundMessage = new ExpectedType(); // Outbound channel (backend) is closed. outboundChannel.finish(); diff --git a/javatests/google/registry/proxy/handler/WhoisServiceHandlerTest.java b/javatests/google/registry/proxy/handler/WhoisServiceHandlerTest.java index 191dbbcdf..cbfcf8386 100644 --- a/javatests/google/registry/proxy/handler/WhoisServiceHandlerTest.java +++ b/javatests/google/registry/proxy/handler/WhoisServiceHandlerTest.java @@ -17,17 +17,20 @@ package google.registry.proxy.handler; import static com.google.common.truth.Truth.assertThat; import static google.registry.proxy.TestUtils.makeWhoisHttpRequest; import static google.registry.proxy.TestUtils.makeWhoisHttpResponse; +import static google.registry.testing.JUnitBackports.assertThrows; import static java.nio.charset.StandardCharsets.US_ASCII; -import static org.junit.Assert.fail; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; +import com.google.common.base.Throwables; +import google.registry.proxy.handler.HttpsRelayServiceHandler.NonOkHttpResponseException; import google.registry.proxy.metric.FrontendMetrics; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.DefaultChannelId; import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.EncoderException; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.HttpResponseStatus; @@ -116,13 +119,11 @@ public class WhoisServiceHandlerTest { String outputString = "line1\r\nline2\r\n"; FullHttpResponse outputResponse = makeWhoisHttpResponse(outputString, HttpResponseStatus.BAD_REQUEST); - try { - channel.writeOutbound(outputResponse); - fail("Expected failure due to non-OK HTTP response status."); - } catch (Exception e) { - assertThat(e).hasCauseThat().isInstanceOf(IllegalArgumentException.class); - assertThat(e).hasMessageThat().contains("400 Bad Request"); - } + EncoderException thrown = + assertThrows(EncoderException.class, () -> channel.writeOutbound(outputResponse)); + assertThat(Throwables.getRootCause(thrown)).isInstanceOf(NonOkHttpResponseException.class); + assertThat(thrown).hasMessageThat().contains("400 Bad Request"); + assertThat((Object) channel.readOutbound()).isNull(); assertThat(channel.isActive()).isFalse(); } }