This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new a46189ce7b [SYSTEMDS-3805] Rewrite and runtime for scalar right 
indexing
a46189ce7b is described below

commit a46189ce7b72992b0597c9cea819abdb390c7b66
Author: Matthias Boehm <[email protected]>
AuthorDate: Wed Dec 11 18:07:02 2024 +0100

    [SYSTEMDS-3805] Rewrite and runtime for scalar right indexing
    
    This patch adds a new rewrite, as well as modifies existing rewrites
    and runtime instructions in order to perform scalar right indexing
    for operations like as.scalar(X[i,1]) which avoids unnecessary
    createvar and cast instructions. On a scenario of running the baseline
    (non-vectorized) exponential smoothing on 10M data points, the patch
    improved end-to-end performance from from 22.3s to 12.2s (6.7s without
    statistics time measurements).
    
    alpha = 0.05
    r = as.scalar(X[1, 1])
    for(i in 2:nrow(X)) {
      r = alpha * as.scalar(X[i, 1]) + (1-alpha) * r
    }
    
    Total elapsed time:             22.348 sec.
    Total compilation time:         0.516 sec.
    Total execution time:           21.832 sec.
    Cache hits (Mem/Li/WB/FS/HDFS): 20000000/0/0/0/0.
    Cache writes (Li/WB/FS/HDFS):   1/0/0/0.
    Cache times (ACQr/m, RLS, EXP): 0.777/0.432/1.124/0.000 sec.
    HOP DAGs recompiled (PRED, SB): 0/0.
    HOP DAGs recompile time:        0.300 sec.
    Functions recompiled:           1.
    Functions recompile time:       0.002 sec.
    Total JIT compile time:         2.608 sec.
    Total JVM GC count:             1.
    Total JVM GC time:              0.018 sec.
    Heavy hitter instructions:
      1  rightIndex     4.894  10000000
      2  createvar      3.585  10000001
      3  rmvar          2.848  30000000
      4  castdts        2.242  10000000
      5  *              1.742  19999998
      6  +              0.898   9999999
      7  mvvar          0.751  10000002
      8  rand           0.213         1
      9  -              0.016         1
     10  print          0.000         1
     11  assignvar      0.000         2
    
    Total elapsed time:             12.589 sec.
    Total compilation time:         0.520 sec.
    Total execution time:           12.069 sec.
    Cache hits (Mem/Li/WB/FS/HDFS): 10000000/0/0/0/0.
    Cache writes (Li/WB/FS/HDFS):   1/0/0/0.
    Cache times (ACQr/m, RLS, EXP): 0.455/0.000/0.463/0.000 sec.
    HOP DAGs recompiled (PRED, SB): 0/0.
    HOP DAGs recompile time:        0.313 sec.
    Functions recompiled:           1.
    Functions recompile time:       0.002 sec.
    Total JIT compile time:         1.923 sec.
    Total JVM GC count:             1.
    Total JVM GC time:              0.011 sec.
    Heavy hitter instructions:
      1  rightIndex     3.046  10000000
      2  *              1.876  19999998
      3  rmvar          1.450  20000000
      4  +              0.954   9999999
      5  mvvar          0.801  10000002
      6  rand           0.213         1
      7  -              0.018         1
      8  print          0.000         1
      9  createvar      0.000         1
     10  assignvar      0.000         2
---
 .../java/org/apache/sysds/hops/IndexingOp.java     |  4 +
 .../apache/sysds/hops/rewrite/HopRewriteUtils.java |  2 +-
 .../RewriteAlgebraicSimplificationDynamic.java     |  2 +-
 .../RewriteAlgebraicSimplificationStatic.java      | 22 ++++++
 .../hops/rewrite/RewriteIndexingVectorization.java |  6 +-
 .../cp/MatrixIndexingCPInstruction.java            | 53 +++++++------
 .../instructions/cp/VariableCPInstruction.java     |  5 ++
 .../spark/MatrixIndexingSPInstruction.java         | 40 ++++++----
 .../rewrite/RewriteLoopVectorization.java          |  2 +
 .../rewrite/RewriteScalarRightIndexingTest.java    | 92 ++++++++++++++++++++++
 .../rewrite/RewriteScalarRightIndexing.dml         | 34 ++++++++
 11 files changed, 220 insertions(+), 42 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/IndexingOp.java 
b/src/main/java/org/apache/sysds/hops/IndexingOp.java
index 35215fa843..1756724e74 100644
--- a/src/main/java/org/apache/sysds/hops/IndexingOp.java
+++ b/src/main/java/org/apache/sysds/hops/IndexingOp.java
@@ -73,6 +73,10 @@ public class IndexingOp extends Hop
                setRowLowerEqualsUpper(passedRowsLEU);
                setColLowerEqualsUpper(passedColsLEU);
        }
+       
+       public boolean isScalarOutput() {
+               return isRowLowerEqualsUpper() && isColLowerEqualsUpper();
+       }
 
        public boolean isRowLowerEqualsUpper(){
                return _rowLowerEqualsUpper;
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java 
b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
index 68167ac3ae..aae2787cd3 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
@@ -1332,7 +1332,7 @@ public class HopRewriteUtils {
        }
        
        public static boolean isUnnecessaryRightIndexing(Hop hop) {
-               if( !(hop instanceof IndexingOp) )
+               if( !(hop instanceof IndexingOp) || hop.isScalar() )
                        return false;
                //note: in addition to equal sizes, we also check a valid
                //starting row and column ranges of 1 in order to guard against
diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 396c40d114..9c1f2174d0 100644
--- 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -241,7 +241,7 @@ public class RewriteAlgebraicSimplificationDynamic extends 
HopRewriteRule
        
        private static Hop removeUnnecessaryRightIndexing(Hop parent, Hop hi, 
int pos)
        {
-               if( HopRewriteUtils.isUnnecessaryRightIndexing(hi) ) {
+               if( HopRewriteUtils.isUnnecessaryRightIndexing(hi) && 
!hi.isScalar() ) {
                        //remove unnecessary right indexing
                        Hop input = hi.getInput().get(0);
                        HopRewriteUtils.replaceChildReference(parent, hi, 
input, pos);
diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 5a79bdee33..d06f89d72e 100644
--- 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -174,6 +174,7 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                        hi = simplifyTraceMatrixMult(hop, hi, i);            
//e.g., trace(X%*%Y)->sum(X*t(Y));
                        hi = simplifySlicedMatrixMult(hop, hi, i);           
//e.g., (X%*%Y)[1,1] -> X[1,] %*% Y[,1];
                        hi = simplifyListIndexing(hi);                       
//e.g., L[i:i, 1:ncol(L)] -> L[i:i, 1:1]
+                       hi = simplifyScalarIndexing(hop, hi, i);             
//e.g., as.scalar(X[i,1])->X[i,1] w/ scalar output
                        hi = simplifyConstantSort(hop, hi, i);               
//e.g., order(matrix())->matrix/seq;
                        hi = simplifyOrderedSort(hop, hi, i);                
//e.g., order(matrix())->seq;
                        hi = fuseOrderOperationChain(hi);                    
//e.g., order(order(X,2),1) -> order(X,(12))
@@ -1508,6 +1509,27 @@ public class RewriteAlgebraicSimplificationStatic 
extends HopRewriteRule
                return hi;
        }
 
+       private static Hop simplifyScalarIndexing(Hop parent, Hop hi, int pos)
+       {
+               //as.scalar(X[i,1]) -> X[i,1] w/ scalar output
+               if( HopRewriteUtils.isUnary(hi, OpOp1.CAST_AS_SCALAR) 
+                       && hi.getInput(0).getParent().size() == 1 // only 
consumer
+                       && hi.getInput(0) instanceof IndexingOp 
+                       && ((IndexingOp)hi.getInput(0)).isScalarOutput() 
+                       && hi.getInput(0).isMatrix() //no frame support yet 
+                       && !HopRewriteUtils.isData(parent, 
OpOpData.TRANSIENTWRITE)) 
+               {
+                       Hop hi2 = hi.getInput().get(0);
+                       hi2.setDataType(DataType.SCALAR); 
+                       hi2.setDim1(0); hi2.setDim2(0);
+                       HopRewriteUtils.replaceChildReference(parent, hi, hi2, 
pos);
+                       HopRewriteUtils.cleanupUnreferenced(hi);
+                       hi = hi2;
+                       LOG.debug("Applied simplifyScalarIndexing (line 
"+hi.getBeginLine()+").");
+               }
+               return hi;
+       }
+       
        private static Hop simplifyConstantSort(Hop parent, Hop hi, int pos)
        {
                //order(matrix(7), indexreturn=FALSE) -> matrix(7)
diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteIndexingVectorization.java 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteIndexingVectorization.java
index 9c04959ed5..6da9e52132 100644
--- 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteIndexingVectorization.java
+++ 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteIndexingVectorization.java
@@ -186,8 +186,8 @@ public class RewriteIndexingVectorization extends 
HopRewriteRule
                                ihops.add(ihop0);
                                for( Hop c : input.getParent() ){
                                        if( c != ihop0 && c instanceof 
IndexingOp && c.getInput().get(0) == input
-                                          && ((IndexingOp) 
c).isRowLowerEqualsUpper() 
-                                          && 
c.getInput().get(1)==ihop0.getInput().get(1) )
+                                               && ((IndexingOp) 
c).isRowLowerEqualsUpper() && !c.isScalar()
+                                               && 
c.getInput().get(1)==ihop0.getInput().get(1) )
                                        {
                                                ihops.add( c );
                                        }
@@ -225,7 +225,7 @@ public class RewriteIndexingVectorization extends 
HopRewriteRule
                                ihops.add(ihop0);
                                for( Hop c : input.getParent() ){
                                        if( c != ihop0 && c instanceof 
IndexingOp && c.getInput().get(0) == input
-                                          && ((IndexingOp) 
c).isColLowerEqualsUpper() 
+                                          && ((IndexingOp) 
c).isColLowerEqualsUpper() && !c.isScalar()
                                           && 
c.getInput().get(3)==ihop0.getInput().get(3) )
                                        {
                                                ihops.add( c );
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixIndexingCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixIndexingCPInstruction.java
index afbf7724ab..99473b7a49 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixIndexingCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixIndexingCPInstruction.java
@@ -52,39 +52,46 @@ public final class MatrixIndexingCPInstruction extends 
IndexingCPInstruction {
                String opcode = getOpcode();
                IndexRange ix = getIndexRange(ec);
                
-               //get original matrix
                MatrixObject mo = ec.getMatrixObject(input1.getName());
+               boolean inRange = ix.rowStart < mo.getNumRows() && ix.colStart 
< mo.getNumColumns();
                
                //right indexing
                if( opcode.equalsIgnoreCase(RightIndex.OPCODE) )
                {
-                       MatrixBlock resultBlock = null;
-                       
-                       if( mo.isPartitioned() ) //via data partitioning
-                               resultBlock = mo.readMatrixPartition(ix.add(1));
-                       else if( ix.isScalar() && ix.rowStart < mo.getNumRows() 
&& ix.colStart < mo.getNumColumns() ) {
+                       if( output.isScalar() && inRange ) { //SCALAR out
                                MatrixBlock matBlock = 
mo.acquireReadAndRelease();
-                               resultBlock = new MatrixBlock(
-                                       matBlock.get((int)ix.rowStart, 
(int)ix.colStart));
+                               ec.setScalarOutput(output.getName(),
+                                       new 
DoubleObject(matBlock.get((int)ix.rowStart, (int)ix.colStart)));
                        }
-                       else //via slicing the in-memory matrix
-                       {
-                               //execute right indexing operation (with 
shallow row copies for range
-                               //of entire sparse rows, which is safe due to 
copy on update)
-                               MatrixBlock matBlock = mo.acquireRead();
-                               resultBlock = matBlock.slice((int)ix.rowStart, 
(int)ix.rowEnd, 
-                                       (int)ix.colStart, (int)ix.colEnd, 
false, new MatrixBlock());
+                       else { //MATRIX out
+                               MatrixBlock resultBlock = null;
                                
-                               //unpin rhs input
-                               ec.releaseMatrixInput(input1.getName());
+                               if( mo.isPartitioned() ) //via data partitioning
+                                       resultBlock = 
mo.readMatrixPartition(ix.add(1));
+                               else if( ix.isScalar() && inRange ) {
+                                       MatrixBlock matBlock = 
mo.acquireReadAndRelease();
+                                       resultBlock = new MatrixBlock(
+                                               matBlock.get((int)ix.rowStart, 
(int)ix.colStart));
+                               }
+                               else //via slicing the in-memory matrix
+                               {
+                                       //execute right indexing operation 
(with shallow row copies for range
+                                       //of entire sparse rows, which is safe 
due to copy on update)
+                                       MatrixBlock matBlock = mo.acquireRead();
+                                       resultBlock = 
matBlock.slice((int)ix.rowStart, (int)ix.rowEnd, 
+                                               (int)ix.colStart, 
(int)ix.colEnd, false, new MatrixBlock());
+                                       
+                                       //unpin rhs input
+                                       ec.releaseMatrixInput(input1.getName());
+                                       
+                                       //ensure correct sparse/dense output 
representation
+                                       if( 
checkGuardedRepresentationChange(matBlock, resultBlock) )
+                                               resultBlock.examSparsity();
+                               }
                                
-                               //ensure correct sparse/dense output 
representation
-                               if( checkGuardedRepresentationChange(matBlock, 
resultBlock) )
-                                       resultBlock.examSparsity();
+                               //unpin output
+                               ec.setMatrixOutput(output.getName(), 
resultBlock);
                        }
-                       
-                       //unpin output
-                       ec.setMatrixOutput(output.getName(), resultBlock);
                }
                //left indexing
                else if ( opcode.equalsIgnoreCase(LeftIndex.OPCODE))
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
index 8826c41b80..3ae0a96f0b 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
@@ -897,6 +897,11 @@ public class VariableCPInstruction extends CPInstruction 
implements LineageTrace
                                ec.setVariable(output.getName(), list.slice(0));
                                break;
                        }
+                       case SCALAR: {
+                               //for robustness in case rewrites added 
unnecessary as.scalars
+                               ec.setScalarOutput(output.getName(), 
ec.getScalarInput(getInput1()));
+                               break;
+                       }
                        default:
                                throw new DMLRuntimeException("Unsupported data 
type "
                                        + "in as.scalar(): 
"+getInput1().getDataType().name());
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/MatrixIndexingSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/MatrixIndexingSPInstruction.java
index ac2d8f4f22..ceaaea2ded 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/MatrixIndexingSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/MatrixIndexingSPInstruction.java
@@ -35,6 +35,7 @@ import 
org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.DoubleObject;
 import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
 import org.apache.sysds.runtime.instructions.spark.data.LazyIterableIterator;
 import org.apache.sysds.runtime.instructions.spark.data.PartitionedBroadcast;
@@ -47,6 +48,7 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
 import org.apache.sysds.runtime.matrix.data.OperationsOnMatrixValues;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 import org.apache.sysds.runtime.util.IndexRange;
 import org.apache.sysds.runtime.util.UtilFunctions;
 import scala.Function1;
@@ -103,26 +105,35 @@ public class MatrixIndexingSPInstruction extends 
IndexingSPInstruction {
                if( opcode.equalsIgnoreCase(RightIndex.OPCODE) )
                {
                        //update and check output dimensions
-                       DataCharacteristics mcOut = 
sec.getDataCharacteristics(output.getName());
+                       DataCharacteristics mcOut = output.isScalar() ? 
+                               new MatrixCharacteristics(1,1) :
+                               ec.getDataCharacteristics(output.getName());
                        mcOut.set(ru-rl+1, cu-cl+1, mcIn.getBlocksize(), 
mcIn.getBlocksize());
                        mcOut.setNonZerosBound(Math.min(mcOut.getLength(), 
mcIn.getNonZerosBound()));
                        checkValidOutputDimensions(mcOut);
                        
                        //execute right indexing operation 
(partitioning-preserving if possible)
                        JavaPairRDD<MatrixIndexes,MatrixBlock> in1 = 
sec.getBinaryMatrixBlockRDDHandleForVariable( input1.getName() );
-                       
-                       if( isSingleBlockLookup(mcIn, ixrange) ) {
-                               sec.setMatrixOutput(output.getName(), 
singleBlockIndexing(in1, mcIn, mcOut, ixrange));
-                       }
-                       else if( isMultiBlockLookup(in1, mcIn, mcOut, ixrange) 
) {
-                               sec.setMatrixOutput(output.getName(), 
multiBlockIndexing(in1, mcIn, mcOut, ixrange));
+               
+                       if( output.isScalar() ) { //SCALAR output
+                               MatrixBlock ret = singleBlockIndexing(in1, 
mcIn, mcOut, ixrange);
+                               sec.setScalarOutput(output.getName(), new 
DoubleObject(ret.get(0, 0)));
                        }
-                       else { //rdd output for general case
-                               JavaPairRDD<MatrixIndexes,MatrixBlock> out = 
generalCaseRightIndexing(in1, mcIn, mcOut, ixrange, _aggType);
+                       else { //MATRIX output
                                
-                               //put output RDD handle into symbol table
-                               sec.setRDDHandleForVariable(output.getName(), 
out);
-                               sec.addLineageRDD(output.getName(), 
input1.getName());
+                               if( isSingleBlockLookup(mcIn, ixrange) ) {
+                                       sec.setMatrixOutput(output.getName(), 
singleBlockIndexing(in1, mcIn, mcOut, ixrange));
+                               }
+                               else if( isMultiBlockLookup(in1, mcIn, mcOut, 
ixrange) ) {
+                                       sec.setMatrixOutput(output.getName(), 
multiBlockIndexing(in1, mcIn, mcOut, ixrange));
+                               }
+                               else { //rdd output for general case
+                                       JavaPairRDD<MatrixIndexes,MatrixBlock> 
out = generalCaseRightIndexing(in1, mcIn, mcOut, ixrange, _aggType);
+                                       
+                                       //put output RDD handle into symbol 
table
+                                       
sec.setRDDHandleForVariable(output.getName(), out);
+                                       sec.addLineageRDD(output.getName(), 
input1.getName());
+                               }
                        }
                }
                //left indexing
@@ -178,12 +189,13 @@ public class MatrixIndexingSPInstruction extends 
IndexingSPInstruction {
                                sec.addLineageRDD(output.getName(), 
input2.getName());
                }
                else
-                       throw new DMLRuntimeException("Invalid opcode (" + 
opcode +") encountered in MatrixIndexingSPInstruction.");            
+                       throw new DMLRuntimeException("Invalid opcode (" + 
opcode +") encountered in MatrixIndexingSPInstruction.");
        }
 
 
        public static MatrixBlock 
inmemoryIndexing(JavaPairRDD<MatrixIndexes,MatrixBlock> in1,
-                                                  DataCharacteristics mcIn, 
DataCharacteristics mcOut, IndexRange ixrange) {
+               DataCharacteristics mcIn, DataCharacteristics mcOut, IndexRange 
ixrange)
+       {
                if( isSingleBlockLookup(mcIn, ixrange) ) {
                        return singleBlockIndexing(in1, mcIn, mcOut, ixrange);
                }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteLoopVectorization.java
 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteLoopVectorization.java
index d9358fef30..927b0fd666 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteLoopVectorization.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteLoopVectorization.java
@@ -22,6 +22,7 @@ package org.apache.sysds.test.functions.rewrite;
 import java.util.HashMap;
 
 import org.junit.Assert;
+import org.junit.Ignore;
 import org.junit.Test;
 import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
@@ -57,6 +58,7 @@ public class RewriteLoopVectorization extends 
AutomatedTestBase
        }
        
        @Test
+       @Ignore //FIXME: extend loop vectorization rewrite
        public void testLoopVectorizationSumRewrite() {
                testRewriteLoopVectorizationSum( TEST_NAME1, true );
        }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteScalarRightIndexingTest.java
 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteScalarRightIndexingTest.java
new file mode 100644
index 0000000000..9a3792d29d
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteScalarRightIndexingTest.java
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.rewrite;
+
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.utils.Statistics;
+
+public class RewriteScalarRightIndexingTest extends AutomatedTestBase
+{
+       private final static String TEST_DIR = "functions/rewrite/";
+       private final static String TEST_NAME = "RewriteScalarRightIndexing";
+       
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
RewriteScalarRightIndexingTest.class.getSimpleName() + "/";
+       
+       private final static int rows = 122;
+       
+       @Override
+       public void setUp() {
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"A"}));
+       }
+
+       @Test
+       public void testScalarRightIndexingCP() {
+               runScalarRightIndexing(true, ExecType.CP);
+       }
+       
+       @Test
+       public void testScalarRightIndexingNoRewriteCP() {
+               runScalarRightIndexing(false, ExecType.CP);
+       }
+       
+       @Test
+       public void testScalarRightIndexingSpark() {
+               runScalarRightIndexing(true, ExecType.SPARK);
+       }
+       
+       @Test
+       public void testScalarRightIndexingNoRewriteSpark() {
+               runScalarRightIndexing(false, ExecType.SPARK);
+       }
+       
+       private void runScalarRightIndexing(boolean rewrite, ExecType instType) 
{
+               ExecMode platformOld = setExecMode(instType);
+               boolean flagOld = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+               try {
+                       TestConfiguration config = 
getTestConfiguration(TEST_NAME);
+                       loadTestConfiguration(config);
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrite;
+                       
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       programArgs = new String[]{"-explain", "-stats", 
"-args",
+                               Long.toString(rows), output("A")};
+                       runTest(true, false, null, -1);
+                       
+                       Double ret = readDMLScalarFromOutputDir("A").get(new 
CellIndex(1,1));
+                       Assert.assertEquals(Double.valueOf(103.0383), ret, 
1e-4);
+                       if(rewrite) //w/o rewrite 122 casts
+                               
Assert.assertTrue(Statistics.getCPHeavyHitterCount("castdts")<=1);
+               }
+               finally {
+                       resetExecMode(platformOld);
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = flagOld;
+               }
+       }
+}
diff --git a/src/test/scripts/functions/rewrite/RewriteScalarRightIndexing.dml 
b/src/test/scripts/functions/rewrite/RewriteScalarRightIndexing.dml
new file mode 100644
index 0000000000..d0b76cd2d8
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteScalarRightIndexing.dml
@@ -0,0 +1,34 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+nrow = $1;
+X = seq(1, nrow);
+
+alpha = 0.05
+
+r = as.scalar(X[1, 1])
+for(i in 2:nrow(X)) {
+  r = alpha * as.scalar(X[i, 1]) + (1-alpha) * r
+}
+
+write(r, $2);
+

Reply via email to