This is an automated email from the ASF dual-hosted git repository.

lhotari pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/pulsar.git

commit 99aed3a74087ad03f5570f04b02e7f007e4635f3
Author: Lari Hotari <lhot...@users.noreply.github.com>
AuthorDate: Fri May 31 23:14:50 2024 +0300

    [improve][ml] RangeCache refactoring: test race conditions and prevent 
endless loops (#22814)
    
    (cherry picked from commit e731674f61a973e9b12eab9394f82731c8fc2384)
---
 .../apache/bookkeeper/mledger/util/RangeCache.java | 172 +++++++++++++--------
 .../mledger/impl/EntryCacheManagerTest.java        |   2 +-
 .../bookkeeper/mledger/util/RangeCacheTest.java    |  35 ++++-
 3 files changed, 143 insertions(+), 66 deletions(-)

diff --git 
a/managed-ledger/src/main/java/org/apache/bookkeeper/mledger/util/RangeCache.java
 
b/managed-ledger/src/main/java/org/apache/bookkeeper/mledger/util/RangeCache.java
index 46d03bea1b5..45295d71906 100644
--- 
a/managed-ledger/src/main/java/org/apache/bookkeeper/mledger/util/RangeCache.java
+++ 
b/managed-ledger/src/main/java/org/apache/bookkeeper/mledger/util/RangeCache.java
@@ -28,32 +28,36 @@ import java.util.ArrayList;
 import java.util.Collection;
 import java.util.List;
 import java.util.Map;
-import java.util.Objects;
 import java.util.concurrent.ConcurrentNavigableMap;
 import java.util.concurrent.ConcurrentSkipListMap;
 import java.util.concurrent.atomic.AtomicLong;
+import lombok.extern.slf4j.Slf4j;
 import org.apache.bookkeeper.mledger.util.RangeCache.ValueWithKeyValidation;
 import org.apache.commons.lang3.tuple.Pair;
 
 /**
  * Special type of cache where get() and delete() operations can be done over 
a range of keys.
- * The implementation avoids locks and synchronization and relies on 
ConcurrentSkipListMap for storing the entries.
- * Since there is no locks, there is a need to have a way to ensure that a 
single entry in the cache is removed
- * exactly once. Removing an entry multiple times would result in the entries 
of the cache getting released too
- * while they could still be in use.
+ * The implementation avoids locks and synchronization by relying on 
ConcurrentSkipListMap for storing the entries.
+ * Since there are no locks, it's necessary to ensure that a single entry in 
the cache is removed exactly once.
+ * Removing an entry multiple times could result in the entries of the cache 
being released multiple times,
+ * even while they are still in use. This is prevented by using a custom 
wrapper around the value to store in the map
+ * that ensures that the value is removed from the map only if the exact same 
instance is present in the map.
+ * There's also a check that ensures that the value matches the key. This is 
used to detect races without impacting
+ * consistency.
  *
  * @param <Key>
  *            Cache key. Needs to be Comparable
  * @param <Value>
  *            Cache value
  */
+@Slf4j
 public class RangeCache<Key extends Comparable<Key>, Value extends 
ValueWithKeyValidation<Key>> {
     public interface ValueWithKeyValidation<T> extends ReferenceCounted {
         boolean matchesKey(T key);
     }
 
     // Map from key to nodes inside the linked list
-    private final ConcurrentNavigableMap<Key, IdentityWrapper<Key, Value>> 
entries;
+    private final ConcurrentNavigableMap<Key, EntryWrapper<Key, Value>> 
entries;
     private AtomicLong size; // Total size of values stored in cache
     private final Weighter<Value> weighter; // Weighter object used to extract 
the size from values
     private final TimestampExtractor<Value> timestampExtractor; // Extract the 
timestamp associated with a value
@@ -63,51 +67,53 @@ public class RangeCache<Key extends Comparable<Key>, Value 
extends ValueWithKeyV
      * the map by calling the {@link Map#remove(Object, Object)} method. 
Certain race conditions could result in the
      * wrong value being removed from the map. The instances of this class are 
recycled to avoid creating new objects.
      */
-    private static class IdentityWrapper<K, V> {
-        private final Handle<IdentityWrapper> recyclerHandle;
-        private static final Recycler<IdentityWrapper> RECYCLER = new 
Recycler<IdentityWrapper>() {
+    private static class EntryWrapper<K, V> {
+        private final Handle<EntryWrapper> recyclerHandle;
+        private static final Recycler<EntryWrapper> RECYCLER = new 
Recycler<EntryWrapper>() {
             @Override
-            protected IdentityWrapper newObject(Handle<IdentityWrapper> 
recyclerHandle) {
-                return new IdentityWrapper(recyclerHandle);
+            protected EntryWrapper newObject(Handle<EntryWrapper> 
recyclerHandle) {
+                return new EntryWrapper(recyclerHandle);
             }
         };
         private K key;
         private V value;
+        long size;
 
-        private IdentityWrapper(Handle<IdentityWrapper> recyclerHandle) {
+        private EntryWrapper(Handle<EntryWrapper> recyclerHandle) {
             this.recyclerHandle = recyclerHandle;
         }
 
-        static <K, V> IdentityWrapper<K, V> create(K key, V value) {
-            IdentityWrapper<K, V> identityWrapper = RECYCLER.get();
-            identityWrapper.key = key;
-            identityWrapper.value = value;
-            return identityWrapper;
+        static <K, V> EntryWrapper<K, V> create(K key, V value, long size) {
+            EntryWrapper<K, V> entryWrapper = RECYCLER.get();
+            synchronized (entryWrapper) {
+                entryWrapper.key = key;
+                entryWrapper.value = value;
+                entryWrapper.size = size;
+            }
+            return entryWrapper;
         }
 
-        K getKey() {
+        synchronized K getKey() {
             return key;
         }
 
-        V getValue() {
+        synchronized V getValue(K key) {
+            if (this.key != key) {
+                return null;
+            }
             return value;
         }
 
+        synchronized long getSize() {
+            return size;
+        }
+
         void recycle() {
+            key = null;
             value = null;
+            size = 0;
             recyclerHandle.recycle(this);
         }
-
-        @Override
-        public boolean equals(Object o) {
-            // only match exact identity of the value
-            return this == o;
-        }
-
-        @Override
-        public int hashCode() {
-            return Objects.hashCode(key);
-        }
     }
 
     /**
@@ -181,9 +187,10 @@ public class RangeCache<Key extends Comparable<Key>, Value 
extends ValueWithKeyV
             if (!value.matchesKey(key)) {
                 throw new IllegalArgumentException("Value '" + value + "' does 
not match key '" + key + "'");
             }
-            IdentityWrapper<Key, Value> newWrapper = 
IdentityWrapper.create(key, value);
+            long entrySize = weighter.getSize(value);
+            EntryWrapper<Key, Value> newWrapper = EntryWrapper.create(key, 
value, entrySize);
             if (entries.putIfAbsent(key, newWrapper) == null) {
-                size.addAndGet(weighter.getSize(value));
+                this.size.addAndGet(entrySize);
                 return true;
             } else {
                 // recycle the new wrapper as it was not used
@@ -207,15 +214,15 @@ public class RangeCache<Key extends Comparable<Key>, 
Value extends ValueWithKeyV
         return getValue(key, entries.get(key));
     }
 
-    private  Value getValue(Key key, IdentityWrapper<Key, Value> valueWrapper) 
{
+    private  Value getValue(Key key, EntryWrapper<Key, Value> valueWrapper) {
         if (valueWrapper == null) {
             return null;
         } else {
-            if (valueWrapper.getKey() != key) {
+            Value value = valueWrapper.getValue(key);
+            if (value == null) {
                 // the wrapper has been recycled and contains another key
                 return null;
             }
-            Value value = valueWrapper.getValue();
             try {
                 value.retain();
             } catch (IllegalReferenceCountException e) {
@@ -247,7 +254,7 @@ public class RangeCache<Key extends Comparable<Key>, Value 
extends ValueWithKeyV
         List<Value> values = new ArrayList();
 
         // Return the values of the entries found in cache
-        for (Map.Entry<Key, IdentityWrapper<Key, Value>> entry : 
entries.subMap(first, true, last, true).entrySet()) {
+        for (Map.Entry<Key, EntryWrapper<Key, Value>> entry : 
entries.subMap(first, true, last, true).entrySet()) {
             Value value = getValue(entry.getKey(), entry.getValue());
             if (value != null) {
                 values.add(value);
@@ -266,9 +273,9 @@ public class RangeCache<Key extends Comparable<Key>, Value 
extends ValueWithKeyV
      */
     public Pair<Integer, Long> removeRange(Key first, Key last, boolean 
lastInclusive) {
         RemovalCounters counters = RemovalCounters.create();
-        Map<Key, IdentityWrapper<Key, Value>> subMap = entries.subMap(first, 
true, last, lastInclusive);
-        for (Map.Entry<Key, IdentityWrapper<Key, Value>> entry : 
subMap.entrySet()) {
-            removeEntry(entry, counters);
+        Map<Key, EntryWrapper<Key, Value>> subMap = entries.subMap(first, 
true, last, lastInclusive);
+        for (Map.Entry<Key, EntryWrapper<Key, Value>> entry : 
subMap.entrySet()) {
+            removeEntry(entry, counters, true);
         }
         return handleRemovalResult(counters);
     }
@@ -279,36 +286,76 @@ public class RangeCache<Key extends Comparable<Key>, 
Value extends ValueWithKeyV
         BREAK_LOOP;
     }
 
-    private RemoveEntryResult removeEntry(Map.Entry<Key, IdentityWrapper<Key, 
Value>> entry, RemovalCounters counters) {
-        return removeEntry(entry, counters, (x) -> true);
+    private RemoveEntryResult removeEntry(Map.Entry<Key, EntryWrapper<Key, 
Value>> entry, RemovalCounters counters,
+                                          boolean skipInvalid) {
+        return removeEntry(entry, counters, skipInvalid, x -> true);
     }
 
-    private RemoveEntryResult removeEntry(Map.Entry<Key, IdentityWrapper<Key, 
Value>> entry, RemovalCounters counters,
-                                          Predicate<Value> removeCondition) {
+    private RemoveEntryResult removeEntry(Map.Entry<Key, EntryWrapper<Key, 
Value>> entry, RemovalCounters counters,
+                                          boolean skipInvalid, 
Predicate<Value> removeCondition) {
         Key key = entry.getKey();
-        IdentityWrapper<Key, Value> identityWrapper = entry.getValue();
-        if (identityWrapper.getKey() != key) {
-            // the wrapper has been recycled and contains another key
+        EntryWrapper<Key, Value> entryWrapper = entry.getValue();
+        Value value = entryWrapper.getValue(key);
+        if (value == null) {
+            // the wrapper has already been recycled and contains another key
+            if (!skipInvalid) {
+                EntryWrapper<Key, Value> removed = entries.remove(key);
+                if (removed != null) {
+                    // log and remove the entry without releasing the value
+                    log.info("Key {} does not match the entry's value 
wrapper's key {}, removed entry by key without "
+                            + "releasing the value", key, 
entryWrapper.getKey());
+                    counters.entryRemoved(removed.getSize());
+                    return RemoveEntryResult.ENTRY_REMOVED;
+                }
+            }
             return RemoveEntryResult.CONTINUE_LOOP;
         }
-        Value value = identityWrapper.getValue();
         try {
             // add extra retain to avoid value being released while we are 
removing it
             value.retain();
         } catch (IllegalReferenceCountException e) {
             // Value was already released
+            if (!skipInvalid) {
+                // remove the specific entry without releasing the value
+                if (entries.remove(key, entryWrapper)) {
+                    log.info("Value was already released for key {}, removed 
entry without releasing the value", key);
+                    counters.entryRemoved(entryWrapper.getSize());
+                    return RemoveEntryResult.ENTRY_REMOVED;
+                }
+            }
             return RemoveEntryResult.CONTINUE_LOOP;
         }
+        if (!value.matchesKey(key)) {
+            // this is unexpected since the IdentityWrapper.getValue(key) 
already checked that the value matches the key
+            log.warn("Unexpected race condition. Value {} does not match the 
key {}. Removing entry.", value, key);
+        }
         try {
             if (!removeCondition.test(value)) {
                 return RemoveEntryResult.BREAK_LOOP;
             }
-            // check that the value hasn't been recycled in between
-            // there should be at least 2 references since this method adds 
one and the cache should have one
-            // it is valid that the value contains references even after the 
key has been removed from the cache
-            if (value.refCnt() > 1 && value.matchesKey(key) && 
entries.remove(key, identityWrapper)) {
-                identityWrapper.recycle();
-                counters.entryRemoved(weighter.getSize(value));
+            if (!skipInvalid) {
+                // remove the specific entry
+                boolean entryRemoved = entries.remove(key, entryWrapper);
+                if (entryRemoved) {
+                    counters.entryRemoved(entryWrapper.getSize());
+                    // check that the value hasn't been recycled in between
+                    // there should be at least 2 references since this method 
adds one and the cache should have
+                    // one reference. it is valid that the value contains 
references even after the key has been
+                    // removed from the cache
+                    if (value.refCnt() > 1) {
+                        entryWrapper.recycle();
+                        // remove the cache reference
+                        value.release();
+                    } else {
+                        log.info("Unexpected refCnt {} for key {}, removed 
entry without releasing the value",
+                                value.refCnt(), key);
+                    }
+                }
+            } else if (skipInvalid && value.refCnt() > 1 && 
entries.remove(key, entryWrapper)) {
+                // when skipInvalid is true, we don't remove the entry if it 
doesn't match matches the key
+                // or the refCnt is invalid
+                counters.entryRemoved(entryWrapper.getSize());
+                entryWrapper.recycle();
                 // remove the cache reference
                 value.release();
             }
@@ -334,12 +381,12 @@ public class RangeCache<Key extends Comparable<Key>, 
Value extends ValueWithKeyV
     public Pair<Integer, Long> evictLeastAccessedEntries(long minSize) {
         checkArgument(minSize > 0);
         RemovalCounters counters = RemovalCounters.create();
-        while (counters.removedSize < minSize) {
-            Map.Entry<Key, IdentityWrapper<Key, Value>> entry = 
entries.firstEntry();
+        while (counters.removedSize < minSize && 
!Thread.currentThread().isInterrupted()) {
+            Map.Entry<Key, EntryWrapper<Key, Value>> entry = 
entries.firstEntry();
             if (entry == null) {
                 break;
             }
-            removeEntry(entry, counters);
+            removeEntry(entry, counters, false);
         }
         return handleRemovalResult(counters);
     }
@@ -351,12 +398,12 @@ public class RangeCache<Key extends Comparable<Key>, 
Value extends ValueWithKeyV
     */
    public Pair<Integer, Long> evictLEntriesBeforeTimestamp(long maxTimestamp) {
        RemovalCounters counters = RemovalCounters.create();
-       while (true) {
-           Map.Entry<Key, IdentityWrapper<Key, Value>> entry = 
entries.firstEntry();
+       while (!Thread.currentThread().isInterrupted()) {
+           Map.Entry<Key, EntryWrapper<Key, Value>> entry = 
entries.firstEntry();
            if (entry == null) {
                break;
            }
-           if (removeEntry(entry, counters, value -> 
timestampExtractor.getTimestamp(value) <= maxTimestamp)
+           if (removeEntry(entry, counters, false, value -> 
timestampExtractor.getTimestamp(value) <= maxTimestamp)
                    == RemoveEntryResult.BREAK_LOOP) {
                break;
            }
@@ -382,12 +429,12 @@ public class RangeCache<Key extends Comparable<Key>, 
Value extends ValueWithKeyV
      */
     public Pair<Integer, Long> clear() {
         RemovalCounters counters = RemovalCounters.create();
-        while (true) {
-            Map.Entry<Key, IdentityWrapper<Key, Value>> entry = 
entries.firstEntry();
+        while (!Thread.currentThread().isInterrupted()) {
+            Map.Entry<Key, EntryWrapper<Key, Value>> entry = 
entries.firstEntry();
             if (entry == null) {
                 break;
             }
-            removeEntry(entry, counters);
+            removeEntry(entry, counters, false);
         }
         return handleRemovalResult(counters);
     }
@@ -421,5 +468,4 @@ public class RangeCache<Key extends Comparable<Key>, Value 
extends ValueWithKeyV
             return 1;
         }
     }
-
 }
diff --git 
a/managed-ledger/src/test/java/org/apache/bookkeeper/mledger/impl/EntryCacheManagerTest.java
 
b/managed-ledger/src/test/java/org/apache/bookkeeper/mledger/impl/EntryCacheManagerTest.java
index 1b02cd674c5..1ab3198498a 100644
--- 
a/managed-ledger/src/test/java/org/apache/bookkeeper/mledger/impl/EntryCacheManagerTest.java
+++ 
b/managed-ledger/src/test/java/org/apache/bookkeeper/mledger/impl/EntryCacheManagerTest.java
@@ -193,9 +193,9 @@ public class EntryCacheManagerTest extends 
MockedBookKeeperTestCase {
         }
 
         cacheManager.removeEntryCache(ml1.getName());
-        assertTrue(cacheManager.getSize() > 0);
         assertEquals(factory2.getMbean().getCacheInsertedEntriesCount(), 20);
         assertEquals(factory2.getMbean().getCacheEntriesCount(), 0);
+        assertEquals(0, cacheManager.getSize());
         assertEquals(factory2.getMbean().getCacheEvictedEntriesCount(), 20);
     }
 
diff --git 
a/managed-ledger/src/test/java/org/apache/bookkeeper/mledger/util/RangeCacheTest.java
 
b/managed-ledger/src/test/java/org/apache/bookkeeper/mledger/util/RangeCacheTest.java
index 01b3c67bf11..4bcf2cc6c4e 100644
--- 
a/managed-ledger/src/test/java/org/apache/bookkeeper/mledger/util/RangeCacheTest.java
+++ 
b/managed-ledger/src/test/java/org/apache/bookkeeper/mledger/util/RangeCacheTest.java
@@ -30,11 +30,14 @@ import java.util.concurrent.Executors;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.TimeUnit;
 import lombok.Cleanup;
+import lombok.Data;
 import org.apache.commons.lang3.tuple.Pair;
+import org.awaitility.Awaitility;
 import org.testng.annotations.Test;
 
 public class RangeCacheTest {
 
+    @Data
     class RefString extends AbstractReferenceCounted implements 
RangeCache.ValueWithKeyValidation<Integer> {
         String s;
         Integer matchingKey;
@@ -288,15 +291,21 @@ public class RangeCacheTest {
     @Test
     public void testPutWhileClearIsCalledConcurrently() {
         RangeCache<Integer, RefString> cache = new RangeCache<>(value -> 
value.s.length(), x -> 0);
-        int numberOfThreads = 4;
+        int numberOfThreads = 8;
         @Cleanup("shutdownNow")
         ScheduledExecutorService executor = 
Executors.newScheduledThreadPool(numberOfThreads);
         for (int i = 0; i < numberOfThreads; i++) {
             executor.scheduleWithFixedDelay(cache::clear, 0, 1, 
TimeUnit.MILLISECONDS);
         }
-        for (int i = 0; i < 100000; i++) {
+        for (int i = 0; i < 200000; i++) {
             cache.put(i, new RefString(String.valueOf(i)));
         }
+        executor.shutdown();
+        // ensure that no clear operation got into endless loop
+        Awaitility.await().untilAsserted(() -> 
assertTrue(executor.isTerminated()));
+        // ensure that clear can be called and all entries are removed
+        cache.clear();
+        assertEquals(cache.getNumberOfEntries(), 0);
     }
 
     @Test
@@ -307,4 +316,26 @@ public class RangeCacheTest {
         assertTrue(cache.put(0, s0));
         assertFalse(cache.put(0, s0));
     }
+
+    @Test
+    public void testRemoveEntryWithInvalidRefCount() {
+        RangeCache<Integer, RefString> cache = new RangeCache<>(value -> 
value.s.length(), x -> 0);
+        RefString value = new RefString("1");
+        cache.put(1, value);
+        // release the value to make the reference count invalid
+        value.release();
+        cache.clear();
+        assertEquals(cache.getNumberOfEntries(), 0);
+    }
+
+    @Test
+    public void testRemoveEntryWithInvalidMatchingKey() {
+        RangeCache<Integer, RefString> cache = new RangeCache<>(value -> 
value.s.length(), x -> 0);
+        RefString value = new RefString("1");
+        cache.put(1, value);
+        // change the matching key to make it invalid
+        value.setMatchingKey(123);
+        cache.clear();
+        assertEquals(cache.getNumberOfEntries(), 0);
+    }
 }

Reply via email to