This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new e705f893f7 [MINOR] Code cleanups in rewrites and tests
e705f893f7 is described below

commit e705f893f719632ef4afd990a908f2c51fbe0a3d
Author: Matthias Boehm <[email protected]>
AuthorDate: Fri Dec 13 08:50:54 2024 +0100

    [MINOR] Code cleanups in rewrites and tests
---
 .../RewriteAlgebraicSimplificationDynamic.java     | 68 +++++++++++-----------
 ...iteElementwiseMultChainOptimizationAllTest.java | 15 +----
 ...ewriteElementwiseMultChainOptimizationTest.java | 15 +----
 3 files changed, 38 insertions(+), 60 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index c9a9745091..15207e87b5 100644
--- 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -243,7 +243,7 @@ public class RewriteAlgebraicSimplificationDynamic extends 
HopRewriteRule
        {
                if( HopRewriteUtils.isUnnecessaryRightIndexing(hi) && 
!hi.isScalar() ) {
                        //remove unnecessary right indexing
-                       Hop input = hi.getInput().get(0);
+                       Hop input = hi.getInput(0);
                        HopRewriteUtils.replaceChildReference(parent, hi, 
input, pos);
                        HopRewriteUtils.cleanupUnreferenced(hi);
                        hi = input;
@@ -258,8 +258,8 @@ public class RewriteAlgebraicSimplificationDynamic extends 
HopRewriteRule
        {
                if( hi instanceof LeftIndexingOp && hi.getDataType() == 
DataType.MATRIX  ) //left indexing op
                {
-                       Hop input1 = hi.getInput().get(0); //lhs matrix
-                       Hop input2 = hi.getInput().get(1); //rhs matrix
+                       Hop input1 = hi.getInput(0); //lhs matrix
+                       Hop input2 = hi.getInput(1); //rhs matrix
                        
                        if(   input1.getNnz()==0 //nnz original known and empty
                           && input2.getNnz()==0  ) //nnz input known and empty
@@ -271,7 +271,7 @@ public class RewriteAlgebraicSimplificationDynamic extends 
HopRewriteRule
                                hi = hnew;
                                
                                LOG.debug("Applied removeEmptyLeftIndexing");
-                       }                       
+                       }
                }
                
                return hi;
@@ -281,19 +281,19 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
        {
                if( hi instanceof LeftIndexingOp  ) //left indexing op
                {
-                       Hop input = hi.getInput().get(1); //rhs matrix/frame
+                       Hop input = hi.getInput(1); //rhs matrix/frame
                        
                        if( HopRewriteUtils.isEqualSize(hi, input) ) //equal 
dims
                        {
                                //equal dims of left indexing input and output 
-> no need for indexing
                                
-                               //remove unnecessary right indexing             
                
+                               //remove unnecessary right indexing
                                HopRewriteUtils.replaceChildReference(parent, 
hi, input, pos);
                                HopRewriteUtils.cleanupUnreferenced(hi);
                                hi = input;
                                
                                LOG.debug("Applied 
removeUnnecessaryLeftIndexing");
-                       }                       
+                       }
                }
                
                return hi;
@@ -306,15 +306,15 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
                //pattern1: X[,1]=A; X[,2]=B -> X=cbind(A,B); matrix / frame
                if( hi instanceof LeftIndexingOp                      //first 
lix 
                        && 
HopRewriteUtils.isFullColumnIndexing((LeftIndexingOp)hi)
-                       && hi.getInput().get(0) instanceof LeftIndexingOp 
//second lix  
+                       && hi.getInput(0) instanceof LeftIndexingOp //second 
lix        
                        && 
HopRewriteUtils.isFullColumnIndexing((LeftIndexingOp)hi.getInput().get(0))
-                       && hi.getInput().get(0).getParent().size()==1     
//first lix is single consumer
-                       && hi.getInput().get(0).getInput().get(0).getDim2() == 
2 ) //two column matrix
+                       && hi.getInput(0).getParent().size()==1     //first lix 
is single consumer
+                       && hi.getInput(0).getInput(0).getDim2() == 2 ) //two 
column matrix
                {
-                       Hop input2 = hi.getInput().get(1); //rhs matrix
-                       Hop pred2 = hi.getInput().get(4); //cl=cu
-                       Hop input1 = hi.getInput().get(0).getInput().get(1); 
//lhs matrix
-                       Hop pred1 = hi.getInput().get(0).getInput().get(4); 
//cl=cu
+                       Hop input2 = hi.getInput(1); //rhs matrix
+                       Hop pred2 = hi.getInput(4); //cl=cu
+                       Hop input1 = hi.getInput(0).getInput(1); //lhs matrix
+                       Hop pred1 = hi.getInput(0).getInput(4); //cl=cu
                        
                        if( pred1 instanceof LiteralOp && 
HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred1)==1
                                && pred2 instanceof LiteralOp && 
HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred2)==2
@@ -332,15 +332,15 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
                //pattern1: X[1,]=A; X[2,]=B -> X=rbind(A,B)
                if( !applied && hi instanceof LeftIndexingOp          //first 
lix 
                        && HopRewriteUtils.isFullRowIndexing((LeftIndexingOp)hi)
-                       && hi.getInput().get(0) instanceof LeftIndexingOp 
//second lix  
+                       && hi.getInput(0) instanceof LeftIndexingOp //second 
lix        
                        && 
HopRewriteUtils.isFullRowIndexing((LeftIndexingOp)hi.getInput().get(0))
-                       && hi.getInput().get(0).getParent().size()==1     
//first lix is single consumer
-                       && hi.getInput().get(0).getInput().get(0).getDim1() == 
2 ) //two column matrix
+                       && hi.getInput(0).getParent().size()==1     //first lix 
is single consumer
+                       && hi.getInput(0).getInput(0).getDim1() == 2 ) //two 
column matrix
                {
-                       Hop input2 = hi.getInput().get(1); //rhs matrix
-                       Hop pred2 = hi.getInput().get(2); //rl=ru
-                       Hop input1 = hi.getInput().get(0).getInput().get(1); 
//lhs matrix
-                       Hop pred1 = hi.getInput().get(0).getInput().get(2); 
//rl=ru
+                       Hop input2 = hi.getInput(1); //rhs matrix
+                       Hop pred2 = hi.getInput(2); //rl=ru
+                       Hop input1 = hi.getInput(0).getInput(1); //lhs matrix
+                       Hop pred1 = hi.getInput(0).getInput(2); //rl=ru
                        
                        if( pred1 instanceof LiteralOp && 
HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred1)==1
                                && pred2 instanceof LiteralOp && 
HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred2)==2
@@ -364,19 +364,19 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
        {
                if( hi instanceof UnaryOp && 
((UnaryOp)hi).isCumulativeUnaryOperation()  )
                {
-                       Hop input = hi.getInput().get(0); //input matrix
+                       Hop input = hi.getInput(0); //input matrix
                        
                        if(   HopRewriteUtils.isDimsKnown(input)  //dims input 
known
                       && input.getDim1()==1 ) //1 row
                        {
                                OpOp1 op = ((UnaryOp)hi).getOp();
                                
-                               //remove unnecessary unary cumsum operator      
                        
+                               //remove unnecessary unary cumsum operator
                                HopRewriteUtils.replaceChildReference(parent, 
hi, input, pos);
                                hi = input;
                                
                                LOG.debug("Applied 
removeUnnecessaryCumulativeOp: "+op);
-                       }                       
+                       }
                }
                
                return hi;
@@ -413,27 +413,27 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
                if( hi instanceof BinaryOp  ) //binary cell operation 
                {
                        OpOp2 bop = ((BinaryOp)hi).getOp();
-                       Hop left = hi.getInput().get(0);
-                       Hop right = hi.getInput().get(1);
+                       Hop left = hi.getInput(0);
+                       Hop right = hi.getInput(1);
                        
                        //check for matrix-vector column replication: (A + b 
%*% ones) -> (A + b)
                        if(    HopRewriteUtils.isMatrixMultiply(right) //matrix 
mult with datagen
                                && 
HopRewriteUtils.isDataGenOpWithConstantValue(right.getInput().get(1), 1)
-                               && right.getInput().get(0).getDim2() == 1 ) 
//column vector for mv binary
+                               && right.getInput(0).getDim2() == 1 ) //column 
vector for mv binary
                        {
                                //remove unnecessary outer product
-                               HopRewriteUtils.replaceChildReference(hi, 
right, right.getInput().get(0), 1 );
+                               HopRewriteUtils.replaceChildReference(hi, 
right, right.getInput(0), 1 );
                                HopRewriteUtils.cleanupUnreferenced(right);
                                
                                LOG.debug("Applied 
removeUnnecessaryOuterProduct1 (line "+right.getBeginLine()+")");
                        }
                        //check for matrix-vector row replication: (A + ones 
%*% b) -> (A + b)
                        else if( HopRewriteUtils.isMatrixMultiply(right) 
//matrix mult with datagen
-                               && 
HopRewriteUtils.isDataGenOpWithConstantValue(right.getInput().get(0), 1)
-                               && right.getInput().get(1).getDim1() == 1 ) 
//row vector for mv binary
+                               && 
HopRewriteUtils.isDataGenOpWithConstantValue(right.getInput(0), 1)
+                               && right.getInput(1).getDim1() == 1 ) //row 
vector for mv binary
                        {
                                //remove unnecessary outer product
-                               HopRewriteUtils.replaceChildReference(hi, 
right, right.getInput().get(1), 1 );
+                               HopRewriteUtils.replaceChildReference(hi, 
right, right.getInput(1), 1 );
                                HopRewriteUtils.cleanupUnreferenced(right);
                                
                                LOG.debug("Applied 
removeUnnecessaryOuterProduct2 (line "+right.getBeginLine()+")");
@@ -442,11 +442,11 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
                        else if(HopRewriteUtils.isValidOuterBinaryOp(bop) 
                                && HopRewriteUtils.isMatrixMultiply(left)
                                && 
HopRewriteUtils.isDataGenOpWithConstantValue(left.getInput().get(1), 1)
-                               && (left.getInput().get(0).getDim2() == 1 
//outer product
-                                       || left.getInput().get(1).getDim1() == 
1)
+                               && (left.getInput(0).getDim2() == 1 //outer 
product
+                                       || left.getInput(1).getDim1() == 1)
                                && left.getDim1() != 1 && right.getDim1() == 1 
) //outer vector binary 
                        {
-                               Hop hnew = 
HopRewriteUtils.createBinary(left.getInput().get(0), right, bop, true);
+                               Hop hnew = 
HopRewriteUtils.createBinary(left.getInput(0), right, bop, true);
                                HopRewriteUtils.replaceChildReference(parent, 
hi, hnew, pos);
                                HopRewriteUtils.cleanupUnreferenced(hi);
                                
diff --git 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteElementwiseMultChainOptimizationAllTest.java
 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteElementwiseMultChainOptimizationAllTest.java
index 78728d9a71..15b24534c1 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteElementwiseMultChainOptimizationAllTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteElementwiseMultChainOptimizationAllTest.java
@@ -23,7 +23,6 @@ import java.util.HashMap;
 
 import org.junit.Assert;
 import org.junit.Test;
-import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.common.Types.ExecType;
@@ -74,16 +73,7 @@ public class RewriteElementwiseMultChainOptimizationAllTest 
extends AutomatedTes
 
        private void testRewriteMatrixMultChainOp(String testname, boolean 
rewrites, ExecType et)
        {       
-               ExecMode platformOld = rtplatform;
-               switch( et ){
-                       case SPARK: rtplatform = ExecMode.SPARK; break;
-                       default: rtplatform = ExecMode.HYBRID; break;
-               }
-               
-               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
-               if( rtplatform == ExecMode.SPARK || rtplatform == 
ExecMode.HYBRID )
-                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
-               
+               ExecMode platformOld = setExecMode(et);
                boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
                OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
                
@@ -126,8 +116,7 @@ public class RewriteElementwiseMultChainOptimizationAllTest 
extends AutomatedTes
                }
                finally {
                        OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld;
-                       rtplatform = platformOld;
-                       DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+                       resetExecMode(platformOld);
                }
        }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteElementwiseMultChainOptimizationTest.java
 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteElementwiseMultChainOptimizationTest.java
index d60df3f665..6c6ede61d7 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteElementwiseMultChainOptimizationTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteElementwiseMultChainOptimizationTest.java
@@ -23,7 +23,6 @@ import java.util.HashMap;
 
 import org.junit.Assert;
 import org.junit.Test;
-import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.common.Types.ExecType;
@@ -73,16 +72,7 @@ public class RewriteElementwiseMultChainOptimizationTest 
extends AutomatedTestBa
 
        private void testRewriteMatrixMultChainOp(String testname, boolean 
rewrites, ExecType et)
        {       
-               ExecMode platformOld = rtplatform;
-               switch( et ){
-                       case SPARK: rtplatform = ExecMode.SPARK; break;
-                       default: rtplatform = ExecMode.HYBRID; break;
-               }
-               
-               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
-               if( rtplatform == ExecMode.SPARK || rtplatform == 
ExecMode.HYBRID )
-                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
-               
+               ExecMode platformOld = setExecMode(et);
                boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
                OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
                
@@ -119,8 +109,7 @@ public class RewriteElementwiseMultChainOptimizationTest 
extends AutomatedTestBa
                }
                finally {
                        OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld;
-                       rtplatform = platformOld;
-                       DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+                       resetExecMode(platformOld);
                }
        }
 }

Reply via email to