This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 02603876c5 [SYSTEMDS-3891] OOC Source Streams and Perf Improvements
02603876c5 is described below
commit 02603876c549f0bd5bed1c8c9fe9b96064ffb83e
Author: Jannik Lindemann <[email protected]>
AuthorDate: Mon Jan 5 12:35:27 2026 +0100
[SYSTEMDS-3891] OOC Source Streams and Perf Improvements
Closes #2393.
---
.../runtime/instructions/ooc/CachingStream.java | 22 +-
.../instructions/ooc/ReblockOOCInstruction.java | 56 ++---
.../sysds/runtime/ooc/cache/OOCCacheManager.java | 26 ++
.../sysds/runtime/ooc/cache/OOCCacheScheduler.java | 22 ++
.../sysds/runtime/ooc/cache/OOCIOHandler.java | 82 ++++++
.../runtime/ooc/cache/OOCLRUCacheScheduler.java | 28 ++-
.../runtime/ooc/cache/OOCMatrixIOHandler.java | 276 ++++++++++++++++++++-
.../sysds/runtime/ooc/stream/OOCSourceStream.java | 52 ++++
.../ooc/cache/SourceBackedCacheSchedulerTest.java | 106 ++++++++
.../cache/SourceBackedReadOOCIOHandlerTest.java | 100 ++++++++
.../functions/ooc/SourceReadOOCIOHandlerTest.java | 143 +++++++++++
11 files changed, 855 insertions(+), 58 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java
index 9cda04e0c7..f9869b20f9 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java
@@ -25,7 +25,9 @@ import
org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.ooc.cache.BlockKey;
+import org.apache.sysds.runtime.ooc.cache.OOCIOHandler;
import org.apache.sysds.runtime.ooc.cache.OOCCacheManager;
+import org.apache.sysds.runtime.ooc.stream.OOCSourceStream;
import shaded.parquet.it.unimi.dsi.fastutil.ints.IntArrayList;
import java.util.HashMap;
@@ -89,10 +91,22 @@ public class CachingStream implements
OOCStreamable<IndexedMatrixValue> {
if(task !=
LocalTaskQueue.NO_MORE_TASKS) {
if (!_cacheInProgress)
throw new
DMLRuntimeException("Stream is closed");
- if (mSubscribers == null ||
mSubscribers.length == 0)
-
OOCCacheManager.put(_streamId, _numBlocks, task);
- else
- mCallback =
OOCCacheManager.putAndPin(_streamId, _numBlocks, task);
+
OOCIOHandler.SourceBlockDescriptor descriptor = null;
+ if (_source instanceof
OOCSourceStream src) {
+ descriptor =
src.getDescriptor(task.getIndexes());
+ }
+ if (descriptor == null) {
+ if (mSubscribers ==
null || mSubscribers.length == 0)
+
OOCCacheManager.put(_streamId, _numBlocks, task);
+ else
+ mCallback =
OOCCacheManager.putAndPin(_streamId, _numBlocks, task);
+ }
+ else {
+ if (mSubscribers ==
null || mSubscribers.length == 0)
+
OOCCacheManager.putSourceBacked(_streamId, _numBlocks, task, descriptor);
+ else
+ mCallback =
OOCCacheManager.putAndPinSourceBacked(_streamId, _numBlocks, task, descriptor);
+ }
if (_index != null)
_index.put(task.getIndexes(), _numBlocks);
blk = _numBlocks;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReblockOOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReblockOOCInstruction.java
index 74b15c9fb0..f744b97506 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReblockOOCInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReblockOOCInstruction.java
@@ -19,24 +19,19 @@
package org.apache.sysds.runtime.instructions.ooc;
-import org.apache.hadoop.fs.FileSystem;
-import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.io.SequenceFile;
-import org.apache.hadoop.mapred.JobConf;
import org.apache.sysds.common.Opcodes;
-import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
-import org.apache.sysds.runtime.io.IOUtilFunctions;
-import org.apache.sysds.runtime.io.MatrixReader;
-import org.apache.sysds.runtime.matrix.data.MatrixBlock;
-import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
+import org.apache.sysds.runtime.ooc.cache.OOCCacheManager;
+import org.apache.sysds.runtime.ooc.cache.OOCIOHandler;
+import org.apache.sysds.runtime.ooc.stream.OOCSourceStream;
public class ReblockOOCInstruction extends ComputationOOCInstruction {
private int blen;
@@ -74,40 +69,19 @@ public class ReblockOOCInstruction extends
ComputationOOCInstruction {
//TODO support other formats than binary
//create queue, spawn thread for asynchronous reading, and
return
- OOCStream<IndexedMatrixValue> q = createWritableStream();
- submitOOCTask(() -> readBinaryBlock(q, min.getFileName()), q);
+ OOCStream<IndexedMatrixValue> q = new OOCSourceStream();
+ OOCIOHandler io = OOCCacheManager.getIOHandler();
+ OOCIOHandler.SourceReadRequest req = new
OOCIOHandler.SourceReadRequest(
+ min.getFileName(), Types.FileFormat.BINARY,
mc.getRows(), mc.getCols(), blen, mc.getNonZeros(),
+ Long.MAX_VALUE, true, q);
+ io.scheduleSourceRead(req).whenComplete((res, err) -> {
+ if (err != null) {
+ Exception ex = err instanceof Exception ?
(Exception) err : new Exception(err);
+ q.propagateFailure(new DMLRuntimeException(ex));
+ }
+ });
MatrixObject mout = ec.getMatrixObject(output);
mout.setStreamHandle(q);
}
-
- @SuppressWarnings("resource")
- private void readBinaryBlock(OOCStream<IndexedMatrixValue> q, String
fname) {
- try {
- //prepare file access
- JobConf job = new
JobConf(ConfigurationManager.getCachedJobConf());
- Path path = new Path( fname );
- FileSystem fs = IOUtilFunctions.getFileSystem(path,
job);
-
- //check existence and non-empty file
- MatrixReader.checkValidInputFile(fs, path);
-
- //core reading
- for( Path lpath :
IOUtilFunctions.getSequenceFilePaths(fs, path) ) { //1..N files
- //directly read from sequence files (individual
partfiles)
- try( SequenceFile.Reader reader = new
SequenceFile
- .Reader(job,
SequenceFile.Reader.file(lpath)) )
- {
- MatrixIndexes key = new MatrixIndexes();
- MatrixBlock value = new MatrixBlock();
- while( reader.next(key, value) )
- q.enqueue(new
IndexedMatrixValue(key, new MatrixBlock(value)));
- }
- }
- q.closeInput();
- }
- catch(Exception ex) {
- throw new DMLRuntimeException(ex);
- }
- }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheManager.java
b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheManager.java
index 50b5cf7821..bbf4cfb314 100644
--- a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheManager.java
+++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheManager.java
@@ -100,6 +100,15 @@ public class OOCCacheManager {
}
}
+ public static OOCIOHandler getIOHandler() {
+ OOCIOHandler io = _ioHandler.get();
+ if(io != null)
+ return io;
+ // Ensure initialization happens
+ getCache();
+ return _ioHandler.get();
+ }
+
/**
* Removes a block from the cache without setting its data to null.
*/
@@ -116,11 +125,28 @@ public class OOCCacheManager {
getCache().put(key, value,
((MatrixBlock)value.getValue()).getExactSerializedSize());
}
+ /**
+ * Store a source-backed block in the OOC cache and register its source
location.
+ */
+ public static void putSourceBacked(long streamId, int blockId,
IndexedMatrixValue value,
+ OOCIOHandler.SourceBlockDescriptor descriptor) {
+ BlockKey key = new BlockKey(streamId, blockId);
+ getCache().putSourceBacked(key, value, ((MatrixBlock)
value.getValue()).getExactSerializedSize(), descriptor);
+ }
+
public static OOCStream.QueueCallback<IndexedMatrixValue>
putAndPin(long streamId, int blockId, IndexedMatrixValue value) {
BlockKey key = new BlockKey(streamId, blockId);
return new CachedQueueCallback<>(getCache().putAndPin(key,
value, ((MatrixBlock)value.getValue()).getExactSerializedSize()), null);
}
+ public static OOCStream.QueueCallback<IndexedMatrixValue>
putAndPinSourceBacked(long streamId, int blockId,
+ IndexedMatrixValue value, OOCIOHandler.SourceBlockDescriptor
descriptor) {
+ BlockKey key = new BlockKey(streamId, blockId);
+ return new CachedQueueCallback<>(
+ getCache().putAndPinSourceBacked(key, value,
((MatrixBlock) value.getValue()).getExactSerializedSize(),
+ descriptor), null);
+ }
+
public static
CompletableFuture<OOCStream.QueueCallback<IndexedMatrixValue>>
requestBlock(long streamId, long blockId) {
BlockKey key = new BlockKey(streamId, blockId);
return getCache().request(key).thenApply(e -> new
CachedQueueCallback<>(e, null));
diff --git
a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheScheduler.java
b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheScheduler.java
index 5346b819cf..cd04f9879a 100644
--- a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheScheduler.java
+++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheScheduler.java
@@ -56,6 +56,28 @@ public interface OOCCacheScheduler {
*/
BlockEntry putAndPin(BlockKey key, Object data, long size);
+ /**
+ * Places a new source-backed block in the cache and registers the
location with the IO handler. The entry is
+ * treated as backed by disk, so eviction does not schedule spill
writes.
+ *
+ * @param key the associated key of the block
+ * @param data the block data
+ * @param size the size of the data
+ * @param descriptor the source location descriptor
+ */
+ void putSourceBacked(BlockKey key, Object data, long size,
OOCIOHandler.SourceBlockDescriptor descriptor);
+
+ /**
+ * Places a new source-backed block in the cache and returns a pinned
handle.
+ *
+ * @param key the associated key of the block
+ * @param data the block data
+ * @param size the size of the data
+ * @param descriptor the source location descriptor
+ */
+ BlockEntry putAndPinSourceBacked(BlockKey key, Object data, long size,
+ OOCIOHandler.SourceBlockDescriptor descriptor);
+
/**
* Forgets a block from the cache.
* @param key the associated key of the block
diff --git a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCIOHandler.java
b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCIOHandler.java
index dbfda4e56d..b4d14646e0 100644
--- a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCIOHandler.java
+++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCIOHandler.java
@@ -20,6 +20,7 @@
package org.apache.sysds.runtime.ooc.cache;
import java.util.concurrent.CompletableFuture;
+import java.util.List;
public interface OOCIOHandler {
void shutdown();
@@ -29,4 +30,85 @@ public interface OOCIOHandler {
CompletableFuture<BlockEntry> scheduleRead(BlockEntry block);
CompletableFuture<Boolean> scheduleDeletion(BlockEntry block);
+
+ /**
+ * Registers the source location of a block for future direct reads.
+ */
+ void registerSourceLocation(BlockKey key, SourceBlockDescriptor
descriptor);
+
+ /**
+ * Schedule an asynchronous read from an external source into the
provided target stream.
+ * The returned future completes when either EOF is reached or the
requested byte budget
+ * is exhausted. When the budget is reached and keepOpenOnLimit is
true, the target stream
+ * is kept open and a continuation token is provided so the caller can
resume.
+ */
+ CompletableFuture<SourceReadResult>
scheduleSourceRead(SourceReadRequest request);
+
+ /**
+ * Continue a previously throttled source read using the provided
continuation token.
+ */
+ CompletableFuture<SourceReadResult>
continueSourceRead(SourceReadContinuation continuation, long maxBytesInFlight);
+
+ interface SourceReadContinuation {}
+
+ class SourceReadRequest {
+ public final String path;
+ public final org.apache.sysds.common.Types.FileFormat format;
+ public final long rows;
+ public final long cols;
+ public final int blen;
+ public final long estNnz;
+ public final long maxBytesInFlight;
+ public final boolean keepOpenOnLimit;
+ public final
org.apache.sysds.runtime.instructions.ooc.OOCStream<org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue>
target;
+
+ public SourceReadRequest(String path,
org.apache.sysds.common.Types.FileFormat format, long rows, long cols,
+ int blen, long estNnz, long maxBytesInFlight, boolean
keepOpenOnLimit,
+
org.apache.sysds.runtime.instructions.ooc.OOCStream<org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue>
target) {
+ this.path = path;
+ this.format = format;
+ this.rows = rows;
+ this.cols = cols;
+ this.blen = blen;
+ this.estNnz = estNnz;
+ this.maxBytesInFlight = maxBytesInFlight;
+ this.keepOpenOnLimit = keepOpenOnLimit;
+ this.target = target;
+ }
+ }
+
+ class SourceReadResult {
+ public final long bytesRead;
+ public final boolean eof;
+ public final SourceReadContinuation continuation;
+ public final List<SourceBlockDescriptor> blocks;
+
+ public SourceReadResult(long bytesRead, boolean eof,
SourceReadContinuation continuation,
+ List<SourceBlockDescriptor> blocks) {
+ this.bytesRead = bytesRead;
+ this.eof = eof;
+ this.continuation = continuation;
+ this.blocks = blocks;
+ }
+ }
+
+ class SourceBlockDescriptor {
+ public final String path;
+ public final org.apache.sysds.common.Types.FileFormat format;
+ public final org.apache.sysds.runtime.matrix.data.MatrixIndexes
indexes;
+ public final long offset;
+ public final int recordLength;
+ public final long serializedSize;
+
+ public SourceBlockDescriptor(String path,
org.apache.sysds.common.Types.FileFormat format,
+ org.apache.sysds.runtime.matrix.data.MatrixIndexes
indexes, long offset, int recordLength,
+ long serializedSize) {
+ this.path = path;
+ this.format = format;
+ this.indexes = indexes;
+ this.offset = offset;
+ this.recordLength = recordLength;
+ this.serializedSize = serializedSize;
+ }
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheScheduler.java
b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheScheduler.java
index 1dbba2e3d8..0f30914770 100644
--- a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheScheduler.java
+++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheScheduler.java
@@ -169,22 +169,36 @@ public class OOCLRUCacheScheduler implements
OOCCacheScheduler {
@Override
public void put(BlockKey key, Object data, long size) {
- put(key, data, size, false);
+ put(key, data, size, false, null);
}
@Override
public BlockEntry putAndPin(BlockKey key, Object data, long size) {
- return put(key, data, size, true);
+ return put(key, data, size, true, null);
}
- private BlockEntry put(BlockKey key, Object data, long size, boolean
pin) {
+ @Override
+ public void putSourceBacked(BlockKey key, Object data, long size,
OOCIOHandler.SourceBlockDescriptor descriptor) {
+ put(key, data, size, false, descriptor);
+ }
+
+ @Override
+ public BlockEntry putAndPinSourceBacked(BlockKey key, Object data, long
size, OOCIOHandler.SourceBlockDescriptor descriptor) {
+ return put(key, data, size, true, descriptor);
+ }
+
+ private BlockEntry put(BlockKey key, Object data, long size, boolean
pin, OOCIOHandler.SourceBlockDescriptor descriptor) {
if (!this._running)
throw new IllegalStateException();
if (data == null)
throw new IllegalArgumentException();
+ if (descriptor != null)
+ _ioHandler.registerSourceLocation(key, descriptor);
Statistics.incrementOOCEvictionPut();
BlockEntry entry = new BlockEntry(key, size, data);
+ if (descriptor != null)
+ entry.setState(BlockState.WARM);
if (pin)
entry.pin();
synchronized(this) {
@@ -301,15 +315,15 @@ public class OOCLRUCacheScheduler implements
OOCCacheScheduler {
}
private synchronized void sanityCheck() {
- if (_cacheSize > _hardLimit) {
+ if (_cacheSize > _hardLimit * 1.1) {
if (!_warnThrottling) {
_warnThrottling = true;
- System.out.println("[INFO] Throttling: " +
_cacheSize/1000 + "KB - " + _bytesUpForEviction/1000 + "KB > " +
_hardLimit/1000 + "KB");
+ System.out.println("[WARN] Cache hard limit
exceeded by over 10%: " + String.format("%.2f", _cacheSize/1000000.0) + "MB (-"
+ String.format("%.2f", _bytesUpForEviction/1000000.0) + "MB) > " +
String.format("%.2f", _hardLimit/1000000.0) + "MB");
}
}
- else if (_warnThrottling) {
+ else if (_warnThrottling && _cacheSize < _hardLimit) {
_warnThrottling = false;
- System.out.println("[INFO] No more throttling: " +
_cacheSize/1000 + "KB - " + _bytesUpForEviction/1000 + "KB <= " +
_hardLimit/1000 + "KB");
+ System.out.println("[INFO] Cache within limit: " +
String.format("%.2f", _cacheSize/1000000.0) + "MB (-" + String.format("%.2f",
_bytesUpForEviction/1000000.0) + "MB) <= " + String.format("%.2f",
_hardLimit/1000000.0) + "MB");
}
if (!SANITY_CHECKS)
diff --git
a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCMatrixIOHandler.java
b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCMatrixIOHandler.java
index 3cd16272d2..a9da3ccd29 100644
--- a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCMatrixIOHandler.java
+++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCMatrixIOHandler.java
@@ -20,12 +20,20 @@
package org.apache.sysds.runtime.ooc.cache;
import org.apache.sysds.api.DMLScript;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.io.IOUtilFunctions;
+import org.apache.sysds.runtime.io.MatrixReader;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.ooc.stats.OOCEventLog;
+import org.apache.sysds.runtime.ooc.stream.OOCSourceStream;
import org.apache.sysds.runtime.util.FastBufferedDataInputStream;
import org.apache.sysds.runtime.util.FastBufferedDataOutputStream;
import org.apache.sysds.runtime.util.LocalFileUtils;
@@ -40,6 +48,9 @@ import java.io.OutputStream;
import java.io.RandomAccessFile;
import java.nio.channels.Channels;
import java.nio.channels.ClosedByInterruptException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
@@ -50,9 +61,13 @@ import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
+import java.util.concurrent.atomic.AtomicIntegerArray;
+import java.util.concurrent.atomic.AtomicLongArray;
+import java.util.concurrent.atomic.AtomicReference;
public class OOCMatrixIOHandler implements OOCIOHandler {
- private static final int WRITER_SIZE = 2;
+ private static final int WRITER_SIZE = 4;
+ private static final int READER_SIZE = 10;
private static final long OVERFLOW = 8192 * 1024;
private static final long MAX_PARTITION_SIZE = 8192 * 8192;
@@ -63,6 +78,7 @@ public class OOCMatrixIOHandler implements OOCIOHandler {
// Spill related structures
private final ConcurrentHashMap<String, SpillLocation> _spillLocations
= new ConcurrentHashMap<>();
private final ConcurrentHashMap<Integer, PartitionFile> _partitions =
new ConcurrentHashMap<>();
+ private final ConcurrentHashMap<BlockKey, SourceBlockDescriptor>
_sourceLocations = new ConcurrentHashMap<>();
private final AtomicInteger _partitionCounter = new AtomicInteger(0);
private final CloseableQueue<Tuple2<BlockEntry,
CompletableFuture<Void>>>[] _q;
private final AtomicLong _wCtr;
@@ -70,6 +86,7 @@ public class OOCMatrixIOHandler implements OOCIOHandler {
private final int _evictCallerId = OOCEventLog.registerCaller("write");
private final int _readCallerId = OOCEventLog.registerCaller("read");
+ private final int _srcReadCallerId =
OOCEventLog.registerCaller("read_src");
@SuppressWarnings("unchecked")
public OOCMatrixIOHandler() {
@@ -81,8 +98,8 @@ public class OOCMatrixIOHandler implements OOCIOHandler {
TimeUnit.MILLISECONDS,
new ArrayBlockingQueue<>(100000));
_readExec = new ThreadPoolExecutor(
- 5,
- 5,
+ READER_SIZE,
+ READER_SIZE,
0L,
TimeUnit.MILLISECONDS,
new ArrayBlockingQueue<>(100000));
@@ -161,14 +178,225 @@ public class OOCMatrixIOHandler implements OOCIOHandler {
@Override
public CompletableFuture<Boolean> scheduleDeletion(BlockEntry block) {
- // TODO
+ _sourceLocations.remove(block.getKey());
return CompletableFuture.completedFuture(true);
}
+ @Override
+ public void registerSourceLocation(BlockKey key, SourceBlockDescriptor
descriptor) {
+ _sourceLocations.put(key, descriptor);
+ }
+
+ @Override
+ public CompletableFuture<SourceReadResult>
scheduleSourceRead(SourceReadRequest request) {
+ return submitSourceRead(request, null,
request.maxBytesInFlight);
+ }
+
+ @Override
+ public CompletableFuture<SourceReadResult>
continueSourceRead(SourceReadContinuation continuation, long maxBytesInFlight) {
+ if (!(continuation instanceof SourceReadState state)) {
+ CompletableFuture<SourceReadResult> failed = new
CompletableFuture<>();
+ failed.completeExceptionally(new
DMLRuntimeException("Unsupported continuation type: " + continuation));
+ return failed;
+ }
+ return submitSourceRead(state.request, state, maxBytesInFlight);
+ }
+
+ private CompletableFuture<SourceReadResult>
submitSourceRead(SourceReadRequest request, SourceReadState state,
+ long maxBytesInFlight) {
+ if(request.format != Types.FileFormat.BINARY)
+ return CompletableFuture.failedFuture(
+ new DMLRuntimeException("Unsupported format for
source read: " + request.format));
+ return readBinarySourceParallel(request, state,
maxBytesInFlight);
+ }
+
+ private CompletableFuture<SourceReadResult>
readBinarySourceParallel(SourceReadRequest request,
+ SourceReadState state, long maxBytesInFlight) {
+ final long byteLimit = maxBytesInFlight > 0 ? maxBytesInFlight
: Long.MAX_VALUE;
+ final AtomicLong bytesRead = new AtomicLong(0);
+ final AtomicBoolean stop = new AtomicBoolean(false);
+ final AtomicBoolean budgetHit = new AtomicBoolean(false);
+ final AtomicReference<Throwable> error = new
AtomicReference<>();
+ final Object budgetLock = new Object();
+ final CompletableFuture<SourceReadResult> result = new
CompletableFuture<>();
+ final ConcurrentLinkedDeque<SourceBlockDescriptor> descriptors
= new ConcurrentLinkedDeque<>();
+
+ JobConf job = new
JobConf(ConfigurationManager.getCachedJobConf());
+ Path path = new Path(request.path);
+
+ Path[] files;
+ AtomicLongArray filePositions;
+ AtomicIntegerArray completed;
+
+ try {
+ FileSystem fs = IOUtilFunctions.getFileSystem(path,
job);
+ MatrixReader.checkValidInputFile(fs, path);
+
+ if(state == null) {
+ List<Path> seqFiles = new
ArrayList<>(Arrays.asList(IOUtilFunctions.getSequenceFilePaths(fs, path)));
+ files = seqFiles.toArray(Path[]::new);
+ filePositions = new
AtomicLongArray(files.length);
+ completed = new
AtomicIntegerArray(files.length);
+ }
+ else {
+ files = state.paths;
+ filePositions = state.filePositions;
+ completed = state.completed;
+ }
+ }
+ catch(IOException e) {
+ throw new DMLRuntimeException(e);
+ }
+
+ int activeTasks = 0;
+ for(int i = 0; i < files.length; i++)
+ if(completed.get(i) == 0)
+ activeTasks++;
+
+ final AtomicInteger remaining = new AtomicInteger(activeTasks);
+ boolean anyTask = activeTasks > 0;
+
+ for(int i = 0; i < files.length; i++) {
+ if(completed.get(i) == 1)
+ continue;
+ final int fileIdx = i;
+ try {
+ _readExec.submit(() -> {
+ try {
+ readSequenceFile(job,
files[fileIdx], request, fileIdx, filePositions, completed, stop,
+ budgetHit, bytesRead,
byteLimit, budgetLock, descriptors);
+ }
+ catch(Throwable t) {
+ error.compareAndSet(null, t);
+ stop.set(true);
+ }
+ finally {
+ if(remaining.decrementAndGet()
== 0)
+ completeResult(result,
bytesRead, budgetHit, error, request, files, filePositions,
+ completed,
descriptors);
+ }
+ });
+ }
+ catch(RejectedExecutionException e) {
+ error.compareAndSet(null, e);
+ stop.set(true);
+ if(remaining.decrementAndGet() == 0)
+ completeResult(result, bytesRead,
budgetHit, error, request, files, filePositions, completed,
+ descriptors);
+ break;
+ }
+ }
+
+ if(!anyTask) {
+ tryCloseTarget(request.target, true);
+ result.complete(new SourceReadResult(bytesRead.get(),
true, null, List.of()));
+ }
+
+ return result;
+ }
+
+ private void completeResult(CompletableFuture<SourceReadResult> future,
AtomicLong bytesRead, AtomicBoolean budgetHit,
+ AtomicReference<Throwable> error, SourceReadRequest request,
Path[] files, AtomicLongArray filePositions,
+ AtomicIntegerArray completed,
ConcurrentLinkedDeque<SourceBlockDescriptor> descriptors) {
+ Throwable err = error.get();
+ if (err != null) {
+ future.completeExceptionally(err instanceof Exception ?
err : new Exception(err));
+ return;
+ }
+
+ if (budgetHit.get()) {
+ if (!request.keepOpenOnLimit)
+ tryCloseTarget(request.target, false);
+ SourceReadContinuation cont = new
SourceReadState(request, files, filePositions, completed);
+ future.complete(new SourceReadResult(bytesRead.get(),
false, cont, new ArrayList<>(descriptors)));
+ return;
+ }
+
+ tryCloseTarget(request.target, true);
+ future.complete(new SourceReadResult(bytesRead.get(), true,
null, new ArrayList<>(descriptors)));
+ }
+
+ private void readSequenceFile(JobConf job, Path path, SourceReadRequest
request, int fileIdx,
+ AtomicLongArray filePositions, AtomicIntegerArray completed,
AtomicBoolean stop, AtomicBoolean budgetHit,
+ AtomicLong bytesRead, long byteLimit, Object budgetLock,
ConcurrentLinkedDeque<SourceBlockDescriptor> descriptors)
+ throws IOException {
+ MatrixIndexes key = new MatrixIndexes();
+ MatrixBlock value = new MatrixBlock();
+
+ try(SequenceFile.Reader reader = new SequenceFile.Reader(job,
SequenceFile.Reader.file(path))) {
+ long pos = filePositions.get(fileIdx);
+ if (pos > 0)
+ reader.seek(pos);
+
+ long ioStart = DMLScript.OOC_LOG_EVENTS ?
System.nanoTime() : 0;
+ while(!stop.get()) {
+ long recordStart = reader.getPosition();
+ if (!reader.next(key, value))
+ break;
+ long recordEnd = reader.getPosition();
+ long blockSize = value.getExactSerializedSize();
+ boolean shouldBreak = false;
+
+ synchronized(budgetLock) {
+ if (stop.get())
+ shouldBreak = true;
+ else if (bytesRead.get() + blockSize >
byteLimit) {
+ stop.set(true);
+ budgetHit.set(true);
+ shouldBreak = true;
+ }
+ bytesRead.addAndGet(blockSize);
+ }
+
+ MatrixIndexes outIdx = new MatrixIndexes(key);
+ MatrixBlock outBlk = new MatrixBlock(value);
+ IndexedMatrixValue imv = new
IndexedMatrixValue(outIdx, outBlk);
+ SourceBlockDescriptor descriptor = new
SourceBlockDescriptor(path.toString(), request.format, outIdx,
+ recordStart, (int)(recordEnd -
recordStart), blockSize);
+
+ if (request.target instanceof OOCSourceStream
src)
+ src.enqueue(imv, descriptor);
+ else
+ request.target.enqueue(imv);
+
+ descriptors.add(descriptor);
+ filePositions.set(fileIdx,
reader.getPosition());
+
+ if (DMLScript.OOC_LOG_EVENTS) {
+ long currTime = System.nanoTime();
+
OOCEventLog.onDiskReadEvent(_srcReadCallerId, ioStart, currTime, blockSize);
+ ioStart = currTime;
+ }
+
+ if (shouldBreak)
+ break; // Note that we knowingly go
over limit, which could result in READER_SIZE*8MB overshoot
+ }
+
+ if (!stop.get())
+ completed.set(fileIdx, 1);
+ }
+ }
+
+ private void
tryCloseTarget(org.apache.sysds.runtime.instructions.ooc.OOCStream<IndexedMatrixValue>
target, boolean close) {
+ if (close) {
+ try {
+ target.closeInput();
+ }
+ catch(Exception ignored) {
+ }
+ }
+ }
+
private void loadFromDisk(BlockEntry block) {
String key = block.getKey().toFileKey();
+ SourceBlockDescriptor src =
_sourceLocations.get(block.getKey());
+ if (src != null) {
+ loadFromSource(block, src);
+ return;
+ }
+
long ioDuration = 0;
// 1. find the blocks address (spill location)
SpillLocation sloc = _spillLocations.get(key);
@@ -207,6 +435,28 @@ public class OOCMatrixIOHandler implements OOCIOHandler {
}
}
+ private void loadFromSource(BlockEntry block, SourceBlockDescriptor
src) {
+ if (src.format != Types.FileFormat.BINARY)
+ throw new DMLRuntimeException("Unsupported format for
source read: " + src.format);
+
+ JobConf job = new
JobConf(ConfigurationManager.getCachedJobConf());
+ Path path = new Path(src.path);
+
+ MatrixIndexes ix = new MatrixIndexes();
+ MatrixBlock mb = new MatrixBlock();
+
+ try(SequenceFile.Reader reader = new SequenceFile.Reader(job,
SequenceFile.Reader.file(path))) {
+ reader.seek(src.offset);
+ if (!reader.next(ix, mb))
+ throw new DMLRuntimeException("Failed to read
source block at offset " + src.offset + " in " + src.path);
+ }
+ catch(IOException e) {
+ throw new DMLRuntimeException(e);
+ }
+
+ block.setDataUnsafe(new IndexedMatrixValue(ix, mb));
+ }
+
private void evictTask(CloseableQueue<Tuple2<BlockEntry,
CompletableFuture<Void>>> q) {
long byteCtr = 0;
@@ -276,8 +526,7 @@ public class OOCMatrixIOHandler implements OOCIOHandler {
catch(IOException | InterruptedException ex) {
throw new DMLRuntimeException(ex);
}
- catch(Exception e) {
- // TODO
+ catch(Exception ignored) {
}
finally {
IOUtilFunctions.closeSilently(dos);
@@ -356,4 +605,19 @@ public class OOCMatrixIOHandler implements OOCIOHandler {
return _count;
}
}
+
+ private static class SourceReadState implements SourceReadContinuation {
+ final SourceReadRequest request;
+ final Path[] paths;
+ final AtomicLongArray filePositions;
+ final AtomicIntegerArray completed;
+
+ SourceReadState(SourceReadRequest request, Path[] paths,
AtomicLongArray filePositions,
+ AtomicIntegerArray completed) {
+ this.request = request;
+ this.paths = paths;
+ this.filePositions = filePositions;
+ this.completed = completed;
+ }
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/ooc/stream/OOCSourceStream.java
b/src/main/java/org/apache/sysds/runtime/ooc/stream/OOCSourceStream.java
new file mode 100644
index 0000000000..c48aaa45ab
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/ooc/stream/OOCSourceStream.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.ooc.stream;
+
+import org.apache.sysds.runtime.instructions.ooc.SubscribableTaskQueue;
+import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
+import org.apache.sysds.runtime.ooc.cache.OOCIOHandler;
+import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
+
+import java.util.concurrent.ConcurrentHashMap;
+
+public class OOCSourceStream extends SubscribableTaskQueue<IndexedMatrixValue>
{
+ private final ConcurrentHashMap<MatrixIndexes,
OOCIOHandler.SourceBlockDescriptor> _idx;
+
+ public OOCSourceStream() {
+ this._idx = new ConcurrentHashMap<>();
+ }
+
+ public void enqueue(IndexedMatrixValue value,
OOCIOHandler.SourceBlockDescriptor descriptor) {
+ if(descriptor == null)
+ throw new IllegalArgumentException("Source descriptor
must not be null");
+ MatrixIndexes key = new MatrixIndexes(descriptor.indexes);
+ _idx.put(key, descriptor);
+ super.enqueue(value);
+ }
+
+ @Override
+ public void enqueue(IndexedMatrixValue val) {
+ throw new UnsupportedOperationException("Use enqueue(value,
descriptor) for source streams");
+ }
+
+ public OOCIOHandler.SourceBlockDescriptor getDescriptor(MatrixIndexes
indexes) {
+ return _idx.get(indexes);
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/runtime/ooc/cache/SourceBackedCacheSchedulerTest.java
b/src/test/java/org/apache/sysds/runtime/ooc/cache/SourceBackedCacheSchedulerTest.java
new file mode 100644
index 0000000000..423c2b7f42
--- /dev/null
+++
b/src/test/java/org/apache/sysds/runtime/ooc/cache/SourceBackedCacheSchedulerTest.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.ooc.cache;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.instructions.ooc.SubscribableTaskQueue;
+import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.io.MatrixWriter;
+import org.apache.sysds.runtime.io.MatrixWriterFactory;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+public class SourceBackedCacheSchedulerTest extends AutomatedTestBase {
+ private static final String TEST_NAME = "SourceBackedCacheScheduler";
+ private static final String TEST_DIR = "functions/ooc/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
SourceBackedCacheSchedulerTest.class.getSimpleName() + "/";
+
+ private OOCMatrixIOHandler handler;
+ private OOCLRUCacheScheduler scheduler;
+
+ @Override
+ @Before
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME));
+ handler = new OOCMatrixIOHandler();
+ scheduler = new OOCLRUCacheScheduler(handler, 0,
Long.MAX_VALUE);
+ }
+
+ @After
+ public void tearDown() {
+ if (scheduler != null)
+ scheduler.shutdown();
+ if (handler != null)
+ handler.shutdown();
+ }
+
+ @Test
+ public void testPutSourceBackedAndReload() throws Exception {
+ getAndLoadTestConfiguration(TEST_NAME);
+ final int rows = 4;
+ final int cols = 4;
+ final int blen = 2;
+
+ MatrixBlock src = MatrixBlock.randOperations(rows, cols, 1.0,
-1, 1, "uniform", 23);
+ String fname = input("binary_src_cache");
+ writeBinaryMatrix(src, fname, blen);
+
+ SubscribableTaskQueue<IndexedMatrixValue> target = new
SubscribableTaskQueue<>();
+ OOCIOHandler.SourceReadRequest req = new
OOCIOHandler.SourceReadRequest(fname, Types.FileFormat.BINARY,
+ rows, cols, blen, src.getNonZeros(), Long.MAX_VALUE,
true, target);
+
+ OOCIOHandler.SourceReadResult res =
handler.scheduleSourceRead(req).get();
+ IndexedMatrixValue imv = target.dequeue();
+ OOCIOHandler.SourceBlockDescriptor desc = res.blocks.get(0);
+
+ BlockKey key = new BlockKey(11, 0);
+ BlockEntry entry = scheduler.putAndPinSourceBacked(key, imv,
+ ((MatrixBlock)
imv.getValue()).getExactSerializedSize(), desc);
+ org.junit.Assert.assertEquals(BlockState.WARM,
entry.getState());
+
+ scheduler.unpin(entry);
+ org.junit.Assert.assertEquals(BlockState.COLD,
entry.getState());
+ org.junit.Assert.assertNull(entry.getDataUnsafe());
+
+ BlockEntry reloaded = scheduler.request(key).get();
+ IndexedMatrixValue reloadImv = (IndexedMatrixValue)
reloaded.getData();
+ MatrixBlock expected = expectedBlock(src, desc.indexes, blen);
+ TestUtils.compareMatrices(expected, (MatrixBlock)
reloadImv.getValue(), 1e-12);
+ }
+
+ private void writeBinaryMatrix(MatrixBlock mb, String fname, int blen)
throws Exception {
+ MatrixWriter writer =
MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY);
+ writer.writeMatrixToHDFS(mb, fname, mb.getNumRows(),
mb.getNumColumns(), blen, mb.getNonZeros());
+ }
+
+ private MatrixBlock expectedBlock(MatrixBlock src,
org.apache.sysds.runtime.matrix.data.MatrixIndexes idx, int blen) {
+ int rowStart = (int) ((idx.getRowIndex() - 1) * blen);
+ int colStart = (int) ((idx.getColumnIndex() - 1) * blen);
+ int rowEnd = Math.min(rowStart + blen - 1, src.getNumRows() -
1);
+ int colEnd = Math.min(colStart + blen - 1, src.getNumColumns()
- 1);
+ return src.slice(rowStart, rowEnd, colStart, colEnd);
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/runtime/ooc/cache/SourceBackedReadOOCIOHandlerTest.java
b/src/test/java/org/apache/sysds/runtime/ooc/cache/SourceBackedReadOOCIOHandlerTest.java
new file mode 100644
index 0000000000..e688bf0f1c
--- /dev/null
+++
b/src/test/java/org/apache/sysds/runtime/ooc/cache/SourceBackedReadOOCIOHandlerTest.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.ooc.cache;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.instructions.ooc.SubscribableTaskQueue;
+import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.io.MatrixWriter;
+import org.apache.sysds.runtime.io.MatrixWriterFactory;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+public class SourceBackedReadOOCIOHandlerTest extends AutomatedTestBase {
+ private static final String TEST_NAME = "SourceBackedReadOOCIOHandler";
+ private static final String TEST_DIR = "functions/ooc/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
SourceBackedReadOOCIOHandlerTest.class.getSimpleName() + "/";
+
+ private OOCMatrixIOHandler handler;
+
+ @Override
+ @Before
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME));
+ handler = new OOCMatrixIOHandler();
+ }
+
+ @After
+ public void tearDown() {
+ if (handler != null)
+ handler.shutdown();
+ }
+
+ @Test
+ public void testSourceBackedScheduleRead() throws Exception {
+ getAndLoadTestConfiguration(TEST_NAME);
+ final int rows = 4;
+ final int cols = 4;
+ final int blen = 2;
+
+ MatrixBlock src = MatrixBlock.randOperations(rows, cols, 1.0,
-1, 1, "uniform", 17);
+ String fname = input("binary_src");
+ writeBinaryMatrix(src, fname, blen);
+
+ SubscribableTaskQueue<IndexedMatrixValue> target = new
SubscribableTaskQueue<>();
+ OOCIOHandler.SourceReadRequest req = new
OOCIOHandler.SourceReadRequest(fname, Types.FileFormat.BINARY,
+ rows, cols, blen, src.getNonZeros(), Long.MAX_VALUE,
true, target);
+
+ OOCIOHandler.SourceReadResult res =
handler.scheduleSourceRead(req).get();
+ org.junit.Assert.assertFalse(res.blocks.isEmpty());
+
+ OOCIOHandler.SourceBlockDescriptor desc = res.blocks.get(0);
+ BlockKey key = new BlockKey(7, 0);
+ handler.registerSourceLocation(key, desc);
+
+ BlockEntry entry = new BlockEntry(key, desc.serializedSize,
null);
+ entry.setState(BlockState.COLD);
+ handler.scheduleRead(entry).get();
+
+ IndexedMatrixValue imv = (IndexedMatrixValue)
entry.getDataUnsafe();
+ MatrixBlock readBlock = (MatrixBlock) imv.getValue();
+ MatrixBlock expected = expectedBlock(src, desc.indexes, blen);
+ TestUtils.compareMatrices(expected, readBlock, 1e-12);
+ }
+
+ private void writeBinaryMatrix(MatrixBlock mb, String fname, int blen)
throws Exception {
+ MatrixWriter writer =
MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY);
+ writer.writeMatrixToHDFS(mb, fname, mb.getNumRows(),
mb.getNumColumns(), blen, mb.getNonZeros());
+ }
+
+ private MatrixBlock expectedBlock(MatrixBlock src,
org.apache.sysds.runtime.matrix.data.MatrixIndexes idx, int blen) {
+ int rowStart = (int) ((idx.getRowIndex() - 1) * blen);
+ int colStart = (int) ((idx.getColumnIndex() - 1) * blen);
+ int rowEnd = Math.min(rowStart + blen - 1, src.getNumRows() -
1);
+ int colEnd = Math.min(colStart + blen - 1, src.getNumColumns()
- 1);
+ return src.slice(rowStart, rowEnd, colStart, colEnd);
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/ooc/SourceReadOOCIOHandlerTest.java
b/src/test/java/org/apache/sysds/test/functions/ooc/SourceReadOOCIOHandlerTest.java
new file mode 100644
index 0000000000..34dd01d662
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/ooc/SourceReadOOCIOHandlerTest.java
@@ -0,0 +1,143 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.ooc;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.instructions.ooc.SubscribableTaskQueue;
+import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.ooc.cache.OOCIOHandler;
+import org.apache.sysds.runtime.ooc.cache.OOCMatrixIOHandler;
+import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
+import org.apache.sysds.runtime.io.MatrixWriter;
+import org.apache.sysds.runtime.io.MatrixWriterFactory;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.List;
+
+public class SourceReadOOCIOHandlerTest extends AutomatedTestBase {
+ private static final String TEST_NAME = "SourceReadOOCIOHandler";
+ private static final String TEST_DIR = "functions/ooc/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
SourceReadOOCIOHandlerTest.class.getSimpleName() + "/";
+
+ private OOCMatrixIOHandler handler;
+
+ @Override
+ @Before
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME));
+ handler = new OOCMatrixIOHandler();
+ }
+
+ @After
+ public void tearDown() {
+ if (handler != null)
+ handler.shutdown();
+ }
+
+ @Test
+ public void testSourceReadCompletes() throws Exception {
+ getAndLoadTestConfiguration(TEST_NAME);
+ final int rows = 4;
+ final int cols = 4;
+ final int blen = 2;
+
+ MatrixBlock src = MatrixBlock.randOperations(rows, cols, 1.0,
-1, 1, "uniform", 7);
+ String fname = input("binary_full");
+ writeBinaryMatrix(src, fname, blen);
+
+ SubscribableTaskQueue<IndexedMatrixValue> target = new
SubscribableTaskQueue<>();
+ OOCIOHandler.SourceReadRequest req = new
OOCIOHandler.SourceReadRequest(fname, Types.FileFormat.BINARY,
+ rows, cols, blen, src.getNonZeros(), Long.MAX_VALUE,
true, target);
+
+ OOCIOHandler.SourceReadResult res =
handler.scheduleSourceRead(req).get();
+ // Drain after EOF
+ MatrixBlock reconstructed = drainToMatrix(target, rows, cols,
blen);
+
+ TestUtils.compareMatrices(src, reconstructed, 1e-12);
+ org.junit.Assert.assertTrue(res.eof);
+ org.junit.Assert.assertNull(res.continuation);
+ org.junit.Assert.assertNotNull(res.blocks);
+ org.junit.Assert.assertEquals((rows / blen) * (cols / blen),
res.blocks.size());
+ org.junit.Assert.assertTrue(res.blocks.stream().allMatch(b ->
b.indexes != null));
+ }
+
+ @Test
+ public void testSourceReadStopsOnBudgetAndContinues() throws Exception {
+ getAndLoadTestConfiguration(TEST_NAME);
+ final int rows = 4;
+ final int cols = 4;
+ final int blen = 2;
+
+ MatrixBlock src = MatrixBlock.randOperations(rows, cols, 1.0,
-1, 1, "uniform", 13);
+ String fname = input("binary_budget");
+ writeBinaryMatrix(src, fname, blen);
+
+ long singleBlockSize = new MatrixBlock(blen, blen,
false).getExactSerializedSize();
+ long budget = singleBlockSize + 1; // ensure we stop before the
second block
+
+ SubscribableTaskQueue<IndexedMatrixValue> target = new
SubscribableTaskQueue<>();
+ OOCIOHandler.SourceReadRequest req = new
OOCIOHandler.SourceReadRequest(fname, Types.FileFormat.BINARY,
+ rows, cols, blen, src.getNonZeros(), budget, true,
target);
+
+ OOCIOHandler.SourceReadResult first =
handler.scheduleSourceRead(req).get();
+ org.junit.Assert.assertFalse(first.eof);
+ org.junit.Assert.assertNotNull(first.continuation);
+ org.junit.Assert.assertNotNull(first.blocks);
+
+ OOCIOHandler.SourceReadResult second =
handler.continueSourceRead(first.continuation, Long.MAX_VALUE).get();
+ org.junit.Assert.assertTrue(second.eof);
+ org.junit.Assert.assertNull(second.continuation);
+ org.junit.Assert.assertNotNull(second.blocks);
+ org.junit.Assert.assertEquals((rows / blen) * (cols / blen),
first.blocks.size() + second.blocks.size());
+
+ MatrixBlock reconstructed = drainToMatrix(target, rows, cols,
blen);
+ TestUtils.compareMatrices(src, reconstructed, 1e-12);
+ }
+
+ private void writeBinaryMatrix(MatrixBlock mb, String fname, int blen)
throws Exception {
+ MatrixWriter writer =
MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY);
+ writer.writeMatrixToHDFS(mb, fname, mb.getNumRows(),
mb.getNumColumns(), blen, mb.getNonZeros());
+ }
+
+ private MatrixBlock
drainToMatrix(SubscribableTaskQueue<IndexedMatrixValue> target, int rows, int
cols, int blen) {
+ List<IndexedMatrixValue> blocks = new ArrayList<>();
+ IndexedMatrixValue tmp;
+ while((tmp = target.dequeue()) != LocalTaskQueue.NO_MORE_TASKS)
{
+ blocks.add(tmp);
+ }
+
+ MatrixBlock out = new MatrixBlock(rows, cols, false);
+ for (IndexedMatrixValue imv : blocks) {
+ int rowOffset = (int)((imv.getIndexes().getRowIndex() -
1) * blen);
+ int colOffset =
(int)((imv.getIndexes().getColumnIndex() - 1) * blen);
+ ((MatrixBlock)imv.getValue()).putInto(out, rowOffset,
colOffset, true);
+ }
+ out.recomputeNonZeros();
+ return out;
+ }
+}