This is an automated email from the ASF dual-hosted git repository.

mmerli pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pulsar.git


The following commit(s) were added to refs/heads/master by this push:
     new c39f9f82b42 [fix][ml] Fix race conditions in RangeCache (#22789)
c39f9f82b42 is described below

commit c39f9f82b425c66c899f818583714c9c98d3e213
Author: Lari Hotari <lhot...@users.noreply.github.com>
AuthorDate: Fri May 31 03:25:52 2024 +0300

    [fix][ml] Fix race conditions in RangeCache (#22789)
---
 .../apache/bookkeeper/mledger/impl/EntryImpl.java  |   7 +-
 .../apache/bookkeeper/mledger/util/RangeCache.java | 278 ++++++++++++++++-----
 .../bookkeeper/mledger/util/RangeCacheTest.java    |  63 +++--
 3 files changed, 254 insertions(+), 94 deletions(-)

diff --git 
a/managed-ledger/src/main/java/org/apache/bookkeeper/mledger/impl/EntryImpl.java
 
b/managed-ledger/src/main/java/org/apache/bookkeeper/mledger/impl/EntryImpl.java
index 80397931357..48a79a4ac52 100644
--- 
a/managed-ledger/src/main/java/org/apache/bookkeeper/mledger/impl/EntryImpl.java
+++ 
b/managed-ledger/src/main/java/org/apache/bookkeeper/mledger/impl/EntryImpl.java
@@ -27,9 +27,10 @@ import io.netty.util.ReferenceCounted;
 import org.apache.bookkeeper.client.api.LedgerEntry;
 import org.apache.bookkeeper.mledger.Entry;
 import org.apache.bookkeeper.mledger.util.AbstractCASReferenceCounted;
+import org.apache.bookkeeper.mledger.util.RangeCache;
 
 public final class EntryImpl extends AbstractCASReferenceCounted implements 
Entry, Comparable<EntryImpl>,
-        ReferenceCounted {
+        RangeCache.ValueWithKeyValidation<PositionImpl> {
 
     private static final Recycler<EntryImpl> RECYCLER = new 
Recycler<EntryImpl>() {
         @Override
@@ -205,4 +206,8 @@ public final class EntryImpl extends 
AbstractCASReferenceCounted implements Entr
         recyclerHandle.recycle(this);
     }
 
+    @Override
+    public boolean matchesKey(PositionImpl key) {
+        return key.compareTo(ledgerId, entryId) == 0;
+    }
 }
diff --git 
a/managed-ledger/src/main/java/org/apache/bookkeeper/mledger/util/RangeCache.java
 
b/managed-ledger/src/main/java/org/apache/bookkeeper/mledger/util/RangeCache.java
index d34857e5e51..46d03bea1b5 100644
--- 
a/managed-ledger/src/main/java/org/apache/bookkeeper/mledger/util/RangeCache.java
+++ 
b/managed-ledger/src/main/java/org/apache/bookkeeper/mledger/util/RangeCache.java
@@ -19,31 +19,134 @@
 package org.apache.bookkeeper.mledger.util;
 
 import static com.google.common.base.Preconditions.checkArgument;
+import com.google.common.base.Predicate;
+import io.netty.util.IllegalReferenceCountException;
+import io.netty.util.Recycler;
+import io.netty.util.Recycler.Handle;
 import io.netty.util.ReferenceCounted;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
 import java.util.concurrent.ConcurrentNavigableMap;
 import java.util.concurrent.ConcurrentSkipListMap;
 import java.util.concurrent.atomic.AtomicLong;
+import org.apache.bookkeeper.mledger.util.RangeCache.ValueWithKeyValidation;
 import org.apache.commons.lang3.tuple.Pair;
 
 /**
  * Special type of cache where get() and delete() operations can be done over 
a range of keys.
+ * The implementation avoids locks and synchronization and relies on 
ConcurrentSkipListMap for storing the entries.
+ * Since there is no locks, there is a need to have a way to ensure that a 
single entry in the cache is removed
+ * exactly once. Removing an entry multiple times would result in the entries 
of the cache getting released too
+ * while they could still be in use.
  *
  * @param <Key>
  *            Cache key. Needs to be Comparable
  * @param <Value>
  *            Cache value
  */
-public class RangeCache<Key extends Comparable<Key>, Value extends 
ReferenceCounted> {
+public class RangeCache<Key extends Comparable<Key>, Value extends 
ValueWithKeyValidation<Key>> {
+    public interface ValueWithKeyValidation<T> extends ReferenceCounted {
+        boolean matchesKey(T key);
+    }
+
     // Map from key to nodes inside the linked list
-    private final ConcurrentNavigableMap<Key, Value> entries;
+    private final ConcurrentNavigableMap<Key, IdentityWrapper<Key, Value>> 
entries;
     private AtomicLong size; // Total size of values stored in cache
     private final Weighter<Value> weighter; // Weighter object used to extract 
the size from values
     private final TimestampExtractor<Value> timestampExtractor; // Extract the 
timestamp associated with a value
 
+    /**
+     * Wrapper around the value to store in Map. This is needed to ensure that 
a specific instance can be removed from
+     * the map by calling the {@link Map#remove(Object, Object)} method. 
Certain race conditions could result in the
+     * wrong value being removed from the map. The instances of this class are 
recycled to avoid creating new objects.
+     */
+    private static class IdentityWrapper<K, V> {
+        private final Handle<IdentityWrapper> recyclerHandle;
+        private static final Recycler<IdentityWrapper> RECYCLER = new 
Recycler<IdentityWrapper>() {
+            @Override
+            protected IdentityWrapper newObject(Handle<IdentityWrapper> 
recyclerHandle) {
+                return new IdentityWrapper(recyclerHandle);
+            }
+        };
+        private K key;
+        private V value;
+
+        private IdentityWrapper(Handle<IdentityWrapper> recyclerHandle) {
+            this.recyclerHandle = recyclerHandle;
+        }
+
+        static <K, V> IdentityWrapper<K, V> create(K key, V value) {
+            IdentityWrapper<K, V> identityWrapper = RECYCLER.get();
+            identityWrapper.key = key;
+            identityWrapper.value = value;
+            return identityWrapper;
+        }
+
+        K getKey() {
+            return key;
+        }
+
+        V getValue() {
+            return value;
+        }
+
+        void recycle() {
+            value = null;
+            recyclerHandle.recycle(this);
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            // only match exact identity of the value
+            return this == o;
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hashCode(key);
+        }
+    }
+
+    /**
+     * Mutable object to store the number of entries and the total size 
removed from the cache. The instances
+     * are recycled to avoid creating new instances.
+     */
+    private static class RemovalCounters {
+        private final Handle<RemovalCounters> recyclerHandle;
+        private static final Recycler<RemovalCounters> RECYCLER = new 
Recycler<RemovalCounters>() {
+            @Override
+            protected RemovalCounters newObject(Handle<RemovalCounters> 
recyclerHandle) {
+                return new RemovalCounters(recyclerHandle);
+            }
+        };
+        int removedEntries;
+        long removedSize;
+        private RemovalCounters(Handle<RemovalCounters> recyclerHandle) {
+            this.recyclerHandle = recyclerHandle;
+        }
+
+        static <T> RemovalCounters create() {
+            RemovalCounters results = RECYCLER.get();
+            results.removedEntries = 0;
+            results.removedSize = 0;
+            return results;
+        }
+
+        void recycle() {
+            removedEntries = 0;
+            removedSize = 0;
+            recyclerHandle.recycle(this);
+        }
+
+        public void entryRemoved(long size) {
+            removedSize += size;
+            removedEntries++;
+        }
+    }
+
     /**
      * Construct a new RangeLruCache with default Weighter.
      */
@@ -68,18 +171,23 @@ public class RangeCache<Key extends Comparable<Key>, Value 
extends ReferenceCoun
      * Insert.
      *
      * @param key
-     * @param value
-     *            ref counted value with at least 1 ref to pass on the cache
+     * @param value ref counted value with at least 1 ref to pass on the cache
      * @return whether the entry was inserted in the cache
      */
     public boolean put(Key key, Value value) {
         // retain value so that it's not released before we put it in the 
cache and calculate the weight
         value.retain();
         try {
-            if (entries.putIfAbsent(key, value) == null) {
+            if (!value.matchesKey(key)) {
+                throw new IllegalArgumentException("Value '" + value + "' does 
not match key '" + key + "'");
+            }
+            IdentityWrapper<Key, Value> newWrapper = 
IdentityWrapper.create(key, value);
+            if (entries.putIfAbsent(key, newWrapper) == null) {
                 size.addAndGet(weighter.getSize(value));
                 return true;
             } else {
+                // recycle the new wrapper as it was not used
+                newWrapper.recycle();
                 return false;
             }
         } finally {
@@ -91,16 +199,37 @@ public class RangeCache<Key extends Comparable<Key>, Value 
extends ReferenceCoun
         return key != null ? entries.containsKey(key) : true;
     }
 
+    /**
+     * Get the value associated with the key and increment the reference count 
of it.
+     * The caller is responsible for releasing the reference.
+     */
     public Value get(Key key) {
-        Value value = entries.get(key);
-        if (value == null) {
+        return getValue(key, entries.get(key));
+    }
+
+    private  Value getValue(Key key, IdentityWrapper<Key, Value> valueWrapper) 
{
+        if (valueWrapper == null) {
             return null;
         } else {
+            if (valueWrapper.getKey() != key) {
+                // the wrapper has been recycled and contains another key
+                return null;
+            }
+            Value value = valueWrapper.getValue();
             try {
                 value.retain();
+            } catch (IllegalReferenceCountException e) {
+                // Value was already deallocated
+                return null;
+            }
+            // check that the value matches the key and that there's at least 
2 references to it since
+            // the cache should be holding one reference and a new reference 
was just added in this method
+            if (value.refCnt() > 1 && value.matchesKey(key)) {
                 return value;
-            } catch (Throwable t) {
-                // Value was already destroyed between get() and retain()
+            } else {
+                // Value or IdentityWrapper was recycled and already contains 
another value
+                // release the reference added in this method
+                value.release();
                 return null;
             }
         }
@@ -118,12 +247,10 @@ public class RangeCache<Key extends Comparable<Key>, 
Value extends ReferenceCoun
         List<Value> values = new ArrayList();
 
         // Return the values of the entries found in cache
-        for (Value value : entries.subMap(first, true, last, true).values()) {
-            try {
-                value.retain();
+        for (Map.Entry<Key, IdentityWrapper<Key, Value>> entry : 
entries.subMap(first, true, last, true).entrySet()) {
+            Value value = getValue(entry.getKey(), entry.getValue());
+            if (value != null) {
                 values.add(value);
-            } catch (Throwable t) {
-                // Value was already destroyed between get() and retain()
             }
         }
 
@@ -138,25 +265,65 @@ public class RangeCache<Key extends Comparable<Key>, 
Value extends ReferenceCoun
      * @return an pair of ints, containing the number of removed entries and 
the total size
      */
     public Pair<Integer, Long> removeRange(Key first, Key last, boolean 
lastInclusive) {
-        Map<Key, Value> subMap = entries.subMap(first, true, last, 
lastInclusive);
+        RemovalCounters counters = RemovalCounters.create();
+        Map<Key, IdentityWrapper<Key, Value>> subMap = entries.subMap(first, 
true, last, lastInclusive);
+        for (Map.Entry<Key, IdentityWrapper<Key, Value>> entry : 
subMap.entrySet()) {
+            removeEntry(entry, counters);
+        }
+        return handleRemovalResult(counters);
+    }
 
-        int removedEntries = 0;
-        long removedSize = 0;
+    enum RemoveEntryResult {
+        ENTRY_REMOVED,
+        CONTINUE_LOOP,
+        BREAK_LOOP;
+    }
 
-        for (Key key : subMap.keySet()) {
-            Value value = entries.remove(key);
-            if (value == null) {
-                continue;
-            }
+    private RemoveEntryResult removeEntry(Map.Entry<Key, IdentityWrapper<Key, 
Value>> entry, RemovalCounters counters) {
+        return removeEntry(entry, counters, (x) -> true);
+    }
 
-            removedSize += weighter.getSize(value);
+    private RemoveEntryResult removeEntry(Map.Entry<Key, IdentityWrapper<Key, 
Value>> entry, RemovalCounters counters,
+                                          Predicate<Value> removeCondition) {
+        Key key = entry.getKey();
+        IdentityWrapper<Key, Value> identityWrapper = entry.getValue();
+        if (identityWrapper.getKey() != key) {
+            // the wrapper has been recycled and contains another key
+            return RemoveEntryResult.CONTINUE_LOOP;
+        }
+        Value value = identityWrapper.getValue();
+        try {
+            // add extra retain to avoid value being released while we are 
removing it
+            value.retain();
+        } catch (IllegalReferenceCountException e) {
+            // Value was already released
+            return RemoveEntryResult.CONTINUE_LOOP;
+        }
+        try {
+            if (!removeCondition.test(value)) {
+                return RemoveEntryResult.BREAK_LOOP;
+            }
+            // check that the value hasn't been recycled in between
+            // there should be at least 2 references since this method adds 
one and the cache should have one
+            // it is valid that the value contains references even after the 
key has been removed from the cache
+            if (value.refCnt() > 1 && value.matchesKey(key) && 
entries.remove(key, identityWrapper)) {
+                identityWrapper.recycle();
+                counters.entryRemoved(weighter.getSize(value));
+                // remove the cache reference
+                value.release();
+            }
+        } finally {
+            // remove the extra retain
             value.release();
-            ++removedEntries;
         }
+        return RemoveEntryResult.ENTRY_REMOVED;
+    }
 
-        size.addAndGet(-removedSize);
-
-        return Pair.of(removedEntries, removedSize);
+    private Pair<Integer, Long> handleRemovalResult(RemovalCounters counters) {
+        size.addAndGet(-counters.removedSize);
+        Pair<Integer, Long> result = Pair.of(counters.removedEntries, 
counters.removedSize);
+        counters.recycle();
+        return result;
     }
 
     /**
@@ -166,24 +333,15 @@ public class RangeCache<Key extends Comparable<Key>, 
Value extends ReferenceCoun
      */
     public Pair<Integer, Long> evictLeastAccessedEntries(long minSize) {
         checkArgument(minSize > 0);
-
-        long removedSize = 0;
-        int removedEntries = 0;
-
-        while (removedSize < minSize) {
-            Map.Entry<Key, Value> entry = entries.pollFirstEntry();
+        RemovalCounters counters = RemovalCounters.create();
+        while (counters.removedSize < minSize) {
+            Map.Entry<Key, IdentityWrapper<Key, Value>> entry = 
entries.firstEntry();
             if (entry == null) {
                 break;
             }
-
-            Value value = entry.getValue();
-            ++removedEntries;
-            removedSize += weighter.getSize(value);
-            value.release();
+            removeEntry(entry, counters);
         }
-
-        size.addAndGet(-removedSize);
-        return Pair.of(removedEntries, removedSize);
+        return handleRemovalResult(counters);
     }
 
     /**
@@ -192,27 +350,18 @@ public class RangeCache<Key extends Comparable<Key>, 
Value extends ReferenceCoun
     * @return the tota
     */
    public Pair<Integer, Long> evictLEntriesBeforeTimestamp(long maxTimestamp) {
-       long removedSize = 0;
-       int removedCount = 0;
-
+       RemovalCounters counters = RemovalCounters.create();
        while (true) {
-           Map.Entry<Key, Value> entry = entries.firstEntry();
-           if (entry == null || 
timestampExtractor.getTimestamp(entry.getValue()) > maxTimestamp) {
+           Map.Entry<Key, IdentityWrapper<Key, Value>> entry = 
entries.firstEntry();
+           if (entry == null) {
                break;
            }
-           Value value = entry.getValue();
-           boolean removeHits = entries.remove(entry.getKey(), value);
-           if (!removeHits) {
+           if (removeEntry(entry, counters, value -> 
timestampExtractor.getTimestamp(value) <= maxTimestamp)
+                   == RemoveEntryResult.BREAK_LOOP) {
                break;
            }
-
-           removedSize += weighter.getSize(value);
-           removedCount++;
-           value.release();
        }
-
-       size.addAndGet(-removedSize);
-       return Pair.of(removedCount, removedSize);
+       return handleRemovalResult(counters);
    }
 
     /**
@@ -231,23 +380,16 @@ public class RangeCache<Key extends Comparable<Key>, 
Value extends ReferenceCoun
      *
      * @return size of removed entries
      */
-    public synchronized Pair<Integer, Long> clear() {
-        long removedSize = 0;
-        int removedCount = 0;
-
+    public Pair<Integer, Long> clear() {
+        RemovalCounters counters = RemovalCounters.create();
         while (true) {
-            Map.Entry<Key, Value> entry = entries.pollFirstEntry();
+            Map.Entry<Key, IdentityWrapper<Key, Value>> entry = 
entries.firstEntry();
             if (entry == null) {
                 break;
             }
-            Value value = entry.getValue();
-            removedSize += weighter.getSize(value);
-            removedCount++;
-            value.release();
+            removeEntry(entry, counters);
         }
-
-        size.getAndAdd(-removedSize);
-        return Pair.of(removedCount, removedSize);
+        return handleRemovalResult(counters);
     }
 
     /**
diff --git 
a/managed-ledger/src/test/java/org/apache/bookkeeper/mledger/util/RangeCacheTest.java
 
b/managed-ledger/src/test/java/org/apache/bookkeeper/mledger/util/RangeCacheTest.java
index 8ce0db4ac4c..01b3c67bf11 100644
--- 
a/managed-ledger/src/test/java/org/apache/bookkeeper/mledger/util/RangeCacheTest.java
+++ 
b/managed-ledger/src/test/java/org/apache/bookkeeper/mledger/util/RangeCacheTest.java
@@ -23,25 +23,30 @@ import static org.testng.Assert.assertFalse;
 import static org.testng.Assert.assertNull;
 import static org.testng.Assert.assertTrue;
 import static org.testng.Assert.fail;
-
 import com.google.common.collect.Lists;
 import io.netty.util.AbstractReferenceCounted;
 import io.netty.util.ReferenceCounted;
-import org.apache.commons.lang3.tuple.Pair;
-import org.testng.annotations.Test;
-import java.util.UUID;
 import java.util.concurrent.Executors;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.TimeUnit;
+import lombok.Cleanup;
+import org.apache.commons.lang3.tuple.Pair;
+import org.testng.annotations.Test;
 
 public class RangeCacheTest {
 
-    class RefString extends AbstractReferenceCounted implements 
ReferenceCounted {
+    class RefString extends AbstractReferenceCounted implements 
RangeCache.ValueWithKeyValidation<Integer> {
         String s;
+        Integer matchingKey;
 
         RefString(String s) {
+            this(s, null);
+        }
+
+        RefString(String s, Integer matchingKey) {
             super();
             this.s = s;
+            this.matchingKey = matchingKey != null ? matchingKey : 
Integer.parseInt(s);
             setRefCnt(1);
         }
 
@@ -65,6 +70,11 @@ public class RangeCacheTest {
 
             return false;
         }
+
+        @Override
+        public boolean matchesKey(Integer key) {
+            return matchingKey.equals(key);
+        }
     }
 
     @Test
@@ -119,8 +129,8 @@ public class RangeCacheTest {
     public void customWeighter() {
         RangeCache<Integer, RefString> cache = new RangeCache<>(value -> 
value.s.length(), x -> 0);
 
-        cache.put(0, new RefString("zero"));
-        cache.put(1, new RefString("one"));
+        cache.put(0, new RefString("zero", 0));
+        cache.put(1, new RefString("one", 1));
 
         assertEquals(cache.getSize(), 7);
         assertEquals(cache.getNumberOfEntries(), 2);
@@ -132,9 +142,9 @@ public class RangeCacheTest {
         RangeCache<Integer, RefString> cache = new RangeCache<>(value -> 
value.s.length(), x -> x.s.length());
 
         cache.put(1, new RefString("1"));
-        cache.put(2, new RefString("22"));
-        cache.put(3, new RefString("333"));
-        cache.put(4, new RefString("4444"));
+        cache.put(22, new RefString("22"));
+        cache.put(333, new RefString("333"));
+        cache.put(4444, new RefString("4444"));
 
         assertEquals(cache.getSize(), 10);
         assertEquals(cache.getNumberOfEntries(), 4);
@@ -151,12 +161,12 @@ public class RangeCacheTest {
     public void doubleInsert() {
         RangeCache<Integer, RefString> cache = new RangeCache<>();
 
-        RefString s0 = new RefString("zero");
+        RefString s0 = new RefString("zero", 0);
         assertEquals(s0.refCnt(), 1);
         assertTrue(cache.put(0, s0));
         assertEquals(s0.refCnt(), 1);
 
-        cache.put(1, new RefString("one"));
+        cache.put(1, new RefString("one", 1));
 
         assertEquals(cache.getSize(), 2);
         assertEquals(cache.getNumberOfEntries(), 2);
@@ -164,7 +174,7 @@ public class RangeCacheTest {
         assertEquals(s.s, "one");
         assertEquals(s.refCnt(), 2);
 
-        RefString s1 = new RefString("uno");
+        RefString s1 = new RefString("uno", 1);
         assertEquals(s1.refCnt(), 1);
         assertFalse(cache.put(1, s1));
         assertEquals(s1.refCnt(), 1);
@@ -201,10 +211,10 @@ public class RangeCacheTest {
     public void eviction() {
         RangeCache<Integer, RefString> cache = new RangeCache<>(value -> 
value.s.length(), x -> 0);
 
-        cache.put(0, new RefString("zero"));
-        cache.put(1, new RefString("one"));
-        cache.put(2, new RefString("two"));
-        cache.put(3, new RefString("three"));
+        cache.put(0, new RefString("zero", 0));
+        cache.put(1, new RefString("one", 1));
+        cache.put(2, new RefString("two", 2));
+        cache.put(3, new RefString("three", 3));
 
         // This should remove the LRU entries: 0, 1 whose combined size is 7
         assertEquals(cache.evictLeastAccessedEntries(5), Pair.of(2, (long) 7));
@@ -276,20 +286,23 @@ public class RangeCacheTest {
     }
 
     @Test
-    public void testInParallel() {
-        RangeCache<String, RefString> cache = new RangeCache<>(value -> 
value.s.length(), x -> 0);
-        ScheduledExecutorService executor = 
Executors.newSingleThreadScheduledExecutor();
-        executor.scheduleWithFixedDelay(cache::clear, 10, 10, 
TimeUnit.MILLISECONDS);
-        for (int i = 0; i < 1000; i++) {
-            cache.put(UUID.randomUUID().toString(), new RefString("zero"));
+    public void testPutWhileClearIsCalledConcurrently() {
+        RangeCache<Integer, RefString> cache = new RangeCache<>(value -> 
value.s.length(), x -> 0);
+        int numberOfThreads = 4;
+        @Cleanup("shutdownNow")
+        ScheduledExecutorService executor = 
Executors.newScheduledThreadPool(numberOfThreads);
+        for (int i = 0; i < numberOfThreads; i++) {
+            executor.scheduleWithFixedDelay(cache::clear, 0, 1, 
TimeUnit.MILLISECONDS);
+        }
+        for (int i = 0; i < 100000; i++) {
+            cache.put(i, new RefString(String.valueOf(i)));
         }
-        executor.shutdown();
     }
 
     @Test
     public void testPutSameObj() {
         RangeCache<Integer, RefString> cache = new RangeCache<>(value -> 
value.s.length(), x -> 0);
-        RefString s0 = new RefString("zero");
+        RefString s0 = new RefString("zero", 0);
         assertEquals(s0.refCnt(), 1);
         assertTrue(cache.put(0, s0));
         assertFalse(cache.put(0, s0));

Reply via email to