This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 07c69e6  [SYSTEMDS-3069] Extended rewrites for splitting DAGs after 
compression
07c69e6 is described below

commit 07c69e62449a95ad889f9453bac8410d667fe689
Author: Matthias Boehm <[email protected]>
AuthorDate: Thu Jul 22 13:36:50 2021 +0200

    [SYSTEMDS-3069] Extended rewrites for splitting DAGs after compression
    
    This patch extends the existing 'split-DAG after data-dependent
    operators' rewrite and the IPA integration of workload-aware compression
    in order to allow recompilation according to compression results (e.g.,
    compile local instead of distributed operations for highly compressible
    data).
---
 .../sysds/hops/ipa/InterProceduralAnalysis.java    |  3 +-
 .../RewriteSplitDagDataDependentOperators.java     | 34 +++++++++++++---------
 .../spark/AggregateUnarySPInstruction.java         |  2 +-
 .../compress/workload/WorkloadAlgorithmTest.java   | 11 ++-----
 4 files changed, 27 insertions(+), 23 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java 
b/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
index 309d823..0b47a19 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
@@ -241,7 +241,8 @@ public class InterProceduralAnalysis
                FunctionCallGraph graph2 = new FunctionCallGraph(_prog);
                List<IPAPass> fpasses = Arrays.asList(
                        new IPAPassRemoveUnusedFunctions(),
-                       new IPAPassCompressionWorkloadAnalysis());
+                       new IPAPassCompressionWorkloadAnalysis(), // 
workload-aware compression
+                       new IPAPassApplyStaticAndDynamicHopRewrites());  
//split after compress
                for(IPAPass pass : fpasses)
                        if( pass.isApplicable(graph2) )
                                pass.rewriteProgram(_prog, graph2, null);
diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteSplitDagDataDependentOperators.java
 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteSplitDagDataDependentOperators.java
index fe00ae0..ecc3f39 100644
--- 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteSplitDagDataDependentOperators.java
+++ 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteSplitDagDataDependentOperators.java
@@ -34,6 +34,7 @@ import org.apache.sysds.common.Types.OpOpN;
 import org.apache.sysds.common.Types.ParamBuiltinOp;
 import org.apache.sysds.common.Types.ReOrgOp;
 import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.conf.DMLConfig;
 import org.apache.sysds.hops.AggBinaryOp;
 import org.apache.sysds.hops.DataOp;
 import org.apache.sysds.hops.Hop;
@@ -42,6 +43,7 @@ import org.apache.sysds.hops.LiteralOp;
 import org.apache.sysds.hops.ParameterizedBuiltinOp;
 import org.apache.sysds.hops.TernaryOp;
 import org.apache.sysds.hops.recompile.Recompiler;
+import org.apache.sysds.lops.Compression.CompressConfig;
 import org.apache.sysds.parser.DataIdentifier;
 import org.apache.sysds.parser.StatementBlock;
 import org.apache.sysds.parser.VariableSet;
@@ -75,7 +77,10 @@ public class RewriteSplitDagDataDependentOperators extends 
StatementBlockRewrite
        public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, 
ProgramRewriteStatus state)
        {
                //DAG splits not required for forced single node
-               if( DMLScript.getGlobalExecMode() == ExecMode.SINGLE_NODE
+               CompressConfig compress = 
CompressConfig.valueOf(ConfigurationManager
+                       
.getDMLConfig().getTextValue(DMLConfig.COMPRESSED_LINALG).toUpperCase());
+               if( (DMLScript.getGlobalExecMode() == ExecMode.SINGLE_NODE
+                       && !(compress != CompressConfig.FALSE) )
                        || !HopRewriteUtils.isLastLevelStatementBlock(sb) )
                        return Arrays.asList(sb);
                
@@ -225,7 +230,8 @@ public class RewriteSplitDagDataDependentOperators extends 
StatementBlockRewrite
                        return;
                
                //prevent unnecessary dag split (dims known or no consumer 
operations)
-               boolean noSplitRequired = ( hop.dimsKnown() || 
HopRewriteUtils.hasOnlyWriteParents(hop, true, true) );
+               boolean noSplitRequired = 
(HopRewriteUtils.hasOnlyWriteParents(hop, true, true)
+                       || hop.dimsKnown() || DMLScript.getGlobalExecMode() == 
ExecMode.SINGLE_NODE);
                boolean investigateChilds = true;
                
                //collect data dependent operations (to be extended as 
necessary)
@@ -294,14 +300,8 @@ public class RewriteSplitDagDataDependentOperators extends 
StatementBlockRewrite
                        }
                }
                
-               //#4 second-order eval function
-               if( HopRewriteUtils.isNary(hop, OpOpN.EVAL) && !noSplitRequired 
) {
-                       cand.add(hop);
-                       investigateChilds = false;
-               }
-               
-               //#5 sql
-               if( hop instanceof DataOp && ((DataOp) hop).getOp() == 
OpOpData.SQLREAD && !noSplitRequired) {
+               //#4 other data dependent operators (default handling)
+               if( isBasicDataDependentOperator(hop, noSplitRequired) ) {
                        cand.add(hop);
                        investigateChilds = false;
                }
@@ -314,6 +314,14 @@ public class RewriteSplitDagDataDependentOperators extends 
StatementBlockRewrite
                
                hop.setVisited();
        }
+       
+       private static boolean isBasicDataDependentOperator(Hop hop, boolean 
noSplitRequired) {
+               return (HopRewriteUtils.isNary(hop, OpOpN.EVAL) & 
!noSplitRequired)
+                       || (HopRewriteUtils.isData(hop, OpOpData.SQLREAD) & 
!noSplitRequired)
+                       || (hop.requiresCompression() & 
!HopRewriteUtils.hasOnlyWriteParents(hop, true, true));
+               //note: for compression we probe for write parents (part of 
noSplitRequired) directly
+               // because we want to split even if the dimensions are known 
+       }
 
        private static boolean hasTransientWriteParents( Hop hop ) {
                for( Hop p : hop.getParent() )
@@ -393,7 +401,7 @@ public class RewriteSplitDagDataDependentOperators extends 
StatementBlockRewrite
                        for( Hop c : hop.getInput() )
                                rAddHopsToProbeSet(c, probeSet);
        
-               hop.setVisited();       
+               hop.setVisited();
        }
        
        /**
@@ -417,11 +425,11 @@ public class RewriteSplitDagDataDependentOperators 
extends StatementBlockRewrite
                                        rProbeAndAddHopsToCandidateSet(c, 
probeSet, candSet);
                                else
                                {
-                                       candSet.add(new Pair<>(hop,c)); 
+                                       candSet.add(new Pair<>(hop,c));
                                }
                        }
                
-               hop.setVisited();       
+               hop.setVisited();
        }
        
        private void collectCandidateChildOperators( ArrayList<Hop> cand, 
HashSet<Hop> candChilds )
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
index cecbd3d..c135b01 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
@@ -212,7 +212,7 @@ public class AggregateUnarySPInstruction extends 
UnarySPInstruction {
                
                @Override
                public Tuple2<MatrixIndexes, MatrixBlock> call( 
Tuple2<MatrixIndexes, MatrixBlock> arg0 ) 
-                       throws Exception 
+                       throws Exception
                {
                        MatrixIndexes ixIn = arg0._1();
                        MatrixBlock blkIn = arg0._2();
diff --git 
a/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java
 
b/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java
index c257a57..eaa9a73 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java
@@ -33,8 +33,6 @@ import org.junit.Test;
 
 public class WorkloadAlgorithmTest extends AutomatedTestBase {
 
-       // private static final Log LOG = 
LogFactory.getLog(WorkloadAnalysisTest.class.getName());
-
        private final static String TEST_NAME1 = "WorkloadAnalysisMLogReg";
        private final static String TEST_NAME2 = "WorkloadAnalysisLm";
        private final static String TEST_NAME3 = "WorkloadAnalysisPCA";
@@ -55,7 +53,6 @@ public class WorkloadAlgorithmTest extends AutomatedTestBase {
                runWorkloadAnalysisTest(TEST_NAME1, ExecMode.HYBRID, 2);
        }
 
-
        @Test
        public void testLmSP() {
                runWorkloadAnalysisTest(TEST_NAME2, ExecMode.SPARK, 2);
@@ -80,12 +77,12 @@ public class WorkloadAlgorithmTest extends 
AutomatedTestBase {
                ExecMode oldPlatform = setExecMode(mode);
 
                try {
-
                        loadTestConfiguration(getTestConfiguration(testname));
 
                        String HOME = SCRIPT_DIR + TEST_DIR;
                        fullDMLScriptName = HOME + testname + ".dml";
-                       programArgs = new String[] {"-stats", "20", "-args", 
input("X"), input("y"), output("B")};
+                       programArgs = new String[] {"-explain","-stats",
+                               "20", "-args", input("X"), input("y"), 
output("B")};
 
                        double[][] X = TestUtils.round(getRandomMatrix(10000, 
20, 0, 10, 1.0, 7));
                        writeInputMatrixWithMTD("X", X, false);
@@ -95,9 +92,7 @@ public class WorkloadAlgorithmTest extends AutomatedTestBase {
                        }
                        writeInputMatrixWithMTD("y", y, false);
 
-                       String ret = runTest(null).toString();
-                       if(ret.contains("ERROR:"))
-                               fail(ret);
+                       runTest(null);
 
                        // check various additional expectations
                        long actualCompressionCount = mode == ExecMode.HYBRID ? 
Statistics

Reply via email to