From 0859cde79040e2a5840794cc700f26aa68b4d695 Mon Sep 17 00:00:00 2001 From: cgoldfeder Date: Tue, 10 Jan 2017 13:49:35 -0800 Subject: [PATCH] Fix the TestServer filter support added in [] In the previous CL I added filter support for the test server, but even though I could verify that filters were being run when debugging, in practice the side effects of the filters (notably, ObjectifyFilter clearing the session cache) were somehow not present in tests (and therefore causing new as-yet unsubmitted tests that rely on proper session caching to break). Investigating further, the way TestServer works is that it creates a wrapper Servlet for each route, and in that wrapper just pushes a future onto a queue and waits on it. The actual target servlet is run within the queue, not within the wrapper servlet's context. I had added filters to the *wrapper* servlet, which meant that even though they were invoked before adding the task to the queue, they were not invoked in the process of actually running the task. In this CL I pushed the filters into the task itself, just like the target servlet. ------------- Created by MOE: https://github.com/google/moe MOE_MIGRATED_REVID=144123637 --- .../ServletWrapperDelegatorServlet.java | 46 ++++++++++++------- .../google/registry/server/TestServer.java | 25 +++++----- 2 files changed, 41 insertions(+), 30 deletions(-) diff --git a/javatests/google/registry/server/ServletWrapperDelegatorServlet.java b/javatests/google/registry/server/ServletWrapperDelegatorServlet.java index 77c5c9aef..77ca0bad6 100644 --- a/javatests/google/registry/server/ServletWrapperDelegatorServlet.java +++ b/javatests/google/registry/server/ServletWrapperDelegatorServlet.java @@ -15,18 +15,23 @@ package google.registry.server; import static com.google.common.base.Preconditions.checkNotNull; -import static com.google.common.base.Suppliers.memoize; +import static google.registry.util.TypeUtils.instantiate; -import com.google.common.base.Supplier; import com.google.common.base.Throwables; +import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.Uninterruptibles; import java.io.IOException; +import java.util.Iterator; import java.util.Queue; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.FutureTask; import javax.annotation.Nullable; +import javax.servlet.Filter; +import javax.servlet.FilterChain; import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -34,17 +39,23 @@ import javax.servlet.http.HttpServletResponse; /** * Servlet that wraps a servlet and delegates request execution to a queue. * + *

The actual invocation of the delegate does not happen within this servlet's lifecycle. + * Therefore, the task on the queue must manually invoke filters within the queue task. + * * @see TestServer */ public final class ServletWrapperDelegatorServlet extends HttpServlet { private final Queue> requestQueue; - private final Supplier servlet; + private final Class servletClass; + private final ImmutableList> filterClasses; ServletWrapperDelegatorServlet( Class servletClass, + ImmutableList> filterClasses, Queue> requestQueue) { - this.servlet = lazilyInstantiate(checkNotNull(servletClass, "servletClass")); + this.servletClass = servletClass; + this.filterClasses = filterClasses; this.requestQueue = checkNotNull(requestQueue, "requestQueue"); } @@ -55,7 +66,20 @@ public final class ServletWrapperDelegatorServlet extends HttpServlet { @Nullable @Override public Void call() throws ServletException, IOException { - servlet.get().service(req, rsp); + // Simulate the full filter chain with the servlet at the end. + final Iterator> filtersIter = filterClasses.iterator(); + FilterChain filterChain = + new FilterChain() { + @Override + public void doFilter(ServletRequest request, ServletResponse response) + throws IOException, ServletException { + if (filtersIter.hasNext()) { + instantiate(filtersIter.next()).doFilter(request, response, this); + } else { + instantiate(servletClass).service(request, response); + } + }}; + filterChain.doFilter(req, rsp); return null; }}); requestQueue.add(task); @@ -67,16 +91,4 @@ public final class ServletWrapperDelegatorServlet extends HttpServlet { throw Throwables.propagate(e.getCause()); } } - - private static Supplier lazilyInstantiate(final Class clazz) { - return memoize(new Supplier() { - @Override - public T get() { - try { - return clazz.newInstance(); - } catch (InstantiationException | IllegalAccessException e) { - throw new RuntimeException(e); - } - }}); - } } diff --git a/javatests/google/registry/server/TestServer.java b/javatests/google/registry/server/TestServer.java index f5f8c6784..be62eaa58 100644 --- a/javatests/google/registry/server/TestServer.java +++ b/javatests/google/registry/server/TestServer.java @@ -15,11 +15,12 @@ package google.registry.server; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.util.concurrent.Runnables.doNothing; import static google.registry.util.NetworkUtils.getCanonicalHostName; import com.google.common.base.Throwables; +import com.google.common.collect.ImmutableList; import com.google.common.net.HostAndPort; -import com.google.common.util.concurrent.Callables; import com.google.common.util.concurrent.SimpleTimeLimiter; import java.net.MalformedURLException; import java.net.URL; @@ -34,7 +35,6 @@ import javax.annotation.Nullable; import javax.servlet.Filter; import javax.servlet.http.HttpServlet; import org.mortbay.jetty.Connector; -import org.mortbay.jetty.Handler; import org.mortbay.jetty.Server; import org.mortbay.jetty.bio.SocketConnector; import org.mortbay.jetty.servlet.Context; @@ -85,8 +85,8 @@ public final class TestServer { public TestServer( HostAndPort address, Map runfiles, - Iterable routes, - Iterable> filters) { + ImmutableList routes, + ImmutableList> filters) { urlAddress = createUrlAddress(address); server.addConnector(createConnector(address)); server.addHandler(createHandler(runfiles, routes, filters)); @@ -120,7 +120,7 @@ public final class TestServer { * main event loop, for post-request processing. */ public void ping() { - requestQueue.add(new FutureTask<>(Callables.returning(null))); + requestQueue.add(new FutureTask(doNothing(), null)); } /** Stops the HTTP server. */ @@ -151,8 +151,8 @@ public final class TestServer { private Context createHandler( Map runfiles, - Iterable routes, - Iterable> filters) { + ImmutableList routes, + ImmutableList> filters) { Context context = new Context(server, CONTEXT_PATH, Context.SESSIONS); context.addServlet(new ServletHolder(HealthzServlet.class), "/healthz"); for (Map.Entry runfile : runfiles.entrySet()) { @@ -161,10 +161,8 @@ public final class TestServer { runfile.getKey()); } for (Route route : routes) { - context.addServlet(new ServletHolder(wrapServlet(route.servletClass())), route.path()); - } - for (Class filter : filters) { - context.addFilter(filter, "/*", Handler.REQUEST); + context.addServlet( + new ServletHolder(wrapServlet(route.servletClass(), filters)), route.path()); } ServletHolder holder = new ServletHolder(DefaultServlet.class); holder.setInitParameter("aliases", "1"); @@ -172,8 +170,9 @@ public final class TestServer { return context; } - private HttpServlet wrapServlet(Class servletClass) { - return new ServletWrapperDelegatorServlet(servletClass, requestQueue); + private HttpServlet wrapServlet( + Class servletClass, ImmutableList> filters) { + return new ServletWrapperDelegatorServlet(servletClass, filters, requestQueue); } private static Connector createConnector(HostAndPort address) {