This is an automated email from the ASF dual-hosted git repository.

janniklinde pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 1f2e2eab00 [MINOR] OOC Bugfix Cache Reference Management + Return 
Right BlockKey on Externally Managed Grouped Callbacks
1f2e2eab00 is described below

commit 1f2e2eab00a835f83a71e07b18f6b2b369ba1b97
Author: Jannik Lindemann <[email protected]>
AuthorDate: Fri Mar 27 10:29:30 2026 +0100

    [MINOR] OOC Bugfix Cache Reference Management + Return Right BlockKey on 
Externally Managed Grouped Callbacks
    
    Closes #2454.
---
 .../runtime/instructions/ooc/CachingStream.java    | 49 +++++++++++++++++++++-
 .../apache/sysds/runtime/ooc/cache/BlockEntry.java | 10 +++++
 .../sysds/runtime/ooc/cache/OOCCacheScheduler.java |  8 ++++
 .../runtime/ooc/cache/OOCLRUCacheScheduler.java    | 39 +++++++++++++++--
 4 files changed, 101 insertions(+), 5 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java
index b3f5e57aaf..b7c4e2aa64 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java
@@ -130,6 +130,7 @@ public class CachingStream implements 
OOCStreamable<IndexedMatrixValue> {
                                                        boolean ownsEntry = 
true;
                                                        if(tmp instanceof 
OOCCacheManager.CachedGroupCallback<?> cachedGroup) {
                                                                baseKey = 
cachedGroup.getBlockKey();
+                                                               
ensureReferencedOrRematerialize(baseKey, cachedGroup);
                                                                ownsEntry = 
false;
                                                                if(mSubscribers 
!= null && mSubscribers.length > 0)
                                                                        
mCallback = tmp.keepOpen();
@@ -183,12 +184,14 @@ public class CachingStream implements 
OOCStreamable<IndexedMatrixValue> {
 
                                                        if(tmp instanceof 
OOCCacheManager.CachedQueueCallback<?> cachedQueue) {
                                                                blockKey = 
cachedQueue.getBlockKey();
+                                                               
ensureReferencedOrRematerialize(blockKey, task);
                                                                ownsEntry = 
false;
                                                                if(mSubscribers 
!= null && mSubscribers.length > 0)
                                                                        
mCallback = tmp.keepOpen();
                                                        }
                                                        else if(tmp instanceof 
OOCCacheManager.CachedSubCallback<?> cachedSub) {
                                                                BlockKey parent 
= cachedSub.getParent().getBlockKey();
+                                                               
ensureReferencedOrRematerialize(parent, cachedSub.getParent());
                                                                blockKey = new 
GroupedBlockKey(parent.getStreamId(), (int) parent.getSequenceNumber(),
                                                                        
cachedSub.getGroupIndex());
                                                                ownsEntry = 
false;
@@ -297,6 +300,49 @@ public class CachingStream implements 
OOCStreamable<IndexedMatrixValue> {
                });
        }
 
+
+       private void ensureReferencedOrRematerialize(BlockKey key, 
IndexedMatrixValue value) {
+               try {
+                       OOCCacheManager.getCache().addReference(key);
+               }
+               catch(IllegalArgumentException ex) {
+                       try {
+                               OOCCacheManager.putRaw(key, value, 
((MatrixBlock) value.getValue()).getExactSerializedSize());
+                       }
+                       catch(IllegalStateException putEx) {
+                               // Another downstream stream may have 
re-materialized the same entry first.
+                               OOCCacheManager.getCache().addReference(key);
+                       }
+               }
+       }
+
+       private void ensureReferencedOrRematerialize(BlockKey key, 
OOCCacheManager.CachedGroupCallback<?> group) {
+               try {
+                       OOCCacheManager.getCache().addReference(key);
+               }
+               catch(IllegalArgumentException ex) {
+                       try {
+                               List<IndexedMatrixValue> values = new 
ArrayList<>(group.size());
+                               long totalSize = 0;
+                               for(int gi = 0; gi < group.size(); gi++) {
+                                       @SuppressWarnings("unchecked")
+                                       
OOCStream.QueueCallback<IndexedMatrixValue> sub =
+                                               
(OOCStream.QueueCallback<IndexedMatrixValue>) group.getCallback(gi);
+                                       try(sub) {
+                                               IndexedMatrixValue imv = 
sub.get();
+                                               values.add(imv);
+                                               totalSize += ((MatrixBlock) 
imv.getValue()).getExactSerializedSize();
+                                       }
+                               }
+                               OOCCacheManager.putRaw(key, values, totalSize);
+                       }
+                       catch(IllegalStateException putEx) {
+                               // Another downstream stream may have 
re-materialized the same entry first.
+                               OOCCacheManager.getCache().addReference(key);
+                       }
+               }
+       }
+
        private String getCtxMsg() {
                StackTraceElement[] st = new Exception().getStackTrace();
                // Skip the first few frames (constructor, 
createWritableStream, etc.)
@@ -687,7 +733,7 @@ public class CachingStream implements 
OOCStreamable<IndexedMatrixValue> {
                        if(groupIdx > 0)
                                continue; // only replay grouped blocks once at 
the base index
 
-                       BlockKey replayKey = (groupSize > 1 && groupIdx == 0) ? 
new BlockKey(_streamId, idx) : getBlockKey(i);
+                       BlockKey replayKey = (groupSize > 1 && groupIdx == 0) ? 
getEntryBlockKey(idx) : getBlockKey(i);
                        
OOCCacheManager.requestBlock(replayKey).whenComplete((cb, r) -> {
                                if(r != null) {
                                        
subscriber.accept(OOCStream.eos(DMLRuntimeException.of(r)));
@@ -697,7 +743,6 @@ public class CachingStream implements 
OOCStreamable<IndexedMatrixValue> {
                                        synchronized(CachingStream.this) {
                                                if(_index != null) {
                                                        if(cb instanceof 
OOCStream.GroupQueueCallback<?> && groupSize > 1) {
-                                                               
@SuppressWarnings("unchecked")
                                                                
OOCStream.GroupQueueCallback<IndexedMatrixValue> group =
                                                                        
(OOCStream.GroupQueueCallback<IndexedMatrixValue>) cb;
                                                                for(int gi = 0; 
gi < groupSize; gi++) {
diff --git a/src/main/java/org/apache/sysds/runtime/ooc/cache/BlockEntry.java 
b/src/main/java/org/apache/sysds/runtime/ooc/cache/BlockEntry.java
index fbc7d64223..c0604d017d 100644
--- a/src/main/java/org/apache/sysds/runtime/ooc/cache/BlockEntry.java
+++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/BlockEntry.java
@@ -30,6 +30,7 @@ public final class BlockEntry {
        private volatile BlockState _state;
        private Object _data;
        private int _retainHintCount;
+       private int _referenceCount; // The number of references from different 
managing instances (e.g. CachingStream)
 
        BlockEntry(BlockKey key, long size, Object data) {
                this._key = key;
@@ -38,6 +39,7 @@ public final class BlockEntry {
                this._state = BlockState.HOT;
                this._data = data;
                this._retainHintCount = 0;
+               this._referenceCount = 1;
        }
 
        public BlockKey getKey() {
@@ -84,6 +86,14 @@ public final class BlockEntry {
                return _pinCount > 0;
        }
 
+       synchronized int addReference() {
+               return ++_referenceCount;
+       }
+
+       synchronized int forget() {
+               return --_referenceCount;
+       }
+
        synchronized void setState(BlockState state) {
                _state = state;
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheScheduler.java 
b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheScheduler.java
index 9cc108db5e..dbbd73d53a 100644
--- a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheScheduler.java
+++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheScheduler.java
@@ -103,6 +103,14 @@ public interface OOCCacheScheduler {
        BlockEntry putAndPinSourceBacked(BlockKey key, Object data, long size,
                OOCIOHandler.SourceBlockDescriptor descriptor);
 
+       /**
+        * Notifies the cache that there is another reference to the same block 
key.
+        * This will prevent forget(key) from removing the block from cache.
+        * A block will only be forgotten after all referencing instances 
called forget(key).
+        * @param key
+        */
+       void addReference(BlockKey key);
+
        /**
         * Forgets a block from the cache.
         * @param key the associated key of the block
diff --git 
a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheScheduler.java 
b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheScheduler.java
index a204cd16db..cc7aa7bcd1 100644
--- a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheScheduler.java
+++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheScheduler.java
@@ -19,6 +19,7 @@
 
 package org.apache.sysds.runtime.ooc.cache;
 
+import org.apache.commons.lang3.mutable.MutableObject;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.api.DMLScript;
@@ -291,6 +292,18 @@ public class OOCLRUCacheScheduler implements 
OOCCacheScheduler {
                return put(key, data, size, true, descriptor);
        }
 
+       @Override
+       public void addReference(BlockKey key) {
+               synchronized(this) {
+                       BlockEntry entry = _cache.get(key);
+                       if(entry == null)
+                               entry = _evictionCache.get(key);
+                       if(entry == null)
+                               throw new IllegalArgumentException("Could not 
find requested block with key " + key);
+                       entry.addReference();
+               }
+       }
+
        private BlockEntry put(BlockKey key, Object data, long size, boolean 
pin, OOCIOHandler.SourceBlockDescriptor descriptor) {
                if (!this._running)
                        throw new IllegalStateException();
@@ -324,14 +337,34 @@ public class OOCLRUCacheScheduler implements 
OOCCacheScheduler {
        public void forget(BlockKey key) {
                if (!this._running)
                        return;
+               final MutableObject<BlockEntry> mEntry = new MutableObject<>();
                BlockEntry entry;
                boolean shouldScheduleDeletion = false;
                long cacheSizeDelta = 0;
                synchronized(this) {
-                       entry = _cache.remove(key);
+                       _cache.compute(key, (k, e) -> {
+                               if(e == null)
+                                       return null;
+                               if(e.forget() == 0) {
+                                       mEntry.setValue(e);
+                                       return null;
+                               }
+                               return e;
+                       });
 
-                       if (entry == null)
-                               entry = _evictionCache.remove(key);
+                       if (mEntry.getValue() == null) {
+                               _evictionCache.compute(key, (k, e) -> {
+                                       if(e == null)
+                                               return null;
+                                       if(e.forget() == 0) {
+                                               mEntry.setValue(e);
+                                               return null;
+                                       }
+                                       return e;
+                               });
+                       }
+
+                       entry = mEntry.getValue();
 
                        if (entry != null) {
                                synchronized(entry) {

Reply via email to