Revision: 1585
Author:   guice.mirror...@gmail.com
Date:     Tue Sep 27 08:36:19 2011
Log:
Replace the Request/Response Context after each ServletModule-registered
Filter. This fixes problems where wrapped request/response classes weren't passed to subsequent filters or servlets in the chain.

Revision created by MOE tool push_codebase.
MOE_MIGRATION=3340

http://code.google.com/p/google-guice/source/detail?r=1585

Modified:
/trunk/extensions/servlet/src/com/google/inject/servlet/FilterChainInvocation.java
 /trunk/extensions/servlet/src/com/google/inject/servlet/GuiceFilter.java
/trunk/extensions/servlet/src/com/google/inject/servlet/ManagedFilterPipeline.java /trunk/extensions/servlet/src/com/google/inject/servlet/ManagedServletPipeline.java /trunk/extensions/servlet/src/com/google/inject/servlet/ServletDefinition.java
 /trunk/extensions/servlet/src/com/google/inject/servlet/ServletScopes.java
 /trunk/extensions/servlet/test/com/google/inject/servlet/ServletTest.java

=======================================
--- /trunk/extensions/servlet/src/com/google/inject/servlet/FilterChainInvocation.java Thu Jul 7 17:34:16 2011 +++ /trunk/extensions/servlet/src/com/google/inject/servlet/FilterChainInvocation.java Tue Sep 27 08:36:19 2011
@@ -21,6 +21,8 @@
 import javax.servlet.ServletException;
 import javax.servlet.ServletRequest;
 import javax.servlet.ServletResponse;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;

 /**
* A Filter chain impl which basically passes itself to the "current" filter and iterates the chain
@@ -54,18 +56,27 @@
       throws IOException, ServletException {
     index++;

-    //dispatch down the chain while there are more filters
-    if (index < filterDefinitions.length) {
- filterDefinitions[index].doFilter(servletRequest, servletResponse, this);
-    } else {
-
- //we've reached the end of the filterchain, let's try to dispatch to a servlet - final boolean serviced = servletPipeline.service(servletRequest, servletResponse);
-
- //dispatch to the normal filter chain only if one of our servlets did not match
-      if (!serviced) {
-        proceedingChain.doFilter(servletRequest, servletResponse);
-      }
+    GuiceFilter.Context previous = GuiceFilter.localContext.get();
+    HttpServletRequest request = (HttpServletRequest) servletRequest;
+    HttpServletResponse response = (HttpServletResponse) servletResponse;
+    HttpServletRequest originalRequest
+        = (previous != null) ? previous.getOriginalRequest() : request;
+ GuiceFilter.localContext.set(new GuiceFilter.Context(originalRequest, request, response));
+    try {
+      //dispatch down the chain while there are more filters
+      if (index < filterDefinitions.length) {
+ filterDefinitions[index].doFilter(servletRequest, servletResponse, this);
+      } else {
+ //we've reached the end of the filterchain, let's try to dispatch to a servlet + final boolean serviced = servletPipeline.service(servletRequest, servletResponse);
+
+ //dispatch to the normal filter chain only if one of our servlets did not match
+        if (!serviced) {
+          proceedingChain.doFilter(servletRequest, servletResponse);
+        }
+      }
+    } finally {
+      GuiceFilter.localContext.set(previous);
     }
   }
 }
=======================================
--- /trunk/extensions/servlet/src/com/google/inject/servlet/GuiceFilter.java Thu Jul 7 17:34:16 2011 +++ /trunk/extensions/servlet/src/com/google/inject/servlet/GuiceFilter.java Tue Sep 27 08:36:19 2011
@@ -96,28 +96,33 @@
   //VisibleForTesting
   static void reset() {
     pipeline = new DefaultFilterPipeline();
+    localContext.remove();
   }

   public void doFilter(ServletRequest servletRequest,
       ServletResponse servletResponse, FilterChain filterChain)
       throws IOException, ServletException {

-    Context previous = localContext.get();
-
// Prefer the injected pipeline, but fall back on the static one for web.xml users. FilterPipeline filterPipeline = null != injectedPipeline ? injectedPipeline : pipeline;

+    Context previous = GuiceFilter.localContext.get();
+    HttpServletRequest request = (HttpServletRequest) servletRequest;
+    HttpServletResponse response = (HttpServletResponse) servletResponse;
+    HttpServletRequest originalRequest
+        = (previous != null) ? previous.getOriginalRequest() : request;
+    localContext.set(new Context(originalRequest, request, response));
     try {
-      localContext.set(new Context((HttpServletRequest) servletRequest,
-          (HttpServletResponse) servletResponse));
-
//dispatch across the servlet pipeline, ensuring web.xml's filterchain is honored filterPipeline.dispatch(servletRequest, servletResponse, filterChain);
-
     } finally {
       localContext.set(previous);
     }
   }
+
+  static HttpServletRequest getOriginalRequest() {
+    return getContext().getOriginalRequest();
+  }

   static HttpServletRequest getRequest() {
     return getContext().getRequest();
@@ -131,7 +136,7 @@
     return servletContext.get();
   }

-  static Context getContext() {
+  private static Context getContext() {
     Context context = localContext.get();
     if (context == null) {
throw new OutOfScopeException("Cannot access scoped object. Either we"
@@ -143,14 +148,20 @@
   }

   static class Context {
-
+    final HttpServletRequest originalRequest;
     final HttpServletRequest request;
     final HttpServletResponse response;

-    Context(HttpServletRequest request, HttpServletResponse response) {
+    Context(HttpServletRequest originalRequest, HttpServletRequest request,
+        HttpServletResponse response) {
+      this.originalRequest = originalRequest;
       this.request = request;
       this.response = response;
     }
+
+    HttpServletRequest getOriginalRequest() {
+      return originalRequest;
+    }

     HttpServletRequest getRequest() {
       return request;
=======================================
--- /trunk/extensions/servlet/src/com/google/inject/servlet/ManagedFilterPipeline.java Thu Jul 7 17:34:16 2011 +++ /trunk/extensions/servlet/src/com/google/inject/servlet/ManagedFilterPipeline.java Tue Sep 27 08:36:19 2011
@@ -137,8 +137,6 @@
   private ServletRequest withDispatcher(ServletRequest servletRequest,
       final ManagedServletPipeline servletPipeline) {

-    HttpServletRequest request = (HttpServletRequest) servletRequest;
-
// don't wrap the request if there are no servlets mapped. This prevents us from inserting our // wrapper unless it's actually going to be used. This is necessary for compatibility for apps // that downcast their HttpServletRequests to a concrete implementation.
@@ -146,6 +144,7 @@
       return servletRequest;
     }

+    HttpServletRequest request = (HttpServletRequest) servletRequest;
     //noinspection OverlyComplexAnonymousInnerClass
     return new HttpServletRequestWrapper(request) {

=======================================
--- /trunk/extensions/servlet/src/com/google/inject/servlet/ManagedServletPipeline.java Thu Jul 7 17:34:16 2011 +++ /trunk/extensions/servlet/src/com/google/inject/servlet/ManagedServletPipeline.java Tue Sep 27 08:36:19 2011
@@ -138,22 +138,21 @@
               // legacy (and internal) code.
               requestToProcess = servletRequest;
             }
-
- servletRequest.setAttribute(REQUEST_DISPATCHER_REQUEST, Boolean.TRUE);

             // now dispatch to the servlet
-            try {
- servletDefinition.doService(requestToProcess, servletResponse);
-            } finally {
-              servletRequest.removeAttribute(REQUEST_DISPATCHER_REQUEST);
-            }
+ doServiceImpl(servletDefinition, requestToProcess, servletResponse);
           }

public void include(ServletRequest servletRequest, ServletResponse servletResponse)
               throws ServletException, IOException {
+            // route to the target servlet
+ doServiceImpl(servletDefinition, servletRequest, servletResponse);
+          }
+
+ private void doServiceImpl(ServletDefinition servletDefinition, ServletRequest servletRequest, + ServletResponse servletResponse) throws ServletException, IOException { servletRequest.setAttribute(REQUEST_DISPATCHER_REQUEST, Boolean.TRUE);

-            // route to the target servlet
             try {
               servletDefinition.doService(servletRequest, servletResponse);
             } finally {
=======================================
--- /trunk/extensions/servlet/src/com/google/inject/servlet/ServletDefinition.java Thu Jul 7 17:34:16 2011 +++ /trunk/extensions/servlet/src/com/google/inject/servlet/ServletDefinition.java Tue Sep 27 08:36:19 2011
@@ -41,6 +41,7 @@
 import javax.servlet.http.HttpServlet;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletRequestWrapper;
+import javax.servlet.http.HttpServletResponse;

 /**
* An internal representation of a servlet definition mapped to a particular URI pattern. Also
@@ -195,19 +196,20 @@

     HttpServletRequest request = new HttpServletRequestWrapper(
         (HttpServletRequest) servletRequest) {
+      private boolean pathComputed;
       private String path;
-      private boolean pathComputed = false;
- //must use a boolean on the memo field, because null is a legal value (TODO no, it's not)
-
-      private boolean pathInfoComputed = false;
+
+      private boolean pathInfoComputed;
       private String pathInfo;

       @Override
       public String getPathInfo() {
         if (!isPathInfoComputed()) {
           int servletPathLength = getServletPath().length();
- pathInfo = getRequestURI().substring(getContextPath().length()).replaceAll("[/]{2,}", "/"); - pathInfo = pathInfo.length() > servletPathLength ? pathInfo.substring(servletPathLength) : null;
+          pathInfo = getRequestURI().substring(getContextPath().length())
+              .replaceAll("[/]{2,}", "/");
+          pathInfo = pathInfo.length() > servletPathLength
+              ? pathInfo.substring(servletPathLength) : null;

// Corner case: when servlet path and request path match exactly (without trailing '/'),
           // then pathinfo is null
@@ -221,8 +223,10 @@
         return pathInfo;
       }

- // NOTE(dhanji): These two are a bit of a hack to help ensure that request dipatcher-sent + // NOTE(dhanji): These two are a bit of a hack to help ensure that request dispatcher-sent // requests don't use the same path info that was memoized for the original request. + // NOTE(iqshum): I don't think this is possible, since the dispatcher-sent request would
+      // perform its own wrapping.
       private boolean isPathInfoComputed() {
         return pathInfoComputed
&& !(null != servletRequest.getAttribute(REQUEST_DISPATCHER_REQUEST));
@@ -261,7 +265,20 @@
       }
     };

-    httpServlet.get().service(request, servletResponse);
+    doServiceImpl(request, (HttpServletResponse) servletResponse);
+  }
+
+ private void doServiceImpl(HttpServletRequest request, HttpServletResponse response)
+      throws ServletException, IOException {
+    GuiceFilter.Context previous = GuiceFilter.localContext.get();
+    HttpServletRequest originalRequest
+        = (previous != null) ? previous.getOriginalRequest() : request;
+ GuiceFilter.localContext.set(new GuiceFilter.Context(originalRequest, request, response));
+    try {
+      httpServlet.get().service(request, response);
+    } finally {
+      GuiceFilter.localContext.set(previous);
+    }
   }

   String getKey() {
=======================================
--- /trunk/extensions/servlet/src/com/google/inject/servlet/ServletScopes.java Thu Jul 7 17:34:16 2011 +++ /trunk/extensions/servlet/src/com/google/inject/servlet/ServletScopes.java Tue Sep 27 08:36:19 2011
@@ -17,6 +17,7 @@
 package com.google.inject.servlet;

 import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Maps;
 import com.google.inject.Key;
 import com.google.inject.OutOfScopeException;
@@ -27,6 +28,7 @@
 import java.util.concurrent.Callable;

 import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
 import javax.servlet.http.HttpSession;

 /**
@@ -37,6 +39,12 @@
 public class ServletScopes {

   private ServletScopes() {}
+
+ /** Keys bound in request-scope which are handled directly by GuiceFilter. */ + private static final ImmutableSet<Key<?>> REQUEST_CONTEXT_KEYS = ImmutableSet.of(
+      Key.get(HttpServletRequest.class),
+      Key.get(HttpServletResponse.class),
+      new Key<Map<String, String[]>>(RequestParameters.class) {});

   /** A sentinel attribute value representing null. */
   enum NullObject { INSTANCE }
@@ -45,7 +53,7 @@
    * HTTP servlet request scope.
    */
   public static final Scope REQUEST = new Scope() {
-    public <T> Provider<T> scope(Key<T> key, final Provider<T> creator) {
+ public <T> Provider<T> scope(final Key<T> key, final Provider<T> creator) {
       final String name = key.toString();
       return new Provider<T>() {
         public T get() {
@@ -77,8 +85,17 @@
               // exception is thrown.
           }

-          HttpServletRequest request = GuiceFilter.getRequest();
-
+ // Always synchronize and get/set attributes on the underlying request + // object since Filters may wrap the request and change the value of
+          // {@code GuiceFilter.getRequest()}.
+          //
+          // This _correctly_ throws up if the thread is out of scope.
+          HttpServletRequest request = GuiceFilter.getOriginalRequest();
+          if (REQUEST_CONTEXT_KEYS.contains(key)) {
+ // Don't store these keys as attributes, since they are handled by
+            // GuiceFilter itself.
+            return creator.get();
+          }
           synchronized (request) {
             Object obj = request.getAttribute(name);
             if (NullObject.INSTANCE == obj) {
@@ -182,7 +199,7 @@
     }

     return new Callable<T>() {
-      private HttpServletRequest request = continuingRequest;
+      private final HttpServletRequest request = continuingRequest;

       public T call() throws Exception {
         GuiceFilter.Context context = GuiceFilter.localContext.get();
@@ -191,7 +208,7 @@

         // Only set up the request continuation if we're running in a
         // new vanilla thread.
- GuiceFilter.localContext.set(new GuiceFilter.Context(request, null)); + GuiceFilter.localContext.set(new GuiceFilter.Context(request, request, null));
         try {
           return callable.call();
         } finally {
=======================================
--- /trunk/extensions/servlet/test/com/google/inject/servlet/ServletTest.java Fri Sep 9 14:23:42 2011 +++ /trunk/extensions/servlet/test/com/google/inject/servlet/ServletTest.java Tue Sep 27 08:36:19 2011
@@ -24,13 +24,17 @@
 import static java.lang.annotation.RetentionPolicy.RUNTIME;

 import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
 import com.google.inject.AbstractModule;
 import com.google.inject.BindingAnnotation;
 import com.google.inject.CreationException;
 import com.google.inject.Guice;
+import com.google.inject.Inject;
 import com.google.inject.Injector;
 import com.google.inject.Key;
+import com.google.inject.Module;
+import com.google.inject.Provider;
 import com.google.inject.util.Providers;

 import junit.framework.TestCase;
@@ -44,13 +48,17 @@
 import java.lang.reflect.Proxy;
 import java.util.Map;

+import javax.servlet.Filter;
 import javax.servlet.FilterChain;
+import javax.servlet.FilterConfig;
 import javax.servlet.ServletException;
 import javax.servlet.ServletRequest;
 import javax.servlet.ServletResponse;
+import javax.servlet.http.HttpServlet;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletRequestWrapper;
 import javax.servlet.http.HttpServletResponse;
+import javax.servlet.http.HttpServletResponseWrapper;
 import javax.servlet.http.HttpSession;

 /**
@@ -77,7 +85,6 @@
     final Injector injector = createInjector();
     final HttpServletRequest request = newFakeHttpServletRequest();
     final HttpServletResponse response = newFakeHttpServletResponse();
-    final Map<String, String[]> params = Maps.newHashMap();

     final boolean[] invoked = new boolean[1];
     GuiceFilter filter = new GuiceFilter();
@@ -100,6 +107,148 @@

     assertTrue(invoked[0]);
   }
+
+ public void testRequestAndResponseBindings_wrappingFilter() throws Exception {
+    final HttpServletRequest request = newFakeHttpServletRequest();
+    final ImmutableMap<String, String[]> wrappedParamMap
+        = ImmutableMap.of("wrap", new String[]{"a", "b"});
+ final HttpServletRequestWrapper requestWrapper = new HttpServletRequestWrapper(request) {
+      @Override public Map getParameterMap() {
+        return wrappedParamMap;
+      }
+
+      @Override public Object getAttribute(String attr) {
+ // Ensure that attributes are stored on the original request object.
+        throw new UnsupportedOperationException();
+      }
+    };
+    final HttpServletResponse response = newFakeHttpServletResponse();
+ final HttpServletResponseWrapper responseWrapper = new HttpServletResponseWrapper(response);
+
+    final boolean[] filterInvoked = new boolean[1];
+    final Injector injector = createInjector(new ServletModule() {
+      @Override protected void configureServlets() {
+        filter("/*").through(new Filter() {
+          @Inject Provider<ServletRequest> servletReqProvider;
+          @Inject Provider<HttpServletRequest> reqProvider;
+          @Inject Provider<ServletResponse> servletRespProvider;
+          @Inject Provider<HttpServletResponse> respProvider;
+
+          public void init(FilterConfig filterConfig) {}
+
+ public void doFilter(ServletRequest req, ServletResponse resp, FilterChain chain)
+              throws IOException, ServletException {
+            filterInvoked[0] = true;
+            assertSame(req, servletReqProvider.get());
+            assertSame(req, reqProvider.get());
+
+            assertSame(resp, servletRespProvider.get());
+            assertSame(resp, respProvider.get());
+
+            chain.doFilter(requestWrapper, responseWrapper);
+
+            assertSame(req, reqProvider.get());
+            assertSame(resp, respProvider.get());
+          }
+
+          public void destroy() {}
+        });
+      }
+    });
+
+    GuiceFilter filter = new GuiceFilter();
+    final boolean[] chainInvoked = new boolean[1];
+    FilterChain filterChain = new FilterChain() {
+      public void doFilter(ServletRequest servletRequest,
+          ServletResponse servletResponse) {
+        chainInvoked[0] = true;
+        assertSame(requestWrapper, servletRequest);
+ assertSame(requestWrapper, injector.getInstance(ServletRequest.class));
+        assertSame(requestWrapper, injector.getInstance(HTTP_REQ_KEY));
+
+        assertSame(responseWrapper, servletResponse);
+ assertSame(responseWrapper, injector.getInstance(ServletResponse.class));
+        assertSame(responseWrapper, injector.getInstance(HTTP_RESP_KEY));
+
+ assertSame(servletRequest.getParameterMap(), injector.getInstance(REQ_PARAMS_KEY));
+
+        InRequest inRequest = injector.getInstance(InRequest.class);
+        assertSame(inRequest, injector.getInstance(InRequest.class));
+      }
+    };
+    filter.doFilter(request, response, filterChain);
+
+    assertTrue(chainInvoked[0]);
+    assertTrue(filterInvoked[0]);
+  }
+
+ public void testRequestAndResponseBindings_matchesPassedParameters() throws Exception {
+    final int[] filterInvoked = new int[1];
+    final boolean[] servletInvoked = new boolean[1];
+    final Injector injector = createInjector(new ServletModule() {
+      @Override protected void configureServlets() {
+        final HttpServletRequest[] previousReq = new HttpServletRequest[1];
+ final HttpServletResponse[] previousResp = new HttpServletResponse[1];
+
+ final Provider<ServletRequest> servletReqProvider = getProvider(ServletRequest.class); + final Provider<HttpServletRequest> reqProvider = getProvider(HttpServletRequest.class); + final Provider<ServletResponse> servletRespProvider = getProvider(ServletResponse.class); + final Provider<HttpServletResponse> respProvider = getProvider(HttpServletResponse.class);
+
+        Filter filter = new Filter() {
+          public void init(FilterConfig filterConfig) {}
+
+ public void doFilter(ServletRequest req, ServletResponse resp, FilterChain chain)
+              throws IOException, ServletException {
+            filterInvoked[0]++;
+            assertSame(req, servletReqProvider.get());
+            assertSame(req, reqProvider.get());
+            if (previousReq[0] != null) {
+              assertEquals(req, previousReq[0]);
+            }
+
+            assertSame(resp, servletRespProvider.get());
+            assertSame(resp, respProvider.get());
+            if (previousResp[0] != null) {
+              assertEquals(resp, previousResp[0]);
+            }
+
+            chain.doFilter(
+ previousReq[0] = new HttpServletRequestWrapper((HttpServletRequest) req), + previousResp[0] = new HttpServletResponseWrapper((HttpServletResponse) resp));
+
+            assertSame(req, reqProvider.get());
+            assertSame(resp, respProvider.get());
+          }
+
+          public void destroy() {}
+        };
+
+        filter("/*").through(filter);
+ filter("/*").through(filter); // filter twice to test wrapping in filters
+        serve("/*").with(new HttpServlet() {
+ @Override protected void doGet(HttpServletRequest req, HttpServletResponse resp) {
+            servletInvoked[0] = true;
+            assertSame(req, servletReqProvider.get());
+            assertSame(req, reqProvider.get());
+
+            assertSame(resp, servletRespProvider.get());
+            assertSame(resp, respProvider.get());
+          }
+        });
+      }
+    });
+
+    GuiceFilter filter = new GuiceFilter();
+ filter.doFilter(newFakeHttpServletRequest(), newFakeHttpServletResponse(), new FilterChain() { + public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse) {
+        throw new IllegalStateException("Shouldn't get here");
+      }
+    });
+
+    assertEquals(2, filterInvoked[0]);
+    assertTrue(servletInvoked[0]);
+  }

   public void testNewRequestObject()
       throws CreationException, IOException, ServletException {
@@ -240,6 +389,10 @@
       final Map<String, Object> attributes = Maps.newHashMap();
       final HttpSession session = newFakeHttpSession();

+      @Override public String getMethod() {
+        return "GET";
+      }
+
       @Override public Object getAttribute(String name) {
         return attributes.get(name);
       }
@@ -300,8 +453,8 @@
         new Class[] { HttpSession.class }, new FakeHttpSessionHandler());
   }

-  private Injector createInjector() throws CreationException {
-    return Guice.createInjector(new AbstractModule() {
+ private Injector createInjector(Module... modules) throws CreationException {
+    return Guice.createInjector(Lists.<Module>asList(new AbstractModule() {
       @Override
       protected void configure() {
         install(new ServletModule());
@@ -310,7 +463,7 @@
         bind(InRequest.class);
bind(IN_REQUEST_NULL_KEY).toProvider(Providers.<InRequest>of(null)).in(RequestScoped.class);
       }
-    });
+    }, modules));
   }

   @SessionScoped

--
You received this message because you are subscribed to the Google Groups 
"google-guice-dev" group.
To post to this group, send email to google-guice-dev@googlegroups.com.
To unsubscribe from this group, send email to 
google-guice-dev+unsubscr...@googlegroups.com.
For more options, visit this group at 
http://groups.google.com/group/google-guice-dev?hl=en.

Reply via email to