diff --git a/javatests/google/registry/server/ServletWrapperDelegatorServlet.java b/javatests/google/registry/server/ServletWrapperDelegatorServlet.java index 77c5c9aef..77ca0bad6 100644 --- a/javatests/google/registry/server/ServletWrapperDelegatorServlet.java +++ b/javatests/google/registry/server/ServletWrapperDelegatorServlet.java @@ -15,18 +15,23 @@ package google.registry.server; import static com.google.common.base.Preconditions.checkNotNull; -import static com.google.common.base.Suppliers.memoize; +import static google.registry.util.TypeUtils.instantiate; -import com.google.common.base.Supplier; import com.google.common.base.Throwables; +import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.Uninterruptibles; import java.io.IOException; +import java.util.Iterator; import java.util.Queue; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.FutureTask; import javax.annotation.Nullable; +import javax.servlet.Filter; +import javax.servlet.FilterChain; import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -34,17 +39,23 @@ import javax.servlet.http.HttpServletResponse; /** * Servlet that wraps a servlet and delegates request execution to a queue. * + *

The actual invocation of the delegate does not happen within this servlet's lifecycle. + * Therefore, the task on the queue must manually invoke filters within the queue task. + * * @see TestServer */ public final class ServletWrapperDelegatorServlet extends HttpServlet { private final Queue> requestQueue; - private final Supplier servlet; + private final Class servletClass; + private final ImmutableList> filterClasses; ServletWrapperDelegatorServlet( Class servletClass, + ImmutableList> filterClasses, Queue> requestQueue) { - this.servlet = lazilyInstantiate(checkNotNull(servletClass, "servletClass")); + this.servletClass = servletClass; + this.filterClasses = filterClasses; this.requestQueue = checkNotNull(requestQueue, "requestQueue"); } @@ -55,7 +66,20 @@ public final class ServletWrapperDelegatorServlet extends HttpServlet { @Nullable @Override public Void call() throws ServletException, IOException { - servlet.get().service(req, rsp); + // Simulate the full filter chain with the servlet at the end. + final Iterator> filtersIter = filterClasses.iterator(); + FilterChain filterChain = + new FilterChain() { + @Override + public void doFilter(ServletRequest request, ServletResponse response) + throws IOException, ServletException { + if (filtersIter.hasNext()) { + instantiate(filtersIter.next()).doFilter(request, response, this); + } else { + instantiate(servletClass).service(request, response); + } + }}; + filterChain.doFilter(req, rsp); return null; }}); requestQueue.add(task); @@ -67,16 +91,4 @@ public final class ServletWrapperDelegatorServlet extends HttpServlet { throw Throwables.propagate(e.getCause()); } } - - private static Supplier lazilyInstantiate(final Class clazz) { - return memoize(new Supplier() { - @Override - public T get() { - try { - return clazz.newInstance(); - } catch (InstantiationException | IllegalAccessException e) { - throw new RuntimeException(e); - } - }}); - } } diff --git a/javatests/google/registry/server/TestServer.java b/javatests/google/registry/server/TestServer.java index f5f8c6784..be62eaa58 100644 --- a/javatests/google/registry/server/TestServer.java +++ b/javatests/google/registry/server/TestServer.java @@ -15,11 +15,12 @@ package google.registry.server; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.util.concurrent.Runnables.doNothing; import static google.registry.util.NetworkUtils.getCanonicalHostName; import com.google.common.base.Throwables; +import com.google.common.collect.ImmutableList; import com.google.common.net.HostAndPort; -import com.google.common.util.concurrent.Callables; import com.google.common.util.concurrent.SimpleTimeLimiter; import java.net.MalformedURLException; import java.net.URL; @@ -34,7 +35,6 @@ import javax.annotation.Nullable; import javax.servlet.Filter; import javax.servlet.http.HttpServlet; import org.mortbay.jetty.Connector; -import org.mortbay.jetty.Handler; import org.mortbay.jetty.Server; import org.mortbay.jetty.bio.SocketConnector; import org.mortbay.jetty.servlet.Context; @@ -85,8 +85,8 @@ public final class TestServer { public TestServer( HostAndPort address, Map runfiles, - Iterable routes, - Iterable> filters) { + ImmutableList routes, + ImmutableList> filters) { urlAddress = createUrlAddress(address); server.addConnector(createConnector(address)); server.addHandler(createHandler(runfiles, routes, filters)); @@ -120,7 +120,7 @@ public final class TestServer { * main event loop, for post-request processing. */ public void ping() { - requestQueue.add(new FutureTask<>(Callables.returning(null))); + requestQueue.add(new FutureTask(doNothing(), null)); } /** Stops the HTTP server. */ @@ -151,8 +151,8 @@ public final class TestServer { private Context createHandler( Map runfiles, - Iterable routes, - Iterable> filters) { + ImmutableList routes, + ImmutableList> filters) { Context context = new Context(server, CONTEXT_PATH, Context.SESSIONS); context.addServlet(new ServletHolder(HealthzServlet.class), "/healthz"); for (Map.Entry runfile : runfiles.entrySet()) { @@ -161,10 +161,8 @@ public final class TestServer { runfile.getKey()); } for (Route route : routes) { - context.addServlet(new ServletHolder(wrapServlet(route.servletClass())), route.path()); - } - for (Class filter : filters) { - context.addFilter(filter, "/*", Handler.REQUEST); + context.addServlet( + new ServletHolder(wrapServlet(route.servletClass(), filters)), route.path()); } ServletHolder holder = new ServletHolder(DefaultServlet.class); holder.setInitParameter("aliases", "1"); @@ -172,8 +170,9 @@ public final class TestServer { return context; } - private HttpServlet wrapServlet(Class servletClass) { - return new ServletWrapperDelegatorServlet(servletClass, requestQueue); + private HttpServlet wrapServlet( + Class servletClass, ImmutableList> filters) { + return new ServletWrapperDelegatorServlet(servletClass, filters, requestQueue); } private static Connector createConnector(HostAndPort address) {