This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new e5a96d8731 [MINOR] BWARE Update Workload Analyzer
e5a96d8731 is described below
commit e5a96d87310dcdc9d8b7989f7aeee36adf0d7989
Author: Sebastian Baunsgaard <[email protected]>
AuthorDate: Wed Aug 28 14:12:36 2024 +0200
[MINOR] BWARE Update Workload Analyzer
This commit adds extensions to the workload analysis
to include more operations in the workload vector.
Closes #2075
---
.../apache/sysds/hops/rewrite/HopRewriteUtils.java | 6 +
.../compress/workload/WorkloadAnalyzer.java | 198 +++++++++++++++------
.../instructions/cp/CompressionCPInstruction.java | 7 +-
.../component/compress/workload/WorkloadTest.java | 6 +-
4 files changed, 162 insertions(+), 55 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
index 144b331327..2b84318e35 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
@@ -1207,6 +1207,11 @@ public class HopRewriteUtils {
public static boolean isParameterizedBuiltinOp(Hop hop, ParamBuiltinOp
type) {
return hop instanceof ParameterizedBuiltinOp &&
((ParameterizedBuiltinOp) hop).getOp().equals(type);
}
+
+ public static boolean isParameterizedBuiltinOp(Hop hop,
ParamBuiltinOp... types) {
+ return hop instanceof ParameterizedBuiltinOp &&
+ ArrayUtils.contains(types, ((ParameterizedBuiltinOp)
hop).getOp());
+ }
public static boolean isRemoveEmpty(Hop hop, boolean rows) {
return isParameterizedBuiltinOp(hop, ParamBuiltinOp.RMEMPTY)
@@ -1380,6 +1385,7 @@ public class HopRewriteUtils {
return ret;
}
+
public static Hop getBasic1NSequenceMax(Hop hop) {
if( isDataGenOp(hop, OpOpDG.SEQ) ) {
DataGenOp dgop = (DataGenOp) hop;
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java
b/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java
index bec6ac18aa..2092440ed1 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java
@@ -27,6 +27,7 @@ import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
+import java.util.Stack;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@@ -35,6 +36,7 @@ import org.apache.sysds.common.Types.OpOp1;
import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.OpOp3;
import org.apache.sysds.common.Types.OpOpData;
+import org.apache.sysds.common.Types.ParamBuiltinOp;
import org.apache.sysds.common.Types.ReOrgOp;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.AggUnaryOp;
@@ -78,6 +80,7 @@ public class WorkloadAnalyzer {
private final Set<Long> overlapping;
private final DMLProgram prog;
private final Map<Long, Op> treeLookup;
+ private final Stack<Hop> stack;
public static Map<Long, WTreeRoot> getAllCandidateWorkloads(DMLProgram
prog) {
// extract all compression candidates from program (in program
order)
@@ -94,7 +97,7 @@ public class WorkloadAnalyzer {
// construct workload tree for candidate
WorkloadAnalyzer wa = new WorkloadAnalyzer(prog);
- WTreeRoot tree = wa.createWorkloadTree(cand);
+ WTreeRoot tree = wa.createWorkloadTreeRoot(cand);
map.put(cand.getHopID(), tree);
allWAs.add(wa);
@@ -111,6 +114,7 @@ public class WorkloadAnalyzer {
this.transientCompressed = new HashMap<>();
this.overlapping = new HashSet<>();
this.treeLookup = new HashMap<>();
+ this.stack = new Stack<>();
}
private WorkloadAnalyzer(DMLProgram prog, Set<Long> compressed,
HashMap<String, Long> transientCompressed,
@@ -122,13 +126,20 @@ public class WorkloadAnalyzer {
this.transientCompressed = transientCompressed;
this.overlapping = overlapping;
this.treeLookup = treeLookup;
+ this.stack = new Stack<>();
}
- private WTreeRoot createWorkloadTree(Hop candidate) {
+ private WTreeRoot createWorkloadTreeRoot(Hop candidate) {
WTreeRoot main = new WTreeRoot(candidate);
compressed.add(candidate.getHopID());
+ if(HopRewriteUtils.isTransformEncode(candidate)) {
+ Hop matrix = ((FunctionOp)
candidate).getOutputs().get(0);
+ compressed.add(matrix.getHopID());
+ transientCompressed.put(matrix.getName(),
matrix.getHopID());
+ }
for(StatementBlock sb : prog.getStatementBlocks())
- createWorkloadTree(main, sb, prog, new HashSet<>());
+ createWorkloadTreeNodes(main, sb, prog, new
HashSet<>());
+
pruneWorkloadTree(main);
return main;
}
@@ -222,23 +233,23 @@ public class WorkloadAnalyzer {
hop.setVisited();
}
- private void createWorkloadTree(AWTreeNode n, StatementBlock sb,
DMLProgram prog, Set<String> fStack) {
+ private void createWorkloadTreeNodes(AWTreeNode n, StatementBlock sb,
DMLProgram prog, Set<String> fStack) {
WTreeNode node;
if(sb instanceof FunctionStatementBlock) {
FunctionStatementBlock fsb = (FunctionStatementBlock)
sb;
FunctionStatement fstmt = (FunctionStatement)
fsb.getStatement(0);
node = new WTreeNode(WTNodeType.FCALL, 1);
for(StatementBlock csb : fstmt.getBody())
- createWorkloadTree(node, csb, prog, fStack);
+ createWorkloadTreeNodes(node, csb, prog,
fStack);
}
else if(sb instanceof WhileStatementBlock) {
WhileStatementBlock wsb = (WhileStatementBlock) sb;
WhileStatement wstmt = (WhileStatement)
wsb.getStatement(0);
- node = new WTreeNode(WTNodeType.WHILE, 10);
+ node = new WTreeNode(WTNodeType.WHILE, 100);
createWorkloadTree(wsb.getPredicateHops(), prog, node,
fStack);
for(StatementBlock csb : wstmt.getBody())
- createWorkloadTree(node, csb, prog, fStack);
+ createWorkloadTreeNodes(node, csb, prog,
fStack);
}
else if(sb instanceof IfStatementBlock) {
IfStatementBlock isb = (IfStatementBlock) sb;
@@ -247,9 +258,9 @@ public class WorkloadAnalyzer {
createWorkloadTree(isb.getPredicateHops(), prog, node,
fStack);
for(StatementBlock csb : istmt.getIfBody())
- createWorkloadTree(node, csb, prog, fStack);
+ createWorkloadTreeNodes(node, csb, prog,
fStack);
for(StatementBlock csb : istmt.getElseBody())
- createWorkloadTree(node, csb, prog, fStack);
+ createWorkloadTreeNodes(node, csb, prog,
fStack);
}
else if(sb instanceof ForStatementBlock) { // incl parfor
ForStatementBlock fsb = (ForStatementBlock) sb;
@@ -260,7 +271,7 @@ public class WorkloadAnalyzer {
createWorkloadTree(fsb.getToHops(), prog, node, fStack);
createWorkloadTree(fsb.getIncrementHops(), prog, node,
fStack);
for(StatementBlock csb : fstmt.getBody())
- createWorkloadTree(node, csb, prog, fStack);
+ createWorkloadTreeNodes(node, csb, prog,
fStack);
}
else { // generic (last-level)
@@ -269,14 +280,19 @@ public class WorkloadAnalyzer {
if(hops != null) {
// process hop DAG to collect operations that
are compressed.
- for(Hop hop : hops)
+ for(Hop hop : hops) {
createWorkloadTree(hop, prog, n,
fStack);
+ // createStack(hop);
+ // processStack(prog, n, fStack);
+ }
// maintain hop DAG outputs (compressed or not
compressed)
for(Hop hop : hops) {
if(hop instanceof FunctionOp) {
FunctionOp fop = (FunctionOp)
hop;
-
if(!fStack.contains(fop.getFunctionKey())) {
+
if(HopRewriteUtils.isTransformEncode(fop))
+ return;
+ else
if(!fStack.contains(fop.getFunctionKey())) {
fStack.add(fop.getFunctionKey());
FunctionStatementBlock
fsb = prog.getFunctionStatementBlock(fop.getFunctionKey());
if(fsb == null)
@@ -295,7 +311,7 @@ public class WorkloadAnalyzer {
WorkloadAnalyzer fa =
new WorkloadAnalyzer(prog, compressed, fCompressed, transposed, overlapping,
treeLookup);
-
fa.createWorkloadTree(n, fsb, prog, fStack);
+
fa.createWorkloadTreeNodes(n, fsb, prog, fStack);
String[] outs =
fop.getOutputVariableNames();
for(int i = 0; i <
outs.length; i++) {
Long id =
fCompressed.get(outs[i]);
@@ -305,7 +321,6 @@ public class WorkloadAnalyzer {
fStack.remove(fop.getFunctionKey());
}
}
-
}
}
return;
@@ -313,27 +328,42 @@ public class WorkloadAnalyzer {
n.addChild(node);
}
- private void createWorkloadTree(Hop hop, DMLProgram prog, AWTreeNode
parent, Set<String> fStack) {
+ private void createStack(Hop hop) {
if(hop == null || visited.contains(hop) || isNoOp(hop))
return;
-
- // DFS: recursively process children (inputs first for
propagation of compression status)
+ stack.add(hop);
for(Hop c : hop.getInput())
- createWorkloadTree(c, prog, parent, fStack);
+ createStack(c);
- // map statement block propagation to hop propagation
- if(HopRewriteUtils.isData(hop, OpOpData.PERSISTENTREAD,
OpOpData.TRANSIENTREAD) &&
- transientCompressed.containsKey(hop.getName())) {
- compressed.add(hop.getHopID());
- treeLookup.put(hop.getHopID(),
treeLookup.get(transientCompressed.get(hop.getName())));
- }
+ visited.add(hop);
+ }
+
+ private void createWorkloadTree(Hop hop, DMLProgram prog, AWTreeNode
parent, Set<String> fStack) {
+ createStack(hop);
+ processStack(prog, parent, fStack);
+ }
- // collect operations on compressed intermediates or inputs
- // if any input is compressed we collect this hop as a
compressed operation
- if(hop.getInput().stream().anyMatch(h ->
compressed.contains(h.getHopID())))
- createOp(hop, parent);
+ private void processStack(DMLProgram prog, AWTreeNode parent,
Set<String> fStack) {
+
+ while(!stack.isEmpty()) {
+ Hop hop = stack.pop();
+
+ // map statement block propagation to hop propagation
+ if(HopRewriteUtils.isData(hop, OpOpData.PERSISTENTREAD,
OpOpData.TRANSIENTREAD) &&
+ transientCompressed.containsKey(hop.getName()))
{
+ compressed.add(hop.getHopID());
+ treeLookup.put(hop.getHopID(),
treeLookup.get(transientCompressed.get(hop.getName())));
+ }
+ else {
+
+ // collect operations on compressed
intermediates or inputs
+ // if any input is compressed we collect this
hop as a compressed operation
+ if(hop.getInput().stream().anyMatch(h ->
compressed.contains(h.getHopID())))
+ createOp(hop, parent);
+
+ }
+ }
- visited.add(hop);
}
private void createOp(Hop hop, AWTreeNode parent) {
@@ -369,11 +399,16 @@ public class WorkloadAnalyzer {
o = new OpNormal(hop, false);
}
}
- else if(hop instanceof UnaryOp &&
- !HopRewriteUtils.isUnary(hop, OpOp1.MULT2,
OpOp1.MINUS1_MULT, OpOp1.MINUS_RIGHT, OpOp1.CAST_AS_MATRIX)) {
- if(isOverlapping(hop.getInput(0))) {
-
treeLookup.get(hop.getInput(0).getHopID()).setDecompressing();
- return;
+ else if(hop instanceof UnaryOp) {
+ if(!HopRewriteUtils.isUnary(hop, OpOp1.MULT2,
OpOp1.MINUS1_MULT, OpOp1.MINUS_RIGHT, OpOp1.CAST_AS_MATRIX)) {
+ if(isOverlapping(hop.getInput(0))) {
+
treeLookup.get(hop.getInput(0).getHopID()).setDecompressing();
+ return;
+ }
+
+ }
+ else if(HopRewriteUtils.isUnary(hop,
OpOp1.DETECTSCHEMA)) {
+ o = new OpNormal(hop, false);
}
}
else if(hop instanceof AggBinaryOp) {
@@ -411,6 +446,9 @@ public class WorkloadAnalyzer {
setDecompressionOnAllInputs(hop,
parent);
return;
}
+ else if(HopRewriteUtils.isBinary(hop,
OpOp2.APPLY_SCHEMA)) {
+ o = new OpNormal(hop, true);
+ }
else {
List<Hop> in = hop.getInput();
final boolean ol0 =
isOverlapping(in.get(0));
@@ -461,22 +499,11 @@ public class WorkloadAnalyzer {
}
else if(hop instanceof IndexingOp) {
- IndexingOp idx = (IndexingOp) hop;
final boolean isOverlapping =
isOverlapping(hop.getInput(0));
- final boolean fullColumn =
HopRewriteUtils.isFullColumnIndexing(idx);
-
- if(fullColumn) {
- o = new OpNormal(hop,
RewriteCompressedReblock.satisfiesSizeConstraintsForCompression(hop));
- if(isOverlapping) {
- overlapping.add(hop.getHopID());
- o.setOverlapping();
- }
- }
- else {
- // This decompression is a little
different, since it does not decompress the entire matrix
- // but only a sub part. therefore
create a new op node and set it to decompressing.
- o = new OpNormal(hop, false);
- o.setDecompressing();
+ o = new OpNormal(hop, true);
+ if(isOverlapping) {
+ overlapping.add(hop.getHopID());
+ o.setOverlapping();
}
}
else if(HopRewriteUtils.isTernary(hop,
OpOp3.MINUS_MULT, OpOp3.PLUS_MULT, OpOp3.QUANTILE, OpOp3.CTABLE)) {
@@ -505,12 +532,26 @@ public class WorkloadAnalyzer {
setDecompressionOnAllInputs(hop,
parent);
}
}
- else if(hop instanceof ParameterizedBuiltinOp || hop
instanceof NaryOp) {
+ else if(hop instanceof ParameterizedBuiltinOp) {
+
if(HopRewriteUtils.isParameterizedBuiltinOp(hop, ParamBuiltinOp.REPLACE,
ParamBuiltinOp.TRANSFORMAPPLY)) {
+ o = new OpNormal(hop, true);
+ }
+ else {
+ LOG.warn("Unknown
ParameterizedBuiltinOp Hop:" + hop.getClass().getSimpleName() + "\n" +
Explain.explain(hop));
+ setDecompressionOnAllInputs(hop,
parent);
+ return;
+ }
+ }
+ else if(hop instanceof NaryOp) {
+ setDecompressionOnAllInputs(hop, parent);
+ return;
+ }
+ else if(hop instanceof ReorgOp){
setDecompressionOnAllInputs(hop, parent);
return;
}
else {
- LOG.warn("Unknown Hop:" +
hop.getClass().getSimpleName() + "\n" + Explain.explain(hop));
+ LOG.warn("Unknown Matrix Hop:" +
hop.getClass().getSimpleName() + "\n" + Explain.explain(hop));
setDecompressionOnAllInputs(hop, parent);
return;
}
@@ -522,7 +563,62 @@ public class WorkloadAnalyzer {
if(o.isCompressedOutput())
compressed.add(hop.getHopID());
}
+ else if(hop.getDataType().isFrame()) {
+ Op o = null;
+ if(HopRewriteUtils.isData(hop, OpOpData.PERSISTENTREAD,
OpOpData.TRANSIENTREAD))
+ return;
+ else if(HopRewriteUtils.isData(hop,
OpOpData.TRANSIENTWRITE, OpOpData.PERSISTENTWRITE)) {
+ transientCompressed.put(hop.getName(),
hop.getInput(0).getHopID());
+ compressed.add(hop.getHopID());
+ o = new OpMetadata(hop, hop.getInput(0));
+ if(isOverlapping(hop.getInput(0)))
+ o.setOverlapping();
+ }
+ else if(HopRewriteUtils.isUnary(hop,
OpOp1.DETECTSCHEMA)) {
+ o = new OpNormal(hop, false);
+ }
+ else if(HopRewriteUtils.isBinary(hop,
OpOp2.APPLY_SCHEMA)) {
+ o = new OpNormal(hop, true);
+ }
+ else if(hop instanceof AggUnaryOp) {
+ o = new OpNormal(hop, false);
+ }
+ else {
+ LOG.warn("Unknown Frame Hop:" +
hop.getClass().getSimpleName() + "\n" + Explain.explain(hop));
+ setDecompressionOnAllInputs(hop, parent);
+ return;
+ }
+
+ o = o != null ? o : new OpNormal(hop,
RewriteCompressedReblock.satisfiesSizeConstraintsForCompression(hop));
+ treeLookup.put(hop.getHopID(), o);
+ parent.addOp(o);
+ if(o.isCompressedOutput())
+ compressed.add(hop.getHopID());
+ }
+ else if(HopRewriteUtils.isTransformEncode(hop)) {
+ Hop matrix = ((FunctionOp) hop).getOutputs().get(0);
+ compressed.add(matrix.getHopID());
+ transientCompressed.put(matrix.getName(),
matrix.getHopID());
+ parent.addOp(new OpNormal(hop, true));
+ }
+ else if(hop instanceof FunctionOp && ((FunctionOp)
hop).getFunctionNamespace().equals(".builtinNS")) {
+ parent.addOp(new OpNormal(hop, false));
+ }
+ else if(hop instanceof AggUnaryOp) {
+ if((isOverlapping(hop.getInput().get(0)) &&
!HopRewriteUtils.isAggUnaryOp(hop, AggOp.SUM, AggOp.MEAN)) ||
+ HopRewriteUtils.isAggUnaryOp(hop,
AggOp.TRACE)) {
+ setDecompressionOnAllInputs(hop,
parent);
+ return;
+ }
+ else {
+ Op o = new OpNormal(hop, false);
+ treeLookup.put(hop.getHopID(), o);
+ parent.addOp(o);
+ }
+ }
else {
+ LOG.warn(
+ "Unknown Matrix or Frame Hop:" +
hop.getClass().getSimpleName() + "\n" + hop.getDataType() + "\n" +
Explain.explain(hop));
parent.addOp(new OpNormal(hop, false));
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java
index 2e87d9eb65..7d0d9f7870 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java
@@ -128,7 +128,10 @@ public class CompressionCPInstruction extends
ComputationCPInstruction {
// Get and clear workload tree entry for this compression
instruction.
final WTreeRoot root = (_singletonLookupID != 0) ? (WTreeRoot)
m.get(_singletonLookupID) : null;
- m.removeKey(_singletonLookupID);
+ // We used to remove the key from the hash map,
+ // however this is not correct since the compression statement
+ // can be reused in multiple for loops.
+
if(ec.isFrameObject(input1.getName()))
processFrameBlockCompression(ec,
ec.getFrameInput(input1.getName()), _numThreads, root);
@@ -144,6 +147,8 @@ public class CompressionCPInstruction extends
ComputationCPInstruction {
if(LOG.isTraceEnabled())
LOG.trace(compResult.getRight());
MatrixBlock out = compResult.getLeft();
+ if(LOG.isInfoEnabled())
+ LOG.info("Compression output class: " +
out.getClass().getSimpleName());
// Set output and release input
ec.releaseMatrixInput(input1.getName());
ec.setMatrixOutput(output.getName(), out);
diff --git
a/src/test/java/org/apache/sysds/test/component/compress/workload/WorkloadTest.java
b/src/test/java/org/apache/sysds/test/component/compress/workload/WorkloadTest.java
index cf084e9ccc..5184556817 100644
---
a/src/test/java/org/apache/sysds/test/component/compress/workload/WorkloadTest.java
+++
b/src/test/java/org/apache/sysds/test/component/compress/workload/WorkloadTest.java
@@ -86,7 +86,7 @@ public class WorkloadTest {
// Simple tests no loops verifying basic behavior
tests.add(new Object[] {0, 0, 0, 0, 0, 0, 1, 0, false, false,
"sum.dml", args});
tests.add(new Object[] {0, 0, 0, 0, 0, 0, 1, 0, false, false,
"mean.dml", args});
- tests.add(new Object[] {0, 0, 0, 0, 0, 0, 1, 1, false, false,
"plus.dml", args});
+ tests.add(new Object[] {0, 0, 0, 0, 0, 0, 2, 1, false, false,
"plus.dml", args});
tests.add(new Object[] {0, 0, 0, 0, 0, 0, 2, 0, false, false,
"sliceCols.dml", args});
tests.add(new Object[] {0, 0, 0, 0, 0, 0, 2, 0, false, false,
"sliceIndex.dml", args});
// tests.add(new Object[] {0, 0, 0, 1, 0, 0, 0, 0, false,
false, "leftMult.dml", args});
@@ -105,9 +105,9 @@ public class WorkloadTest {
// Builtins:
// nr 11:
tests.add(new Object[] {0, 0, 0, 0, 0, 0, 6, 0, true, false,
"functions/scale.dml", args});
- tests.add(new Object[] {0, 0, 0, 0, 0, 0, 5, 0, true, true,
"functions/scale.dml", args});
+ tests.add(new Object[] {0, 0, 0, 0, 0, 0, 4, 0, false, true,
"functions/scale.dml", args});
tests.add(new Object[] {0, 0, 0, 0, 0, 0, 6, 0, true, false,
"functions/scale_continued.dml", args});
- tests.add(new Object[] {0, 0, 0, 0, 0, 0, 6, 0, true, true,
"functions/scale_continued.dml", args});
+ tests.add(new Object[] {0, 0, 0, 0, 0, 0, 5, 0, true, true,
"functions/scale_continued.dml", args});
tests.add(new Object[] {0, 0, 0, 0, 0, 0, 2, 0, false, true,
"functions/scale_onlySide.dml", args});
tests.add(new Object[] {0, 0, 0, 0, 0, 0, 6, 0, true, false,
"functions/scale_onlySide.dml", args});