diff --git a/javatests/google/registry/server/ServletWrapperDelegatorServlet.java b/javatests/google/registry/server/ServletWrapperDelegatorServlet.java
index 77c5c9aef..77ca0bad6 100644
--- a/javatests/google/registry/server/ServletWrapperDelegatorServlet.java
+++ b/javatests/google/registry/server/ServletWrapperDelegatorServlet.java
@@ -15,18 +15,23 @@
package google.registry.server;
import static com.google.common.base.Preconditions.checkNotNull;
-import static com.google.common.base.Suppliers.memoize;
+import static google.registry.util.TypeUtils.instantiate;
-import com.google.common.base.Supplier;
import com.google.common.base.Throwables;
+import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.Uninterruptibles;
import java.io.IOException;
+import java.util.Iterator;
import java.util.Queue;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.FutureTask;
import javax.annotation.Nullable;
+import javax.servlet.Filter;
+import javax.servlet.FilterChain;
import javax.servlet.ServletException;
+import javax.servlet.ServletRequest;
+import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
@@ -34,17 +39,23 @@ import javax.servlet.http.HttpServletResponse;
/**
* Servlet that wraps a servlet and delegates request execution to a queue.
*
+ *
The actual invocation of the delegate does not happen within this servlet's lifecycle.
+ * Therefore, the task on the queue must manually invoke filters within the queue task.
+ *
* @see TestServer
*/
public final class ServletWrapperDelegatorServlet extends HttpServlet {
private final Queue> requestQueue;
- private final Supplier servlet;
+ private final Class extends HttpServlet> servletClass;
+ private final ImmutableList> filterClasses;
ServletWrapperDelegatorServlet(
Class extends HttpServlet> servletClass,
+ ImmutableList> filterClasses,
Queue> requestQueue) {
- this.servlet = lazilyInstantiate(checkNotNull(servletClass, "servletClass"));
+ this.servletClass = servletClass;
+ this.filterClasses = filterClasses;
this.requestQueue = checkNotNull(requestQueue, "requestQueue");
}
@@ -55,7 +66,20 @@ public final class ServletWrapperDelegatorServlet extends HttpServlet {
@Nullable
@Override
public Void call() throws ServletException, IOException {
- servlet.get().service(req, rsp);
+ // Simulate the full filter chain with the servlet at the end.
+ final Iterator> filtersIter = filterClasses.iterator();
+ FilterChain filterChain =
+ new FilterChain() {
+ @Override
+ public void doFilter(ServletRequest request, ServletResponse response)
+ throws IOException, ServletException {
+ if (filtersIter.hasNext()) {
+ instantiate(filtersIter.next()).doFilter(request, response, this);
+ } else {
+ instantiate(servletClass).service(request, response);
+ }
+ }};
+ filterChain.doFilter(req, rsp);
return null;
}});
requestQueue.add(task);
@@ -67,16 +91,4 @@ public final class ServletWrapperDelegatorServlet extends HttpServlet {
throw Throwables.propagate(e.getCause());
}
}
-
- private static Supplier lazilyInstantiate(final Class extends T> clazz) {
- return memoize(new Supplier() {
- @Override
- public T get() {
- try {
- return clazz.newInstance();
- } catch (InstantiationException | IllegalAccessException e) {
- throw new RuntimeException(e);
- }
- }});
- }
}
diff --git a/javatests/google/registry/server/TestServer.java b/javatests/google/registry/server/TestServer.java
index f5f8c6784..be62eaa58 100644
--- a/javatests/google/registry/server/TestServer.java
+++ b/javatests/google/registry/server/TestServer.java
@@ -15,11 +15,12 @@
package google.registry.server;
import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.util.concurrent.Runnables.doNothing;
import static google.registry.util.NetworkUtils.getCanonicalHostName;
import com.google.common.base.Throwables;
+import com.google.common.collect.ImmutableList;
import com.google.common.net.HostAndPort;
-import com.google.common.util.concurrent.Callables;
import com.google.common.util.concurrent.SimpleTimeLimiter;
import java.net.MalformedURLException;
import java.net.URL;
@@ -34,7 +35,6 @@ import javax.annotation.Nullable;
import javax.servlet.Filter;
import javax.servlet.http.HttpServlet;
import org.mortbay.jetty.Connector;
-import org.mortbay.jetty.Handler;
import org.mortbay.jetty.Server;
import org.mortbay.jetty.bio.SocketConnector;
import org.mortbay.jetty.servlet.Context;
@@ -85,8 +85,8 @@ public final class TestServer {
public TestServer(
HostAndPort address,
Map runfiles,
- Iterable routes,
- Iterable> filters) {
+ ImmutableList routes,
+ ImmutableList> filters) {
urlAddress = createUrlAddress(address);
server.addConnector(createConnector(address));
server.addHandler(createHandler(runfiles, routes, filters));
@@ -120,7 +120,7 @@ public final class TestServer {
* main event loop, for post-request processing.
*/
public void ping() {
- requestQueue.add(new FutureTask<>(Callables.returning(null)));
+ requestQueue.add(new FutureTask(doNothing(), null));
}
/** Stops the HTTP server. */
@@ -151,8 +151,8 @@ public final class TestServer {
private Context createHandler(
Map runfiles,
- Iterable routes,
- Iterable> filters) {
+ ImmutableList routes,
+ ImmutableList> filters) {
Context context = new Context(server, CONTEXT_PATH, Context.SESSIONS);
context.addServlet(new ServletHolder(HealthzServlet.class), "/healthz");
for (Map.Entry runfile : runfiles.entrySet()) {
@@ -161,10 +161,8 @@ public final class TestServer {
runfile.getKey());
}
for (Route route : routes) {
- context.addServlet(new ServletHolder(wrapServlet(route.servletClass())), route.path());
- }
- for (Class extends Filter> filter : filters) {
- context.addFilter(filter, "/*", Handler.REQUEST);
+ context.addServlet(
+ new ServletHolder(wrapServlet(route.servletClass(), filters)), route.path());
}
ServletHolder holder = new ServletHolder(DefaultServlet.class);
holder.setInitParameter("aliases", "1");
@@ -172,8 +170,9 @@ public final class TestServer {
return context;
}
- private HttpServlet wrapServlet(Class extends HttpServlet> servletClass) {
- return new ServletWrapperDelegatorServlet(servletClass, requestQueue);
+ private HttpServlet wrapServlet(
+ Class extends HttpServlet> servletClass, ImmutableList> filters) {
+ return new ServletWrapperDelegatorServlet(servletClass, filters, requestQueue);
}
private static Connector createConnector(HostAndPort address) {