Hi,

currently (latest code of JDK 11) an instance of TreeMap created with no-arg 
contructor has nullable comparator i.e. utilizes no comparator.

As it comes from the code, a TreeMap created with nullable comparator works 
exactly as a TreeMap created with comparator provided by 
Comparator.naturalOrder(). This is also explicitly specifid in Javadoc.

I propose to initialize default comparator of TreeMap with instance of 
Comparator returned by Comparator.naturalOrder() instead of null.
This allows to remove the code responsible for handling nullable comparator, e. 
g. TreeMap::getEntryUsingComparator can be completely removed in favour of 
TreeMap::getEntry.
Similar simplification available for TreeMap::put, TreeMap::compare, 
EntrySpliterator::getComparator.

I've prepared a patch for this.
The patch contains both described major change and some tiny clean-ups e. g. 
utilization of Objects::requireNonNull where appropriate and Objects::equals 
instead of hand-written TreeMap::valEquals.

TreeMapTest is green after my changes.

Regards,
Sergey Tsypanov


diff --git a/src/java.base/share/classes/java/util/TreeMap.java b/src/java.base/share/classes/java/util/TreeMap.java
--- a/src/java.base/share/classes/java/util/TreeMap.java
+++ b/src/java.base/share/classes/java/util/TreeMap.java
@@ -133,6 +133,11 @@
     private transient int modCount = 0;
 
     /**
+     * Comparator used by default.
+     */
+    private static final Comparator DEFAULT_COMPARATOR = Comparator.naturalOrder();
+
+    /**
      * Constructs a new, empty tree map, using the natural ordering of its
      * keys.  All keys inserted into the map must implement the {@link
      * Comparable} interface.  Furthermore, all such keys must be
@@ -145,7 +150,7 @@
      * {@code ClassCastException}.
      */
     public TreeMap() {
-        comparator = null;
+        comparator = DEFAULT_COMPARATOR;
     }
 
     /**
@@ -163,7 +168,7 @@
      *        ordering} of the keys will be used.
      */
     public TreeMap(Comparator<? super K> comparator) {
-        this.comparator = comparator;
+        this.comparator = comparator == null ? DEFAULT_COMPARATOR : comparator;
     }
 
     /**
@@ -181,7 +186,7 @@
      * @throws NullPointerException if the specified map is null
      */
     public TreeMap(Map<? extends K, ? extends V> m) {
-        comparator = null;
+        this();
         putAll(m);
     }
 
@@ -195,7 +200,8 @@
      * @throws NullPointerException if the specified map is null
      */
     public TreeMap(SortedMap<K, ? extends V> m) {
-        comparator = m.comparator();
+        Comparator<? super K> comparator = m.comparator();
+        this.comparator = comparator == null ? DEFAULT_COMPARATOR : comparator;
         try {
             buildFromSorted(m.size(), m.entrySet().iterator(), null, null);
         } catch (java.io.IOException | ClassNotFoundException cannotHappen) {
@@ -246,7 +252,7 @@
      */
     public boolean containsValue(Object value) {
         for (Entry<K,V> e = getFirstEntry(); e != null; e = successor(e))
-            if (valEquals(value, e.value))
+            if (Objects.equals(value, e.value))
                 return true;
         return false;
     }
@@ -312,7 +318,7 @@
         int mapSize = map.size();
         if (size==0 && mapSize!=0 && map instanceof SortedMap) {
             Comparator<?> c = ((SortedMap<?,?>)map).comparator();
-            if (c == comparator || (c != null && c.equals(comparator))) {
+            if (Objects.equals(c, comparator)) {
                 ++modCount;
                 try {
                     buildFromSorted(mapSize, map.entrySet().iterator(),
@@ -338,16 +344,14 @@
      *         does not permit null keys
      */
     final Entry<K,V> getEntry(Object key) {
-        // Offload comparator-based version for sake of performance
-        if (comparator != null)
-            return getEntryUsingComparator(key);
-        if (key == null)
+        if (key == null && comparator == DEFAULT_COMPARATOR)
             throw new NullPointerException();
         @SuppressWarnings("unchecked")
-            Comparable<? super K> k = (Comparable<? super K>) key;
+        K k = (K) key;
+        Comparator<? super K> cpr = comparator;
         Entry<K,V> p = root;
         while (p != null) {
-            int cmp = k.compareTo(p.key);
+            int cmp = cpr.compare(k, p.key);
             if (cmp < 0)
                 p = p.left;
             else if (cmp > 0)
@@ -359,31 +363,6 @@
     }
 
     /**
-     * Version of getEntry using comparator. Split off from getEntry
-     * for performance. (This is not worth doing for most methods,
-     * that are less dependent on comparator performance, but is
-     * worthwhile here.)
-     */
-    final Entry<K,V> getEntryUsingComparator(Object key) {
-        @SuppressWarnings("unchecked")
-            K k = (K) key;
-        Comparator<? super K> cpr = comparator;
-        if (cpr != null) {
-            Entry<K,V> p = root;
-            while (p != null) {
-                int cmp = cpr.compare(k, p.key);
-                if (cmp < 0)
-                    p = p.left;
-                else if (cmp > 0)
-                    p = p.right;
-                else
-                    return p;
-            }
-        }
-        return null;
-    }
-
-    /**
      * Gets the entry corresponding to the specified key; if no such entry
      * exists, returns the entry for the least key greater than the specified
      * key; if no such entry exists (i.e., the greatest key in the Tree is less
@@ -540,38 +519,22 @@
             modCount++;
             return null;
         }
+        Comparator<? super K> cpr = comparator;
+        if (key == null && cpr == DEFAULT_COMPARATOR)
+            throw new NullPointerException();
+
         int cmp;
         Entry<K,V> parent;
-        // split comparator and comparable paths
-        Comparator<? super K> cpr = comparator;
-        if (cpr != null) {
-            do {
-                parent = t;
-                cmp = cpr.compare(key, t.key);
-                if (cmp < 0)
-                    t = t.left;
-                else if (cmp > 0)
-                    t = t.right;
-                else
-                    return t.setValue(value);
-            } while (t != null);
-        }
-        else {
-            if (key == null)
-                throw new NullPointerException();
-            @SuppressWarnings("unchecked")
-                Comparable<? super K> k = (Comparable<? super K>) key;
-            do {
-                parent = t;
-                cmp = k.compareTo(t.key);
-                if (cmp < 0)
-                    t = t.left;
-                else if (cmp > 0)
-                    t = t.right;
-                else
-                    return t.setValue(value);
-            } while (t != null);
-        }
+        do {
+            parent = t;
+            cmp = cpr.compare(key, t.key);
+            if (cmp < 0)
+                t = t.left;
+            else if (cmp > 0)
+                t = t.right;
+            else
+                return t.setValue(value);
+        } while (t != null);
         Entry<K,V> e = new Entry<>(key, value, parent);
         if (cmp < 0)
             parent.left = e;
@@ -1038,7 +1001,7 @@
 
         public boolean remove(Object o) {
             for (Entry<K,V> e = getFirstEntry(); e != null; e = successor(e)) {
-                if (valEquals(e.getValue(), o)) {
+                if (Objects.equals(e.getValue(), o)) {
                     deleteEntry(e);
                     return true;
                 }
@@ -1066,7 +1029,7 @@
             Map.Entry<?,?> entry = (Map.Entry<?,?>) o;
             Object value = entry.getValue();
             Entry<K,V> p = getEntry(entry.getKey());
-            return p != null && valEquals(p.getValue(), value);
+            return p != null && Objects.equals(p.getValue(), value);
         }
 
         public boolean remove(Object o) {
@@ -1075,7 +1038,7 @@
             Map.Entry<?,?> entry = (Map.Entry<?,?>) o;
             Object value = entry.getValue();
             Entry<K,V> p = getEntry(entry.getKey());
-            if (p != null && valEquals(p.getValue(), value)) {
+            if (p != null && Objects.equals(p.getValue(), value)) {
                 deleteEntry(p);
                 return true;
             }
@@ -1288,16 +1251,7 @@
      */
     @SuppressWarnings("unchecked")
     final int compare(Object k1, Object k2) {
-        return comparator==null ? ((Comparable<? super K>)k1).compareTo((K)k2)
-            : comparator.compare((K)k1, (K)k2);
-    }
-
-    /**
-     * Test two values for equality.  Differs from o1.equals(o2) only in
-     * that it copes with {@code null} o1 properly.
-     */
-    static final boolean valEquals(Object o1, Object o2) {
-        return (o1==null ? o2==null : o1.equals(o2));
+        return comparator.compare((K)k1, (K)k2);
     }
 
     /**
@@ -1651,7 +1605,7 @@
                     return false;
                 TreeMap.Entry<?,?> node = m.getEntry(key);
                 return node != null &&
-                    valEquals(node.getValue(), entry.getValue());
+                    Objects.equals(node.getValue(), entry.getValue());
             }
 
             public boolean remove(Object o) {
@@ -1662,7 +1616,7 @@
                 if (!inRange(key))
                     return false;
                 TreeMap.Entry<K,V> node = m.getEntry(key);
-                if (node!=null && valEquals(node.getValue(),
+                if (node!=null && Objects.equals(node.getValue(),
                                             entry.getValue())) {
                     m.deleteEntry(node);
                     return true;
@@ -2101,7 +2055,7 @@
                 return false;
             Map.Entry<?,?> e = (Map.Entry<?,?>)o;
 
-            return valEquals(key,e.getKey()) && valEquals(value,e.getValue());
+            return Objects.equals(key, e.getKey()) && Objects.equals(value, e.getValue());
         }
 
         public int hashCode() {
@@ -2728,8 +2682,7 @@
         }
 
         public void forEachRemaining(Consumer<? super K> action) {
-            if (action == null)
-                throw new NullPointerException();
+            Objects.requireNonNull(action);
             if (est < 0)
                 getEstimate(); // force initialization
             TreeMap.Entry<K,V> f = fence, e, p, pl;
@@ -2752,12 +2705,11 @@
         }
 
         public boolean tryAdvance(Consumer<? super K> action) {
-            TreeMap.Entry<K,V> e;
-            if (action == null)
-                throw new NullPointerException();
+            Objects.requireNonNull(action);
             if (est < 0)
                 getEstimate(); // force initialization
-            if ((e = current) == null || e == fence)
+            TreeMap.Entry<K,V> e = current;
+            if (e == null || e == fence)
                 return false;
             current = successor(e);
             action.accept(e.key);
@@ -2806,8 +2758,7 @@
         }
 
         public void forEachRemaining(Consumer<? super K> action) {
-            if (action == null)
-                throw new NullPointerException();
+            Objects.requireNonNull(action);
             if (est < 0)
                 getEstimate(); // force initialization
             TreeMap.Entry<K,V> f = fence, e, p, pr;
@@ -2830,12 +2781,11 @@
         }
 
         public boolean tryAdvance(Consumer<? super K> action) {
-            TreeMap.Entry<K,V> e;
-            if (action == null)
-                throw new NullPointerException();
+            Objects.requireNonNull(action);
             if (est < 0)
                 getEstimate(); // force initialization
-            if ((e = current) == null || e == fence)
+            TreeMap.Entry<K,V> e = current;
+            if (e == null || e == fence)
                 return false;
             current = predecessor(e);
             action.accept(e.key);
@@ -2879,8 +2829,7 @@
         }
 
         public void forEachRemaining(Consumer<? super V> action) {
-            if (action == null)
-                throw new NullPointerException();
+            Objects.requireNonNull(action);
             if (est < 0)
                 getEstimate(); // force initialization
             TreeMap.Entry<K,V> f = fence, e, p, pl;
@@ -2903,12 +2852,12 @@
         }
 
         public boolean tryAdvance(Consumer<? super V> action) {
-            TreeMap.Entry<K,V> e;
-            if (action == null)
-                throw new NullPointerException();
-            if (est < 0)
+            Objects.requireNonNull(action);
+            if (est < 0) {
                 getEstimate(); // force initialization
-            if ((e = current) == null || e == fence)
+            }
+            TreeMap.Entry<K,V> e = current;
+            if (e == null || e == fence)
                 return false;
             current = successor(e);
             action.accept(e.value);
@@ -2951,8 +2900,7 @@
         }
 
         public void forEachRemaining(Consumer<? super Map.Entry<K, V>> action) {
-            if (action == null)
-                throw new NullPointerException();
+            Objects.requireNonNull(action);
             if (est < 0)
                 getEstimate(); // force initialization
             TreeMap.Entry<K,V> f = fence, e, p, pl;
@@ -2975,12 +2923,11 @@
         }
 
         public boolean tryAdvance(Consumer<? super Map.Entry<K,V>> action) {
-            TreeMap.Entry<K,V> e;
-            if (action == null)
-                throw new NullPointerException();
+            Objects.requireNonNull(action);
             if (est < 0)
                 getEstimate(); // force initialization
-            if ((e = current) == null || e == fence)
+            TreeMap.Entry<K,V> e = current;
+            if (e == null || e == fence)
                 return false;
             current = successor(e);
             action.accept(e);
@@ -2996,17 +2943,7 @@
 
         @Override
         public Comparator<Map.Entry<K, V>> getComparator() {
-            // Adapt or create a key-based comparator
-            if (tree.comparator != null) {
-                return Map.Entry.comparingByKey(tree.comparator);
-            }
-            else {
-                return (Comparator<Map.Entry<K, V>> & Serializable) (e1, e2) -> {
-                    @SuppressWarnings("unchecked")
-                    Comparable<? super K> k1 = (Comparable<? super K>) e1.getKey();
-                    return k1.compareTo(e2.getKey());
-                };
-            }
+            return Map.Entry.comparingByKey(tree.comparator);
         }
     }
 }

Reply via email to