Jeroen Frijters wrote:
This is not true. Once an object becomes "finalizer reachable" the finalizer 
can run, but that doesn't mean all references are gone. Any object with a finalizer can 
still have references to objects that have already been finalized. In its finalizer it 
will be able to resurrect these objects. This is a well known attack vector and the 
reason that PhantomReferences were introduced to do post-mortem cleanup.

In normal use it would be so hard to do this, but you're right that in theory someone could try to use a thread local after its been finalized in their own finalizer using the finalizer I wrote. I've changed the patch as you suggested to make the localIndex an invalid value after finalization to prevent this attack. It's unfortunate as the change stops localIndex being a prime candidate for final field chasing or caching in a register. This seems better than using a phantom reference and polling a queue though.


Thanks,
Ian
diff -u java/lang/InheritableThreadLocal.java 
java/lang/InheritableThreadLocal.java
--- java/lang/InheritableThreadLocal.java       2007-08-21 12:01:24.000000000 
+0100
+++ java/lang/InheritableThreadLocal.java       2007-08-21 11:55:49.000000000 
+0100
@@ -37,9 +37,9 @@
 
 package java.lang;
 
-import gnu.java.util.WeakIdentityHashMap;
-
-import java.util.Iterator;
+import java.lang.ref.WeakReference;
+import java.util.HashMap;
+import java.util.Arrays;
 
 /**
  * A ThreadLocal whose value is inherited by child Threads. The value of the
@@ -56,12 +56,22 @@
  * @author Eric Blake ([EMAIL PROTECTED])
  * @author Tom Tromey ([EMAIL PROTECTED])
  * @author Andrew John Hughes ([EMAIL PROTECTED])
+ * @author Ian Rogers ([EMAIL PROTECTED])
  * @see ThreadLocal
  * @since 1.2
  * @status updated to 1.4
  */
 public class InheritableThreadLocal<T> extends ThreadLocal<T>
 {
+  
+
+  /**
+   * Map of owners of local index to inheritable thread local. Weak references
+   * allow thread locals to be collected when no longer in use.
+   */
+  private static final 
HashMap<Integer,WeakReference<InheritableThreadLocal<?>>>
+  owners = new HashMap<Integer,WeakReference<InheritableThreadLocal<?>>>();
+  
 
   /**
    * Creates a new InheritableThreadLocal that has no values associated
@@ -69,6 +79,8 @@
    */
   public InheritableThreadLocal()
   {
+    super();
+    owners.put(localIndex, new WeakReference<InheritableThreadLocal<?>>(this));
   }
 
   /**
@@ -97,24 +109,25 @@
   {
     // The currentThread is the parent of the new thread.
     Thread parentThread = Thread.currentThread();
-    if (parentThread.locals != null)
-      {
-        Iterator keys = parentThread.locals.keySet().iterator();
-        while (keys.hasNext())
-          {
-            Object key = keys.next();
-            if (key instanceof InheritableThreadLocal)
-              {
-               InheritableThreadLocal local = (InheritableThreadLocal)key;
-                Object parentValue = parentThread.locals.get(key);
-               Object childValue = local.childValue(parentValue == sentinel
-                                               ? null : parentValue);
-                if (childThread.locals == null)
-                    childThread.locals = new WeakIdentityHashMap();
-                childThread.locals.put(key, (childValue == null
-                                             ? sentinel : childValue));
-              }
+    if (parentThread.locals != null) {
+      childThread.locals = new Object[parentThread.locals.length];
+      Arrays.fill(childThread.locals, sentinel);
+      for (int i=0; i < parentThread.locals.length; i++) {
+        WeakReference ref = owners.get(i);
+        if (ref != null) {
+          InheritableThreadLocal local = (InheritableThreadLocal)ref.get();
+          if (local != null) {
+            Object parentValue = parentThread.locals[i];
+            Object childValue;
+            if (parentValue == sentinel) {
+              childValue = sentinel;
+            } else {
+              childValue = local.childValue(parentValue);
+            }
+            childThread.locals[i] = childValue; 
           }
+        }
       }
+    }
   }
 }
diff -u java/lang/Thread.java java/lang/Thread.java
--- java/lang/Thread.java       2007-08-21 12:01:24.000000000 +0100
+++ java/lang/Thread.java       2007-08-21 11:55:14.000000000 +0100
@@ -39,14 +39,11 @@
 package java.lang;
 
 import gnu.classpath.VMStackWalker;
-import gnu.java.util.WeakIdentityHashMap;
 
 import java.lang.management.ManagementFactory;
 import java.lang.management.ThreadInfo;
 import java.lang.management.ThreadMXBean;
-
 import java.security.Permission;
-
 import java.util.HashMap;
 import java.util.Map;
 
@@ -90,6 +87,7 @@
  * @author John Keiser
  * @author Eric Blake ([EMAIL PROTECTED])
  * @author Andrew John Hughes ([EMAIL PROTECTED])
+ * @author Ian Rogers ([EMAIL PROTECTED])
  * @see Runnable
  * @see Runtime#exit(int)
  * @see #run()
@@ -156,10 +154,11 @@
   /** The default exception handler.  */
   private static UncaughtExceptionHandler defaultHandler;
 
-  /** Thread local storage. Package accessible for use by
-    * InheritableThreadLocal.
-    */
-  WeakIdentityHashMap locals;
+  /** Default value for a thread with no locals. */
+  private static final Object[] noLocals = new Object[0];
+
+  /** Where thread local storage is hung by ThreadLocal. */
+  Object[] locals = noLocals;
 
   /** The uncaught exception handler.  */
   UncaughtExceptionHandler exceptionHandler;
@@ -1066,20 +1065,6 @@
     locals = null;
   }
 
-  /**
-   * Returns the map used by ThreadLocal to store the thread local values.
-   */
-  static Map getThreadLocals()
-  {
-    Thread thread = currentThread();
-    Map locals = thread.locals;
-    if (locals == null)
-      {
-        locals = thread.locals = new WeakIdentityHashMap();
-      }
-    return locals;
-  }
-
   /** 
    * Assigns the given <code>UncaughtExceptionHandler</code> to this
    * thread.  This will then be called if the thread terminates due
diff -u java/lang/ThreadLocal.java java/lang/ThreadLocal.java
--- java/lang/ThreadLocal.java  2006-12-10 20:25:44.000000000 +0000
+++ java/lang/ThreadLocal.java  2007-08-21 14:42:16.000000000 +0100
@@ -37,8 +37,8 @@
 
 package java.lang;
 
-import java.util.Map;
-
+import java.util.Arrays;
+import java.util.BitSet;
 
 /**
  * ThreadLocal objects have a different state associated with every
@@ -83,6 +83,7 @@
  *
  * @author Mark Wielaard ([EMAIL PROTECTED])
  * @author Eric Blake ([EMAIL PROTECTED])
+ * @author Ian Rogers ([EMAIL PROTECTED])
  * @since 1.2
  * @status updated to 1.5
  */
@@ -94,12 +95,79 @@
    * InheritableThreadLocal
    */
   static final Object sentinel = new Object();
-
+  
+  /**
+   * List of slots owned by thread locals, cleared when the thread local is
+   * finalized
+   */
+  private static final BitSet ownedSlots = new BitSet();
+  
+  /**
+   * Index of thread local within Thread's pool. Package visible for use by
+   * InheritableThreadLocal
+   */
+  volatile int localIndex;
+  
+  /**
+   * Sole reference to object whose sole purpose is to make the ownedSlot
+   * available following collection of a thread local. We don't override
+   * finalize directly as this would mean sub-classes calling this finalizer.
+   */
+  private final Object finalizer = new Object() {
+    /**
+     * Drop references to locals visible from this slot and free slot for
+     * reuse
+     */
+    protected void finalize() {
+      // Create list of threads
+      ThreadGroup group = Thread.currentThread().group;
+      while (group.getParent() != null)
+        group = group.getParent();
+      int arraySize = group.activeCount();
+      Thread[] threadList = new Thread[arraySize];
+      int filled = group.enumerate(threadList);
+      while (filled == arraySize)
+      {
+        arraySize *= 2;
+        threadList = new Thread[arraySize];
+        filled = group.enumerate(threadList);
+      }
+      // Iterate over all threads re-initializing slot
+      for (Thread thread : threadList) {
+        if (thread.locals.length > localIndex) {
+          thread.locals[localIndex] = sentinel;
+        }
+      }
+      // Mark slot as available for reuse
+      synchronized(ownedSlots) {
+        ownedSlots.clear(localIndex);
+      }
+      // Make local index nonsensical to avoid possible finalizer attack
+      localIndex = -1;
+    }
+  };
+  
+  /**
+   * Allocate space within the owned slots for this thread local. Once
+   * allocated a thread local slot will only ever be reused, not reclaimed.
+   * @param the thread local we're allocating space for
+   * @return the slot to hold the thread local
+   */
+  private static int allocateThreadLocal(ThreadLocal<?> tl) {
+    synchronized(ownedSlots) {
+      int freeSlot = ownedSlots.nextClearBit(0);
+      ownedSlots.set(freeSlot);
+      return freeSlot;
+    }
+  }
+  
   /**
    * Creates a ThreadLocal object without associating any value to it yet.
    */
   public ThreadLocal()
   {
+    localIndex = allocateThreadLocal(this);
   }
 
   /**
@@ -125,16 +193,24 @@
    */
   public T get()
   {
-    Map<ThreadLocal<T>,T> map = (Map<ThreadLocal<T>,T>) 
Thread.getThreadLocals();
-    // Note that we don't have to synchronize, as only this thread will
-    // ever modify the map.
-    T value = map.get(this);
-    if (value == null)
-      {
-        value = initialValue();
-        map.put(this, (T) (value == null ? sentinel : value));
-      }
-    return value == (T) sentinel ? null : value;
+    Thread thread = Thread.currentThread();
+    T value;
+    try {
+      // Get value
+      value = (T)thread.locals[localIndex];
+    } catch (ArrayIndexOutOfBoundsException e) {
+      // Uncommon case: first access to local in this thread. Grow array.
+      final int oldLength = thread.locals.length;
+      thread.locals = Arrays.copyOf(thread.locals, ownedSlots.length());
+      Arrays.fill(thread.locals, oldLength, thread.locals.length, sentinel);
+      value = (T)sentinel;
+    }
+    // Initialize on first use
+    if (value == sentinel) {
+      value = initialValue();
+      thread.locals[localIndex] = value;
+    }
+    return value;
   }
 
   /**
@@ -147,10 +223,19 @@
    */
   public void set(T value)
   {
-    Map map = Thread.getThreadLocals();
     // Note that we don't have to synchronize, as only this thread will
     // ever modify the map.
-    map.put(this, value == null ? sentinel : value);
+    Thread thread = Thread.currentThread();
+    try {
+      // Set the value
+      thread.locals[localIndex] = value;
+    } catch (ArrayIndexOutOfBoundsException e) {
+      // Uncommon case: first access to local in this thread. Grow array.
+      final int oldLength = thread.locals.length;
+      thread.locals = Arrays.copyOf(thread.locals, ownedSlots.length());
+      Arrays.fill(thread.locals, oldLength, thread.locals.length, sentinel);
+      thread.locals[localIndex] = value;
+    }
   }
 
   /**
@@ -160,7 +245,6 @@
    */
   public void remove()
   {
-    Map map = Thread.getThreadLocals();
-    map.remove(this);
+    set((T)sentinel);
   }
 }

Reply via email to