This is an automated email from the ASF dual-hosted git repository. baunsgaard pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push: new e5a96d8731 [MINOR] BWARE Update Workload Analyzer e5a96d8731 is described below commit e5a96d87310dcdc9d8b7989f7aeee36adf0d7989 Author: Sebastian Baunsgaard <baunsga...@apache.org> AuthorDate: Wed Aug 28 14:12:36 2024 +0200 [MINOR] BWARE Update Workload Analyzer This commit adds extensions to the workload analysis to include more operations in the workload vector. Closes #2075 --- .../apache/sysds/hops/rewrite/HopRewriteUtils.java | 6 + .../compress/workload/WorkloadAnalyzer.java | 198 +++++++++++++++------ .../instructions/cp/CompressionCPInstruction.java | 7 +- .../component/compress/workload/WorkloadTest.java | 6 +- 4 files changed, 162 insertions(+), 55 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java index 144b331327..2b84318e35 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java @@ -1207,6 +1207,11 @@ public class HopRewriteUtils { public static boolean isParameterizedBuiltinOp(Hop hop, ParamBuiltinOp type) { return hop instanceof ParameterizedBuiltinOp && ((ParameterizedBuiltinOp) hop).getOp().equals(type); } + + public static boolean isParameterizedBuiltinOp(Hop hop, ParamBuiltinOp... types) { + return hop instanceof ParameterizedBuiltinOp && + ArrayUtils.contains(types, ((ParameterizedBuiltinOp) hop).getOp()); + } public static boolean isRemoveEmpty(Hop hop, boolean rows) { return isParameterizedBuiltinOp(hop, ParamBuiltinOp.RMEMPTY) @@ -1380,6 +1385,7 @@ public class HopRewriteUtils { return ret; } + public static Hop getBasic1NSequenceMax(Hop hop) { if( isDataGenOp(hop, OpOpDG.SEQ) ) { DataGenOp dgop = (DataGenOp) hop; diff --git a/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java b/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java index bec6ac18aa..2092440ed1 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java +++ b/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java @@ -27,6 +27,7 @@ import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.Stack; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -35,6 +36,7 @@ import org.apache.sysds.common.Types.OpOp1; import org.apache.sysds.common.Types.OpOp2; import org.apache.sysds.common.Types.OpOp3; import org.apache.sysds.common.Types.OpOpData; +import org.apache.sysds.common.Types.ParamBuiltinOp; import org.apache.sysds.common.Types.ReOrgOp; import org.apache.sysds.hops.AggBinaryOp; import org.apache.sysds.hops.AggUnaryOp; @@ -78,6 +80,7 @@ public class WorkloadAnalyzer { private final Set<Long> overlapping; private final DMLProgram prog; private final Map<Long, Op> treeLookup; + private final Stack<Hop> stack; public static Map<Long, WTreeRoot> getAllCandidateWorkloads(DMLProgram prog) { // extract all compression candidates from program (in program order) @@ -94,7 +97,7 @@ public class WorkloadAnalyzer { // construct workload tree for candidate WorkloadAnalyzer wa = new WorkloadAnalyzer(prog); - WTreeRoot tree = wa.createWorkloadTree(cand); + WTreeRoot tree = wa.createWorkloadTreeRoot(cand); map.put(cand.getHopID(), tree); allWAs.add(wa); @@ -111,6 +114,7 @@ public class WorkloadAnalyzer { this.transientCompressed = new HashMap<>(); this.overlapping = new HashSet<>(); this.treeLookup = new HashMap<>(); + this.stack = new Stack<>(); } private WorkloadAnalyzer(DMLProgram prog, Set<Long> compressed, HashMap<String, Long> transientCompressed, @@ -122,13 +126,20 @@ public class WorkloadAnalyzer { this.transientCompressed = transientCompressed; this.overlapping = overlapping; this.treeLookup = treeLookup; + this.stack = new Stack<>(); } - private WTreeRoot createWorkloadTree(Hop candidate) { + private WTreeRoot createWorkloadTreeRoot(Hop candidate) { WTreeRoot main = new WTreeRoot(candidate); compressed.add(candidate.getHopID()); + if(HopRewriteUtils.isTransformEncode(candidate)) { + Hop matrix = ((FunctionOp) candidate).getOutputs().get(0); + compressed.add(matrix.getHopID()); + transientCompressed.put(matrix.getName(), matrix.getHopID()); + } for(StatementBlock sb : prog.getStatementBlocks()) - createWorkloadTree(main, sb, prog, new HashSet<>()); + createWorkloadTreeNodes(main, sb, prog, new HashSet<>()); + pruneWorkloadTree(main); return main; } @@ -222,23 +233,23 @@ public class WorkloadAnalyzer { hop.setVisited(); } - private void createWorkloadTree(AWTreeNode n, StatementBlock sb, DMLProgram prog, Set<String> fStack) { + private void createWorkloadTreeNodes(AWTreeNode n, StatementBlock sb, DMLProgram prog, Set<String> fStack) { WTreeNode node; if(sb instanceof FunctionStatementBlock) { FunctionStatementBlock fsb = (FunctionStatementBlock) sb; FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0); node = new WTreeNode(WTNodeType.FCALL, 1); for(StatementBlock csb : fstmt.getBody()) - createWorkloadTree(node, csb, prog, fStack); + createWorkloadTreeNodes(node, csb, prog, fStack); } else if(sb instanceof WhileStatementBlock) { WhileStatementBlock wsb = (WhileStatementBlock) sb; WhileStatement wstmt = (WhileStatement) wsb.getStatement(0); - node = new WTreeNode(WTNodeType.WHILE, 10); + node = new WTreeNode(WTNodeType.WHILE, 100); createWorkloadTree(wsb.getPredicateHops(), prog, node, fStack); for(StatementBlock csb : wstmt.getBody()) - createWorkloadTree(node, csb, prog, fStack); + createWorkloadTreeNodes(node, csb, prog, fStack); } else if(sb instanceof IfStatementBlock) { IfStatementBlock isb = (IfStatementBlock) sb; @@ -247,9 +258,9 @@ public class WorkloadAnalyzer { createWorkloadTree(isb.getPredicateHops(), prog, node, fStack); for(StatementBlock csb : istmt.getIfBody()) - createWorkloadTree(node, csb, prog, fStack); + createWorkloadTreeNodes(node, csb, prog, fStack); for(StatementBlock csb : istmt.getElseBody()) - createWorkloadTree(node, csb, prog, fStack); + createWorkloadTreeNodes(node, csb, prog, fStack); } else if(sb instanceof ForStatementBlock) { // incl parfor ForStatementBlock fsb = (ForStatementBlock) sb; @@ -260,7 +271,7 @@ public class WorkloadAnalyzer { createWorkloadTree(fsb.getToHops(), prog, node, fStack); createWorkloadTree(fsb.getIncrementHops(), prog, node, fStack); for(StatementBlock csb : fstmt.getBody()) - createWorkloadTree(node, csb, prog, fStack); + createWorkloadTreeNodes(node, csb, prog, fStack); } else { // generic (last-level) @@ -269,14 +280,19 @@ public class WorkloadAnalyzer { if(hops != null) { // process hop DAG to collect operations that are compressed. - for(Hop hop : hops) + for(Hop hop : hops) { createWorkloadTree(hop, prog, n, fStack); + // createStack(hop); + // processStack(prog, n, fStack); + } // maintain hop DAG outputs (compressed or not compressed) for(Hop hop : hops) { if(hop instanceof FunctionOp) { FunctionOp fop = (FunctionOp) hop; - if(!fStack.contains(fop.getFunctionKey())) { + if(HopRewriteUtils.isTransformEncode(fop)) + return; + else if(!fStack.contains(fop.getFunctionKey())) { fStack.add(fop.getFunctionKey()); FunctionStatementBlock fsb = prog.getFunctionStatementBlock(fop.getFunctionKey()); if(fsb == null) @@ -295,7 +311,7 @@ public class WorkloadAnalyzer { WorkloadAnalyzer fa = new WorkloadAnalyzer(prog, compressed, fCompressed, transposed, overlapping, treeLookup); - fa.createWorkloadTree(n, fsb, prog, fStack); + fa.createWorkloadTreeNodes(n, fsb, prog, fStack); String[] outs = fop.getOutputVariableNames(); for(int i = 0; i < outs.length; i++) { Long id = fCompressed.get(outs[i]); @@ -305,7 +321,6 @@ public class WorkloadAnalyzer { fStack.remove(fop.getFunctionKey()); } } - } } return; @@ -313,27 +328,42 @@ public class WorkloadAnalyzer { n.addChild(node); } - private void createWorkloadTree(Hop hop, DMLProgram prog, AWTreeNode parent, Set<String> fStack) { + private void createStack(Hop hop) { if(hop == null || visited.contains(hop) || isNoOp(hop)) return; - - // DFS: recursively process children (inputs first for propagation of compression status) + stack.add(hop); for(Hop c : hop.getInput()) - createWorkloadTree(c, prog, parent, fStack); + createStack(c); - // map statement block propagation to hop propagation - if(HopRewriteUtils.isData(hop, OpOpData.PERSISTENTREAD, OpOpData.TRANSIENTREAD) && - transientCompressed.containsKey(hop.getName())) { - compressed.add(hop.getHopID()); - treeLookup.put(hop.getHopID(), treeLookup.get(transientCompressed.get(hop.getName()))); - } + visited.add(hop); + } + + private void createWorkloadTree(Hop hop, DMLProgram prog, AWTreeNode parent, Set<String> fStack) { + createStack(hop); + processStack(prog, parent, fStack); + } - // collect operations on compressed intermediates or inputs - // if any input is compressed we collect this hop as a compressed operation - if(hop.getInput().stream().anyMatch(h -> compressed.contains(h.getHopID()))) - createOp(hop, parent); + private void processStack(DMLProgram prog, AWTreeNode parent, Set<String> fStack) { + + while(!stack.isEmpty()) { + Hop hop = stack.pop(); + + // map statement block propagation to hop propagation + if(HopRewriteUtils.isData(hop, OpOpData.PERSISTENTREAD, OpOpData.TRANSIENTREAD) && + transientCompressed.containsKey(hop.getName())) { + compressed.add(hop.getHopID()); + treeLookup.put(hop.getHopID(), treeLookup.get(transientCompressed.get(hop.getName()))); + } + else { + + // collect operations on compressed intermediates or inputs + // if any input is compressed we collect this hop as a compressed operation + if(hop.getInput().stream().anyMatch(h -> compressed.contains(h.getHopID()))) + createOp(hop, parent); + + } + } - visited.add(hop); } private void createOp(Hop hop, AWTreeNode parent) { @@ -369,11 +399,16 @@ public class WorkloadAnalyzer { o = new OpNormal(hop, false); } } - else if(hop instanceof UnaryOp && - !HopRewriteUtils.isUnary(hop, OpOp1.MULT2, OpOp1.MINUS1_MULT, OpOp1.MINUS_RIGHT, OpOp1.CAST_AS_MATRIX)) { - if(isOverlapping(hop.getInput(0))) { - treeLookup.get(hop.getInput(0).getHopID()).setDecompressing(); - return; + else if(hop instanceof UnaryOp) { + if(!HopRewriteUtils.isUnary(hop, OpOp1.MULT2, OpOp1.MINUS1_MULT, OpOp1.MINUS_RIGHT, OpOp1.CAST_AS_MATRIX)) { + if(isOverlapping(hop.getInput(0))) { + treeLookup.get(hop.getInput(0).getHopID()).setDecompressing(); + return; + } + + } + else if(HopRewriteUtils.isUnary(hop, OpOp1.DETECTSCHEMA)) { + o = new OpNormal(hop, false); } } else if(hop instanceof AggBinaryOp) { @@ -411,6 +446,9 @@ public class WorkloadAnalyzer { setDecompressionOnAllInputs(hop, parent); return; } + else if(HopRewriteUtils.isBinary(hop, OpOp2.APPLY_SCHEMA)) { + o = new OpNormal(hop, true); + } else { List<Hop> in = hop.getInput(); final boolean ol0 = isOverlapping(in.get(0)); @@ -461,22 +499,11 @@ public class WorkloadAnalyzer { } else if(hop instanceof IndexingOp) { - IndexingOp idx = (IndexingOp) hop; final boolean isOverlapping = isOverlapping(hop.getInput(0)); - final boolean fullColumn = HopRewriteUtils.isFullColumnIndexing(idx); - - if(fullColumn) { - o = new OpNormal(hop, RewriteCompressedReblock.satisfiesSizeConstraintsForCompression(hop)); - if(isOverlapping) { - overlapping.add(hop.getHopID()); - o.setOverlapping(); - } - } - else { - // This decompression is a little different, since it does not decompress the entire matrix - // but only a sub part. therefore create a new op node and set it to decompressing. - o = new OpNormal(hop, false); - o.setDecompressing(); + o = new OpNormal(hop, true); + if(isOverlapping) { + overlapping.add(hop.getHopID()); + o.setOverlapping(); } } else if(HopRewriteUtils.isTernary(hop, OpOp3.MINUS_MULT, OpOp3.PLUS_MULT, OpOp3.QUANTILE, OpOp3.CTABLE)) { @@ -505,12 +532,26 @@ public class WorkloadAnalyzer { setDecompressionOnAllInputs(hop, parent); } } - else if(hop instanceof ParameterizedBuiltinOp || hop instanceof NaryOp) { + else if(hop instanceof ParameterizedBuiltinOp) { + if(HopRewriteUtils.isParameterizedBuiltinOp(hop, ParamBuiltinOp.REPLACE, ParamBuiltinOp.TRANSFORMAPPLY)) { + o = new OpNormal(hop, true); + } + else { + LOG.warn("Unknown ParameterizedBuiltinOp Hop:" + hop.getClass().getSimpleName() + "\n" + Explain.explain(hop)); + setDecompressionOnAllInputs(hop, parent); + return; + } + } + else if(hop instanceof NaryOp) { + setDecompressionOnAllInputs(hop, parent); + return; + } + else if(hop instanceof ReorgOp){ setDecompressionOnAllInputs(hop, parent); return; } else { - LOG.warn("Unknown Hop:" + hop.getClass().getSimpleName() + "\n" + Explain.explain(hop)); + LOG.warn("Unknown Matrix Hop:" + hop.getClass().getSimpleName() + "\n" + Explain.explain(hop)); setDecompressionOnAllInputs(hop, parent); return; } @@ -522,7 +563,62 @@ public class WorkloadAnalyzer { if(o.isCompressedOutput()) compressed.add(hop.getHopID()); } + else if(hop.getDataType().isFrame()) { + Op o = null; + if(HopRewriteUtils.isData(hop, OpOpData.PERSISTENTREAD, OpOpData.TRANSIENTREAD)) + return; + else if(HopRewriteUtils.isData(hop, OpOpData.TRANSIENTWRITE, OpOpData.PERSISTENTWRITE)) { + transientCompressed.put(hop.getName(), hop.getInput(0).getHopID()); + compressed.add(hop.getHopID()); + o = new OpMetadata(hop, hop.getInput(0)); + if(isOverlapping(hop.getInput(0))) + o.setOverlapping(); + } + else if(HopRewriteUtils.isUnary(hop, OpOp1.DETECTSCHEMA)) { + o = new OpNormal(hop, false); + } + else if(HopRewriteUtils.isBinary(hop, OpOp2.APPLY_SCHEMA)) { + o = new OpNormal(hop, true); + } + else if(hop instanceof AggUnaryOp) { + o = new OpNormal(hop, false); + } + else { + LOG.warn("Unknown Frame Hop:" + hop.getClass().getSimpleName() + "\n" + Explain.explain(hop)); + setDecompressionOnAllInputs(hop, parent); + return; + } + + o = o != null ? o : new OpNormal(hop, RewriteCompressedReblock.satisfiesSizeConstraintsForCompression(hop)); + treeLookup.put(hop.getHopID(), o); + parent.addOp(o); + if(o.isCompressedOutput()) + compressed.add(hop.getHopID()); + } + else if(HopRewriteUtils.isTransformEncode(hop)) { + Hop matrix = ((FunctionOp) hop).getOutputs().get(0); + compressed.add(matrix.getHopID()); + transientCompressed.put(matrix.getName(), matrix.getHopID()); + parent.addOp(new OpNormal(hop, true)); + } + else if(hop instanceof FunctionOp && ((FunctionOp) hop).getFunctionNamespace().equals(".builtinNS")) { + parent.addOp(new OpNormal(hop, false)); + } + else if(hop instanceof AggUnaryOp) { + if((isOverlapping(hop.getInput().get(0)) && !HopRewriteUtils.isAggUnaryOp(hop, AggOp.SUM, AggOp.MEAN)) || + HopRewriteUtils.isAggUnaryOp(hop, AggOp.TRACE)) { + setDecompressionOnAllInputs(hop, parent); + return; + } + else { + Op o = new OpNormal(hop, false); + treeLookup.put(hop.getHopID(), o); + parent.addOp(o); + } + } else { + LOG.warn( + "Unknown Matrix or Frame Hop:" + hop.getClass().getSimpleName() + "\n" + hop.getDataType() + "\n" + Explain.explain(hop)); parent.addOp(new OpNormal(hop, false)); } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java index 2e87d9eb65..7d0d9f7870 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java @@ -128,7 +128,10 @@ public class CompressionCPInstruction extends ComputationCPInstruction { // Get and clear workload tree entry for this compression instruction. final WTreeRoot root = (_singletonLookupID != 0) ? (WTreeRoot) m.get(_singletonLookupID) : null; - m.removeKey(_singletonLookupID); + // We used to remove the key from the hash map, + // however this is not correct since the compression statement + // can be reused in multiple for loops. + if(ec.isFrameObject(input1.getName())) processFrameBlockCompression(ec, ec.getFrameInput(input1.getName()), _numThreads, root); @@ -144,6 +147,8 @@ public class CompressionCPInstruction extends ComputationCPInstruction { if(LOG.isTraceEnabled()) LOG.trace(compResult.getRight()); MatrixBlock out = compResult.getLeft(); + if(LOG.isInfoEnabled()) + LOG.info("Compression output class: " + out.getClass().getSimpleName()); // Set output and release input ec.releaseMatrixInput(input1.getName()); ec.setMatrixOutput(output.getName(), out); diff --git a/src/test/java/org/apache/sysds/test/component/compress/workload/WorkloadTest.java b/src/test/java/org/apache/sysds/test/component/compress/workload/WorkloadTest.java index cf084e9ccc..5184556817 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/workload/WorkloadTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/workload/WorkloadTest.java @@ -86,7 +86,7 @@ public class WorkloadTest { // Simple tests no loops verifying basic behavior tests.add(new Object[] {0, 0, 0, 0, 0, 0, 1, 0, false, false, "sum.dml", args}); tests.add(new Object[] {0, 0, 0, 0, 0, 0, 1, 0, false, false, "mean.dml", args}); - tests.add(new Object[] {0, 0, 0, 0, 0, 0, 1, 1, false, false, "plus.dml", args}); + tests.add(new Object[] {0, 0, 0, 0, 0, 0, 2, 1, false, false, "plus.dml", args}); tests.add(new Object[] {0, 0, 0, 0, 0, 0, 2, 0, false, false, "sliceCols.dml", args}); tests.add(new Object[] {0, 0, 0, 0, 0, 0, 2, 0, false, false, "sliceIndex.dml", args}); // tests.add(new Object[] {0, 0, 0, 1, 0, 0, 0, 0, false, false, "leftMult.dml", args}); @@ -105,9 +105,9 @@ public class WorkloadTest { // Builtins: // nr 11: tests.add(new Object[] {0, 0, 0, 0, 0, 0, 6, 0, true, false, "functions/scale.dml", args}); - tests.add(new Object[] {0, 0, 0, 0, 0, 0, 5, 0, true, true, "functions/scale.dml", args}); + tests.add(new Object[] {0, 0, 0, 0, 0, 0, 4, 0, false, true, "functions/scale.dml", args}); tests.add(new Object[] {0, 0, 0, 0, 0, 0, 6, 0, true, false, "functions/scale_continued.dml", args}); - tests.add(new Object[] {0, 0, 0, 0, 0, 0, 6, 0, true, true, "functions/scale_continued.dml", args}); + tests.add(new Object[] {0, 0, 0, 0, 0, 0, 5, 0, true, true, "functions/scale_continued.dml", args}); tests.add(new Object[] {0, 0, 0, 0, 0, 0, 2, 0, false, true, "functions/scale_onlySide.dml", args}); tests.add(new Object[] {0, 0, 0, 0, 0, 0, 6, 0, true, false, "functions/scale_onlySide.dml", args});