This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 1329f3db21 [SYSTEMDS-3525] Binary Inplace Operations
1329f3db21 is described below
commit 1329f3db21b21cbf470ebdef63ca61af1e0a64a8
Author: baunsgaard <[email protected]>
AuthorDate: Fri Apr 21 12:28:41 2023 +0200
[SYSTEMDS-3525] Binary Inplace Operations
This commit initialize the inplace logic for Binary operations.
Initially this is only used in a very specific case of division by a vector
that does not contain NaN or zero and the input is not used by any other
operator.
Additionally this commit adds a parameterized test that verify equivalent
behavior of the inplace operations and the normal operations.
Closes #1808
---
.../java/org/apache/sysds/hops/AggUnaryOp.java | 45 ++--
src/main/java/org/apache/sysds/hops/BinaryOp.java | 93 ++++++--
.../apache/sysds/hops/ParameterizedBuiltinOp.java | 12 +
src/main/java/org/apache/sysds/lops/Binary.java | 14 +-
.../sysds/runtime/instructions/Instruction.java | 2 +-
.../instructions/cp/BinaryCPInstruction.java | 2 +-
.../cp/BinaryMatrixMatrixCPInstruction.java | 62 +++--
.../runtime/matrix/data/LibMatrixBincell.java | 180 +++++++++++++--
.../runtime/matrix/operators/BinaryOperator.java | 159 ++++++++++++-
.../matrix/BinaryOperationInPlaceTest.java | 251 ++++++++++++++++++++-
.../BinaryOperationInPlaceTestParameterized.java | 190 ++++++++++++++++
11 files changed, 921 insertions(+), 89 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
index 468caa707e..f60ee6dd15 100644
--- a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
@@ -337,6 +337,28 @@ public class AggUnaryOp extends MultiThreadedHop
}
+ private boolean inputAlreadySpark(){
+ return (!(getInput(0) instanceof DataOp) //input is not
checkpoint
+ && getInput(0).optFindExecType() == ExecType.SPARK);
+ }
+
+ private boolean inputOnlyRDD(){
+ return (getInput(0) instanceof DataOp &&
((DataOp)getInput(0)).hasOnlyRDD());
+ }
+
+ private boolean onlyOneParent(){
+ return getInput(0).getParent().size()==1;
+ }
+
+ private boolean allParentsSpark(){
+ return getInput(0).getParent().stream().filter(h -> h != this)
+ .allMatch(h -> h.optFindExecType(false)
== ExecType.SPARK);
+ }
+
+ private boolean inputDoesNotRequireAggregation(){
+ return !requiresAggregation(getInput(0), _direction);
+ }
+
@Override
protected ExecType optFindExecType(boolean transitive) {
@@ -351,17 +373,14 @@ public class AggUnaryOp extends MultiThreadedHop
}
else
{
- if ( OptimizerUtils.isMemoryBasedOptLevel() )
- {
+ if ( OptimizerUtils.isMemoryBasedOptLevel()) {
_etype = findExecTypeByMemEstimate();
}
// Choose CP, if the input dimensions are below
threshold or if the input is a vector
- else if ( getInput().get(0).areDimsBelowThreshold() ||
getInput().get(0).isVector() )
- {
+ else if(getInput().get(0).areDimsBelowThreshold() ||
getInput().get(0).isVector()) {
_etype = ExecType.CP;
}
- else
- {
+ else {
_etype = REMOTE;
}
@@ -372,14 +391,12 @@ public class AggUnaryOp extends MultiThreadedHop
//spark-specific decision refinement (execute unary aggregate
w/ spark input and
//single parent also in spark because it's likely cheap and
reduces data transfer)
//we also allow multiple parents, if all other parents are
already in Spark mode
- if( transitive && _etype == ExecType.CP && _etypeForced !=
ExecType.CP
- && ((!(getInput(0) instanceof DataOp) //input is not
checkpoint
- && getInput(0).optFindExecType() ==
ExecType.SPARK)
- || (getInput(0) instanceof DataOp &&
((DataOp)getInput(0)).hasOnlyRDD()))
- && (getInput(0).getParent().size()==1 //uagg is only
parent, or
- || getInput(0).getParent().stream().filter(h ->
h != this)
- .allMatch(h -> h.optFindExecType(false)
== ExecType.SPARK)
- || !requiresAggregation(getInput(0),
_direction)) ) //w/o agg
+
+ boolean shouldEvaluateIfSpark = transitive && _etype ==
ExecType.CP && _etypeForced != ExecType.CP;
+
+ if( shouldEvaluateIfSpark
+ && (inputAlreadySpark() || inputOnlyRDD())
+ && (onlyOneParent() || allParentsSpark() ||
inputDoesNotRequireAggregation() ))
{
//pull unary aggregate into spark
_etype = ExecType.SPARK;
diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java
b/src/main/java/org/apache/sysds/hops/BinaryOp.java
index 2346eeebfe..04585d7dc4 100644
--- a/src/main/java/org/apache/sysds/hops/BinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java
@@ -19,6 +19,8 @@
package org.apache.sysds.hops;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.AggOp;
import org.apache.sysds.common.Types.DataType;
@@ -27,6 +29,7 @@ import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.common.Types.OpOp1;
import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.OpOpDnn;
+import org.apache.sysds.common.Types.ParamBuiltinOp;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
@@ -52,21 +55,21 @@ import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
-/* Binary (cell operations): aij + bij
+/** Binary (cell operations): aij + bij
* Properties:
* Symbol: *, -, +, ...
* 2 Operands
* Semantic: align indices (sort), then perform operation
*/
-
public class BinaryOp extends MultiThreadedHop {
- // private static final Log LOG =
LogFactory.getLog(BinaryOp.class.getName());
+ protected static final Log LOG =
LogFactory.getLog(BinaryOp.class.getName());
//we use the full remote memory budget (but reduced by sort buffer),
public static final double APPEND_MEM_MULTIPLIER = 1.0;
private OpOp2 op;
private boolean outer = false;
+ private boolean inplace = false;
public static AppendMethod FORCED_APPEND_METHOD = null;
public static MMBinaryMethod FORCED_BINARY_METHOD = null;
@@ -126,6 +129,10 @@ public class BinaryOp extends MultiThreadedHop {
public boolean isOuter(){
return outer;
}
+
+ public boolean isInplace(){
+ return inplace;
+ }
@Override
public boolean isGPUEnabled() {
@@ -435,7 +442,7 @@ public class BinaryOp extends MultiThreadedHop {
else { //general case
tmp = new Binary(getInput(0).constructLops(),
getInput(1).constructLops(),
op, getDataType(), getValueType(), et,
-
OptimizerUtils.getConstrainedNumThreads(_maxNumThreads));
+
OptimizerUtils.getConstrainedNumThreads(_maxNumThreads), inplace);
}
setOutputDimensions(tmp);
@@ -477,7 +484,7 @@ public class BinaryOp extends MultiThreadedHop {
else
binary = new
Binary(getInput(0).constructLops(), getInput(1).constructLops(),
op, getDataType(),
getValueType(), et,
-
OptimizerUtils.getConstrainedNumThreads(_maxNumThreads));
+
OptimizerUtils.getConstrainedNumThreads(_maxNumThreads), inplace);
setOutputDimensions(binary);
setLineNumbers(binary);
@@ -700,6 +707,44 @@ public class BinaryOp extends MultiThreadedHop {
return true;
}
+ private static boolean isReplace(Hop h) {
+ return h instanceof ParameterizedBuiltinOp && //
+ ((ParameterizedBuiltinOp) h).getOp() ==
ParamBuiltinOp.REPLACE;
+ }
+
+ private static boolean isReplaceWithPattern(ParameterizedBuiltinOp h,
double pattern, double replace) {
+ Hop pat = h.getParameterHop("pattern");
+ Hop rep = h.getParameterHop("replacement");
+ if(pat instanceof LiteralOp && rep instanceof LiteralOp) {
+ double patOb = ((LiteralOp) pat).getDoubleValue();
+ double repOb = ((LiteralOp) rep).getDoubleValue();
+ return ((Double.isNaN(pattern) && Double.isNaN(patOb))
// is both NaN
+ || Double.compare(pattern, patOb) == 0) // Is
equivalent pattern
+ && Double.compare(replace, repOb) == 0; // is
equivalent replace.
+ }
+ return false;
+ }
+
+ private static boolean doesNotContainNanAndInf(Hop p1) {
+ if(isReplace(p1)) {
+ Hop p2 = p1.getInput().get(0);
+ if(isReplace(p2)) {
+ ParameterizedBuiltinOp pp1 =
(ParameterizedBuiltinOp) p1;
+ ParameterizedBuiltinOp pp2 =
(ParameterizedBuiltinOp) p2;
+ return (isReplaceWithPattern(pp1, Double.NaN,
1) && isReplaceWithPattern(pp2, 0, 1)) ||
+ (isReplaceWithPattern(pp2, Double.NaN,
1) && isReplaceWithPattern(pp1, 0, 1));
+ }
+ }
+ return false;
+ }
+
+ private boolean memOfInputIsLessThanBudget() {
+ final double in1Memory = getInput().get(0).getMemEstimate();
+ final double in2Memory = getInput().get(1).getMemEstimate();
+ final double budget = OptimizerUtils.getLocalMemBudget();
+ return in1Memory + in2Memory < budget;
+ }
+
@Override
protected ExecType optFindExecType(boolean transitive) {
@@ -755,20 +800,34 @@ public class BinaryOp extends MultiThreadedHop {
checkAndSetInvalidCPDimsAndSize();
}
- //spark-specific decision refinement (execute unary scalar w/
spark input and
- //single parent also in spark because it's likely cheap and
reduces intermediates)
- if( transitive && _etype == ExecType.CP && _etypeForced !=
ExecType.CP && _etypeForced != ExecType.FED
- && getDataType().isMatrix() && (dt1.isScalar() ||
dt2.isScalar())
- && supportsMatrixScalarOperations()
//scalar operations
- && !(getInput().get(dt1.isScalar()?1:0) instanceof
DataOp) //input is not checkpoint
- &&
getInput().get(dt1.isScalar()?1:0).getParent().size()==1 //unary scalar is
only parent
- &&
!HopRewriteUtils.isSingleBlock(getInput().get(dt1.isScalar()?1:0)) //single
block triggered exec
- && getInput().get(dt1.isScalar()?1:0).optFindExecType()
== ExecType.SPARK )
- {
- //pull unary scalar operation into spark
+ //spark-specific decision refinement (execute unary scalar w/
spark input and
+ // single parent also in spark because it's likely cheap and
reduces intermediates)
+ if(transitive && _etype == ExecType.CP && _etypeForced !=
ExecType.CP && _etypeForced != ExecType.FED &&
+ getDataType().isMatrix() // output should be a matrix
+ && (dt1.isScalar() || dt2.isScalar()) // one side
should be scalar
+ && supportsMatrixScalarOperations() // scalar operations
+ && !(getInput().get(dt1.isScalar() ? 1 : 0) instanceof
DataOp) // input is not checkpoint
+ && getInput().get(dt1.isScalar() ? 1 :
0).getParent().size() == 1 // unary scalar is only parent
+ &&
!HopRewriteUtils.isSingleBlock(getInput().get(dt1.isScalar() ? 1 : 0)) //
single block triggered exec
+ && getInput().get(dt1.isScalar() ? 1 :
0).optFindExecType() == ExecType.SPARK) {
+ // pull unary scalar operation into spark
_etype = ExecType.SPARK;
}
-
+
+ if( transitive && _etypeForced != ExecType.SPARK &&
_etypeForced != ExecType.FED && //
+ getDataType().isMatrix() // Output is a matrix
+ && op == OpOp2.DIV // Operation is division
+ && dt1.isMatrix() // Left hand side is a Matrix
+ // right hand side is a scalar or a vector.
+ && (dt2.isScalar() || (dt2.isMatrix() &
getInput().get(1).isVector())) //
+ && memOfInputIsLessThanBudget() //
+ && getInput().get(0).getExecType() != ExecType.SPARK //
Is not already a spark operation
+ && doesNotContainNanAndInf(getInput().get(1)) //
Guaranteed not to densify the operation
+ ) {
+ inplace = true;
+ _etype = ExecType.CP;
+ }
+
//ensure cp exec type for single-node operations
if ( op == OpOp2.SOLVE ) {
if (isGPUEnabled())
diff --git a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
index 4404579894..59b957ac5b 100644
--- a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
+++ b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
@@ -734,6 +734,18 @@ public class ParameterizedBuiltinOp extends
MultiThreadedHop {
_etype = ExecType.CP;
}
+ // If previous instructions were in spark force aggregating
+ // parameterized operations to be executed in spark
+ if(transitive && _etype == ExecType.CP && _etypeForced !=
ExecType.CP) {
+ switch(_op) {
+ case CONTAINS:
+ if(getTargetHop().optFindExecType() ==
ExecType.SPARK)
+ _etype = ExecType.SPARK;
+ default:
+ // Do not change execution type.
+ }
+ }
+
//mark for recompile (forever)
setRequiresRecompileIfNecessary();
diff --git a/src/main/java/org/apache/sysds/lops/Binary.java
b/src/main/java/org/apache/sysds/lops/Binary.java
index 6949188564..c1c233e4f2 100644
--- a/src/main/java/org/apache/sysds/lops/Binary.java
+++ b/src/main/java/org/apache/sysds/lops/Binary.java
@@ -30,14 +30,14 @@ import org.apache.sysds.common.Types.ValueType;
/**
- * Lop to perform binary operation. Both inputs must be matrices or vectors.
- * Example - A = B + C, where B and C are matrices or vectors.
+ * Lop to perform binary operation. Both inputs must be matrices, vectors or
scalars.
+ * Example - A = B + C.
*/
-
public class Binary extends Lop
{
private OpOp2 operation;
private final int _numThreads;
+ private final boolean inplace;
/**
* Constructor to perform a binary operation.
@@ -55,9 +55,14 @@ public class Binary extends Lop
}
public Binary(Lop input1, Lop input2, OpOp2 op, DataType dt, ValueType
vt, ExecType et, int k) {
+ this(input1, input2, op, dt, vt, et, k, false);
+ }
+
+ public Binary(Lop input1, Lop input2, OpOp2 op, DataType dt, ValueType
vt, ExecType et, int k, boolean inplace) {
super(Lop.Type.Binary, dt, vt);
init(input1, input2, op, dt, vt, et);
_numThreads = k;
+ this.inplace = inplace;
}
private void init(Lop input1, Lop input2, OpOp2 op, DataType dt,
ValueType vt, ExecType et) {
@@ -107,6 +112,9 @@ public class Binary extends Lop
else if( getExecType() == ExecType.FED )
ret = InstructionUtils.concatOperands(ret,
String.valueOf(_numThreads), _fedOutput.name());
+ if (getExecType() == ExecType.CP && inplace)
+ ret = InstructionUtils.concatOperands(ret, "InPlace");
+
return ret;
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java
index e0fbecaaea..3190bad650 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java
@@ -38,7 +38,7 @@ public abstract class Instruction
FEDERATED
}
- private static final Log LOG =
LogFactory.getLog(Instruction.class.getName());
+ protected static final Log LOG =
LogFactory.getLog(Instruction.class.getName());
protected final Operator _optr;
protected Instruction(Operator _optr){
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java
index fddb8301a9..28b8775ebd 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java
@@ -67,7 +67,7 @@ public abstract class BinaryCPInstruction extends
ComputationCPInstruction {
private static String[] parseBinaryInstruction(String instr, CPOperand
in1, CPOperand in2, CPOperand out) {
String[] parts =
InstructionUtils.getInstructionPartsWithValueType(instr);
- InstructionUtils.checkNumFields ( parts, 3, 4, 5 );
+ InstructionUtils.checkNumFields ( parts, 3, 4, 5, 6 );
in1.split(parts[1]);
in2.split(parts[2]);
out.split(parts[3]);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java
index 565210b585..20119ceacd 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java
@@ -23,19 +23,34 @@ import
org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.matrix.data.LibCommonsMath;
+import org.apache.sysds.runtime.matrix.data.LibMatrixBincell;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
public class BinaryMatrixMatrixCPInstruction extends BinaryCPInstruction {
+ private final boolean inplace;
+
protected BinaryMatrixMatrixCPInstruction(Operator op, CPOperand in1,
CPOperand in2, CPOperand out, String opcode,
String istr) {
super(CPType.Binary, op, in1, in2, out, opcode, istr);
if(op instanceof BinaryOperator) {
- String[] parts =
InstructionUtils.getInstructionParts(istr);
- ((BinaryOperator)
op).setNumThreads(Integer.parseInt(parts[parts.length - 1]));
+ final String[] parts =
InstructionUtils.getInstructionParts(istr);
+ if(parts.length == 5) {
+ ((BinaryOperator)
op).setNumThreads(Integer.parseInt(parts[parts.length - 1]));
+ inplace = false;
+ }
+ else {
+ ((BinaryOperator)
op).setNumThreads(Integer.parseInt(parts[parts.length - 2]));
+ if(parts[parts.length - 1].equals("InPlace"))
+ inplace = true;
+ else
+ inplace = false;
+ }
}
+ else
+ inplace = false;
}
@Override
@@ -49,25 +64,36 @@ public class BinaryMatrixMatrixCPInstruction extends
BinaryCPInstruction {
MatrixBlock retBlock;
- if(LibCommonsMath.isSupportedMatrixMatrixOperation(getOpcode())
&& !compressedLeft && !compressedRight)
- retBlock =
LibCommonsMath.matrixMatrixOperations(inBlock1, inBlock2, getOpcode());
+ if(inplace && (compressedLeft || compressedRight))
+ LOG.error("Not supporting inplace compressed binary
operations yet");
+
+ if(inplace && !(compressedLeft || compressedRight)) {
+ inBlock1 = LibMatrixBincell.bincellOpInPlace(inBlock1,
inBlock2, (BinaryOperator) _optr);
+ // Release the memory occupied by input matrices
+ ec.releaseMatrixInput(input1.getName(),
input2.getName());
+ // Cleanup the inplace metadata input.
+ ec.removeVariable(input1.getName());
+ retBlock = inBlock1;
+ }
else {
- // Perform computation using input matrices, and
produce the result matrix
- BinaryOperator bop = (BinaryOperator) _optr;
- if(!compressedLeft && compressedRight)
- retBlock = ((CompressedMatrixBlock)
inBlock2).binaryOperationsLeft(bop, inBlock1, new MatrixBlock());
- else
- retBlock = inBlock1.binaryOperations(bop,
inBlock2, new MatrixBlock());
+
if(LibCommonsMath.isSupportedMatrixMatrixOperation(getOpcode()) &&
!compressedLeft && !compressedRight)
+ retBlock =
LibCommonsMath.matrixMatrixOperations(inBlock1, inBlock2, getOpcode());
+ else {
+ // Perform computation using input matrices,
and produce the result matrix
+ BinaryOperator bop = (BinaryOperator) _optr;
+ if(!compressedLeft && compressedRight)
+ retBlock = ((CompressedMatrixBlock)
inBlock2).binaryOperationsLeft(bop, inBlock1, new MatrixBlock());
+ else
+ retBlock =
inBlock1.binaryOperations(bop, inBlock2, new MatrixBlock());
+ }
+ // Release the memory occupied by input matrices
+ ec.releaseMatrixInput(input1.getName(),
input2.getName());
+ // Ensure right dense/sparse output representation
(guarded by released input memory)
+ if(checkGuardedRepresentationChange(inBlock1, inBlock2,
retBlock))
+ retBlock.examSparsity();
}
- // Release the memory occupied by input matrices
- ec.releaseMatrixInput(input1.getName(), input2.getName());
-
- // Ensure right dense/sparse output representation (guarded by
released input memory)
- if(checkGuardedRepresentationChange(inBlock1, inBlock2,
retBlock))
- retBlock.examSparsity();
-
// Attach result matrix with MatrixObject associated with
output_name
ec.setMatrixOutput(output.getName(), retBlock);
}
-}
\ No newline at end of file
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java
index 14f83e7de2..33d7bb2da5 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java
@@ -38,6 +38,7 @@ import org.apache.sysds.runtime.data.SparseBlockFactory;
import org.apache.sysds.runtime.data.SparseBlockMCSR;
import org.apache.sysds.runtime.data.SparseRow;
import org.apache.sysds.runtime.data.SparseRowVector;
+import org.apache.sysds.runtime.functionobjects.And;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode;
import org.apache.sysds.runtime.functionobjects.Divide;
@@ -451,7 +452,7 @@ public class LibMatrixBincell {
|| (clen1 == 1 && rlen2
== 1 ) ); //VV
if( !isValid ) {
- throw new RuntimeException("Block sizes are not matched
for binary " +
+ throw new DMLRuntimeException("Block sizes are not
matched for binary " +
"cell operations: " + rlen1 + "x" +
clen1 + " vs " + rlen2 + "x" + clen2);
}
}
@@ -1589,27 +1590,69 @@ public class LibMatrixBincell {
}
private static void safeBinaryInPlace(MatrixBlock m1ret, MatrixBlock
m2, BinaryOperator op) {
- //early abort on skip and empty
- if( (m1ret.isEmpty() && m2.isEmpty() )
- || (op.fn instanceof Plus && m2.isEmpty())
- || (op.fn instanceof Minus && m2.isEmpty()))
+ // early abort on skip and empty
+ final boolean PoM = op.fn instanceof Plus || op.fn instanceof
Minus;
+ if((m1ret.isEmpty() && m2.isEmpty()) || (PoM && m2.isEmpty())) {
+ final boolean isEquals = op.fn instanceof Equals ||
op.fn instanceof LessThanEquals ||
+ op.fn instanceof GreaterThanEquals;
+
+ if(isEquals)
+ m1ret.reset(m1ret.rlen, m1ret.clen, 1);
return; // skip entire empty block
- //special case: start aggregation
- else if( op.fn instanceof Plus && m1ret.isEmpty() ){
+ }
+ else if(m2.isEmpty() && // empty other side
+ (op.fn instanceof Multiply || (op.fn instanceof And))) {
+ m1ret.reset(m1ret.rlen, m1ret.clen, 0);
+ return;
+ }
+
+ if(m1ret.getNumRows() > 1 && m2.getNumRows() == 1)
+ safeBinaryInPlaceMatrixRowVector(m1ret, m2, op);
+ else
+ safeBinaryInPlaceMatrixMatrix(m1ret, m2, op);
+ }
+
+ private static void safeBinaryInPlaceMatrixRowVector(MatrixBlock m1ret,
MatrixBlock m2, BinaryOperator op) {
+ if(m1ret.sparse) {
+ if(m2.isInSparseFormat() && !op.isRowSafeLeft(m2))
+ throw new DMLRuntimeException("Invalid row
safety of inplace row operation: " + op);
+ else if(m2.isEmpty())
+ safeBinaryInPlaceSparseConst(m1ret, 0.0, op);
+ else if(m2.sparse)
+ throw new NotImplementedException("Not made
sparse vector inplace to sparse " + op);
+ else
+ safeBinaryInPlaceSparseVector(m1ret, m2, op);
+ }
+ else {
+ if(!m1ret.isAllocated()) {
+ LOG.warn("Allocating inplace output block");
+ m1ret.allocateBlock();
+ }
+
+ if(m2.isEmpty())
+ safeBinaryInPlaceDenseConst(m1ret, 0.0, op);
+ else if(m2.sparse)
+ throw new NotImplementedException("Not made
sparse vector inplace to dense " + op);
+ else
+ safeBinaryInPlaceDenseVector(m1ret, m2, op);
+ }
+ }
+
+ private static void safeBinaryInPlaceMatrixMatrix(MatrixBlock m1ret,
MatrixBlock m2, BinaryOperator op) {
+ if(op.fn instanceof Plus && m1ret.isEmpty()) {
m1ret.copy(m2);
- return;
+ return;
}
-
if(m1ret.sparse && m2.sparse)
safeBinaryInPlaceSparse(m1ret, m2, op);
else if(!m1ret.sparse && !m2.sparse)
safeBinaryInPlaceDense(m1ret, m2, op);
else if(m2.sparse && (op.fn instanceof Plus || op.fn instanceof
Minus))
safeBinaryInPlaceDenseSparseAdd(m1ret, m2, op);
- else //GENERIC
+ else
safeBinaryInPlaceGeneric(m1ret, m2, op);
}
-
+
private static void safeBinaryInPlaceSparse(MatrixBlock m1ret,
MatrixBlock m2, BinaryOperator op) {
//allocation and preparation (note: for correctness and
performance, this
//implementation requires the lhs in MCSR and hence we
explicitly convert)
@@ -1625,6 +1668,9 @@ public class LibMatrixBincell {
final int rlen = m1ret.rlen;
final int clen = m1ret.clen;
+ final boolean compact = (op.fn instanceof Multiply || op.fn
instanceof And );
+ final boolean mcsr = c instanceof SparseBlockMCSR;
+
if( c!=null && b!=null ) {
for(int r=0; r<rlen; r++) {
if(c.isEmpty(r) && b.isEmpty(r))
@@ -1645,6 +1691,8 @@ public class LibMatrixBincell {
mergeForSparseBinary(op, old.values(),
old.indexes(), 0,
old.size(), b.values(r),
b.indexes(r), b.pos(r), b.size(r), r, m1ret);
}
+ if(compact && mcsr && !c.isEmpty(r))
+ c.get(r).compact();
}
}
else if( c == null ) { //lhs empty
@@ -1660,31 +1708,81 @@ public class LibMatrixBincell {
if( c.isEmpty(r) ) continue;
zeroRightForSparseBinary(op, r, m1ret);
}
+
}
m1ret.recomputeNonZeros();
}
+ private static void safeBinaryInPlaceSparseConst(MatrixBlock m1ret,
double m2, BinaryOperator op) {
+ if(m1ret.isEmpty()) // early termination... it is empty and
safe... just stop.
+ return;
+ final SparseBlock sb = m1ret.getSparseBlock();
+ final int rlen = m1ret.rlen;
+ for(int r = 0; r < rlen; r++) {
+ if(sb.isEmpty(r))
+ continue;
+ final int apos = sb.pos(r);
+ final int alen = sb.size(r) + apos;
+ final double[] avals = sb.values(r);
+ for(int k = apos; k < alen; k++)
+ avals[k] = op.fn.execute(avals[k], m2);
+ }
+ }
+
+ private static void safeBinaryInPlaceSparseVector(MatrixBlock m1ret,
MatrixBlock m2, BinaryOperator op) {
+
+ if(m1ret.isEmpty()) // early termination... it is empty and
safe... just stop.
+ return;
+ final SparseBlock sb = m1ret.getSparseBlock();
+ final double[] b = m2.getDenseBlockValues();
+ final int rlen = m1ret.rlen;
+
+ final boolean compact = (op.fn instanceof Multiply || op.fn
instanceof And) //
+ && op.isIntroducingZerosRight(m2);
+ final boolean mcsr = sb instanceof SparseBlockMCSR;
+ for(int r = 0; r < rlen; r++) {
+ if(sb.isEmpty(r))
+ continue;
+ final int apos = sb.pos(r);
+ final int alen = sb.size(r) + apos;
+ final double[] avals = sb.values(r);
+ final int[] aix = sb.indexes(r);
+ for(int k = apos; k < alen; k++)
+ avals[k] = op.fn.execute(avals[k], b[aix[k]]);
+
+ if(compact && mcsr) {
+ SparseRow sr = sb.get(r);
+ if(sr instanceof SparseRowVector)
+ ((SparseRowVector)
sr).setSize(avals.length);
+ sr.compact();
+ }
+ }
+ if(compact && !mcsr) {
+ ((SparseBlockCSR) sb).compact();
+ }
+ }
+
private static void safeBinaryInPlaceDense(MatrixBlock m1ret,
MatrixBlock m2, BinaryOperator op) {
- //prepare outputs
+ // prepare outputs
m1ret.allocateDenseBlock();
DenseBlock a = m1ret.getDenseBlock();
DenseBlock b = m2.getDenseBlock();
final int rlen = m1ret.rlen;
final int clen = m1ret.clen;
-
+
long lnnz = 0;
- if( m2.isEmptyBlock(false) ) {
- for(int r=0; r<rlen; r++) {
+ if(m2.isEmptyBlock(false)) {
+ for(int r = 0; r < rlen; r++) {
double[] avals = a.values(r);
- for(int c=0, ix=a.pos(r); c<clen; c++, ix++) {
+ for(int c = 0, ix = a.pos(r); c < clen; c++,
ix++) {
double tmp = op.fn.execute(avals[ix],
0);
- lnnz += (avals[ix] = tmp) != 0 ? 1: 0;
+ lnnz += (avals[ix] = tmp) != 0 ? 1 : 0;
}
}
}
- else if( op.fn instanceof Plus ) {
- for(int r=0; r<rlen; r++) {
+ else if(op.fn instanceof Plus) {
+ for(int r = 0; r < rlen; r++) {
int aix = a.pos(r), bix = b.pos(r);
double[] avals = a.values(r), bvals =
b.values(r);
LibMatrixMult.vectAdd(bvals, avals, bix, aix,
clen);
@@ -1692,15 +1790,53 @@ public class LibMatrixBincell {
}
}
else {
- for(int r=0; r<rlen; r++) {
+ for(int r = 0; r < rlen; r++) {
double[] avals = a.values(r), bvals =
b.values(r);
- for(int c=0, ix=a.pos(r); c<clen; c++, ix++) {
+ for(int c = 0, ix = a.pos(r); c < clen; c++,
ix++) {
double tmp = op.fn.execute(avals[ix],
bvals[ix]);
lnnz += (avals[ix] = tmp) != 0 ? 1 : 0;
}
}
}
-
+
+ m1ret.setNonZeros(lnnz);
+ }
+
+ private static void safeBinaryInPlaceDenseConst(MatrixBlock m1ret,
double m2, BinaryOperator op) {
+ // prepare outputs
+ m1ret.allocateDenseBlock();
+ DenseBlock a = m1ret.getDenseBlock();
+ final int rlen = m1ret.rlen;
+ final int clen = m1ret.clen;
+
+ long lnnz = 0;
+ for(int r = 0; r < rlen; r++) {
+ double[] avals = a.values(r);
+ for(int c = 0, ix = a.pos(r); c < clen; c++, ix++) {
+ double tmp = op.fn.execute(avals[ix], m2);
+ lnnz += (avals[ix] = tmp) != 0 ? 1 : 0;
+ }
+ }
+
+ m1ret.setNonZeros(lnnz);
+ }
+
+ private static void safeBinaryInPlaceDenseVector(MatrixBlock m1ret,
MatrixBlock m2, BinaryOperator op) {
+ // prepare outputs
+ m1ret.allocateDenseBlock();
+ DenseBlock a = m1ret.getDenseBlock();
+ double[] b = m2.getDenseBlockValues();
+ final int rlen = m1ret.rlen;
+ final int clen = m1ret.clen;
+
+ long lnnz = 0;
+ for(int r = 0; r < rlen; r++) {
+ double[] avals = a.values(r);
+ for(int c = 0, ix = a.pos(r); c < clen; c++, ix++) {
+ double tmp = op.fn.execute(avals[ix], b[ix %
clen]);
+ lnnz += (avals[ix] = tmp) != 0 ? 1 : 0;
+ }
+ }
m1ret.setNonZeros(lnnz);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java
b/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java
index f1b98cdb8b..992ac5bee9 100644
---
a/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java
+++
b/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java
@@ -21,6 +21,7 @@
package org.apache.sysds.runtime.matrix.operators;
import org.apache.sysds.common.Types.OpOp2;
+import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.functionobjects.And;
import org.apache.sysds.runtime.functionobjects.BitwAnd;
import org.apache.sysds.runtime.functionobjects.BitwOr;
@@ -29,6 +30,7 @@ import org.apache.sysds.runtime.functionobjects.BitwShiftR;
import org.apache.sysds.runtime.functionobjects.BitwXor;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.functionobjects.Divide;
import org.apache.sysds.runtime.functionobjects.Equals;
import org.apache.sysds.runtime.functionobjects.GreaterThan;
@@ -50,12 +52,22 @@ import org.apache.sysds.runtime.functionobjects.Power;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
import org.apache.sysds.runtime.functionobjects.Xor;
+/**
+ * BinaryOperator class for operations that have two inputs.
+ *
+ * For instance
+ *
+ * <pre>
+ *BinaryOperator op = new BinaryOperator(Plus.getPlusFnObject());
+ *double r = op.execute(5.0, 8.2)
+ * </pre>
+ */
public class BinaryOperator extends MultiThreadedOperator {
private static final long serialVersionUID = -2547950181558989209L;
public final ValueFunction fn;
public final boolean commutative;
-
+
public BinaryOperator(ValueFunction p) {
this(p, 1);
}
@@ -120,6 +132,12 @@ public class BinaryOperator extends MultiThreadedOperator {
return commutative;
}
+ /**
+ * Check if the operation returns zeros if the input zero.
+ *
+ * @param row The values to check
+ * @return If the output is always zero if other value is zero
+ */
public boolean isRowSafeLeft(double[] row){
for(double v : row)
if(0 != fn.execute(v, 0))
@@ -127,6 +145,78 @@ public class BinaryOperator extends MultiThreadedOperator {
return true;
}
+ /**
+ * Check if the operation returns zeros if the input zero.
+ *
+ * @param row The values to check
+ * @return If the output is always zero if other value is zero
+ */
+ public boolean isRowSafeLeft(MatrixBlock row) {
+ if(row.isEmpty())
+ return 0 == fn.execute(0.0, 0.0);
+ else if(row.isInSparseFormat()) {
+ if(0 != fn.execute(0.0, 0.0))
+ return false;
+ SparseBlock sb = row.getSparseBlock();
+ if(sb.isEmpty(0))
+ return true;
+ return isRowSafeLeft(sb.values(0));
+ }
+ else
+ return isRowSafeLeft(row.getDenseBlockValues());
+ }
+
+ /**
+ * Check if the operation returns zeros if the input is contained in
row.
+ *
+ * @param row The values to check
+ * @return If the output contains zeros
+ */
+ public boolean isIntroducingZerosLeft(MatrixBlock row) {
+ if(row.isEmpty())
+ return introduceZeroLeft(0.0);
+ else if(row.isInSparseFormat()) {
+ if(introduceZeroLeft(0.0))
+ return true;
+ SparseBlock sb = row.getSparseBlock();
+ if(sb.isEmpty(0))
+ return false;
+ return isIntroducingZerosLeft(sb.values(0));
+ }
+ else
+ return
isIntroducingZerosLeft(row.getDenseBlockValues());
+ }
+
+ /**
+ * Check if the operation returns zeros if the input is contained in
row.
+ *
+ * @param row The values to check
+ * @return If the output contains zeros
+ */
+ public boolean isIntroducingZerosLeft(double[] row) {
+ for(double v : row)
+ if(introduceZeroLeft(v))
+ return true;
+ return false;
+ }
+
+ /**
+ * Check if zero is returned at arbitrary input. The verification is
done via two different values that hopefully do
+ * not return 0 in both instances unless the operation really have a
tendency to return zero.
+ *
+ * @param v The value to check if returns zero
+ * @return if the evaluation return zero
+ */
+ private boolean introduceZeroLeft(double v) {
+ return 0 == fn.execute(v, 11.42) && 0 == fn.execute(v, -11.22);
+ }
+
+ /**
+ * Check if the operation returns zeros if the input zero.
+ *
+ * @param row The values to check
+ * @return If the output is always zero if other value is zero
+ */
public boolean isRowSafeRight(double[] row){
for(double v : row)
if(0 != fn.execute(0, v))
@@ -134,6 +224,73 @@ public class BinaryOperator extends MultiThreadedOperator {
return true;
}
+ /**
+ * Check if the operation returns zeros if the input zero.
+ *
+ * @param row The values to check
+ * @return If the output is always zero if other value is zero
+ */
+ public boolean isRowSafeRight(MatrixBlock row) {
+ if(row.isEmpty())
+ return 0 == fn.execute(0.0, 0.0);
+ else if(row.isInSparseFormat()) {
+ if(0 != fn.execute(0.0, 0.0))
+ return false;
+ SparseBlock sb = row.getSparseBlock();
+ if(sb.isEmpty(0))
+ return true;
+ return isRowSafeRight(sb.values(0));
+ }
+ else
+ return isRowSafeRight(row.getDenseBlockValues());
+ }
+
+ /**
+ * Check if the operation returns zeros if the input is contained in
row.
+ *
+ * @param row The values to check
+ * @return If the output contains zeros
+ */
+ public boolean isIntroducingZerosRight(MatrixBlock row){
+ if(row.isEmpty())
+ return introduceZeroRight(0.0);
+ else if(row.isInSparseFormat()){
+ if (introduceZeroRight(0.0))
+ return true;
+ SparseBlock sb = row.getSparseBlock();
+ if(sb.isEmpty(0))
+ return false;
+ return isIntroducingZerosRight(sb.values(0));
+ }
+ else
+ return
isIntroducingZerosRight(row.getDenseBlockValues());
+ }
+
+ /**
+ * Check if the operation returns zeros if the input is contained in
row.
+ *
+ * @param row The values to check
+ * @return If the output contains zeros
+ */
+ public boolean isIntroducingZerosRight(double[] row){
+ for(double v : row)
+ if( introduceZeroRight(v))
+ return true;
+
+ return false;
+ }
+
+ /**
+ * Check if zero is returned at arbitrary input. The verification is
done via two different values that hopefully do
+ * not return 0 in both instances unless the operation really have a
tendency to return zero.
+ *
+ * @param v The value to check if returns zero
+ * @return if the evaluation return zero
+ */
+ private boolean introduceZeroRight(double v) {
+ return 0 == fn.execute(11.42, v) && 0 == fn.execute(-11.22, v);
+ }
+
@Override
public String toString() {
return "BinaryOperator("+fn.getClass().getSimpleName()+")";
diff --git
a/src/test/java/org/apache/sysds/test/component/matrix/BinaryOperationInPlaceTest.java
b/src/test/java/org/apache/sysds/test/component/matrix/BinaryOperationInPlaceTest.java
index 68fcd1be44..bf8258efc6 100644
---
a/src/test/java/org/apache/sysds/test/component/matrix/BinaryOperationInPlaceTest.java
+++
b/src/test/java/org/apache/sysds/test/component/matrix/BinaryOperationInPlaceTest.java
@@ -19,6 +19,14 @@
package org.apache.sysds.test.component.matrix;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.functionobjects.Divide;
+import org.apache.sysds.runtime.functionobjects.LessThan;
+import org.apache.sysds.runtime.functionobjects.Minus;
+import org.apache.sysds.runtime.functionobjects.Or;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
@@ -30,36 +38,255 @@ public class BinaryOperationInPlaceTest {
public void testPlus() {
MatrixBlock m1 = TestUtils.generateTestMatrixBlock(10, 10, 0,
10, 1.0, 1);
MatrixBlock m2 = TestUtils.generateTestMatrixBlock(10, 10, 0,
10, 1.0, 2);
- execute(m1,m2);
+ executePlus(m1, m2);
}
@Test
public void testPlus_emptyInplace() {
- MatrixBlock m1 = new MatrixBlock(10,10,false);
+ MatrixBlock m1 = new MatrixBlock(10, 10, false);
MatrixBlock m2 = TestUtils.generateTestMatrixBlock(10, 10, 0,
10, 1.0, 2);
- execute(m1,m2);
+ executePlus(m1, m2);
}
- @Test
- public void testPlus_emptyOther(){
+ @Test
+ public void testPlus_emptyOther() {
MatrixBlock m1 = TestUtils.generateTestMatrixBlock(10, 10, 0,
10, 1.0, 1);
- MatrixBlock m2 = new MatrixBlock(10,10,false);
- execute(m1,m2);
+ MatrixBlock m2 = new MatrixBlock(10, 10, false);
+ executePlus(m1, m2);
}
- @Test
+ @Test
public void testPlus_emptyInplace_butAllocatedDense() {
- MatrixBlock m1 = new MatrixBlock(10,10,false);
+ MatrixBlock m1 = new MatrixBlock(10, 10, false);
m1.allocateDenseBlock();
MatrixBlock m2 = TestUtils.generateTestMatrixBlock(10, 10, 0,
10, 1.0, 2);
- execute(m1,m2);
+ executePlus(m1, m2);
+ }
+
+ @Test
+ public void testDivide() {
+ MatrixBlock m1 = TestUtils.generateTestMatrixBlock(10, 10, 0,
10, 1.0, 1);
+ MatrixBlock m2 = TestUtils.generateTestMatrixBlock(10, 10, 0,
10, 1.0, 2);
+ executeDivide(m1, m2);
+ }
+
+ @Test
+ public void testDivide_matrixVector() {
+ MatrixBlock m1 = TestUtils.generateTestMatrixBlock(100, 10, 0,
10, 1.0, 1);
+ MatrixBlock m2 = TestUtils.generateTestMatrixBlock(1, 10, 0,
10, 1.0, 2);
+ executeDivide(m1, m2);
+ }
+
+ @Test(expected = DMLRuntimeException.class)
+ public void testDivide_Invalid_1() {
+ MatrixBlock m1 = TestUtils.generateTestMatrixBlock(100, 10, 0,
10, 1.0, 1);
+ MatrixBlock m2 = TestUtils.generateTestMatrixBlock(1, 11, 0,
10, 1.0, 2);
+ executeDivide(m1, m2);
+ }
+
+ @Test(expected = DMLRuntimeException.class)
+ public void testDivide_Invalid_2() {
+ try {
+
+ MatrixBlock m1 = TestUtils.generateTestMatrixBlock(100,
10, 0, 10, 1.0, 1);
+ MatrixBlock m2 = TestUtils.generateTestMatrixBlock(1,
9, 0, 10, 1.0, 2);
+ executeDivide(m1, m2);
+ }
+ catch(DMLRuntimeException e) {
+ throw e;
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testDivide_matrixVector_emptyVector() {
+ try {
+
+ MatrixBlock m1 = TestUtils.generateTestMatrixBlock(100,
10, 0, 10, 1.0, 1);
+ MatrixBlock m2 = new MatrixBlock(1, 10, 0.0);
+ executeDivide(m1, m2);
+ }
+ catch(DMLRuntimeException e) {
+ throw e;
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ @Test
+ public void testDivide_matrixVector_sparseBoth() {
+ MatrixBlock m1 = TestUtils.generateTestMatrixBlock(100, 1000,
0, 10, 0.2, 1);
+ MatrixBlock m2 = TestUtils.generateTestMatrixBlock(1, 1000, 0,
10, 0.2, 2);
+ m1.examSparsity();
+ m2.examSparsity();
+ executeDivide(m1, m2);
+ }
+
+ @Test
+ public void testDivide_matrixVector_oneEmpty() {
+ MatrixBlock m1 = TestUtils.generateTestMatrixBlock(10, 10, 0,
10, 0.2, 1);
+ MatrixBlock m2 = TestUtils.generateTestMatrixBlock(1, 10, 0,
10, 0.0, 2);
+ m1.examSparsity();
+ m2.examSparsity();
+ executeDivide(m1, m2);
+ }
+
+ @Test
+ public void testOr_matrixMatrix_denseDense() {
+ MatrixBlock m1 = TestUtils.generateTestMatrixBlock(10, 10, 0,
10, 1.0, 1);
+ MatrixBlock m2 = TestUtils.generateTestMatrixBlock(10, 10, 0,
10, 1.0, 2);
+ executeOr(m1, m2);
+ }
+
+ @Test
+ public void testOr_matrixMatrix_denseSparse() {
+ MatrixBlock m1 = TestUtils.generateTestMatrixBlock(100, 100, 0,
10, 1.0, 1);
+ MatrixBlock m2 = TestUtils.generateTestMatrixBlock(100, 100, 0,
10, 0.1, 2);
+ executeOr(m1, m2);
+ }
+
+ @Test
+ public void testLT_matrixMatrix_denseDense() {
+ MatrixBlock m1 = TestUtils.generateTestMatrixBlock(10, 10, 0,
10, 1.0, 1);
+ MatrixBlock m2 = TestUtils.generateTestMatrixBlock(10, 10, 0,
10, 1.0, 2);
+ executeLT(m1, m2);
+ }
+
+ @Test
+ public void testLT_matrixMatrix_denseSparse() {
+ MatrixBlock m1 = TestUtils.generateTestMatrixBlock(100, 100, 0,
10, 1.0, 1);
+ MatrixBlock m2 = TestUtils.generateTestMatrixBlock(100, 100, 0,
10, 0.1, 2);
+ executeLT(m1, m2);
+ }
+
+ @Test
+ public void testLT_matrixMatrix_denseEmpty() {
+ MatrixBlock m1 = TestUtils.generateTestMatrixBlock(100, 100, 0,
10, 1.0, 1);
+ MatrixBlock m2 = TestUtils.generateTestMatrixBlock(100, 100, 0,
10, 0.0, 2);
+ executeLT(m1, m2);
+ }
+
+ @Test
+ public void testLT_matrixMatrix_EmptyDense() {
+ MatrixBlock m1 = TestUtils.generateTestMatrixBlock(100, 100, 0,
10, 0.0, 1);
+ MatrixBlock m2 = TestUtils.generateTestMatrixBlock(100, 100, 0,
10, 1.0, 2);
+ executeLT(m1, m2);
+ }
+
+ @Test
+ public void testLT_matrixMatrix_EmptySparse() {
+ MatrixBlock m1 = TestUtils.generateTestMatrixBlock(100, 100, 0,
10, 0.0, 1);
+ MatrixBlock m2 = TestUtils.generateTestMatrixBlock(100, 100, 0,
10, .1, 2);
+ executeLT(m1, m2);
+ }
+
+ @Test
+ public void testLT_matrixMatrix_EmptyEmpty() {
+ MatrixBlock m1 = TestUtils.generateTestMatrixBlock(100, 100, 0,
10, 0.0, 1);
+ MatrixBlock m2 = TestUtils.generateTestMatrixBlock(100, 100, 0,
10, 0.0, 2);
+ executeLT(m1, m2);
+ }
+
+ @Test
+ public void testPlus_matrixMatrix_DenseSparse() {
+ MatrixBlock m1 = TestUtils.generateTestMatrixBlock(100, 100, 0,
10, 1.0, 1);
+ MatrixBlock m2 = TestUtils.generateTestMatrixBlock(100, 100, 0,
10, 0.1, 2);
+ executePlus(m1, m2);
+ }
+
+ @Test
+ public void testMinus_matrixMatrix_DenseSparse() {
+ MatrixBlock m1 = TestUtils.generateTestMatrixBlock(100, 100, 0,
10, 1.0, 1);
+ MatrixBlock m2 = TestUtils.generateTestMatrixBlock(100, 100, 0,
10, 0.1, 2);
+ executeMinus(m1, m2);
+ }
+
+ @Test
+ public void testMinus_matrixMatrix_DenseDense() {
+ MatrixBlock m1 = TestUtils.generateTestMatrixBlock(100, 100, 0,
10, 1.0, 1);
+ MatrixBlock m2 = TestUtils.generateTestMatrixBlock(100, 100, 0,
10, 1.0, 2);
+ executeMinus(m1, m2);
+ }
+
+ @Test
+ public void testLT_matrixMatrix_DenseDense() {
+ MatrixBlock m1 = TestUtils.generateTestMatrixBlock(100, 100, 0,
10, 1.0, 1);
+ MatrixBlock m2 = TestUtils.generateTestMatrixBlock(100, 100, 0,
10, 1.0, 2);
+ executeLT(m1, m2);
+ }
+
+ @Test
+ public void testLT_matrixMatrix_DenseSparse() {
+ MatrixBlock m1 = TestUtils.generateTestMatrixBlock(100, 100, 0,
10, 1.0, 1);
+ MatrixBlock m2 = TestUtils.generateTestMatrixBlock(100, 100, 0,
10, 1.0, 2);
+ executeLT(m1, m2);
}
- private void execute(MatrixBlock m1, MatrixBlock m2){
+ @Test
+ public void testLT_matrixMatrix_SparseDense() {
+ MatrixBlock m1 = TestUtils.generateTestMatrixBlock(100, 100, 0,
10, 0.1, 1);
+ MatrixBlock m2 = TestUtils.generateTestMatrixBlock(100, 100, 0,
10, 1.0, 2);
+ assertTrue(m1.isInSparseFormat());
+ executeLT(m1, m2);
+ }
+
+ @Test
+ public void testPlus_matrixMatrix_SparseDense() {
+ MatrixBlock m1 = TestUtils.generateTestMatrixBlock(100, 100, 0,
10, 0.1, 1);
+ MatrixBlock m2 = TestUtils.generateTestMatrixBlock(100, 100, 0,
10, 1.0, 2);
+ assertTrue(m1.isInSparseFormat());
+ executePlus(m1, m2);
+ }
+
+ @Test
+ public void testPlus_matrixMatrix_SparseSparse() {
+ MatrixBlock m1 = TestUtils.generateTestMatrixBlock(100, 100, 0,
10, 0.1, 1);
+ MatrixBlock m2 = TestUtils.generateTestMatrixBlock(100, 100, 0,
10, 0.1, 2);
+ assertTrue(m1.isInSparseFormat());
+ executePlus(m1, m2);
+ }
+
+ @Test
+ public void testDiv_matrixMatrix_SparseSparse() {
+ MatrixBlock m1 = TestUtils.generateTestMatrixBlock(100, 100, 0,
10, 0.1, 1);
+ MatrixBlock m2 = TestUtils.generateTestMatrixBlock(100, 100, 0,
10, 0.1, 2);
+ assertTrue(m1.isInSparseFormat());
+ executeDivide(m1, m2);
+ }
+
+ private void executeDivide(MatrixBlock m1, MatrixBlock m2) {
+ BinaryOperator op = new
BinaryOperator(Divide.getDivideFnObject());
+ testInplace(m1, m2, op);
+ }
+
+ private void executePlus(MatrixBlock m1, MatrixBlock m2) {
BinaryOperator op = new BinaryOperator(Plus.getPlusFnObject());
+ testInplace(m1, m2, op);
+ }
+
+ private void executeMinus(MatrixBlock m1, MatrixBlock m2) {
+ BinaryOperator op = new
BinaryOperator(Minus.getMinusFnObject());
+ testInplace(m1, m2, op);
+ }
+
+ private void executeOr(MatrixBlock m1, MatrixBlock m2) {
+ BinaryOperator op = new BinaryOperator(Or.getOrFnObject());
+ testInplace(m1, m2, op);
+ }
+
+ private void executeLT(MatrixBlock m1, MatrixBlock m2) {
+ BinaryOperator op = new
BinaryOperator(LessThan.getLessThanFnObject());
+ testInplace(m1, m2, op);
+ }
+
+ private void testInplace(MatrixBlock m1, MatrixBlock m2, BinaryOperator
op) {
MatrixBlock ret1 = m1.binaryOperations(op, m2);
m1.binaryOperationsInPlace(op, m2);
-
TestUtils.compareMatricesBitAvgDistance(ret1, m1, 0, 0, "Result
is incorrect for inplace op");
}
}
diff --git
a/src/test/java/org/apache/sysds/test/component/matrix/BinaryOperationInPlaceTestParameterized.java
b/src/test/java/org/apache/sysds/test/component/matrix/BinaryOperationInPlaceTestParameterized.java
new file mode 100644
index 0000000000..05f29afcc5
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/component/matrix/BinaryOperationInPlaceTestParameterized.java
@@ -0,0 +1,190 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.matrix;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+
+import org.apache.commons.lang.NotImplementedException;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.functionobjects.And;
+import org.apache.sysds.runtime.functionobjects.BitwAnd;
+import org.apache.sysds.runtime.functionobjects.BitwOr;
+import org.apache.sysds.runtime.functionobjects.BitwShiftL;
+import org.apache.sysds.runtime.functionobjects.BitwShiftR;
+import org.apache.sysds.runtime.functionobjects.BitwXor;
+import org.apache.sysds.runtime.functionobjects.Builtin;
+import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode;
+import org.apache.sysds.runtime.functionobjects.Divide;
+import org.apache.sysds.runtime.functionobjects.Equals;
+import org.apache.sysds.runtime.functionobjects.GreaterThan;
+import org.apache.sysds.runtime.functionobjects.GreaterThanEquals;
+import org.apache.sysds.runtime.functionobjects.IntegerDivide;
+import org.apache.sysds.runtime.functionobjects.LessThan;
+import org.apache.sysds.runtime.functionobjects.LessThanEquals;
+import org.apache.sysds.runtime.functionobjects.Minus;
+import org.apache.sysds.runtime.functionobjects.MinusNz;
+import org.apache.sysds.runtime.functionobjects.Modulus;
+import org.apache.sysds.runtime.functionobjects.Multiply;
+import org.apache.sysds.runtime.functionobjects.NotEquals;
+import org.apache.sysds.runtime.functionobjects.Or;
+import org.apache.sysds.runtime.functionobjects.Plus;
+import org.apache.sysds.runtime.functionobjects.Power;
+import org.apache.sysds.runtime.functionobjects.Xor;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(value = Parameterized.class)
+public class BinaryOperationInPlaceTestParameterized {
+ protected static final Log LOG =
LogFactory.getLog(BinaryOperationInPlaceTestParameterized.class.getName());
+
+ private final MatrixBlock left;
+ private final MatrixBlock right;
+ private final BinaryOperator op;
+
+ public BinaryOperationInPlaceTestParameterized(MatrixBlock left,
MatrixBlock right, BinaryOperator op) {
+ this.left = new MatrixBlock();
+ this.right = right;
+ this.op = op;
+ this.left.copy(left);
+ }
+
+ @Parameters
+ public static Collection<Object[]> data() {
+ List<Object[]> tests = new ArrayList<>();
+
+ try {
+ double[] sparsities = new double[] {0.0, 0.001, 0.1,
0.5, 1.0};
+
+ BinaryOperator[] operators = new BinaryOperator[] {//
+ new BinaryOperator(Plus.getPlusFnObject()), //
+ new BinaryOperator(Minus.getMinusFnObject()), //
+ new BinaryOperator(Or.getOrFnObject()), //
+ new
BinaryOperator(LessThan.getLessThanFnObject()), //
+ new
BinaryOperator(LessThanEquals.getLessThanEqualsFnObject()), //
+ new
BinaryOperator(GreaterThan.getGreaterThanFnObject()), //
+ new
BinaryOperator(GreaterThanEquals.getGreaterThanEqualsFnObject()), //
+ new
BinaryOperator(Multiply.getMultiplyFnObject()), //
+ new BinaryOperator(Modulus.getFnObject()), //
+ new
BinaryOperator(IntegerDivide.getFnObject()), //
+ new BinaryOperator(Equals.getEqualsFnObject()),
//
+ new
BinaryOperator(NotEquals.getNotEqualsFnObject()), //
+ new BinaryOperator(And.getAndFnObject()), //
+ new BinaryOperator(Xor.getXorFnObject()), //
+ new
BinaryOperator(BitwAnd.getBitwAndFnObject()), //
+ new BinaryOperator(BitwOr.getBitwOrFnObject()),
//
+ new
BinaryOperator(BitwXor.getBitwXorFnObject()), //
+ new
BinaryOperator(BitwShiftL.getBitwShiftLFnObject()), //
+ new
BinaryOperator(BitwShiftR.getBitwShiftRFnObject()), //
+ new BinaryOperator(Power.getPowerFnObject()), //
+ new
BinaryOperator(MinusNz.getMinusNzFnObject()), //
+ // Builtin
+ new
BinaryOperator(Builtin.getBuiltinFnObject(BuiltinCode.MIN)), //
+ new
BinaryOperator(Builtin.getBuiltinFnObject(BuiltinCode.MAX)), //
+ new
BinaryOperator(Builtin.getBuiltinFnObject(BuiltinCode.LOG)), //
+ new
BinaryOperator(Builtin.getBuiltinFnObject(BuiltinCode.LOG_NZ)), //
+ new
BinaryOperator(Builtin.getBuiltinFnObject(BuiltinCode.MAXINDEX)), //
+ new
BinaryOperator(Builtin.getBuiltinFnObject(BuiltinCode.MININDEX)), //
+ new
BinaryOperator(Builtin.getBuiltinFnObject(BuiltinCode.CUMMAX)), //
+ new
BinaryOperator(Builtin.getBuiltinFnObject(BuiltinCode.CUMMIN)),//
+ };
+
+ for(double rightSparsity : sparsities) {
+ MatrixBlock right =
TestUtils.generateTestMatrixBlock(100, 100, 0, 10, rightSparsity, 2);
+ MatrixBlock rightV =
TestUtils.generateTestMatrixBlock(1, 100, 0, 10, rightSparsity, 2);
+ for(double leftSparsity : sparsities) {
+ MatrixBlock left =
TestUtils.generateTestMatrixBlock(100, 100, 0, 10, leftSparsity, 2);
+ for(BinaryOperator op : operators) {
+ tests.add(new Object[] {left,
right, op});
+ tests.add(new Object[] {left,
rightV, op});
+ }
+ }
+ }
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ fail("failed constructing tests");
+ }
+
+ return tests;
+ }
+
+ @Test
+ public void testInplace() {
+ try {
+ final int lrb = left.getNumRows();
+ final int lcb = left.getNumColumns();
+ final int rrb = right.getNumRows();
+ final int rcb = right.getNumColumns();
+
+ final double lspb = left.getSparsity();
+ final double rspb = right.getSparsity();
+
+ final MatrixBlock ret1 = left.binaryOperations(op,
right);
+
+ assertEquals(lrb, left.getNumRows());
+ assertEquals(lcb, left.getNumColumns());
+ assertEquals(rrb, right.getNumRows());
+ assertEquals(rcb, right.getNumColumns());
+
+ left.binaryOperationsInPlace(op, right);
+
+ assertEquals(lrb, left.getNumRows());
+ assertEquals(lcb, left.getNumColumns());
+ assertEquals(rrb, right.getNumRows());
+ assertEquals(rcb, right.getNumColumns());
+ TestUtils.compareMatricesBitAvgDistance(ret1, left, 0,
0, "Result is incorrect for inplace \n" + op + " "
+ + lspb + " " + rspb + " (" + lrb + "," + lcb +
")" + " (" + rrb + "," + rcb + ")");
+ }
+ catch(DMLRuntimeException e) {
+ if(e.getMessage().contains("Invalid row safety of
inplace row operation: ")) {
+ if(op.fn instanceof Divide || //
+ op.fn instanceof Plus || //
+ op.fn instanceof Minus || //
+ op.fn instanceof Or)
+ return;
+ }
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ catch(NotImplementedException e) {
+ // TODO fix the not implemented instances.
+ if(e.getMessage().contains("Not made sparse vector
inplace"))
+ return;
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+}