This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 58361b2dde [SYSTEMDS-3893] Basic out-of-core binary-read and acquire
primitive
58361b2dde is described below
commit 58361b2dde9ab2e361b20f27e46a352a90003c1a
Author: Matthias Boehm <[email protected]>
AuthorDate: Sun Jul 13 12:46:15 2025 +0200
[SYSTEMDS-3893] Basic out-of-core binary-read and acquire primitive
This patch introduces a basic integration of the out-of-core backend.
For reading, we use a dedicated reblock instruction which creates
a queue of blocks, spawns a thread for reading and immediately returns.
In addition, we extended the acquireRead functionality to collect such
streams of blocks whenever an operations requires the full matrix.
Based on these foundations, we can now add other OCC operations that
directly work with the input stream of blocks and produce either results
or created modified output streams.
---
.github/workflows/javaTests.yml | 2 +-
src/main/java/org/apache/sysds/common/Opcodes.java | 24 ++--
.../java/org/apache/sysds/hops/AggUnaryOp.java | 5 +-
src/main/java/org/apache/sysds/hops/BinaryOp.java | 3 +
src/main/java/org/apache/sysds/hops/DataOp.java | 6 +-
src/main/java/org/apache/sysds/hops/Hop.java | 4 +-
.../hops/rewrite/RewriteBlockSizeAndReblock.java | 7 +-
src/main/java/org/apache/sysds/lops/ReBlock.java | 4 +-
.../controlprogram/caching/CacheableData.java | 22 +++-
.../controlprogram/caching/FrameObject.java | 8 ++
.../controlprogram/caching/MatrixObject.java | 34 +++++-
.../controlprogram/caching/TensorObject.java | 9 ++
.../runtime/instructions/OOCInstructionParser.java | 5 +-
.../ooc/ComputationOOCInstruction.java | 48 ++++++++
.../runtime/instructions/ooc/OOCInstruction.java | 2 +-
.../instructions/ooc/ReblockOOCInstruction.java | 123 +++++++++++++++++++++
.../org/apache/sysds/runtime/io/MatrixReader.java | 3 +-
src/main/java/org/apache/sysds/utils/Explain.java | 6 +-
.../functions/ooc/SumScalarMultiplicationTest.java | 53 +++++----
19 files changed, 318 insertions(+), 50 deletions(-)
diff --git a/.github/workflows/javaTests.yml b/.github/workflows/javaTests.yml
index d13b187fb2..c11f00ed4f 100644
--- a/.github/workflows/javaTests.yml
+++ b/.github/workflows/javaTests.yml
@@ -73,7 +73,7 @@ jobs:
"**.functions.builtin.part1.**",
"**.functions.builtin.part2.**",
"**.functions.frame.**,**.functions.indexing.**,**.functions.io.**,**.functions.iogen.**",
- "**.functions.dnn.**",
+ "**.functions.dnn.**,**.functions.ooc.**",
"**.functions.paramserv.**",
"**.functions.recompile.**,**.functions.misc.**",
"**.functions.mlcontext.**",
diff --git a/src/main/java/org/apache/sysds/common/Opcodes.java
b/src/main/java/org/apache/sysds/common/Opcodes.java
index a4081f9292..fd5c6bfd12 100644
--- a/src/main/java/org/apache/sysds/common/Opcodes.java
+++ b/src/main/java/org/apache/sysds/common/Opcodes.java
@@ -349,7 +349,7 @@ public enum Opcodes {
MAPMIN("mapmin", InstructionType.Binary),
//REBLOCK Instruction Opcodes
- RBLK("rblk", null, InstructionType.Reblock),
+ RBLK("rblk", null, InstructionType.Reblock, null,
InstructionType.Reblock),
CSVRBLK("csvrblk", InstructionType.CSVReblock),
LIBSVMRBLK("libsvmrblk", InstructionType.LIBSVMReblock),
@@ -398,24 +398,23 @@ public enum Opcodes {
// Constructors
Opcodes(String name, InstructionType type) {
- this._name = name;
- this._type = type;
- this._spType=null;
- this._fedType=null;
+ this(name, type, null, null, null);
}
Opcodes(String name, InstructionType type, InstructionType spType){
- this._name=name;
- this._type=type;
- this._spType=spType;
- this._fedType=null;
+ this(name, type, spType, null, null);
}
Opcodes(String name, InstructionType type, InstructionType spType,
InstructionType fedType){
+ this(name, type, spType, fedType, null);
+ }
+
+ Opcodes(String name, InstructionType type, InstructionType spType,
InstructionType fedType, InstructionType oocType){
this._name=name;
this._type=type;
this._spType=spType;
this._fedType=fedType;
+ this._oocType=oocType;
}
// Fields
@@ -423,6 +422,7 @@ public enum Opcodes {
private final InstructionType _type;
private final InstructionType _spType;
private final InstructionType _fedType;
+ private final InstructionType _oocType;
private static final Map<String, Opcodes> _lookupMap = new HashMap<>();
@@ -451,6 +451,10 @@ public enum Opcodes {
public InstructionType getFedType(){
return _fedType != null ? _fedType : _type;
}
+
+ public InstructionType getOocType(){
+ return _oocType != null ? _oocType : _type;
+ }
public static InstructionType getTypeByOpcode(String opcode,
Types.ExecType type) {
if (opcode == null || opcode.trim().isEmpty()) {
@@ -463,6 +467,8 @@ public enum Opcodes {
return (op.getSpType() != null) ?
op.getSpType() : op.getType();
case FED:
return (op.getFedType() != null) ?
op.getFedType() : op.getType();
+ case OOC:
+ return (op.getOocType() != null) ?
op.getOocType() : op.getType();
default:
return op.getType();
}
diff --git a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
index 0b2d62bbe3..2f5cb53acf 100644
--- a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
@@ -116,7 +116,7 @@ public class AggUnaryOp extends MultiThreadedHop
ExecType et = optFindExecType();
Hop input = getInput().get(0);
- if ( et == ExecType.CP || et == ExecType.GPU || et ==
ExecType.FED )
+ if ( et == ExecType.CP || et == ExecType.GPU || et ==
ExecType.FED || et == ExecType.OOC )
{
Lop agg1 = null;
if( isTernaryAggregateRewriteApplicable() ) {
@@ -409,6 +409,9 @@ public class AggUnaryOp extends MultiThreadedHop
else
setRequiresRecompileIfNecessary();
+ if( _etype == ExecType.OOC ) //TODO
+ setExecType(ExecType.CP);
+
return _etype;
}
diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java
b/src/main/java/org/apache/sysds/hops/BinaryOp.java
index a3ddb45ea6..f433931a52 100644
--- a/src/main/java/org/apache/sysds/hops/BinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java
@@ -854,6 +854,9 @@ public class BinaryOp extends MultiThreadedHop {
_etype = ExecType.CP;
}
+ if( _etype == ExecType.OOC ) //TODO
+ setExecType(ExecType.CP);
+
//mark for recompile (forever)
setRequiresRecompileIfNecessary();
diff --git a/src/main/java/org/apache/sysds/hops/DataOp.java
b/src/main/java/org/apache/sysds/hops/DataOp.java
index 1ae8616001..eb0d1961cf 100644
--- a/src/main/java/org/apache/sysds/hops/DataOp.java
+++ b/src/main/java/org/apache/sysds/hops/DataOp.java
@@ -24,6 +24,7 @@ import java.util.Map.Entry;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.FileFormat;
import org.apache.sysds.common.Types.OpOpData;
@@ -465,6 +466,9 @@ public class DataOp extends Hop {
}
else //READ
{
+ if( DMLScript.USE_OOC )
+ checkAndSetForcedPlatform();
+
//mark for recompile (forever)
if( ConfigurationManager.isDynamicRecompilation() &&
!dimsKnown(true) && letype==ExecType.SPARK
&& (_recompileRead || _requiresCheckpoint) )
@@ -473,7 +477,7 @@ public class DataOp extends Hop {
}
_etype = letype;
- if ( _etypeForced == ExecType.FED )
+ if ( _etypeForced == ExecType.FED || _etypeForced ==
ExecType.OOC )
_etype = _etypeForced;
}
diff --git a/src/main/java/org/apache/sysds/hops/Hop.java
b/src/main/java/org/apache/sysds/hops/Hop.java
index 68e5bc94c0..86749d44c1 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -256,6 +256,8 @@ public abstract class Hop implements ParseInfo {
{
if(DMLScript.USE_ACCELERATOR && DMLScript.FORCE_ACCELERATOR &&
isGPUEnabled())
_etypeForced = ExecType.GPU; // enabled with -gpu force
option
+ else if (DMLScript.USE_OOC)
+ _etypeForced = ExecType.OOC;
else if ( DMLScript.getGlobalExecMode() == ExecMode.SINGLE_NODE
&& _etypeForced != ExecType.FED ) {
if(OptimizerUtils.isMemoryBasedOptLevel() &&
DMLScript.USE_ACCELERATOR && isGPUEnabled()) {
// enabled with -exec singlenode -gpu option
@@ -406,7 +408,7 @@ public abstract class Hop implements ParseInfo {
private void constructAndSetReblockLopIfRequired()
{
//determine execution type
- ExecType et = ExecType.CP;
+ ExecType et = DMLScript.USE_OOC ? ExecType.OOC : ExecType.CP;
if( DMLScript.getGlobalExecMode() != ExecMode.SINGLE_NODE
&& !(getDataType()==DataType.SCALAR) )
{
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteBlockSizeAndReblock.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteBlockSizeAndReblock.java
index 4e03e02f62..4b5eaa8a9a 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteBlockSizeAndReblock.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteBlockSizeAndReblock.java
@@ -29,6 +29,7 @@ import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.HopsException;
import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.DataType;
/**
@@ -80,8 +81,12 @@ public class RewriteBlockSizeAndReblock extends
HopRewriteRule
{
DataOp dop = (DataOp) hop;
+ if( DMLScript.USE_OOC && dop.getOp() ==
OpOpData.PERSISTENTREAD ) {
+ dop.setRequiresReblock(true);
+ dop.setBlocksize(blocksize);
+ }
// if block size does not match
- if( (dop.getDataType() == DataType.MATRIX &&
(dop.getBlocksize() != blocksize))
+ else if( (dop.getDataType() == DataType.MATRIX &&
(dop.getBlocksize() != blocksize))
||(dop.getDataType() == DataType.FRAME &&
OptimizerUtils.isSparkExecutionMode()
&& (dop.getFileFormat()==FileFormat.TEXT ||
dop.getFileFormat()==FileFormat.CSV)) )
{
diff --git a/src/main/java/org/apache/sysds/lops/ReBlock.java
b/src/main/java/org/apache/sysds/lops/ReBlock.java
index 2e2c9dc2fd..d92144d63e 100644
--- a/src/main/java/org/apache/sysds/lops/ReBlock.java
+++ b/src/main/java/org/apache/sysds/lops/ReBlock.java
@@ -46,8 +46,8 @@ public class ReBlock extends Lop {
_blocksize = blen;
_outputEmptyBlocks = outputEmptyBlocks;
- if(et == ExecType.SPARK)
- lps.setProperties(inputs, ExecType.SPARK);
+ if(et == ExecType.SPARK || et == ExecType.OOC)
+ lps.setProperties(inputs, et);
else
throw new LopsException("Incorrect execution type for
Reblock:" + et);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
index eba22e7f15..e075b55e17 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
@@ -42,12 +42,14 @@ import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.controlprogram.caching.LazyWriteBuffer.RPolicy;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.fed.InitFEDInstruction;
import org.apache.sysds.runtime.instructions.gpu.context.GPUContext;
import org.apache.sysds.runtime.instructions.gpu.context.GPUObject;
import org.apache.sysds.runtime.instructions.spark.data.BroadcastObject;
+import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.instructions.spark.data.RDDObject;
import org.apache.sysds.runtime.io.FileFormatProperties;
import org.apache.sysds.runtime.io.IOUtilFunctions;
@@ -210,13 +212,15 @@ public abstract class CacheableData<T extends
CacheBlock<?>> extends Data
private boolean _requiresLocalWrite = false; //flag if local write for
read obj
private boolean _isAcquireFromEmpty = false; //flag if read from status
empty
- //spark-specific handles
+ //backend-specific handles
//note: we use the abstraction of LineageObjects for two reasons: (1)
to keep track of cleanup
//for lazily evaluated RDDs, and (2) as abstraction for environments
that do not necessarily have spark libraries available
private RDDObject _rddHandle = null; //RDD handle
private BroadcastObject<T> _bcHandle = null; //Broadcast handle
protected HashMap<GPUContext, GPUObject> _gpuObjects = null; //Per
GPUContext object allocated on GPU
-
+ //TODO generalize for frames
+ private LocalTaskQueue<IndexedMatrixValue> _streamHandle = null;
+
private LineageItem _lineage = null;
/**
@@ -460,6 +464,10 @@ public abstract class CacheableData<T extends
CacheBlock<?>> extends Data
public boolean hasBroadcastHandle() {
return _bcHandle != null && _bcHandle.hasBackReference();
}
+
+ public LocalTaskQueue<IndexedMatrixValue> getStreamHandle() {
+ return _streamHandle;
+ }
@SuppressWarnings({ "rawtypes", "unchecked" })
public void setBroadcastHandle( BroadcastObject bc ) {
@@ -490,6 +498,10 @@ public abstract class CacheableData<T extends
CacheBlock<?>> extends Data
public synchronized void removeGPUObject(GPUContext gCtx) {
_gpuObjects.remove(gCtx);
}
+
+ public synchronized void
setStreamHandle(LocalTaskQueue<IndexedMatrixValue> q) {
+ _streamHandle = q;
+ }
// *********************************************
// *** ***
@@ -580,6 +592,9 @@ public abstract class CacheableData<T extends
CacheBlock<?>> extends Data
//mark for initial local write despite
read operation
_requiresLocalWrite = false;
}
+ else if( getStreamHandle() != null ) {
+ _data = readBlobFromStream(
getStreamHandle() );
+ }
else if( getRDDHandle()==null ||
getRDDHandle().allowsShortCircuitRead() ) {
if( DMLScript.STATISTICS )
CacheStatistics.incrementHDFSHits();
@@ -1099,6 +1114,9 @@ public abstract class CacheableData<T extends
CacheBlock<?>> extends Data
protected abstract T readBlobFromRDD(RDDObject rdd, MutableBoolean
status)
throws IOException;
+ protected abstract T
readBlobFromStream(LocalTaskQueue<IndexedMatrixValue> stream)
+ throws IOException;
+
// Federated read
protected T readBlobFromFederated(FederationMap fedMap) throws
IOException {
if( LOG.isDebugEnabled() ) //common if instructions keep
federated outputs
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java
index 582bb64dd8..56cc276cd8 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java
@@ -33,7 +33,9 @@ import
org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
import org.apache.sysds.runtime.frame.data.FrameBlock;
+import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.instructions.spark.data.RDDObject;
import org.apache.sysds.runtime.io.FileFormatProperties;
import org.apache.sysds.runtime.io.FrameReaderFactory;
@@ -304,6 +306,12 @@ public class FrameObject extends CacheableData<FrameBlock>
//lazy evaluation of pending transformations.
SparkExecutionContext.writeFrameRDDtoHDFS(rdd, fname,
iimd.getFileFormat());
}
+
+ @Override
+ protected FrameBlock
readBlobFromStream(LocalTaskQueue<IndexedMatrixValue> stream) throws
IOException {
+ // TODO Auto-generated method stub
+ return null;
+ }
@Override
protected FrameBlock reconstructByLineage(LineageItem li) throws
IOException {
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
index f58b315e68..e9204bdaed 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
@@ -42,7 +42,9 @@ import
org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
import org.apache.sysds.runtime.instructions.fed.InitFEDInstruction;
+import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.instructions.spark.data.RDDObject;
import org.apache.sysds.runtime.io.FileFormatProperties;
import org.apache.sysds.runtime.io.ReaderWriterFederated;
@@ -442,8 +444,10 @@ public class MatrixObject extends
CacheableData<MatrixBlock> {
// Read matrix and maintain meta data,
// if the MatrixObject is federated there is nothing extra to
read, and therefore only acquire read and release
int blen = mc.getBlocksize() <= 0 ?
ConfigurationManager.getBlocksize() : mc.getBlocksize();
- MatrixBlock newData = isFederated() ? acquireReadAndRelease() :
DataConverter.readMatrixFromHDFS(fname,
- iimd.getFileFormat(), rlen, clen, blen,
mc.getNonZeros(), getFileFormatProperties());
+ MatrixBlock newData =
+ isFederated() ? acquireReadAndRelease() :
+ DataConverter.readMatrixFromHDFS(fname,
iimd.getFileFormat(),
+ rlen, clen, blen, mc.getNonZeros(),
getFileFormatProperties());
if(iimd.getFileFormat() == FileFormat.CSV) {
_metaData = _metaData instanceof MetaDataFormat ? new
MetaDataFormat(newData.getDataCharacteristics(),
@@ -518,6 +522,32 @@ public class MatrixObject extends
CacheableData<MatrixBlock> {
return mb;
}
+
+
+ @Override
+ protected MatrixBlock
readBlobFromStream(LocalTaskQueue<IndexedMatrixValue> stream) throws
IOException {
+ MatrixBlock ret = new MatrixBlock((int)getNumRows(),
(int)getNumColumns(), false);
+ IndexedMatrixValue tmp = null;
+ try {
+ int blen = getBlocksize(), lnnz = 0;
+ while( (tmp = stream.dequeueTask()) !=
LocalTaskQueue.NO_MORE_TASKS ) {
+ // compute row/column block offsets
+ final int row_offset = (int)
(tmp.getIndexes().getRowIndex() - 1) * blen;
+ final int col_offset = (int)
(tmp.getIndexes().getColumnIndex() - 1) * blen;
+
+ // Add the values of this block into the output
block.
+ ((MatrixBlock)tmp.getValue()).putInto(ret,
row_offset, col_offset, true);
+
+ // incremental maintenance nnz
+ lnnz += tmp.getValue().getNonZeros();
+ }
+ ret.setNonZeros(lnnz);
+ }
+ catch(Exception ex) {
+ throw new DMLRuntimeException(ex);
+ }
+ return ret;
+ }
@Override
protected MatrixBlock readBlobFromFederated(FederationMap fedMap,
long[] dims) throws IOException {
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/TensorObject.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/TensorObject.java
index 8908f55d06..d39ed8c8a9 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/TensorObject.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/TensorObject.java
@@ -30,8 +30,10 @@ import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
import org.apache.sysds.runtime.data.TensorBlock;
import org.apache.sysds.runtime.data.TensorIndexes;
+import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.instructions.spark.data.RDDObject;
import org.apache.sysds.runtime.io.FileFormatProperties;
import org.apache.sysds.runtime.lineage.LineageItem;
@@ -199,6 +201,13 @@ public class TensorObject extends
CacheableData<TensorBlock> {
//TODO rdd write
}
+
+ @Override
+ protected TensorBlock
readBlobFromStream(LocalTaskQueue<IndexedMatrixValue> stream) throws
IOException {
+ // TODO Auto-generated method stub
+ return null;
+ }
+
@Override
protected TensorBlock reconstructByLineage(LineageItem li) throws
IOException {
return ((TensorObject) LineageRecomputeUtils
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
index 191976f094..e0f84c5bd2 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
@@ -24,6 +24,7 @@ import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.InstructionType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.instructions.ooc.OOCInstruction;
+import org.apache.sysds.runtime.instructions.ooc.ReblockOOCInstruction;
public class OOCInstructionParser extends InstructionParser {
protected static final Log LOG =
LogFactory.getLog(OOCInstructionParser.class.getName());
@@ -44,7 +45,9 @@ public class OOCInstructionParser extends InstructionParser {
if(str == null || str.isEmpty())
return null;
switch(ooctype) {
-
+ case Reblock:
+ return
ReblockOOCInstruction.parseInstruction(str);
+
// TODO:
case AggregateUnary:
case Binary:
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ComputationOOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ComputationOOCInstruction.java
new file mode 100644
index 0000000000..5552017493
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ComputationOOCInstruction.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.ooc;
+
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+
+public abstract class ComputationOOCInstruction extends OOCInstruction {
+ public CPOperand output;
+ public CPOperand input1, input2, input3;
+
+ protected ComputationOOCInstruction(OOCType type, Operator op,
CPOperand in1, CPOperand out, String opcode, String istr) {
+ super(type, op, opcode, istr);
+ input1 = in1;
+ input2 = null;
+ input3 = null;
+ output = out;
+ }
+
+ protected ComputationOOCInstruction(OOCType type, Operator op,
CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) {
+ super(type, op, opcode, istr);
+ input1 = in1;
+ input2 = in2;
+ input3 = null;
+ output = out;
+ }
+
+ public String getOutputVariableName() {
+ return output.getName();
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
index 83cc972135..fe73e57fd2 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
@@ -30,7 +30,7 @@ public abstract class OOCInstruction extends Instruction {
protected static final Log LOG =
LogFactory.getLog(OOCInstruction.class.getName());
public enum OOCType {
- AggregateUnary, Binary
+ Reblock, AggregateUnary, Binary
}
protected final OOCInstruction.OOCType _ooctype;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReblockOOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReblockOOCInstruction.java
new file mode 100644
index 0000000000..9a7059be51
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReblockOOCInstruction.java
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.ooc;
+
+import java.util.concurrent.ExecutorService;
+
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.sysds.common.Opcodes;
+import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
+import org.apache.sysds.runtime.io.IOUtilFunctions;
+import org.apache.sysds.runtime.io.MatrixReader;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.meta.DataCharacteristics;
+import org.apache.sysds.runtime.util.CommonThreadPool;
+
+public class ReblockOOCInstruction extends ComputationOOCInstruction {
+ private int blen;
+
+ private ReblockOOCInstruction(Operator op, CPOperand in, CPOperand out,
+ int br, int bc, String opcode, String instr)
+ {
+ super(OOCType.Reblock, op, in, out, opcode, instr);
+ blen = br;
+ blen = bc;
+ }
+
+ public static ReblockOOCInstruction parseInstruction(String str) {
+ String parts[] =
InstructionUtils.getInstructionPartsWithValueType(str);
+ String opcode = parts[0];
+ if(!opcode.equals(Opcodes.RBLK.toString()))
+ throw new DMLRuntimeException("Incorrect opcode for
ReblockOOCInstruction:" + opcode);
+
+ CPOperand in = new CPOperand(parts[1]);
+ CPOperand out = new CPOperand(parts[2]);
+ int blen=Integer.parseInt(parts[3]);
+ return new ReblockOOCInstruction(null, in, out, blen, blen,
opcode, str);
+ }
+
+ @Override
+ public void processInstruction(ExecutionContext ec) {
+ //set the output characteristics
+ MatrixObject min = ec.getMatrixObject(input1);
+ DataCharacteristics mc =
ec.getDataCharacteristics(input1.getName());
+ DataCharacteristics mcOut =
ec.getDataCharacteristics(output.getName());
+ mcOut.set(mc.getRows(), mc.getCols(), blen, mc.getNonZeros());
+
+ //get the source format from the meta data
+ //MetaDataFormat iimd = (MetaDataFormat) min.getMetaData();
+ //TODO support other formats than binary
+
+ //create queue, spawn thread for asynchronous reading, and
return
+ LocalTaskQueue<IndexedMatrixValue> q = new
LocalTaskQueue<IndexedMatrixValue>();
+ ExecutorService pool = CommonThreadPool.get();
+ try {
+ pool.submit(() -> readBinaryBlock(q,
min.getFileName()));
+ }
+ finally {
+ pool.shutdown();
+ }
+
+ MatrixObject mout = ec.getMatrixObject(output);
+ mout.setStreamHandle(q);
+ }
+
+ @SuppressWarnings("resource")
+ private void readBinaryBlock(LocalTaskQueue<IndexedMatrixValue> q,
String fname) {
+ try {
+ //prepare file access
+ JobConf job = new
JobConf(ConfigurationManager.getCachedJobConf());
+ Path path = new Path( fname );
+ FileSystem fs = IOUtilFunctions.getFileSystem(path,
job);
+
+ //check existence and non-empty file
+ MatrixReader.checkValidInputFile(fs, path);
+
+ //core reading
+ for( Path lpath :
IOUtilFunctions.getSequenceFilePaths(fs, path) ) { //1..N files
+ //directly read from sequence files (individual
partfiles)
+ try( SequenceFile.Reader reader = new
SequenceFile
+ .Reader(job,
SequenceFile.Reader.file(lpath)) )
+ {
+ MatrixIndexes key = new MatrixIndexes();
+ MatrixBlock value = new MatrixBlock();
+ while( reader.next(key, value) )
+ q.enqueueTask(new
IndexedMatrixValue(key, new MatrixBlock(value)));
+ }
+ }
+ q.closeInput();
+ }
+ catch(Exception ex) {
+ throw new DMLRuntimeException(ex);
+ }
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/io/MatrixReader.java
b/src/main/java/org/apache/sysds/runtime/io/MatrixReader.java
index 893d665a24..245347e9cf 100644
--- a/src/main/java/org/apache/sysds/runtime/io/MatrixReader.java
+++ b/src/main/java/org/apache/sysds/runtime/io/MatrixReader.java
@@ -111,7 +111,7 @@ public abstract class MatrixReader
return ret;
}
- protected static void checkValidInputFile(FileSystem fs, Path path)
+ public static void checkValidInputFile(FileSystem fs, Path path)
throws IOException
{
//check non-existing file
@@ -121,7 +121,6 @@ public abstract class MatrixReader
//check for empty file
if( HDFSTool.isFileEmpty(fs, path) )
throw new EOFException("Empty input file "+
path.toString() +".");
-
}
protected static void sortSparseRowsParallel(MatrixBlock dest, long
rlen, int k, ExecutorService pool)
diff --git a/src/main/java/org/apache/sysds/utils/Explain.java
b/src/main/java/org/apache/sysds/utils/Explain.java
index bcd17ef7f0..7f5fb4c06f 100644
--- a/src/main/java/org/apache/sysds/utils/Explain.java
+++ b/src/main/java/org/apache/sysds/utils/Explain.java
@@ -62,6 +62,7 @@ import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.cp.CPInstruction;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.instructions.gpu.GPUInstruction;
+import org.apache.sysds.runtime.instructions.ooc.OOCInstruction;
import org.apache.sysds.runtime.instructions.spark.CSVReblockSPInstruction;
import org.apache.sysds.runtime.instructions.spark.CheckpointSPInstruction;
import org.apache.sysds.runtime.instructions.spark.ReblockSPInstruction;
@@ -837,8 +838,9 @@ public class Explain
private static String explainGenericInstruction( Instruction inst, int
level )
{
String tmp = null;
- if ( inst instanceof SPInstruction || inst instanceof
CPInstruction || inst instanceof GPUInstruction ||
- inst instanceof FEDInstruction )
+ if ( inst instanceof SPInstruction || inst instanceof
CPInstruction
+ || inst instanceof GPUInstruction || inst instanceof
FEDInstruction
+ || inst instanceof OOCInstruction)
tmp = inst.toString();
if( REPLACE_SPECIAL_CHARACTERS && tmp != null){
diff --git
a/src/test/java/org/apache/sysds/test/functions/ooc/SumScalarMultiplicationTest.java
b/src/test/java/org/apache/sysds/test/functions/ooc/SumScalarMultiplicationTest.java
index d9d42c913b..3681b74f83 100644
---
a/src/test/java/org/apache/sysds/test/functions/ooc/SumScalarMultiplicationTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/ooc/SumScalarMultiplicationTest.java
@@ -21,14 +21,19 @@ package org.apache.sysds.test.functions.ooc;
import org.apache.sysds.common.Opcodes;
import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.FileFormat;
+import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.instructions.Instruction;
+import org.apache.sysds.runtime.io.MatrixWriter;
+import org.apache.sysds.runtime.io.MatrixWriterFactory;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
-import org.apache.sysds.utils.Statistics;
import org.junit.Assert;
-import org.junit.Ignore;
import org.junit.Test;
import java.util.HashMap;
@@ -52,7 +57,6 @@ public class SumScalarMultiplicationTest extends
AutomatedTestBase {
* Test the sum of scalar multiplication, "sum(X*7)", with OOC backend.
*/
@Test
- @Ignore
public void testSumScalarMult() {
Types.ExecMode platformOld = rtplatform;
@@ -62,42 +66,43 @@ public class SumScalarMultiplicationTest extends
AutomatedTestBase {
getAndLoadTestConfiguration(TEST_NAME);
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + TEST_NAME + ".dml";
- programArgs = new String[] {"-explain", "-stats",
"-ooc", "-args", input(INPUT_NAME), output(OUTPUT_NAME)};
-
- int rows = 3;
- int cols = 4;
- double sparsity = 0.8;
-
- double[][] X = getRandomMatrix(rows, cols, -1, 1,
sparsity, 7);
- writeInputMatrixWithMTD(INPUT_NAME, X, true);
-
+ programArgs = new String[] {"-explain", "-stats",
"-ooc",
+ "-args", input(INPUT_NAME),
output(OUTPUT_NAME)};
+
+ int rows = 3500, cols = 4;
+ MatrixBlock mb = MatrixBlock.randOperations(rows, cols,
1.0, -1, 1, "uniform", 7);
+ MatrixWriter writer =
MatrixWriterFactory.createMatrixWriter(FileFormat.BINARY);
+ writer.writeMatrixToHDFS(mb, input(INPUT_NAME), rows,
cols, 1000, rows*cols);
+ HDFSTool.writeMetaDataFile(input(INPUT_NAME+"mtd"),
ValueType.FP64,
+ new
MatrixCharacteristics(rows,cols,1000,rows*cols), FileFormat.BINARY);
+
runTest(true, false, null, -1);
HashMap<MatrixValue.CellIndex, Double> dmlfile =
readDMLMatrixFromOutputDir(OUTPUT_NAME);
- // only one entry
Double result = dmlfile.get(new
MatrixValue.CellIndex(1, 1));
-
double expected = 0.0;
for(int i = 0; i < rows; i++) {
for(int j = 0; j < cols; j++) {
- expected += X[i][j] * 7;
+ expected += mb.get(i, j) * 7;
}
}
Assert.assertEquals(expected, result, 1e-10);
String prefix = Instruction.OOC_INST_PREFIX;
-
- boolean usedOOCMult =
Statistics.getCPHeavyHitterOpCodes().contains(prefix + Opcodes.MULT);
- Assert.assertTrue("OOC wasn't used for MULT",
usedOOCMult);
-
- boolean usedOOCSum =
Statistics.getCPHeavyHitterOpCodes().contains(prefix + Opcodes.UAKP);
- Assert.assertTrue("OOC wasn't used for SUM",
usedOOCSum);
-
+ Assert.assertTrue("OOC wasn't used for RBLK",
+ heavyHittersContainsString(prefix +
Opcodes.RBLK));
+
+// boolean usedOOCMult =
Statistics.getCPHeavyHitterOpCodes().contains(prefix + Opcodes.MULT);
+// Assert.assertTrue("OOC wasn't used for MULT",
usedOOCMult);
+// boolean usedOOCSum =
Statistics.getCPHeavyHitterOpCodes().contains(prefix + Opcodes.UAKP);
+// Assert.assertTrue("OOC wasn't used for SUM",
usedOOCSum);
+ }
+ catch(Exception ex) {
+ Assert.fail(ex.getMessage());
}
finally {
- // reset
- rtplatform = platformOld;
+ resetExecMode(platformOld);
}
}
}