This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 4cb9321  [SYSTEMDS-2991,2994] Initial workload analysis for 
compression planning
4cb9321 is described below

commit 4cb932168d7d6b334922b79c05e959817b6400b1
Author: Matthias Boehm <mboe...@gmail.com>
AuthorDate: Sun May 30 23:19:37 2021 +0200

    [SYSTEMDS-2991,2994] Initial workload analysis for compression planning
    
    This patch introduces an initial workload representation (the workload
    tree, WTree for short) as well as workload analyzer for compression
    planning that returns a WTree for every compression candidate. In
    detail, we make a pass to determine candidates (currently only preads of
    certain sizes), and then for every candidate an additional pass to
    construct the WTree across statements blocks, HOP DAGs, and function
    calls. All of that is integrated as a new IPA cleanup pass (executed
    only once with access to the entire program), and except for specific
    tests still disabled by default.
---
 .../ipa/IPAPassCompressionWorkloadAnalysis.java    |  57 +++++
 .../sysds/hops/ipa/InterProceduralAnalysis.java    |  14 +-
 .../apache/sysds/hops/rewrite/HopRewriteUtils.java |   7 +
 .../hops/rewrite/RewriteCompressedReblock.java     |   6 +-
 .../sysds/runtime/compress/workload/WTreeNode.java | 124 +++++++++
 .../compress/workload/WorkloadAnalyzer.java        | 279 +++++++++++++++++++++
 .../functions/compress/WorkloadAnalysisTest.java   |  88 +++++++
 .../functions/compress/WorkloadAnalysisLm.dml      |  32 +++
 .../functions/compress/WorkloadAnalysisMlogreg.dml |  32 +++
 9 files changed, 632 insertions(+), 7 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/hops/ipa/IPAPassCompressionWorkloadAnalysis.java
 
b/src/main/java/org/apache/sysds/hops/ipa/IPAPassCompressionWorkloadAnalysis.java
new file mode 100644
index 0000000..0c4bd7c
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/hops/ipa/IPAPassCompressionWorkloadAnalysis.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.ipa;
+
+import java.util.Map;
+import java.util.Map.Entry;
+
+import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.conf.DMLConfig;
+import org.apache.sysds.lops.Compression.CompressConfig;
+import org.apache.sysds.parser.DMLProgram;
+import org.apache.sysds.runtime.compress.workload.WTreeNode;
+import org.apache.sysds.runtime.compress.workload.WorkloadAnalyzer;
+
+/**
+ * This rewrite obtains workload summaries for all hops candidates amenable
+ * for compression as a basis for workload-aware compression planning.
+ * 
+ */
+public class IPAPassCompressionWorkloadAnalysis extends IPAPass
+{
+       @Override
+       public boolean isApplicable(FunctionCallGraph fgraph) {
+               return InterProceduralAnalysis.CLA_WORKLOAD_ANALYSIS
+                       && 
CompressConfig.valueOf(ConfigurationManager.getDMLConfig()
+                               
.getTextValue(DMLConfig.COMPRESSED_LINALG).toUpperCase()).isEnabled();
+       }
+       
+       @Override
+       public boolean rewriteProgram(DMLProgram prog, FunctionCallGraph 
fgraph, FunctionCallSizeInfo fcallSizes) {
+               //obtain CLA workload analysis for all applicable operators
+               Map<Long, WTreeNode> map = 
WorkloadAnalyzer.getAllCandidateWorkloads(prog);
+               
+               //TODO influence compression planning, for now just printing
+               for( Entry<Long, WTreeNode> e : map.entrySet() )
+                       System.out.println(e.getValue());
+               
+               return map != null;
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java 
b/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
index 35e04f4..9dcfd53 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
@@ -57,7 +57,9 @@ import org.apache.sysds.runtime.meta.MetaDataFormat;
 import org.apache.sysds.utils.Explain;
 
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.HashSet;
+import java.util.List;
 import java.util.Set;
 
 /**
@@ -96,6 +98,7 @@ public class InterProceduralAnalysis
        protected static final boolean ELIMINATE_DEAD_CODE            = true; 
//remove dead code (e.g., assigments) not used later on
        protected static final boolean FORWARD_SIMPLE_FUN_CALLS       = true; 
//replace a call to a simple forwarding function with the function itself
        protected static final boolean FLAG_NONDETERMINISM            = true; 
//flag functions which directly or transitively contain non-deterministic calls
+       public    static       boolean CLA_WORKLOAD_ANALYSIS          = false; 
//obtain workload for workload-aware compression
        
        private final DMLProgram _prog;
        private final StatementBlock _sb;
@@ -235,11 +238,14 @@ public class InterProceduralAnalysis
                                _fgraph = new FunctionCallGraph(_prog);
                }
                
-               //cleanup pass: remove unused functions
+               //cleanup passes: remove unused functions, CLA workload 
extraction
                FunctionCallGraph graph2 = new FunctionCallGraph(_prog);
-               IPAPass rmFuns = new IPAPassRemoveUnusedFunctions();
-               if( rmFuns.isApplicable(graph2) )
-                       rmFuns.rewriteProgram(_prog, graph2, null);
+               List<IPAPass> fpasses = Arrays.asList(
+                       new IPAPassRemoveUnusedFunctions(),
+                       new IPAPassCompressionWorkloadAnalysis());
+               for(IPAPass pass : fpasses)
+                       if( pass.isApplicable(graph2) )
+                               pass.rewriteProgram(_prog, graph2, null);
        }
        
        public Set<String> analyzeSubProgram() {
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java 
b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
index 2c2b018..e23af65 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
@@ -1120,6 +1120,13 @@ public class HopRewriteUtils
                return ret;
        }
        
+       public static boolean isData(Hop hop, OpOpData... types) {
+               boolean ret = false;
+               for( OpOpData type : types )
+                       ret |= isData(hop, type);
+               return ret;
+       }
+       
        public static boolean isData(Hop hop, OpOpData type) {
                return hop instanceof DataOp && ((DataOp)hop).getOp()==type;
        }
diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteCompressedReblock.java 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteCompressedReblock.java
index 0f26fa1..cc57961 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteCompressedReblock.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteCompressedReblock.java
@@ -77,7 +77,7 @@ public class RewriteCompressedReblock extends 
StatementBlockRewriteRule {
                // parse compression config
                DMLConfig conf = ConfigurationManager.getDMLConfig();
                CompressConfig compress = 
CompressConfig.valueOf(conf.getTextValue(DMLConfig.COMPRESSED_LINALG).toUpperCase());
-
+               
                // perform compressed reblock rewrite
                if(compress.isEnabled()) {
                        Hop.resetVisitStatus(sb.getHops());
@@ -126,14 +126,14 @@ public class RewriteCompressedReblock extends 
StatementBlockRewriteRule {
                hop.setVisited();
        }
 
-       private static boolean satisfiesSizeConstraintsForCompression(Hop hop) {
+       public static boolean satisfiesSizeConstraintsForCompression(Hop hop) {
                if(hop.getDim2() >= 1) {
                        return (hop.getDim1() >= 1000 && hop.getDim2() < 100) 
|| hop.getDim1() / hop.getDim2() >= 75;
                }
                return false;
        }
 
-       private static boolean satisfiesCompressionCondition(Hop hop) {
+       public static boolean satisfiesCompressionCondition(Hop hop) {
                boolean satisfies = false;
                if(satisfiesSizeConstraintsForCompression(hop))
                        satisfies |= HopRewriteUtils.isData(hop, 
OpOpData.PERSISTENTREAD);
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/workload/WTreeNode.java 
b/src/main/java/org/apache/sysds/runtime/compress/workload/WTreeNode.java
new file mode 100644
index 0000000..aa64ad1
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/compress/workload/WTreeNode.java
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.compress.workload;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.sysds.hops.Hop;
+
+/**
+ * A workload tree is a compact representation of the operations
+ * on a compressed matrix and derived intermediates, including
+ * the basic control structure and inlined functions as well
+ * as links to categories 
+ * 
+ * TODO separate classes for inner and leaf nodes?
+ */
+public class WTreeNode
+{
+       public enum WTNodeType{
+               MAIN,
+               FCALL,
+               IF,
+               WHILE,
+               FOR,
+               PARFOR,
+               BASIC_BLOCK;
+               public boolean isLoop() {
+                       return this == WHILE ||
+                               this == FOR || this == PARFOR;
+               }
+       }
+       
+       private final WTNodeType _type;
+       private final List<WTreeNode> _childs = new ArrayList<>();
+       private final List<Hop> _cops = new ArrayList<>();
+       private int _beginLine = -1;
+       private int _endLine = -1;
+       
+       public WTreeNode(WTNodeType type) {
+               _type = type;
+       }
+       
+       public WTNodeType getType() {
+               return _type;
+       }
+       
+       public List<WTreeNode> getChildNodes() {
+               return _childs;
+       }
+       
+       public void addChild(WTreeNode node) {
+               _childs.add(node);
+       }
+       
+       public List<Hop> getCompressedOps() {
+               return _cops;
+       }
+       
+       public void addCompressedOp(Hop hop) {
+               _cops.add(hop);
+       }
+       
+       public void setLineNumbers(int begin, int end) {
+               _beginLine = begin;
+               _endLine = end;
+       }
+       
+       public String explain(int level) {
+               StringBuilder sb = new StringBuilder();
+               //append indentation
+               for( int i=0; i<level; i++ )
+                       sb.append("--");
+               //append node summary
+               sb.append(_type.name());
+               if( _beginLine>=0 && _endLine>=0 ) {
+                       sb.append(" (lines ");
+                       sb.append(_beginLine);
+                       sb.append("-");
+                       sb.append(_endLine);
+                       sb.append(")");
+               }
+               sb.append("\n");
+               //append child nodes
+               if( !_childs.isEmpty() )
+                       for( WTreeNode n : _childs )
+                               sb.append(n.explain(level+1));
+               else if( !_cops.isEmpty() ) {
+                       for( Hop hop : _cops ) {
+                               for( int i=0; i<level+1; i++ )
+                                       sb.append("--");
+                               sb.append(hop.toString());
+                               sb.append("\n");
+                       }
+               }
+               return sb.toString();
+       }
+       
+       @Override
+       public String toString() {
+               StringBuilder sb = new StringBuilder("Workload Tree:\n");
+               
sb.append("--------------------------------------------------------------------------------\n");
+               sb.append(this.explain(1));
+               
sb.append("--------------------------------------------------------------------------------\n");
+               return sb.toString();
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java
 
b/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java
new file mode 100644
index 0000000..f9aad7c
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java
@@ -0,0 +1,279 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.compress.workload;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import org.apache.sysds.common.Types.OpOpData;
+import org.apache.sysds.hops.FunctionOp;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.rewrite.HopRewriteUtils;
+import org.apache.sysds.hops.rewrite.RewriteCompressedReblock;
+import org.apache.sysds.parser.DMLProgram;
+import org.apache.sysds.parser.DataIdentifier;
+import org.apache.sysds.parser.ForStatement;
+import org.apache.sysds.parser.ForStatementBlock;
+import org.apache.sysds.parser.FunctionStatement;
+import org.apache.sysds.parser.FunctionStatementBlock;
+import org.apache.sysds.parser.IfStatement;
+import org.apache.sysds.parser.IfStatementBlock;
+import org.apache.sysds.parser.ParForStatementBlock;
+import org.apache.sysds.parser.StatementBlock;
+import org.apache.sysds.parser.WhileStatement;
+import org.apache.sysds.parser.WhileStatementBlock;
+import org.apache.sysds.runtime.compress.workload.WTreeNode.WTNodeType;
+
+public class WorkloadAnalyzer {
+       
+       public static Map<Long, WTreeNode> getAllCandidateWorkloads(DMLProgram 
prog) {
+               // extract all compression candidates from program
+               List<Hop> candidates = getCandidates(prog);
+               
+               // for each candidate, create pruned workload tree
+               // TODO memoization of processed subtree if overlap
+               Map<Long, WTreeNode> map = new HashMap<>();
+               for( Hop cand : candidates ) {
+                       WTreeNode tree = createWorkloadTree(prog, cand);
+                       pruneWorkloadTree(tree);
+                       map.put(cand.getHopID(), tree);
+               }
+               
+               return map;
+       }
+       
+       public static List<Hop> getCandidates(DMLProgram prog) {
+               List<Hop> candidates = new ArrayList<>();
+               for( StatementBlock sb : prog.getStatementBlocks() )
+                       getCandidates(sb, prog, candidates, new HashSet<>());
+               return candidates;
+       }
+       
+       public static WTreeNode createWorkloadTree(DMLProgram prog, Hop 
candidate) {
+               WTreeNode main = new WTreeNode(WTNodeType.MAIN);
+               //TODO generalize, below line assumes only pread candidates (at 
bottom on DAGs)
+               Set<String> compressed = new HashSet<>();
+               compressed.add(candidate.getName());
+               for( StatementBlock sb : prog.getStatementBlocks() )
+                       main.addChild(createWorkloadTree(sb, prog, compressed, 
new HashSet<>()));
+               return main;
+       }
+       
+       public static boolean pruneWorkloadTree(WTreeNode node) {
+               //recursively process sub trees
+               Iterator<WTreeNode> iter = node.getChildNodes().iterator();
+               while( iter.hasNext() ) {
+                       if( pruneWorkloadTree(iter.next()) )
+                               iter.remove();
+               }
+               
+               //indicate that node can be removed
+               return node.getChildNodes().isEmpty()
+                       && node.getCompressedOps().isEmpty();
+       }
+       
+       private static void getCandidates(StatementBlock sb, DMLProgram prog, 
List<Hop> cands, Set<String> fStack) {
+               if(sb instanceof FunctionStatementBlock) {
+                       FunctionStatementBlock fsb = (FunctionStatementBlock)sb;
+                       FunctionStatement fstmt = 
(FunctionStatement)fsb.getStatement(0);
+                       for(StatementBlock csb : fstmt.getBody())
+                               getCandidates(csb, prog, cands, fStack);
+               }
+               else if(sb instanceof WhileStatementBlock) {
+                       WhileStatementBlock wsb = (WhileStatementBlock) sb;
+                       WhileStatement wstmt = 
(WhileStatement)wsb.getStatement(0);
+                       for(StatementBlock csb : wstmt.getBody())
+                               getCandidates(csb, prog, cands, fStack);
+               }
+               else if(sb instanceof IfStatementBlock) {
+                       IfStatementBlock isb = (IfStatementBlock) sb;
+                       IfStatement istmt = (IfStatement)isb.getStatement(0);
+                       for(StatementBlock csb : istmt.getIfBody())
+                               getCandidates(csb, prog, cands, fStack);
+                       for(StatementBlock csb : istmt.getElseBody())
+                               getCandidates(csb, prog, cands, fStack);
+               }
+               else if(sb instanceof ForStatementBlock) { //incl parfor
+                       ForStatementBlock fsb = (ForStatementBlock) sb;
+                       ForStatement fstmt = (ForStatement)fsb.getStatement(0);
+                       for(StatementBlock csb : fstmt.getBody())
+                               getCandidates(csb, prog, cands, fStack);
+               }
+               else { //generic (last-level)
+                       if( sb.getHops() == null )
+                               return;
+                       Hop.resetVisitStatus(sb.getHops());
+                       for(Hop hop : sb.getHops())
+                               getCandidates(hop, prog, cands, fStack);
+                       Hop.resetVisitStatus(sb.getHops());
+               }
+       }
+       
+       private static void getCandidates(Hop hop, DMLProgram prog, List<Hop> 
cands, Set<String> fStack) {
+               if( hop.isVisited() )
+                       return;
+               
+               //evaluate and add candidates (type and size)
+               if( RewriteCompressedReblock.satisfiesCompressionCondition(hop) 
)
+                       cands.add(hop);
+               
+               //recursively process children (inputs)
+               for( Hop c : hop.getInput() )
+                       getCandidates(c, prog, cands, fStack);
+               
+               //process function calls with awareness of the current
+               //call stack to avoid endless loops in recursive functions
+               if( hop instanceof FunctionOp ) {
+                       FunctionOp fop = (FunctionOp) hop;
+                       if( !fStack.contains(fop.getFunctionKey()) ) {
+                               fStack.add(fop.getFunctionKey());
+                               
getCandidates(prog.getFunctionStatementBlock(fop.getFunctionKey()), prog, 
cands, fStack);
+                               fStack.remove(fop.getFunctionKey());
+                       }
+               }
+               
+               hop.setVisited();
+       }
+       
+       private static WTreeNode createWorkloadTree(StatementBlock sb, 
DMLProgram prog, Set<String> compressed, Set<String> fStack) {
+               WTreeNode node = null;
+               if(sb instanceof FunctionStatementBlock) {
+                       FunctionStatementBlock fsb = (FunctionStatementBlock)sb;
+                       FunctionStatement fstmt = 
(FunctionStatement)fsb.getStatement(0);
+                       node = new WTreeNode(WTNodeType.FCALL);
+                       for(StatementBlock csb : fstmt.getBody())
+                               node.addChild(createWorkloadTree(csb, prog, 
compressed, fStack));
+               }
+               else if(sb instanceof WhileStatementBlock) {
+                       WhileStatementBlock wsb = (WhileStatementBlock) sb;
+                       WhileStatement wstmt = 
(WhileStatement)wsb.getStatement(0);
+                       node = new WTreeNode(WTNodeType.WHILE);
+                       createWorkloadTree(wsb.getPredicateHops(), prog, node, 
compressed, fStack);
+                       for(StatementBlock csb : wstmt.getBody())
+                               node.addChild(createWorkloadTree(csb, prog, 
compressed, fStack));
+               }
+               else if(sb instanceof IfStatementBlock) {
+                       IfStatementBlock isb = (IfStatementBlock) sb;
+                       IfStatement istmt = (IfStatement)isb.getStatement(0);
+                       node = new WTreeNode(WTNodeType.IF);
+                       createWorkloadTree(isb.getPredicateHops(), prog, node, 
compressed, fStack);
+                       for(StatementBlock csb : istmt.getIfBody())
+                               node.addChild(createWorkloadTree(csb, prog, 
compressed, fStack));
+                       for(StatementBlock csb : istmt.getElseBody())
+                               node.addChild(createWorkloadTree(csb, prog, 
compressed, fStack));
+               }
+               else if(sb instanceof ForStatementBlock) { //incl parfor
+                       ForStatementBlock fsb = (ForStatementBlock) sb;
+                       ForStatement fstmt = (ForStatement)fsb.getStatement(0);
+                       node = new WTreeNode(sb instanceof ParForStatementBlock 
? WTNodeType.PARFOR:WTNodeType.FOR);
+                       createWorkloadTree(fsb.getFromHops(), prog, node, 
compressed, fStack);
+                       createWorkloadTree(fsb.getToHops(), prog, node, 
compressed, fStack);
+                       createWorkloadTree(fsb.getIncrementHops(), prog, node, 
compressed, fStack);
+                       for(StatementBlock csb : fstmt.getBody())
+                               node.addChild(createWorkloadTree(csb, prog, 
compressed, fStack));
+               }
+               else { //generic (last-level)
+                       node = new WTreeNode(WTNodeType.BASIC_BLOCK);
+                       if( sb.getHops() != null ) {
+                               Hop.resetVisitStatus(sb.getHops());
+                               //process hop DAG to collect operations
+                               Set<Long> compressed2 = new HashSet<>();
+                               for(Hop hop : sb.getHops()) 
+                                       createWorkloadTree(hop, prog, node, 
compressed, compressed2, fStack);
+                               //maintain hop DAG outputs (compressed or not 
compressed)
+                               for(Hop hop : sb.getHops()) {
+                                       if( hop instanceof FunctionOp ) {
+                                               FunctionOp fop = (FunctionOp) 
hop;
+                                               if( 
!fStack.contains(fop.getFunctionKey()) ) {
+                                                       
fStack.add(fop.getFunctionKey());
+                                                       FunctionStatementBlock 
fsb = prog.getFunctionStatementBlock(fop.getFunctionKey());
+                                                       FunctionStatement fstmt 
= (FunctionStatement) fsb.getStatement(0);
+                                                       Set<String> fCompressed 
= new HashSet<>();
+                                                       //handle propagation of 
compressed intermediates into functions
+                                                       List<DataIdentifier> 
fArgs = fstmt.getInputParams();
+                                                       for( int i=0; 
i<fArgs.size(); i++ )
+                                                               if( 
compressed2.contains(fop.getInput(i).getHopID()) )
+                                                                       
fCompressed.add(fArgs.get(i).getName());
+                                                       
node.addChild(createWorkloadTree(fsb, prog, fCompressed, fStack));
+                                                       
fStack.remove(fop.getFunctionKey());
+                                               }
+                                       }
+                                       else if( HopRewriteUtils.isData(hop, 
OpOpData.TRANSIENTWRITE) ) {
+                                               //handle propagation of 
compressed intermediates across blocks
+                                               if( 
compressed.contains(hop.getName()) && !compressed2.contains(hop.getHopID()) )
+                                                       
compressed.remove(hop.getName());
+                                               if( 
!compressed.contains(hop.getName()) && compressed2.contains(hop.getHopID()) )
+                                                       
compressed.add(hop.getName());
+                                       }
+                               }
+                               Hop.resetVisitStatus(sb.getHops());
+                       }
+               }
+               node.setLineNumbers(sb.getBeginLine(), sb.getEndLine());
+               return node;
+       }
+       
+       private static void createWorkloadTree(Hop hop, DMLProgram prog, 
WTreeNode parent, Set<String> compressed, Set<String> fStack) {
+               if( hop == null )
+                       return;
+               hop.resetVisitStatus();
+               createWorkloadTree(hop, prog, parent, compressed, new 
HashSet<>(), fStack); //see below
+               hop.resetVisitStatus();
+       }
+       
+       private static void createWorkloadTree(Hop hop, DMLProgram prog, 
WTreeNode parent, Set<String> compressed, Set<Long> compressed2, Set<String> 
fStack) {
+               if( hop == null || hop.isVisited() )
+                       return;
+               
+               //recursively process children (inputs first for propagation of 
compression status)
+               for( Hop c : hop.getInput() )
+                       createWorkloadTree(c, prog, parent, compressed, 
compressed2, fStack);
+               
+               //map statement block propagation to hop propagation
+               if( HopRewriteUtils.isData(hop, OpOpData.PERSISTENTREAD, 
OpOpData.TRANSIENTREAD)
+                       && compressed.contains(hop.getName()) ) {
+                       compressed2.add(hop.getHopID());
+               }
+               
+               //collect operations on compressed intermediates or inputs
+               //if any input is compressed we collect this hop as a 
compressed operation
+               if( hop.getInput().stream().anyMatch(h -> 
compressed2.contains(h.getHopID())) ) {
+                       if(!HopRewriteUtils.isData(hop, 
OpOpData.PERSISTENTREAD, //all, but data ops
+                               OpOpData.TRANSIENTREAD, 
OpOpData.TRANSIENTWRITE) )
+                       {
+                               parent.addCompressedOp(hop);
+                       }
+                       
+                       //if the output size also qualifies for compression, we 
propagate this status
+                       if( 
RewriteCompressedReblock.satisfiesSizeConstraintsForCompression(hop)
+                               && hop.getDataType().isMatrix() )
+                       {
+                               compressed2.add(hop.getHopID());
+                       }
+               }
+               
+               hop.setVisited();
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/compress/WorkloadAnalysisTest.java
 
b/src/test/java/org/apache/sysds/test/functions/compress/WorkloadAnalysisTest.java
new file mode 100644
index 0000000..5585617
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/compress/WorkloadAnalysisTest.java
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.compress;
+
+import java.io.File;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.hops.ipa.InterProceduralAnalysis;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+public class WorkloadAnalysisTest extends AutomatedTestBase
+{
+       private final static String TEST_NAME1 = "WorkloadAnalysisMlogreg";
+       private final static String TEST_NAME2 = "WorkloadAnalysisLm";
+       private final static String TEST_DIR = "functions/compress/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
WorkloadAnalysisTest.class.getSimpleName() + "/";
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{"B"}));
+               addTestConfiguration(TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[]{"B"}));
+       }
+
+       @Test
+       public void testMlogregCP() {
+               runWorkloadAnalysisTest(TEST_NAME1, ExecMode.HYBRID);
+       }
+       
+       @Test
+       public void testLmCP() {
+               runWorkloadAnalysisTest(TEST_NAME2, ExecMode.HYBRID);
+       }
+
+       private void runWorkloadAnalysisTest(String testname, ExecMode mode)
+       {
+               ExecMode oldPlatform = setExecMode(mode);
+               boolean oldFlag = InterProceduralAnalysis.CLA_WORKLOAD_ANALYSIS;
+               
+               try
+               {
+                       loadTestConfiguration(getTestConfiguration(testname));
+                       
+                       InterProceduralAnalysis.CLA_WORKLOAD_ANALYSIS = true;
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + testname + ".dml";
+                       programArgs = new String[]{"-args", input("X"), 
input("y"), output("B") };
+
+                       double[][] X = getRandomMatrix(10000, 20, 0, 1, 1.0, 7);
+                       writeInputMatrixWithMTD("X", X, false);
+                       double[][] y = TestUtils.round(getRandomMatrix(10000, 
1, 1, 2, 1.0, 3));
+                       writeInputMatrixWithMTD("y", y, false);
+
+                       runTest(true, false, null, -1);
+                       //TODO check for compressed operations 
+                       //(right now test only checks that the workload 
analysis does not crash)
+               }
+               finally {
+                       resetExecMode(oldPlatform);
+                       InterProceduralAnalysis.CLA_WORKLOAD_ANALYSIS = oldFlag;
+               }
+       }
+       
+       @Override
+       protected File getConfigTemplateFile() {
+               return new File(SCRIPT_DIR + TEST_DIR + "force", 
"SystemDS-config-compress.xml");
+       }
+}
diff --git a/src/test/scripts/functions/compress/WorkloadAnalysisLm.dml 
b/src/test/scripts/functions/compress/WorkloadAnalysisLm.dml
new file mode 100644
index 0000000..ccaec92
--- /dev/null
+++ b/src/test/scripts/functions/compress/WorkloadAnalysisLm.dml
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = read($1);
+y = read($2);
+
+#X = scale(X=X, scale=TRUE, center=TRUE);
+X = X - colMeans(X);
+X = X / sqrt(colVars(X));
+
+while(FALSE){}
+B = lm(X=X, y=y, verbose=TRUE);
+
+write(B, $3)
diff --git a/src/test/scripts/functions/compress/WorkloadAnalysisMlogreg.dml 
b/src/test/scripts/functions/compress/WorkloadAnalysisMlogreg.dml
new file mode 100644
index 0000000..7f0a76a
--- /dev/null
+++ b/src/test/scripts/functions/compress/WorkloadAnalysisMlogreg.dml
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = read($1);
+y = read($2);
+
+#X = scale(X=X, scale=TRUE, center=TRUE);
+X = X - colMeans(X);
+X = X / sqrt(colVars(X));
+
+while(FALSE){}
+B = multiLogReg(X=X, Y=y, verbose=TRUE);
+
+write(B, $3)

Reply via email to