This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new c17fcb2ceb [SYSTEMDS-3762] Cumulative Row Aggregates and Rewrites
c17fcb2ceb is described below
commit c17fcb2ceb76df2816fbd2c140be4ca837fa040d
Author: Vi Vuong <[email protected]>
AuthorDate: Fri Jul 25 15:02:07 2025 +0200
[SYSTEMDS-3762] Cumulative Row Aggregates and Rewrites
Closes #2279.
Co-authored-by: Vien Thanh Thai <[email protected]>
Co-authored-by: Sandra Shi <[email protected]>
---
docs/site/dml-language-reference.md | 3 +-
.../java/org/apache/sysds/common/Builtins.java | 1 +
src/main/java/org/apache/sysds/common/Opcodes.java | 3 +
src/main/java/org/apache/sysds/common/Types.java | 4 +-
.../RewriteAlgebraicSimplificationStatic.java | 23 ++
.../sysds/parser/BuiltinFunctionExpression.java | 1 +
.../org/apache/sysds/parser/DMLTranslator.java | 1 +
.../sysds/runtime/functionobjects/Builtin.java | 3 +-
.../runtime/instructions/InstructionUtils.java | 2 +
.../spark/CumulativeOffsetSPInstruction.java | 2 +
.../spark/UnaryMatrixSPInstruction.java | 195 +++++++++++++++
.../sysds/runtime/matrix/data/LibMatrixAgg.java | 107 ++++++++-
.../RewriteSimplifyTransposedCumsumTest.java | 261 +++++++++++++++++++++
.../functions/unary/matrix/FullRowcumsumTest.java | 242 +++++++++++++++++++
.../rewrite/RewriteSimplifyTransposedCumsum.R | 36 +++
.../rewrite/RewriteSimplifyTransposedCumsum.dml | 25 ++
.../scripts/functions/unary/matrix/Rowcumsum.R | 36 +++
.../scripts/functions/unary/matrix/Rowcumsum.dml | 26 ++
18 files changed, 967 insertions(+), 4 deletions(-)
diff --git a/docs/site/dml-language-reference.md
b/docs/site/dml-language-reference.md
index abafb74a5a..264b3c6a2b 100644
--- a/docs/site/dml-language-reference.md
+++ b/docs/site/dml-language-reference.md
@@ -702,7 +702,8 @@ quantile () | The p-quantile for a random variable X is the
value x such that Pr
quantile () | Returns a column matrix with list of all quantiles requested in
P. | Input: (X <(n x 1) matrix>, [W <(n x 1) matrix>),] P <(q x
1) matrix>) <br/> Output: matrix | quantile(X, P) <br/> quantile(X, W, P)
median() | Computes the median in a given column matrix of values | Input: (X
<(n x 1) matrix>, [W <(n x 1) matrix>),]) <br/> Output:
<scalar> | median(X) <br/> median(X,W)
rowSums() <br/> rowMeans() <br/> rowVars() <br/> rowSds() <br/> rowMaxs()
<br/> rowMins() | Row-wise computations -- for each row, compute the
sum/mean/variance/stdDev/max/min of cell value | Input: matrix <br/> Output: (n
x 1) matrix | rowSums(X) <br/> rowMeans(X) <br/> rowVars(X) <br/> rowSds(X)
<br/> rowMaxs(X) <br/> rowMins(X)
-cumsum() | Column prefix-sum (For row-prefix sum, use cumsum(t(X)) | Input:
matrix <br/> Output: matrix of the same dimensions | A = matrix("1 2 3 4 5 6",
rows=3, cols=2) <br/> B = cumsum(A) <br/> The output matrix B = [[1, 2], [4,
6], [9, 12]]
+cumsum() | Column prefix-sum | Input: matrix <br/> Output: matrix of the same
dimensions | A = matrix("1 2 3 4 5 6", rows=3, cols=2) <br/> B = cumsum(A)
<br/> The output matrix B = [[1, 2], [4, 6], [9, 12]]
+rowcumsum() | Row prefix-sum | Input: matrix <br/> Output: matrix of the same
dimensions | A = matrix("1 2 3 4 5 6", rows=2, cols=3) <br/> B = rowcumsum(A)
<br/> The output matrix B = [[1, 3, 6], [4, 9, 15]]
cumprod() | Column prefix-prod (For row-prefix prod, use cumprod(t(X)) |
Input: matrix <br/> Output: matrix of the same dimensions | A = matrix("1 2 3 4
5 6", rows=3, cols=2) <br/> B = cumprod(A) <br/> The output matrix B = [[1, 2],
[3, 8], [15, 48]]
cummin() | Column prefix-min (For row-prefix min, use cummin(t(X)) | Input:
matrix <br/> Output: matrix of the same dimensions | A = matrix("3 4 1 6 5 2",
rows=3, cols=2) <br/> B = cummin(A) <br/> The output matrix B = [[3, 4], [1,
4], [1, 2]]
cummax() | Column prefix-max (For row-prefix min, use cummax(t(X)) | Input:
matrix <br/> Output: matrix of the same dimensions | A = matrix("3 4 1 6 5 2",
rows=3, cols=2) <br/> B = cummax(A) <br/> The output matrix B = [[3, 4], [3,
6], [5, 6]]
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java
b/src/main/java/org/apache/sysds/common/Builtins.java
index 423679d038..fe75aec6a0 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -291,6 +291,7 @@ public enum Builtins {
ROLL("roll", false),
ROUND("round", false),
ROW_COUNT_DISTINCT("rowCountDistinct",false),
+ ROWCUMSUM("rowcumsum", false),
ROWINDEXMAX("rowIndexMax", false),
ROWINDEXMIN("rowIndexMin", false),
ROWMAX("rowMaxs", false),
diff --git a/src/main/java/org/apache/sysds/common/Opcodes.java
b/src/main/java/org/apache/sysds/common/Opcodes.java
index fd5c6bfd12..64a6c7dd27 100644
--- a/src/main/java/org/apache/sysds/common/Opcodes.java
+++ b/src/main/java/org/apache/sysds/common/Opcodes.java
@@ -34,6 +34,7 @@ public enum Opcodes {
UAKP("uak+", InstructionType.AggregateUnary),
UARKP("uark+", InstructionType.AggregateUnary),
UACKP("uack+", InstructionType.AggregateUnary),
+ UARCKP("uarck+", InstructionType.AggregateUnary),
UASQKP("uasqk+", InstructionType.AggregateUnary),
UARSQKP("uarsqk+", InstructionType.AggregateUnary),
UACSQKP("uacsqk+", InstructionType.AggregateUnary),
@@ -151,6 +152,7 @@ public enum Opcodes {
CEIL("ceil", InstructionType.Unary),
FLOOR("floor", InstructionType.Unary),
UCUMKP("ucumk+", InstructionType.Unary),
+ UROWCUMKP("urowcumk+", InstructionType.Unary),
UCUMM("ucum*", InstructionType.Unary),
UCUMKPM("ucumk+*", InstructionType.Unary),
UCUMMIN("ucummin", InstructionType.Unary),
@@ -383,6 +385,7 @@ public enum Opcodes {
UCUMACMIN("ucumacmin", InstructionType.CumsumAggregate),
UCUMACMAX("ucumacmax", InstructionType.CumsumAggregate),
BCUMOFFKP("bcumoffk+", InstructionType.CumsumOffset),
+ BROWCUMOFFKP("browcumoffk+", InstructionType.CumsumOffset),
BCUMOFFM("bcumoff*", InstructionType.CumsumOffset),
BCUMOFFPM("bcumoff+*", InstructionType.CumsumOffset),
BCUMOFFMIN("bcumoffmin", InstructionType.CumsumOffset),
diff --git a/src/main/java/org/apache/sysds/common/Types.java
b/src/main/java/org/apache/sysds/common/Types.java
index fc6e1610ca..cc7f6eb377 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -547,7 +547,7 @@ public interface Types {
CEIL, CHOLESKY, COS, COSH, CUMMAX, CUMMIN, CUMPROD, CUMSUM,
CUMSUMPROD, DET, DETECTSCHEMA, COLNAMES, EIGEN, EXISTS, EXP,
FLOOR, INVERSE,
IQM, ISNA, ISNAN, ISINF, LENGTH, LINEAGE, LOG, NCOL, NOT, NROW,
- MEDIAN, PREFETCH, PRINT, ROUND, SIN, SINH, SIGN, SOFTMAX, SQRT,
STOP, _EVICT,
+ MEDIAN, PREFETCH, PRINT, ROUND, ROWCUMSUM, SIN, SINH, SIGN,
SOFTMAX, SQRT, STOP, _EVICT,
SVD, TAN, TANH, TYPEOF, TRIGREMOTE, SQRT_MATRIX_JAVA,
//fused ML-specific operators for performance
SPROP, //sample proportion: P * (1 - P)
@@ -591,6 +591,7 @@ public interface Types {
case MULT2: return
Opcodes.MULT2.toString();
case NOT: return
Opcodes.NOT.toString();
case POW2: return
Opcodes.POW2.toString();
+ case ROWCUMSUM: return
Opcodes.UROWCUMKP.toString();
case TYPEOF: return
Opcodes.TYPEOF.toString();
default: return
name().toLowerCase();
}
@@ -610,6 +611,7 @@ public interface Types {
case "ucummin": return CUMMIN;
case "ucum*": return CUMPROD;
case "ucumk+": return CUMSUM;
+ case "urowcumk+": return ROWCUMSUM;
case "ucumk+*": return CUMSUMPROD;
case "detectSchema": return DETECTSCHEMA;
case "*2": return MULT2;
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index fd4445cf44..2ae1550257 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -205,6 +205,7 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
hi = simplifyNotOverComparisons(hop, hi, i);
//e.g., !(A>B) -> (A<=B)
hi = simplifyMatrixScalarPMOperation(hop, hi, i);
//e.g., a-A-b -> (a-b)-A; a+A-b -> (a-b)+A
//hi = removeUnecessaryPPred(hop, hi, i);
//e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X))
+ hi = simplifyTransposedCumsum(hop, hi, i);
//e.g., t(cumsum(t(X))) -> rowcumsum(X)
//process childs recursively after rewrites (to
investigate pattern newly created by rewrites)
if( !descendFirst )
@@ -214,6 +215,28 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
hop.setVisited();
}
+ private static Hop simplifyTransposedCumsum( Hop parent, Hop hi, int
pos )
+ {
+ //e.g., t(cumsum(t(X))) -> rowcumsum(X)
+ if( HopRewriteUtils.isTransposeOperation(hi)
+ && hi.getInput(0) instanceof UnaryOp
+ && ((UnaryOp)hi.getInput(0)).getOp() ==
OpOp1.CUMSUM
+ && hi.getInput(0).getParent().size() == 1
+ &&
HopRewriteUtils.isTransposeOperation(hi.getInput(0).getInput(0), 1)) //inner
transpose with single consumer
+ {
+ UnaryOp cumsum=(UnaryOp)hi.getInput(0);
+ Hop innerMatrix = cumsum.getInput(0).getInput(0);
+
+ UnaryOp rowcumsumOp =
HopRewriteUtils.createUnary(innerMatrix, OpOp1.ROWCUMSUM);
+ HopRewriteUtils.replaceChildReference(parent,hi,
rowcumsumOp, pos);
+
+ hi = rowcumsumOp;
+ LOG.debug("Applied simplifyTransposedCumsum (line
"+hi.getBeginLine()+").");
+ }
+
+ return hi;
+ }
+
private Hop simplifyMatrixScalarPMOperation(Hop parent, Hop hi, int
pos) {
if (!(hi instanceof BinaryOp))
return hi;
diff --git
a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
index ae582b052b..540b522a8b 100644
--- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
@@ -1034,6 +1034,7 @@ public class BuiltinFunctionExpression extends
DataIdentifier {
break;
case CUMSUM:
+ case ROWCUMSUM:
case CUMPROD:
case CUMSUMPROD:
case CUMMIN:
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index c4f7f672ab..6827bcc4bf 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -2616,6 +2616,7 @@ public class DMLTranslator
case CEIL:
case FLOOR:
case CUMSUM:
+ case ROWCUMSUM:
case CUMPROD:
case CUMSUMPROD:
case CUMMIN:
diff --git
a/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java
b/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java
index 6b196489ea..8e9aef9466 100644
--- a/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java
+++ b/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java
@@ -49,7 +49,7 @@ public class Builtin extends ValueFunction
public enum BuiltinCode { AUTODIFF, SIN, COS, TAN, SINH, COSH, TANH,
ASIN, ACOS, ATAN, LOG, LOG_NZ, MIN,
MAX, ABS, SIGN, SQRT, EXP, PLOGP, PRINT, PRINTF, NROW, NCOL,
LENGTH, LINEAGE, ROUND, MAXINDEX, MININDEX,
- STOP, CEIL, FLOOR, CUMSUM, CUMPROD, CUMMIN, CUMMAX, CUMSUMPROD,
INVERSE, SPROP, SIGMOID, EVAL, LIST,
+ STOP, CEIL, FLOOR, CUMSUM, ROWCUMSUM, CUMPROD, CUMMIN, CUMMAX,
CUMSUMPROD, INVERSE, SPROP, SIGMOID, EVAL, LIST,
TYPEOF, APPLY_SCHEMA, DETECTSCHEMA, ISNA, ISNAN, ISINF,
DROP_INVALID_TYPE,
DROP_INVALID_LENGTH, VALUE_SWAP, FRAME_ROW_REPLICATE,
MAP, COUNT_DISTINCT, COUNT_DISTINCT_APPROX, UNIQUE}
@@ -95,6 +95,7 @@ public class Builtin extends ValueFunction
String2BuiltinCode.put( "ceil" , BuiltinCode.CEIL);
String2BuiltinCode.put( "floor" , BuiltinCode.FLOOR);
String2BuiltinCode.put( "ucumk+" , BuiltinCode.CUMSUM);
+ String2BuiltinCode.put( "urowcumk+" , BuiltinCode.ROWCUMSUM);
String2BuiltinCode.put( "ucum*" , BuiltinCode.CUMPROD);
String2BuiltinCode.put( "ucumk+*", BuiltinCode.CUMSUMPROD);
String2BuiltinCode.put( "ucummin", BuiltinCode.CUMMIN);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
index e244e9cd27..da3de02419 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
@@ -557,6 +557,8 @@ public class InstructionUtils {
Builtin f = (Builtin)uop.fn;
if( f.getBuiltinCode()==BuiltinCode.CUMSUM )
return
parseBasicAggregateUnaryOperator(Opcodes.UACKP.toString()) ;
+ else if( f.getBuiltinCode()==BuiltinCode.ROWCUMSUM )
+ return
parseBasicAggregateUnaryOperator(Opcodes.UARCKP.toString()) ;
else if( f.getBuiltinCode()==BuiltinCode.CUMPROD )
return
parseBasicAggregateUnaryOperator(Opcodes.UACM.toString()) ;
else if( f.getBuiltinCode()==BuiltinCode.CUMMIN )
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/CumulativeOffsetSPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/CumulativeOffsetSPInstruction.java
index 9f469922ad..61b61b1533 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/CumulativeOffsetSPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/CumulativeOffsetSPInstruction.java
@@ -56,6 +56,8 @@ public class CumulativeOffsetSPInstruction extends
BinarySPInstruction {
if (Opcodes.BCUMOFFKP.toString().equals(opcode))
_uop = new
UnaryOperator(Builtin.getBuiltinFnObject("ucumk+"));
+ else if (Opcodes.BROWCUMOFFKP.toString().equals(opcode))
+ _uop = new
UnaryOperator(Builtin.getBuiltinFnObject("urowcumk+"));
else if (Opcodes.BCUMOFFM.toString().equals(opcode))
_uop = new
UnaryOperator(Builtin.getBuiltinFnObject("ucum*"));
else if (Opcodes.BCUMOFFPM.toString().equals(opcode)) {
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/UnaryMatrixSPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/UnaryMatrixSPInstruction.java
index e2653f4310..eebeef0b2e 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/UnaryMatrixSPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/UnaryMatrixSPInstruction.java
@@ -21,16 +21,27 @@ package org.apache.sysds.runtime.instructions.spark;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
+import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
+import org.apache.sysds.runtime.functionobjects.KahanPlus;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.KahanObject;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
+import scala.Serializable;
+import scala.Tuple2;
+
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
public class UnaryMatrixSPInstruction extends UnarySPInstruction {
@@ -61,6 +72,190 @@ public class UnaryMatrixSPInstruction extends
UnarySPInstruction {
updateUnaryOutputDataCharacteristics(sec);
sec.setRDDHandleForVariable(output.getName(), out);
sec.addLineageRDD(output.getName(), input1.getName());
+
+ //FIXME: implement similar to cumsum through
+ // CumulativeAggregateSPInstruction (Spark)
+ // UnaryMatrixCPInstruction (local cumsum on aggregates)
+ // CumulativeOffsetSPInstruction (Spark)
+ if ( "urowcumk+".equals(getOpcode()) ) {
+
+ JavaPairRDD< MatrixIndexes, Tuple2<MatrixBlock,
MatrixBlock> > localRowcumsum = in.mapToPair( new LocalRowCumsumFunction() );
+
+ // Collect end-values of every block of every row for
offset calc by grouping by global row index
+ JavaPairRDD< Long, Iterable<Tuple3<Long, Long,
double[]>> > rowEndValues = localRowcumsum
+ .mapToPair( tuple2 -> {
+ // get index of block
+ MatrixIndexes indexes = tuple2._1;
+ // get cum matrix block
+ MatrixBlock localRowcumsumBlock =
tuple2._2._2;
+
+ // get row and column block index
+ long rowBlockIndex =
indexes.getRowIndex();
+ long colBlockIndex =
indexes.getColumnIndex();
+
+ // Save end value of every row of every
block (if block is empty save 0)
+ double[] endValues = new double[
localRowcumsumBlock.getNumRows() ];
+
+ for ( int i = 0; i <
localRowcumsumBlock.getNumRows(); i ++ ) {
+ if
(localRowcumsumBlock.getNumColumns() > 0)
+ endValues[i] =
localRowcumsumBlock.get(i, localRowcumsumBlock.getNumColumns() - 1);
+ else
+ endValues[i] = 0.0 ;
+ }
+ return new Tuple2<>(rowBlockIndex, new
Tuple3<>(rowBlockIndex, colBlockIndex, endValues));
+ }
+ ).groupByKey();
+
+ // compute offset for every block
+ List< Tuple2 <Tuple2<Long, Long>, double[]> >
offsetList = rowEndValues
+ .flatMapToPair(tuple2 -> {
+ Long rowBlockIndex = tuple2._1;
+ List< Tuple3<Long, Long, double[]> >
colValues = new ArrayList<>();
+ for ( Tuple3<Long, Long, double[]> cv :
tuple2._2 )
+ colValues.add(cv);
+
+ // sort blocks from one row by column
index
+
colValues.sort(Comparator.comparing(Tuple3::_2));
+
+ // get number of rows of a block by
counting amount of end (row) values of said block
+ int numberOfRows = 0;
+ if ( !colValues.isEmpty() ) {
+ Tuple3<Long, Long, double[]>
firstTuple = colValues.get(0);
+ double[] lastValuesArray =
firstTuple._3();
+ numberOfRows =
lastValuesArray.length;
+ }
+
+ List<Tuple2<Tuple2<Long, Long>,
double[]>> blockOffsets = new ArrayList<>();
+ double[] cumulativeOffsets = new
double[numberOfRows];
+ for (Tuple3<Long, Long, double[]>
colValue : colValues) {
+ Long colBlockIndex =
colValue._2();
+ double[] endValues =
colValue._3();
+
+ // copy current offsets
+ double[] currentOffsets =
cumulativeOffsets.clone();
+
+ // and save block indexes with
its offsets
+ blockOffsets.add( new
Tuple2<>(new Tuple2<>(rowBlockIndex, colBlockIndex), currentOffsets) );
+
+ for ( int i = 0; i <
numberOfRows && i < endValues.length; i++ ) {
+ cumulativeOffsets[i] +=
endValues[i];
+ }
+ }
+ return blockOffsets.iterator();
+ }
+ ).collect();
+
+ // convert list to map for easier access to offsets
+ Map< Tuple2<Long, Long>, double[] > offsetMap = new
HashMap<>();
+ for (Tuple2<Tuple2<Long, Long>, double[]> offset :
offsetList) {
+ offsetMap.put(offset._1, offset._2);
+ }
+
+ out = localRowcumsum.mapToPair( new
FinalRowCumsumFunction(offsetMap)) ;
+
+ updateUnaryOutputDataCharacteristics(sec);
+ sec.setRDDHandleForVariable(output.getName(), out);
+ sec.addLineageRDD(output.getName(), input1.getName());
+ }
+ }
+
+
+
+ private static class LocalRowCumsumFunction implements PairFunction<
Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, Tuple2<MatrixBlock,
MatrixBlock> > {
+ private static final long serialVersionUID =
2388003441846068046L;
+
+ @Override
+ public Tuple2< MatrixIndexes, Tuple2<MatrixBlock, MatrixBlock>
> call(Tuple2<MatrixIndexes, MatrixBlock> tuple2) {
+
+
+ MatrixBlock inputBlock = tuple2._2;
+ MatrixBlock cumsumBlock = new MatrixBlock(
inputBlock.getNumRows(), inputBlock.getNumColumns(), false );
+
+
+ for ( int i = 0; i < inputBlock.getNumRows(); i++ ) {
+
+ KahanObject kbuff = new KahanObject(0, 0);
+ KahanPlus kplus =
KahanPlus.getKahanPlusFnObject();
+
+ for ( int j = 0; j <
inputBlock.getNumColumns(); j++ ) {
+
+ double val = inputBlock.get(i, j);
+ kplus.execute2(kbuff, val);
+ cumsumBlock.set(i, j, kbuff._sum);
+ }
+ }
+ // original index, original matrix and local cumsum
block
+ return new Tuple2<>( tuple2._1, new
Tuple2<>(inputBlock, cumsumBlock) );
+ }
+ }
+
+
+
+
+ private static class FinalRowCumsumFunction implements
PairFunction<Tuple2< MatrixIndexes, Tuple2<MatrixBlock, MatrixBlock> >,
MatrixIndexes, MatrixBlock> {
+ private static final long serialVersionUID =
-6738155890298916270L;
+ // map block indexes to the row offsets
+ private final Map< Tuple2<Long, Long>, double[] > offsetMap;
+
+ public FinalRowCumsumFunction(Map<Tuple2<Long, Long>, double[]>
offsetMap) {
+ this.offsetMap = offsetMap;
+ }
+
+
+ @Override
+ public Tuple2<MatrixIndexes, MatrixBlock> call( Tuple2<
MatrixIndexes, Tuple2<MatrixBlock, MatrixBlock> > tuple ) {
+
+ MatrixIndexes indexes = tuple._1;
+ MatrixBlock inputBlock = tuple._2._1;
+ MatrixBlock localRowCumsumBlock = tuple._2._2;
+
+ // key to get the row offset for this block
+ Tuple2<Long, Long> blockKey = new Tuple2<>(
indexes.getRowIndex(), indexes.getColumnIndex()) ;
+ double[] offsets = offsetMap.get(blockKey);
+
+ MatrixBlock cumsumBlock = new MatrixBlock(
inputBlock.getNumRows(), inputBlock.getNumColumns(), false );
+
+
+ for ( int i = 0; i < inputBlock.getNumRows(); i++ ) {
+
+ double rowOffset = 0.0;
+ if ( offsets != null && i < offsets.length ) {
+ rowOffset = offsets[i];
+ }
+
+ for ( int j = 0; j <
inputBlock.getNumColumns(); j++ ) {
+ double cumsumValue =
localRowCumsumBlock.get(i, j);
+ cumsumBlock.set(i, j, cumsumValue +
rowOffset);
+ }
+ }
+
+ // block index and final cumsum block
+ return new Tuple2<>(indexes, cumsumBlock);
+ }
+ }
+
+
+
+ // helper class
+ private static class Tuple3<Type1, Type2, Type3> implements
Serializable {
+
+ private static final long serialVersionUID = 123;
+ private final Type2 _2;
+ private final Type3 _3;
+
+
+ public Tuple3( Type1 _1, Type2 _2, Type3 _3 ) {
+ this._2 = _2;
+ this._3 = _3;
+ }
+
+ public Type2 _2() {
+ return _2;
+ }
+
+ public Type3 _3() {
+ return _3;
+ }
}
private static class RDDMatrixBuiltinUnaryOp implements
Function<MatrixBlock,MatrixBlock>
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java
index 3cba9fb8c5..59301db7ec 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java
@@ -112,6 +112,7 @@ public class LibMatrixAgg {
SUM,
SUM_SQ,
CUM_KAHAN_SUM,
+ ROW_CUM_SUM,
CUM_MIN,
CUM_MAX,
CUM_PROD,
@@ -783,6 +784,7 @@ public class LibMatrixAgg {
BuiltinCode bfunc = ((Builtin) vfn).bFunc;
switch( bfunc ) {
case CUMSUM: return AggType.CUM_KAHAN_SUM;
+ case ROWCUMSUM: return AggType.ROW_CUM_SUM;
case CUMPROD: return AggType.CUM_PROD;
case CUMMIN: return AggType.CUM_MIN;
case CUMMAX: return AggType.CUM_MAX;
@@ -1548,6 +1550,12 @@ public class LibMatrixAgg {
d_ucumkp(in.getDenseBlock(), null,
out.getDenseBlock(), n, kbuff, kplus, rl, ru);
break;
}
+ case ROW_CUM_SUM: { //ROWCUMSUM
+ KahanObject kbuff = new KahanObject(0, 0);
+ KahanPlus kplus =
KahanPlus.getKahanPlusFnObject();
+ d_urowcumkp(in.getDenseBlock(), null,
out.getDenseBlock(), n, kbuff, kplus, rl, ru);
+ break;
+ }
case CUM_PROD: { //CUMPROD
d_ucumm(in.getDenseBlockValues(), null,
out.getDenseBlockValues(), n, rl, ru);
break;
@@ -1666,6 +1674,12 @@ public class LibMatrixAgg {
s_ucumkp(a, null, out.getDenseBlock(), m, n,
kbuff, kplus, rl, ru);
break;
}
+ case ROW_CUM_SUM: { //ROWCUMSUM
+ KahanObject kbuff = new KahanObject(0, 0);
+ KahanPlus kplus =
KahanPlus.getKahanPlusFnObject();
+ s_urowcumkp(a, null, out.getDenseBlock(), m, n,
kbuff, kplus, rl, ru);
+ break;
+ }
case CUM_PROD: { //CUMPROD
s_ucumm(a, null, out.getDenseBlockValues(), n,
rl, ru);
break;
@@ -1747,6 +1761,12 @@ public class LibMatrixAgg {
d_ucumkp(da, agg, dc, n, kbuff, kplus, rl, ru);
break;
}
+ case ROW_CUM_SUM: { //ROWCUMSUM
+ KahanObject kbuff = new KahanObject(0, 0);
+ KahanPlus kplus =
KahanPlus.getKahanPlusFnObject();
+ d_urowcumkp(da, agg, dc, n, kbuff, kplus, rl,
ru);
+ break;
+ }
case CUM_SUM_PROD: { //CUMSUMPROD
if( n != 2 )
throw new
DMLRuntimeException("Cumsumprod expects two-column input (n="+n+").");
@@ -1791,6 +1811,12 @@ public class LibMatrixAgg {
s_ucumkp(a, agg, dc, m, n, kbuff, kplus, rl,
ru);
break;
}
+ case ROW_CUM_SUM: { //ROWCUMSUM
+ KahanObject kbuff = new KahanObject(0, 0);
+ KahanPlus kplus =
KahanPlus.getKahanPlusFnObject();
+ s_urowcumkp(a, agg, dc, m, n, kbuff, kplus, rl,
ru);
+ break;
+ }
case CUM_SUM_PROD: { //CUMSUMPROD
if( n != 2 )
throw new
DMLRuntimeException("Cumsumprod expects two-column input (n="+n+").");
@@ -1821,6 +1847,7 @@ public class LibMatrixAgg {
case SUM:
case SUM_SQ:
case KAHAN_SUM:
+ case ROW_CUM_SUM:
case KAHAN_SUM_SQ: val = 0; break;
case MIN: val =
Double.POSITIVE_INFINITY; break;
case MAX: val =
Double.NEGATIVE_INFINITY; break;
@@ -1838,7 +1865,7 @@ public class LibMatrixAgg {
if(optype == AggType.KAHAN_SUM || optype == AggType.KAHAN_SUM_SQ
|| optype == AggType.SUM || optype ==
AggType.SUM_SQ
|| optype == AggType.MIN || optype ==
AggType.MAX || optype == AggType.PROD
- || optype == AggType.CUM_KAHAN_SUM || optype ==
AggType.CUM_PROD
+ || optype == AggType.CUM_KAHAN_SUM || optype ==
AggType.ROW_CUM_SUM || optype == AggType.CUM_PROD
|| optype == AggType.CUM_MIN || optype ==
AggType.CUM_MAX)
{
return out;
@@ -2099,6 +2126,39 @@ public class LibMatrixAgg {
c.set(i, csums.values(0));
}
}
+
+ /**
+ * ROWCUMSUM, opcode: urowcumk+, dense input.
+ *
+ * @param a input matrix
+ * @param agg initial array
+ * @param c output matrix
+ * @param n number of rows
+ * @param kbuff collects sum
+ * @param kplus sums up
+ * @param rl row lower index
+ * @param ru row upper index
+ */
+ private static void d_urowcumkp( DenseBlock a, double[] agg, DenseBlock
c, int n, KahanObject kbuff, KahanPlus kplus, int rl, int ru ) {
+ //row-wise cumulative sum w/ optional row offsets
+ for (int i = rl; i < ru; i++) {
+ double start = 0.0;
+ int localRow = i - rl;
+ if (agg != null) {
+ if (localRow >= 0 && localRow < agg.length) {
+ start = agg[localRow];
+ }
+ }
+ kbuff.set(start, 0);
+ //compute cumulative sum over row
+ for (int j = 0; j < n; j++) {
+ double val = a.get(i, j);
+ kplus.execute2(kbuff, val);
+ c.set(i, j, kbuff._sum);
+ }
+
+ }
+ }
/**
* CUMSUMPROD, opcode: ucumk+*, dense input.
@@ -2750,6 +2810,51 @@ public class LibMatrixAgg {
c.set(i, csums.values(0));
}
}
+
+ /**
+ * ROWCUMSUM, opcode: urowcumk+, sparse input.
+ *
+ * @param a input matrix
+ * @param agg intial array
+ * @param c output matrix
+ * @param m number of columns
+ * @param n number of rows
+ * @param kbuff collects sum
+ * @param kplus sums up
+ * @param rl row lower index
+ * @param ru row upper index
+ */
+ private static void s_urowcumkp(SparseBlock a, double[] agg, DenseBlock
c, int m, int n, KahanObject kbuff, KahanPlus kplus, int rl, int ru) {
+ //scan rows and compute row-wise prefix sums
+ for (int i = rl; i < ru; i++) {
+ double start = 0.0;
+ int localRow = i - rl;
+ if (agg != null && localRow >= 0 && localRow <
agg.length)
+ start = agg[localRow];
+ if (!a.isEmpty(i)) {
+ double[] ain = a.values(i);
+ int[] aix = a.indexes(i);
+ int apos = a.pos(i);
+ int alen = a.size(i);
+ kbuff.set(start, 0);
+ int sparseIdx = 0;
+ //prefix sum over sparse row
+ for (int j = 0; j < n; j++) {
+ if (sparseIdx < alen && aix[apos +
sparseIdx] == j) {
+ kplus.execute2(kbuff, ain[apos
+ sparseIdx]);
+ start = kbuff._sum;
+ sparseIdx++;
+ }
+ c.set(i, j, start);
+ }
+ }
+ else {
+ //fill empty row with start value
+ for (int j = 0; j < n; j++)
+ c.set(i, j, start);
+ }
+ }
+ }
/**
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTransposedCumsumTest.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTransposedCumsumTest.java
new file mode 100644
index 0000000000..d526b21fc8
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTransposedCumsumTest.java
@@ -0,0 +1,261 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.rewrite;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Opcodes;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+import org.junit.AfterClass;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import java.util.HashMap;
+
+
+public class RewriteSimplifyTransposedCumsumTest extends AutomatedTestBase{
+ private static final String TEST_NAME =
"RewriteSimplifyTransposedCumsum";
+ private static final String TEST_DIR = "functions/rewrite/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
RewriteSimplifyTransposedCumsumTest.class.getSimpleName() + "/";
+
+ private static final double eps = 1e-10;
+
+ private static final int rowsMatrix = 1201;
+ private static final int colsMatrix = 1103;
+ private static final double spSparse = 0.1;
+ private static final double spDense = 0.9;
+
+ private enum InputType {
+ COL_VECTOR,
+ ROW_VECTOR,
+ MATRIX
+ }
+
+ @Override
+ public void setUp() {
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"B"}));
+ if (TEST_CACHE_ENABLED) {
+ setOutAndExpectedDeletionDisabled(true);
+ }
+ }
+
+ @BeforeClass
+ public static void init() {
+ TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR);
+ }
+
+ @AfterClass
+ public static void cleanUp() {
+ if (TEST_CACHE_ENABLED) {
+ TestUtils.clearDirectory(TEST_DATA_DIR +
TEST_CLASS_DIR);
+ }
+ }
+
+ // dense cp
+ @Test
+ public void testRewriteMatrixDenseCPNoRewrite() {
+ testRewriteSimplifyRowcumsum(InputType.MATRIX, false,
ExecType.CP, false);
+ }
+ @Test
+ public void testRewriteMatrixDenseCP() {
+ testRewriteSimplifyRowcumsum(InputType.MATRIX, false,
ExecType.CP, true);
+ }
+
+ @Test
+ public void testRewriteColVectorDenseCPNoRewrite() {
+ testRewriteSimplifyRowcumsum(InputType.COL_VECTOR, false,
ExecType.CP, false);
+ }
+ @Test
+ public void testRewriteColVectorDenseCP(){
+ testRewriteSimplifyRowcumsum(InputType.COL_VECTOR, false,
ExecType.CP, true);
+ }
+
+ @Test
+ public void testRewriteRowVectorDenseCPNoRewrite() {
+ testRewriteSimplifyRowcumsum(InputType.ROW_VECTOR, false,
ExecType.CP, false);
+ }
+ @Test
+ public void testRewriteRowVectorDenseCP(){
+ testRewriteSimplifyRowcumsum(InputType.ROW_VECTOR, false,
ExecType.CP, true);
+ }
+
+ // sparse cp
+ @Test
+ public void testRewriteMatrixSparseCPNoRewrite() {
+ testRewriteSimplifyRowcumsum(InputType.MATRIX, true,
ExecType.CP, false);
+ }
+ @Test
+ public void testRewriteMatrixSparseCP() {
+ testRewriteSimplifyRowcumsum(InputType.MATRIX, true,
ExecType.CP, true);
+ }
+
+ @Test
+ public void testRewriteColVectorSparseCPNoRewrite() {
+ testRewriteSimplifyRowcumsum(InputType.COL_VECTOR, true,
ExecType.CP, false);
+ }
+ @Test
+ public void testRewriteColVectorSparseCP() {
+ testRewriteSimplifyRowcumsum(InputType.COL_VECTOR, true,
ExecType.CP, true);
+ }
+
+ @Test
+ public void testRewriteRowVectorSparseCPNoRewrite() {
+ testRewriteSimplifyRowcumsum(InputType.ROW_VECTOR, true,
ExecType.CP, false);
+ }
+ @Test
+ public void testRewriteRowVectorSparseCP() {
+ testRewriteSimplifyRowcumsum(InputType.ROW_VECTOR, true,
ExecType.CP, true);
+ }
+
+ // dense sp
+ @Test
+ public void testRewriteMatrixDenseSPNoRewrite() {
+ testRewriteSimplifyRowcumsum(InputType.MATRIX, false,
ExecType.SPARK, false);
+ }
+ @Test
+ public void testRewriteMatrixDenseSP() {
+ testRewriteSimplifyRowcumsum(InputType.MATRIX, false,
ExecType.SPARK, true);
+ }
+
+ @Test
+ public void testRewriteColVectorDenseSPNoRewrite() {
+ testRewriteSimplifyRowcumsum(InputType.COL_VECTOR, false,
ExecType.SPARK, false);
+ }
+ @Test
+ public void testRewriteColVectorDenseSP() {
+ testRewriteSimplifyRowcumsum(InputType.COL_VECTOR, false,
ExecType.SPARK, true);
+ }
+
+ @Test
+ public void testRewriteRowVectorDenseSPNoRewrite() {
+ testRewriteSimplifyRowcumsum(InputType.ROW_VECTOR, false,
ExecType.SPARK, false);
+ }
+ @Test
+ public void testRewriteRowVectorDenseSP() {
+ testRewriteSimplifyRowcumsum(InputType.ROW_VECTOR, false,
ExecType.SPARK, true);
+ }
+
+ // sparse sp
+ @Test
+ public void testRewriteMatrixSparseSPNoRewrite() {
+ testRewriteSimplifyRowcumsum(InputType.MATRIX, true,
ExecType.SPARK, false);
+ }
+ @Test
+ public void testRewriteMatrixSparseSP() {
+ testRewriteSimplifyRowcumsum(InputType.MATRIX, true,
ExecType.SPARK, true);
+ }
+
+ @Test
+ public void testRewriteColVectorSparseSPNoRewrite() {
+ testRewriteSimplifyRowcumsum(InputType.COL_VECTOR, true,
ExecType.SPARK, false);
+ }
+ @Test
+ public void testRewriteColVectorSparseSP() {
+ testRewriteSimplifyRowcumsum(InputType.COL_VECTOR, true,
ExecType.SPARK, true);
+ }
+
+ @Test
+ public void testRewriteRowVectorSparseSPNoRewrite() {
+ testRewriteSimplifyRowcumsum(InputType.ROW_VECTOR, true,
ExecType.SPARK, false);
+ }
+ @Test
+ public void testRewriteRowVectorSparseSP() {
+ testRewriteSimplifyRowcumsum(InputType.ROW_VECTOR, true,
ExecType.SPARK, true);
+ }
+
+ private void testRewriteSimplifyRowcumsum(InputType type, boolean
sparse, ExecType instType, boolean rewrites) {
+
+ ExecMode platformOld = rtplatform;
+ switch( instType ){
+ case SPARK: rtplatform = ExecMode.SPARK; break;
+ default: rtplatform = ExecMode.HYBRID; break;
+ }
+
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ if( rtplatform == ExecMode.SPARK )
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+ //rewrites
+ boolean oldFlagRewrites =
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites;
+
+
+ try {
+ // Determine matrix dimensions based on InputType
+ int rows = (type == InputType.ROW_VECTOR) ? 1 :
rowsMatrix;
+ int cols = (type == InputType.COL_VECTOR) ? 1 :
colsMatrix;
+ double sparsity = (sparse) ? spSparse : spDense;
+
+ String TEST_CACHE_DIR = !TEST_CACHE_ENABLED ? "" :
+ type.ordinal() + "_" + sparsity + "/";
+
+ TestConfiguration config =
getTestConfiguration(TEST_NAME);
+ loadTestConfiguration(config, TEST_CACHE_DIR);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-stats", "-args",
input("A"), output("B")};
+
+ fullRScriptName = HOME + TEST_NAME + ".R";
+ rCmd = "Rscript" + " " + fullRScriptName + " " +
inputDir() + " " + expectedDir();
+
+ // create and write matrix
+ double[][] A = getRandomMatrix(rows, cols, -0.05, 1,
sparsity, 7);
+ writeInputMatrixWithMTD("A", A, true);
+
+ runTest(true, false, null, -1);
+ if( instType == ExecType.CP ) {
+ Assert.assertEquals("Unexpected number of
executed Spark jobs.", 0, Statistics.getNoOfExecutedSPInst());
+ }
+
+ runRScript(true);
+
+ //compare matrices
+ HashMap<MatrixValue.CellIndex, Double> dmlfile =
readDMLMatrixFromOutputDir("B");
+ HashMap<MatrixValue.CellIndex, Double> rfile =
readRMatrixFromExpectedDir("B");
+ TestUtils.compareMatrices(dmlfile, rfile, eps,
"Stat-DML", "Stat-R");
+
+ // Assertions for opcodes
+ if(rewrites) {
+ // rewrite is enabled: double transposed CUMSUM
and CUMSUM is not found, ROWCUMSUM operation is found
+
Assert.assertFalse(heavyHittersContainsString(Opcodes.TRANSPOSE.toString()) ||
heavyHittersContainsString("sp_r'"));
+
Assert.assertFalse(heavyHittersContainsString(Opcodes.UCUMKP.toString()) ||
heavyHittersContainsString("sp_bcumoffk+"));
+
Assert.assertTrue(heavyHittersContainsString(Opcodes.UROWCUMKP.toString()) ||
heavyHittersContainsString("sp_urowcumk+"));
+ } else {
+ // rewrite is disabled: double transposed
CUMSUM and CUMSUM is found, ROWCUMSUM operation is not found
+
Assert.assertTrue(heavyHittersContainsString(Opcodes.TRANSPOSE.toString()) ||
heavyHittersContainsString("sp_r'"));
+
Assert.assertTrue(heavyHittersContainsString(Opcodes.UCUMKP.toString()) ||
heavyHittersContainsString("sp_bcumoffk+"));
+
Assert.assertFalse(heavyHittersContainsString(Opcodes.UROWCUMKP.toString()) ||
heavyHittersContainsString("sp_urowcumk+"));
+ }
+ }
+ finally {
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION =
oldFlagRewrites;
+ }
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/unary/matrix/FullRowcumsumTest.java
b/src/test/java/org/apache/sysds/test/functions/unary/matrix/FullRowcumsumTest.java
new file mode 100644
index 0000000000..b18027783c
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/unary/matrix/FullRowcumsumTest.java
@@ -0,0 +1,242 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.unary.matrix;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+import org.junit.AfterClass;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import java.util.HashMap;
+
+public class FullRowcumsumTest extends AutomatedTestBase
+{
+ private final static String TEST_NAME = "Rowcumsum";
+ private final static String TEST_DIR = "functions/unary/matrix/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
FullRowcumsumTest.class.getSimpleName() + "/";
+
+ private final static double eps = 1e-10;
+
+ private final static int rowsMatrix = 1201;
+ private final static int colsMatrix = 1103;
+ private final static double spSparse = 0.1;
+ private final static double spDense = 0.9;
+
+ private enum InputType {
+ COL_VECTOR,
+ ROW_VECTOR,
+ MATRIX
+ }
+
+ @Override
+ public void setUp()
+ {
+ addTestConfiguration(TEST_NAME,new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME,new String[]{"B"}));
+
+ if (TEST_CACHE_ENABLED) {
+ setOutAndExpectedDeletionDisabled(true);
+ }
+ }
+
+ @BeforeClass
+ public static void init() {
+ TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR);
+ }
+
+ @AfterClass
+ public static void cleanUp() {
+ if (TEST_CACHE_ENABLED) {
+ TestUtils.clearDirectory(TEST_DATA_DIR +
TEST_CLASS_DIR);
+ }
+ }
+
+ @Test
+ public void testRowcumsumColVectorDenseCP() {
+ runColAggregateOperationTest(InputType.COL_VECTOR, false,
ExecType.CP);
+ }
+
+ @Test
+ public void testRowcumsumRowVectorDenseCP() {
+ runColAggregateOperationTest(InputType.ROW_VECTOR, false,
ExecType.CP);
+ }
+
+ @Test
+ public void testRowcumsumRowVectorDenseNoRewritesCP() {
+ runColAggregateOperationTest(InputType.ROW_VECTOR, false,
ExecType.CP, false);
+ }
+
+ @Test
+ public void testRowcumsumColVectorDenseNoRewritesCP() {
+ runColAggregateOperationTest(InputType.COL_VECTOR, false,
ExecType.CP, false);
+ }
+
+ @Test
+ public void testRowcumsumMatrixDenseCP() {
+ runColAggregateOperationTest(InputType.MATRIX, false,
ExecType.CP);
+ }
+
+ @Test
+ public void testRowcumsumColVectorSparseCP() {
+ runColAggregateOperationTest(InputType.COL_VECTOR, true,
ExecType.CP);
+ }
+
+ @Test
+ public void testRowcumsumRowVectorSparseCP() {
+ runColAggregateOperationTest(InputType.ROW_VECTOR, true,
ExecType.CP);
+ }
+
+ @Test
+ public void testRowcumsumRowVectorSparseNoRewritesCP() {
+ runColAggregateOperationTest(InputType.ROW_VECTOR, true,
ExecType.CP, false);
+ }
+
+ @Test
+ public void testRowcumsumColVectorSparseNoRewritesCP() {
+ runColAggregateOperationTest(InputType.COL_VECTOR, true,
ExecType.CP, false);
+ }
+
+ @Test
+ public void testRowcumsumMatrixSparseCP() {
+ runColAggregateOperationTest(InputType.MATRIX, true,
ExecType.CP);
+ }
+
+ @Test
+ public void testRowcumsumColVectorDenseSP() {
+ runColAggregateOperationTest(InputType.COL_VECTOR, false,
ExecType.SPARK);
+ }
+
+ @Test
+ public void testRowcumsumRowVectorDenseSP() {
+ runColAggregateOperationTest(InputType.ROW_VECTOR, false,
ExecType.SPARK);
+ }
+
+ @Test
+ public void testRowcumsumRowVectorDenseNoRewritesSP() {
+ runColAggregateOperationTest(InputType.ROW_VECTOR, false,
ExecType.SPARK, false);
+ }
+
+ @Test
+ public void testRowcumsumColVectorDenseNoRewritesSP() {
+ runColAggregateOperationTest(InputType.COL_VECTOR, false,
ExecType.SPARK, false);
+ }
+
+ @Test
+ public void testRowcumsumMatrixDenseSP() {
+ runColAggregateOperationTest(InputType.MATRIX, false,
ExecType.SPARK);
+ }
+
+ @Test
+ public void testRowcumsumColVectorSparseSP() {
+ runColAggregateOperationTest(InputType.COL_VECTOR, true,
ExecType.SPARK);
+ }
+
+ @Test
+ public void testRowcumsumRowVectorSparseSP() {
+ runColAggregateOperationTest(InputType.ROW_VECTOR, true,
ExecType.SPARK);
+ }
+
+ @Test
+ public void testRowcumsumRowVectorSparseNoRewritesSP() {
+ runColAggregateOperationTest(InputType.ROW_VECTOR, true,
ExecType.SPARK, false);
+ }
+
+ @Test
+ public void testRowcumsumColVectorSparseNoRewritesSP() {
+ runColAggregateOperationTest(InputType.COL_VECTOR, true,
ExecType.SPARK, false);
+ }
+
+ @Test
+ public void testRowcumsumMatrixSparseSP() {
+ runColAggregateOperationTest(InputType.MATRIX, true,
ExecType.SPARK);
+ }
+
+ private void runColAggregateOperationTest( InputType type, boolean
sparse, ExecType instType) {
+ //by default we apply algebraic simplification rewrites
+ runColAggregateOperationTest(type, sparse, instType, true);
+ }
+
+ private void runColAggregateOperationTest( InputType type, boolean
sparse, ExecType instType, boolean rewrites)
+ {
+ ExecMode platformOld = rtplatform;
+ switch( instType ){
+ case SPARK: rtplatform = ExecMode.SPARK; break;
+ default: rtplatform = ExecMode.HYBRID; break;
+ }
+
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ if( rtplatform == ExecMode.SPARK )
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+ //rewrites
+ boolean oldFlagRewrites =
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites;
+
+ try
+ {
+ int cols = (type== InputType.COL_VECTOR) ? 1 :
colsMatrix;
+ int rows = (type== InputType.ROW_VECTOR) ? 1 :
rowsMatrix;
+ double sparsity = (sparse) ? spSparse : spDense;
+
+ String TEST_CACHE_DIR = !TEST_CACHE_ENABLED ? "" :
+ type.ordinal() + "_" + sparsity + "/";
+
+ TestConfiguration config =
getTestConfiguration(TEST_NAME);
+ loadTestConfiguration(config, TEST_CACHE_DIR);
+
+ // This is for running the junit test the new way,
i.e., construct the arguments directly
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[]{"-explain", "-args",
input("A"), output("B") };
+
+ fullRScriptName = HOME + TEST_NAME + ".R";
+ rCmd = "Rscript" + " " + fullRScriptName + " " +
inputDir() + " " + expectedDir();
+
+ //generate actual dataset
+ double[][] A = getRandomMatrix(rows, cols, -0.05, 1,
sparsity, 7);
+ writeInputMatrixWithMTD("A", A, true);
+
+ runTest(true, false, null, -1);
+ if( instType==ExecType.CP ) //in CP no spark jobs
should be executed
+ Assert.assertEquals("Unexpected number of
executed MR jobs.", 0, Statistics.getNoOfExecutedSPInst());
+
+ runRScript(true);
+
+ //compare matrices
+ HashMap<CellIndex, Double> dmlfile =
readDMLMatrixFromOutputDir("B");
+ HashMap<CellIndex, Double> rfile =
readRMatrixFromExpectedDir("B");
+ TestUtils.compareMatrices(dmlfile, rfile, eps,
"Stat-DML", "Stat-R");
+ }
+ finally
+ {
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION =
oldFlagRewrites;
+ }
+ }
+}
diff --git
a/src/test/scripts/functions/rewrite/RewriteSimplifyTransposedCumsum.R
b/src/test/scripts/functions/rewrite/RewriteSimplifyTransposedCumsum.R
new file mode 100644
index 0000000000..7c80e304f6
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteSimplifyTransposedCumsum.R
@@ -0,0 +1,36 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+args <- commandArgs(TRUE)
+options(digits=22)
+
+library("Matrix")
+
+A = as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+
+if( ncol(A)>1 ){
+ B = t(apply(A, 1, cumsum));
+} else {
+ B = A;
+}
+
+writeMM(as(B, "CsparseMatrix"), paste(args[2], "B", sep=""));
\ No newline at end of file
diff --git
a/src/test/scripts/functions/rewrite/RewriteSimplifyTransposedCumsum.dml
b/src/test/scripts/functions/rewrite/RewriteSimplifyTransposedCumsum.dml
new file mode 100644
index 0000000000..c2dff2fdec
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteSimplifyTransposedCumsum.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+A = read($1);
+B = t(cumsum(t(A))); #this should trigger the rewrite: t(cumsum(t(A))) ->
rowcumsum(A)
+write(B, $2);
\ No newline at end of file
diff --git a/src/test/scripts/functions/unary/matrix/Rowcumsum.R
b/src/test/scripts/functions/unary/matrix/Rowcumsum.R
new file mode 100644
index 0000000000..7c80e304f6
--- /dev/null
+++ b/src/test/scripts/functions/unary/matrix/Rowcumsum.R
@@ -0,0 +1,36 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+args <- commandArgs(TRUE)
+options(digits=22)
+
+library("Matrix")
+
+A = as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+
+if( ncol(A)>1 ){
+ B = t(apply(A, 1, cumsum));
+} else {
+ B = A;
+}
+
+writeMM(as(B, "CsparseMatrix"), paste(args[2], "B", sep=""));
\ No newline at end of file
diff --git a/src/test/scripts/functions/unary/matrix/Rowcumsum.dml
b/src/test/scripts/functions/unary/matrix/Rowcumsum.dml
new file mode 100644
index 0000000000..8b4b4fb0a9
--- /dev/null
+++ b/src/test/scripts/functions/unary/matrix/Rowcumsum.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+A = read($1);
+B = rowcumsum(A);
+write(B, $2);
+