This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push: new 4cb9321 [SYSTEMDS-2991,2994] Initial workload analysis for compression planning 4cb9321 is described below commit 4cb932168d7d6b334922b79c05e959817b6400b1 Author: Matthias Boehm <mboe...@gmail.com> AuthorDate: Sun May 30 23:19:37 2021 +0200 [SYSTEMDS-2991,2994] Initial workload analysis for compression planning This patch introduces an initial workload representation (the workload tree, WTree for short) as well as workload analyzer for compression planning that returns a WTree for every compression candidate. In detail, we make a pass to determine candidates (currently only preads of certain sizes), and then for every candidate an additional pass to construct the WTree across statements blocks, HOP DAGs, and function calls. All of that is integrated as a new IPA cleanup pass (executed only once with access to the entire program), and except for specific tests still disabled by default. --- .../ipa/IPAPassCompressionWorkloadAnalysis.java | 57 +++++ .../sysds/hops/ipa/InterProceduralAnalysis.java | 14 +- .../apache/sysds/hops/rewrite/HopRewriteUtils.java | 7 + .../hops/rewrite/RewriteCompressedReblock.java | 6 +- .../sysds/runtime/compress/workload/WTreeNode.java | 124 +++++++++ .../compress/workload/WorkloadAnalyzer.java | 279 +++++++++++++++++++++ .../functions/compress/WorkloadAnalysisTest.java | 88 +++++++ .../functions/compress/WorkloadAnalysisLm.dml | 32 +++ .../functions/compress/WorkloadAnalysisMlogreg.dml | 32 +++ 9 files changed, 632 insertions(+), 7 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/ipa/IPAPassCompressionWorkloadAnalysis.java b/src/main/java/org/apache/sysds/hops/ipa/IPAPassCompressionWorkloadAnalysis.java new file mode 100644 index 0000000..0c4bd7c --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/ipa/IPAPassCompressionWorkloadAnalysis.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.ipa; + +import java.util.Map; +import java.util.Map.Entry; + +import org.apache.sysds.conf.ConfigurationManager; +import org.apache.sysds.conf.DMLConfig; +import org.apache.sysds.lops.Compression.CompressConfig; +import org.apache.sysds.parser.DMLProgram; +import org.apache.sysds.runtime.compress.workload.WTreeNode; +import org.apache.sysds.runtime.compress.workload.WorkloadAnalyzer; + +/** + * This rewrite obtains workload summaries for all hops candidates amenable + * for compression as a basis for workload-aware compression planning. + * + */ +public class IPAPassCompressionWorkloadAnalysis extends IPAPass +{ + @Override + public boolean isApplicable(FunctionCallGraph fgraph) { + return InterProceduralAnalysis.CLA_WORKLOAD_ANALYSIS + && CompressConfig.valueOf(ConfigurationManager.getDMLConfig() + .getTextValue(DMLConfig.COMPRESSED_LINALG).toUpperCase()).isEnabled(); + } + + @Override + public boolean rewriteProgram(DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes) { + //obtain CLA workload analysis for all applicable operators + Map<Long, WTreeNode> map = WorkloadAnalyzer.getAllCandidateWorkloads(prog); + + //TODO influence compression planning, for now just printing + for( Entry<Long, WTreeNode> e : map.entrySet() ) + System.out.println(e.getValue()); + + return map != null; + } +} diff --git a/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java b/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java index 35e04f4..9dcfd53 100644 --- a/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java +++ b/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java @@ -57,7 +57,9 @@ import org.apache.sysds.runtime.meta.MetaDataFormat; import org.apache.sysds.utils.Explain; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashSet; +import java.util.List; import java.util.Set; /** @@ -96,6 +98,7 @@ public class InterProceduralAnalysis protected static final boolean ELIMINATE_DEAD_CODE = true; //remove dead code (e.g., assigments) not used later on protected static final boolean FORWARD_SIMPLE_FUN_CALLS = true; //replace a call to a simple forwarding function with the function itself protected static final boolean FLAG_NONDETERMINISM = true; //flag functions which directly or transitively contain non-deterministic calls + public static boolean CLA_WORKLOAD_ANALYSIS = false; //obtain workload for workload-aware compression private final DMLProgram _prog; private final StatementBlock _sb; @@ -235,11 +238,14 @@ public class InterProceduralAnalysis _fgraph = new FunctionCallGraph(_prog); } - //cleanup pass: remove unused functions + //cleanup passes: remove unused functions, CLA workload extraction FunctionCallGraph graph2 = new FunctionCallGraph(_prog); - IPAPass rmFuns = new IPAPassRemoveUnusedFunctions(); - if( rmFuns.isApplicable(graph2) ) - rmFuns.rewriteProgram(_prog, graph2, null); + List<IPAPass> fpasses = Arrays.asList( + new IPAPassRemoveUnusedFunctions(), + new IPAPassCompressionWorkloadAnalysis()); + for(IPAPass pass : fpasses) + if( pass.isApplicable(graph2) ) + pass.rewriteProgram(_prog, graph2, null); } public Set<String> analyzeSubProgram() { diff --git a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java index 2c2b018..e23af65 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java @@ -1120,6 +1120,13 @@ public class HopRewriteUtils return ret; } + public static boolean isData(Hop hop, OpOpData... types) { + boolean ret = false; + for( OpOpData type : types ) + ret |= isData(hop, type); + return ret; + } + public static boolean isData(Hop hop, OpOpData type) { return hop instanceof DataOp && ((DataOp)hop).getOp()==type; } diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteCompressedReblock.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteCompressedReblock.java index 0f26fa1..cc57961 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteCompressedReblock.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteCompressedReblock.java @@ -77,7 +77,7 @@ public class RewriteCompressedReblock extends StatementBlockRewriteRule { // parse compression config DMLConfig conf = ConfigurationManager.getDMLConfig(); CompressConfig compress = CompressConfig.valueOf(conf.getTextValue(DMLConfig.COMPRESSED_LINALG).toUpperCase()); - + // perform compressed reblock rewrite if(compress.isEnabled()) { Hop.resetVisitStatus(sb.getHops()); @@ -126,14 +126,14 @@ public class RewriteCompressedReblock extends StatementBlockRewriteRule { hop.setVisited(); } - private static boolean satisfiesSizeConstraintsForCompression(Hop hop) { + public static boolean satisfiesSizeConstraintsForCompression(Hop hop) { if(hop.getDim2() >= 1) { return (hop.getDim1() >= 1000 && hop.getDim2() < 100) || hop.getDim1() / hop.getDim2() >= 75; } return false; } - private static boolean satisfiesCompressionCondition(Hop hop) { + public static boolean satisfiesCompressionCondition(Hop hop) { boolean satisfies = false; if(satisfiesSizeConstraintsForCompression(hop)) satisfies |= HopRewriteUtils.isData(hop, OpOpData.PERSISTENTREAD); diff --git a/src/main/java/org/apache/sysds/runtime/compress/workload/WTreeNode.java b/src/main/java/org/apache/sysds/runtime/compress/workload/WTreeNode.java new file mode 100644 index 0000000..aa64ad1 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/compress/workload/WTreeNode.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.compress.workload; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.sysds.hops.Hop; + +/** + * A workload tree is a compact representation of the operations + * on a compressed matrix and derived intermediates, including + * the basic control structure and inlined functions as well + * as links to categories + * + * TODO separate classes for inner and leaf nodes? + */ +public class WTreeNode +{ + public enum WTNodeType{ + MAIN, + FCALL, + IF, + WHILE, + FOR, + PARFOR, + BASIC_BLOCK; + public boolean isLoop() { + return this == WHILE || + this == FOR || this == PARFOR; + } + } + + private final WTNodeType _type; + private final List<WTreeNode> _childs = new ArrayList<>(); + private final List<Hop> _cops = new ArrayList<>(); + private int _beginLine = -1; + private int _endLine = -1; + + public WTreeNode(WTNodeType type) { + _type = type; + } + + public WTNodeType getType() { + return _type; + } + + public List<WTreeNode> getChildNodes() { + return _childs; + } + + public void addChild(WTreeNode node) { + _childs.add(node); + } + + public List<Hop> getCompressedOps() { + return _cops; + } + + public void addCompressedOp(Hop hop) { + _cops.add(hop); + } + + public void setLineNumbers(int begin, int end) { + _beginLine = begin; + _endLine = end; + } + + public String explain(int level) { + StringBuilder sb = new StringBuilder(); + //append indentation + for( int i=0; i<level; i++ ) + sb.append("--"); + //append node summary + sb.append(_type.name()); + if( _beginLine>=0 && _endLine>=0 ) { + sb.append(" (lines "); + sb.append(_beginLine); + sb.append("-"); + sb.append(_endLine); + sb.append(")"); + } + sb.append("\n"); + //append child nodes + if( !_childs.isEmpty() ) + for( WTreeNode n : _childs ) + sb.append(n.explain(level+1)); + else if( !_cops.isEmpty() ) { + for( Hop hop : _cops ) { + for( int i=0; i<level+1; i++ ) + sb.append("--"); + sb.append(hop.toString()); + sb.append("\n"); + } + } + return sb.toString(); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder("Workload Tree:\n"); + sb.append("--------------------------------------------------------------------------------\n"); + sb.append(this.explain(1)); + sb.append("--------------------------------------------------------------------------------\n"); + return sb.toString(); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java b/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java new file mode 100644 index 0000000..f9aad7c --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java @@ -0,0 +1,279 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.compress.workload; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.apache.sysds.common.Types.OpOpData; +import org.apache.sysds.hops.FunctionOp; +import org.apache.sysds.hops.Hop; +import org.apache.sysds.hops.rewrite.HopRewriteUtils; +import org.apache.sysds.hops.rewrite.RewriteCompressedReblock; +import org.apache.sysds.parser.DMLProgram; +import org.apache.sysds.parser.DataIdentifier; +import org.apache.sysds.parser.ForStatement; +import org.apache.sysds.parser.ForStatementBlock; +import org.apache.sysds.parser.FunctionStatement; +import org.apache.sysds.parser.FunctionStatementBlock; +import org.apache.sysds.parser.IfStatement; +import org.apache.sysds.parser.IfStatementBlock; +import org.apache.sysds.parser.ParForStatementBlock; +import org.apache.sysds.parser.StatementBlock; +import org.apache.sysds.parser.WhileStatement; +import org.apache.sysds.parser.WhileStatementBlock; +import org.apache.sysds.runtime.compress.workload.WTreeNode.WTNodeType; + +public class WorkloadAnalyzer { + + public static Map<Long, WTreeNode> getAllCandidateWorkloads(DMLProgram prog) { + // extract all compression candidates from program + List<Hop> candidates = getCandidates(prog); + + // for each candidate, create pruned workload tree + // TODO memoization of processed subtree if overlap + Map<Long, WTreeNode> map = new HashMap<>(); + for( Hop cand : candidates ) { + WTreeNode tree = createWorkloadTree(prog, cand); + pruneWorkloadTree(tree); + map.put(cand.getHopID(), tree); + } + + return map; + } + + public static List<Hop> getCandidates(DMLProgram prog) { + List<Hop> candidates = new ArrayList<>(); + for( StatementBlock sb : prog.getStatementBlocks() ) + getCandidates(sb, prog, candidates, new HashSet<>()); + return candidates; + } + + public static WTreeNode createWorkloadTree(DMLProgram prog, Hop candidate) { + WTreeNode main = new WTreeNode(WTNodeType.MAIN); + //TODO generalize, below line assumes only pread candidates (at bottom on DAGs) + Set<String> compressed = new HashSet<>(); + compressed.add(candidate.getName()); + for( StatementBlock sb : prog.getStatementBlocks() ) + main.addChild(createWorkloadTree(sb, prog, compressed, new HashSet<>())); + return main; + } + + public static boolean pruneWorkloadTree(WTreeNode node) { + //recursively process sub trees + Iterator<WTreeNode> iter = node.getChildNodes().iterator(); + while( iter.hasNext() ) { + if( pruneWorkloadTree(iter.next()) ) + iter.remove(); + } + + //indicate that node can be removed + return node.getChildNodes().isEmpty() + && node.getCompressedOps().isEmpty(); + } + + private static void getCandidates(StatementBlock sb, DMLProgram prog, List<Hop> cands, Set<String> fStack) { + if(sb instanceof FunctionStatementBlock) { + FunctionStatementBlock fsb = (FunctionStatementBlock)sb; + FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0); + for(StatementBlock csb : fstmt.getBody()) + getCandidates(csb, prog, cands, fStack); + } + else if(sb instanceof WhileStatementBlock) { + WhileStatementBlock wsb = (WhileStatementBlock) sb; + WhileStatement wstmt = (WhileStatement)wsb.getStatement(0); + for(StatementBlock csb : wstmt.getBody()) + getCandidates(csb, prog, cands, fStack); + } + else if(sb instanceof IfStatementBlock) { + IfStatementBlock isb = (IfStatementBlock) sb; + IfStatement istmt = (IfStatement)isb.getStatement(0); + for(StatementBlock csb : istmt.getIfBody()) + getCandidates(csb, prog, cands, fStack); + for(StatementBlock csb : istmt.getElseBody()) + getCandidates(csb, prog, cands, fStack); + } + else if(sb instanceof ForStatementBlock) { //incl parfor + ForStatementBlock fsb = (ForStatementBlock) sb; + ForStatement fstmt = (ForStatement)fsb.getStatement(0); + for(StatementBlock csb : fstmt.getBody()) + getCandidates(csb, prog, cands, fStack); + } + else { //generic (last-level) + if( sb.getHops() == null ) + return; + Hop.resetVisitStatus(sb.getHops()); + for(Hop hop : sb.getHops()) + getCandidates(hop, prog, cands, fStack); + Hop.resetVisitStatus(sb.getHops()); + } + } + + private static void getCandidates(Hop hop, DMLProgram prog, List<Hop> cands, Set<String> fStack) { + if( hop.isVisited() ) + return; + + //evaluate and add candidates (type and size) + if( RewriteCompressedReblock.satisfiesCompressionCondition(hop) ) + cands.add(hop); + + //recursively process children (inputs) + for( Hop c : hop.getInput() ) + getCandidates(c, prog, cands, fStack); + + //process function calls with awareness of the current + //call stack to avoid endless loops in recursive functions + if( hop instanceof FunctionOp ) { + FunctionOp fop = (FunctionOp) hop; + if( !fStack.contains(fop.getFunctionKey()) ) { + fStack.add(fop.getFunctionKey()); + getCandidates(prog.getFunctionStatementBlock(fop.getFunctionKey()), prog, cands, fStack); + fStack.remove(fop.getFunctionKey()); + } + } + + hop.setVisited(); + } + + private static WTreeNode createWorkloadTree(StatementBlock sb, DMLProgram prog, Set<String> compressed, Set<String> fStack) { + WTreeNode node = null; + if(sb instanceof FunctionStatementBlock) { + FunctionStatementBlock fsb = (FunctionStatementBlock)sb; + FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0); + node = new WTreeNode(WTNodeType.FCALL); + for(StatementBlock csb : fstmt.getBody()) + node.addChild(createWorkloadTree(csb, prog, compressed, fStack)); + } + else if(sb instanceof WhileStatementBlock) { + WhileStatementBlock wsb = (WhileStatementBlock) sb; + WhileStatement wstmt = (WhileStatement)wsb.getStatement(0); + node = new WTreeNode(WTNodeType.WHILE); + createWorkloadTree(wsb.getPredicateHops(), prog, node, compressed, fStack); + for(StatementBlock csb : wstmt.getBody()) + node.addChild(createWorkloadTree(csb, prog, compressed, fStack)); + } + else if(sb instanceof IfStatementBlock) { + IfStatementBlock isb = (IfStatementBlock) sb; + IfStatement istmt = (IfStatement)isb.getStatement(0); + node = new WTreeNode(WTNodeType.IF); + createWorkloadTree(isb.getPredicateHops(), prog, node, compressed, fStack); + for(StatementBlock csb : istmt.getIfBody()) + node.addChild(createWorkloadTree(csb, prog, compressed, fStack)); + for(StatementBlock csb : istmt.getElseBody()) + node.addChild(createWorkloadTree(csb, prog, compressed, fStack)); + } + else if(sb instanceof ForStatementBlock) { //incl parfor + ForStatementBlock fsb = (ForStatementBlock) sb; + ForStatement fstmt = (ForStatement)fsb.getStatement(0); + node = new WTreeNode(sb instanceof ParForStatementBlock ? WTNodeType.PARFOR:WTNodeType.FOR); + createWorkloadTree(fsb.getFromHops(), prog, node, compressed, fStack); + createWorkloadTree(fsb.getToHops(), prog, node, compressed, fStack); + createWorkloadTree(fsb.getIncrementHops(), prog, node, compressed, fStack); + for(StatementBlock csb : fstmt.getBody()) + node.addChild(createWorkloadTree(csb, prog, compressed, fStack)); + } + else { //generic (last-level) + node = new WTreeNode(WTNodeType.BASIC_BLOCK); + if( sb.getHops() != null ) { + Hop.resetVisitStatus(sb.getHops()); + //process hop DAG to collect operations + Set<Long> compressed2 = new HashSet<>(); + for(Hop hop : sb.getHops()) + createWorkloadTree(hop, prog, node, compressed, compressed2, fStack); + //maintain hop DAG outputs (compressed or not compressed) + for(Hop hop : sb.getHops()) { + if( hop instanceof FunctionOp ) { + FunctionOp fop = (FunctionOp) hop; + if( !fStack.contains(fop.getFunctionKey()) ) { + fStack.add(fop.getFunctionKey()); + FunctionStatementBlock fsb = prog.getFunctionStatementBlock(fop.getFunctionKey()); + FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0); + Set<String> fCompressed = new HashSet<>(); + //handle propagation of compressed intermediates into functions + List<DataIdentifier> fArgs = fstmt.getInputParams(); + for( int i=0; i<fArgs.size(); i++ ) + if( compressed2.contains(fop.getInput(i).getHopID()) ) + fCompressed.add(fArgs.get(i).getName()); + node.addChild(createWorkloadTree(fsb, prog, fCompressed, fStack)); + fStack.remove(fop.getFunctionKey()); + } + } + else if( HopRewriteUtils.isData(hop, OpOpData.TRANSIENTWRITE) ) { + //handle propagation of compressed intermediates across blocks + if( compressed.contains(hop.getName()) && !compressed2.contains(hop.getHopID()) ) + compressed.remove(hop.getName()); + if( !compressed.contains(hop.getName()) && compressed2.contains(hop.getHopID()) ) + compressed.add(hop.getName()); + } + } + Hop.resetVisitStatus(sb.getHops()); + } + } + node.setLineNumbers(sb.getBeginLine(), sb.getEndLine()); + return node; + } + + private static void createWorkloadTree(Hop hop, DMLProgram prog, WTreeNode parent, Set<String> compressed, Set<String> fStack) { + if( hop == null ) + return; + hop.resetVisitStatus(); + createWorkloadTree(hop, prog, parent, compressed, new HashSet<>(), fStack); //see below + hop.resetVisitStatus(); + } + + private static void createWorkloadTree(Hop hop, DMLProgram prog, WTreeNode parent, Set<String> compressed, Set<Long> compressed2, Set<String> fStack) { + if( hop == null || hop.isVisited() ) + return; + + //recursively process children (inputs first for propagation of compression status) + for( Hop c : hop.getInput() ) + createWorkloadTree(c, prog, parent, compressed, compressed2, fStack); + + //map statement block propagation to hop propagation + if( HopRewriteUtils.isData(hop, OpOpData.PERSISTENTREAD, OpOpData.TRANSIENTREAD) + && compressed.contains(hop.getName()) ) { + compressed2.add(hop.getHopID()); + } + + //collect operations on compressed intermediates or inputs + //if any input is compressed we collect this hop as a compressed operation + if( hop.getInput().stream().anyMatch(h -> compressed2.contains(h.getHopID())) ) { + if(!HopRewriteUtils.isData(hop, OpOpData.PERSISTENTREAD, //all, but data ops + OpOpData.TRANSIENTREAD, OpOpData.TRANSIENTWRITE) ) + { + parent.addCompressedOp(hop); + } + + //if the output size also qualifies for compression, we propagate this status + if( RewriteCompressedReblock.satisfiesSizeConstraintsForCompression(hop) + && hop.getDataType().isMatrix() ) + { + compressed2.add(hop.getHopID()); + } + } + + hop.setVisited(); + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/compress/WorkloadAnalysisTest.java b/src/test/java/org/apache/sysds/test/functions/compress/WorkloadAnalysisTest.java new file mode 100644 index 0000000..5585617 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/compress/WorkloadAnalysisTest.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.compress; + +import java.io.File; + +import org.apache.sysds.common.Types.ExecMode; +import org.apache.sysds.hops.ipa.InterProceduralAnalysis; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +public class WorkloadAnalysisTest extends AutomatedTestBase +{ + private final static String TEST_NAME1 = "WorkloadAnalysisMlogreg"; + private final static String TEST_NAME2 = "WorkloadAnalysisLm"; + private final static String TEST_DIR = "functions/compress/"; + private final static String TEST_CLASS_DIR = TEST_DIR + WorkloadAnalysisTest.class.getSimpleName() + "/"; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{"B"})); + addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[]{"B"})); + } + + @Test + public void testMlogregCP() { + runWorkloadAnalysisTest(TEST_NAME1, ExecMode.HYBRID); + } + + @Test + public void testLmCP() { + runWorkloadAnalysisTest(TEST_NAME2, ExecMode.HYBRID); + } + + private void runWorkloadAnalysisTest(String testname, ExecMode mode) + { + ExecMode oldPlatform = setExecMode(mode); + boolean oldFlag = InterProceduralAnalysis.CLA_WORKLOAD_ANALYSIS; + + try + { + loadTestConfiguration(getTestConfiguration(testname)); + + InterProceduralAnalysis.CLA_WORKLOAD_ANALYSIS = true; + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testname + ".dml"; + programArgs = new String[]{"-args", input("X"), input("y"), output("B") }; + + double[][] X = getRandomMatrix(10000, 20, 0, 1, 1.0, 7); + writeInputMatrixWithMTD("X", X, false); + double[][] y = TestUtils.round(getRandomMatrix(10000, 1, 1, 2, 1.0, 3)); + writeInputMatrixWithMTD("y", y, false); + + runTest(true, false, null, -1); + //TODO check for compressed operations + //(right now test only checks that the workload analysis does not crash) + } + finally { + resetExecMode(oldPlatform); + InterProceduralAnalysis.CLA_WORKLOAD_ANALYSIS = oldFlag; + } + } + + @Override + protected File getConfigTemplateFile() { + return new File(SCRIPT_DIR + TEST_DIR + "force", "SystemDS-config-compress.xml"); + } +} diff --git a/src/test/scripts/functions/compress/WorkloadAnalysisLm.dml b/src/test/scripts/functions/compress/WorkloadAnalysisLm.dml new file mode 100644 index 0000000..ccaec92 --- /dev/null +++ b/src/test/scripts/functions/compress/WorkloadAnalysisLm.dml @@ -0,0 +1,32 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = read($1); +y = read($2); + +#X = scale(X=X, scale=TRUE, center=TRUE); +X = X - colMeans(X); +X = X / sqrt(colVars(X)); + +while(FALSE){} +B = lm(X=X, y=y, verbose=TRUE); + +write(B, $3) diff --git a/src/test/scripts/functions/compress/WorkloadAnalysisMlogreg.dml b/src/test/scripts/functions/compress/WorkloadAnalysisMlogreg.dml new file mode 100644 index 0000000..7f0a76a --- /dev/null +++ b/src/test/scripts/functions/compress/WorkloadAnalysisMlogreg.dml @@ -0,0 +1,32 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = read($1); +y = read($2); + +#X = scale(X=X, scale=TRUE, center=TRUE); +X = X - colMeans(X); +X = X / sqrt(colVars(X)); + +while(FALSE){} +B = multiLogReg(X=X, Y=y, verbose=TRUE); + +write(B, $3)