This is an automated email from the ASF dual-hosted git repository.
arnabp20 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 5596fcf [SYSTEMDS-2650] Non-recursive construction of HOPs from
Lineage
5596fcf is described below
commit 5596fcf0d77946cf11ff34085e8879552dd852be
Author: arnabp <[email protected]>
AuthorDate: Sat Sep 5 22:37:22 2020 +0200
[SYSTEMDS-2650] Non-recursive construction of HOPs from Lineage
This patch implements a non-recursive version of HOP dag construction
from lineage dag, which fixes the stack overflow while re-computing
from lineage.
---
.../runtime/lineage/LineageRecomputeUtils.java | 325 ++++++++++++++++++++-
.../functions/lineage/LineageTraceDedupTest.java | 7 +-
2 files changed, 314 insertions(+), 18 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/lineage/LineageRecomputeUtils.java
b/src/main/java/org/apache/sysds/runtime/lineage/LineageRecomputeUtils.java
index fffc2dc..0df1651 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageRecomputeUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageRecomputeUtils.java
@@ -25,8 +25,10 @@ import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
+import java.util.Stack;
import java.util.stream.Collectors;
+import org.apache.commons.lang3.mutable.MutableInt;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.OpOp1;
@@ -100,7 +102,7 @@ public class LineageRecomputeUtils {
root.resetVisitStatusNR();
Map<Long, Hop> operands = new HashMap<>();
Map<String, Hop> partDagRoots = new HashMap<>();
- rConstructHops(root, operands, partDagRoots, prog);
+ constructHopsNR(root, operands, partDagRoots, prog);
Hop out = HopRewriteUtils.createTransientWrite(
varname, operands.get(rootId));
@@ -134,17 +136,38 @@ public class LineageRecomputeUtils {
prog.addProgramBlock(pb);
}
-
- private static void rConstructHops(LineageItem item, Map<Long, Hop>
operands, Map<String, Hop> partDagRoots, Program prog)
+ private static void constructHopsNR(LineageItem item, Map<Long, Hop>
operands, Map<String, Hop> partDagRoots, Program prog)
+ {
+ //NOTE: This method follows the same non-recursive
+ //skeleton as explainLineageItemNR
+ Stack<LineageItem> stackItem = new Stack<>();
+ Stack<MutableInt> stackPos = new Stack<>();
+ stackItem.push(item); stackPos.push(new MutableInt(0));
+ while (!stackItem.empty()) {
+ LineageItem tmpItem = stackItem.peek();
+ MutableInt tmpPos = stackPos.peek();
+ //check ascent condition - no item processing
+ if (tmpItem.isVisited()) {
+ stackItem.pop(); stackPos.pop();
+ }
+ //check ascent condition - append item
+ else if( tmpItem.getInputs() == null
+ || tmpItem.getInputs().length <=
tmpPos.intValue() ) {
+ constructSingleHop(tmpItem, operands,
partDagRoots, prog);
+ stackItem.pop(); stackPos.pop();
+ tmpItem.setVisited();
+ }
+ //check descent condition
+ else if( tmpItem.getInputs() != null ) {
+
stackItem.push(tmpItem.getInputs()[tmpPos.intValue()]);
+ tmpPos.increment();
+ stackPos.push(new MutableInt(0));
+ }
+ }
+ }
+
+ private static void constructSingleHop(LineageItem item, Map<Long, Hop>
operands, Map<String, Hop> partDagRoots, Program prog)
{
- if (item.isVisited())
- return;
-
- //recursively process children (ordering by data dependencies)
- if (!item.isLeaf())
- for (LineageItem c : item.getInputs())
- rConstructHops(c, operands, partDagRoots, prog);
-
//process current lineage item
//NOTE: we generate instructions from hops (but without
rewrites) to automatically
//handle execution types, rmvar instructions, and rewiring of
inputs/outputs
@@ -406,8 +429,6 @@ public class LineageRecomputeUtils {
break;
}
}
-
- item.setVisited();
}
// Construct and compile the function body
@@ -428,7 +449,7 @@ public class LineageRecomputeUtils {
for (int i=0; i<inputs.length; i++)
operands.put((long)i,
HopRewriteUtils.createTransientRead(inputs[i], inpHops.get(i))); //order
preserving
// Construct the Hop dag.
- rConstructHops(patchRoot, operands, null, null);
+ constructHopsNR(patchRoot, operands, null, null);
// TWrite the func return (pass dag root to copy datatype)
Hop out = HopRewriteUtils.createTransientWrite(outname,
operands.get(patchRoot.getId()));
// Save the Hop dag
@@ -518,6 +539,282 @@ public class LineageRecomputeUtils {
throw new DMLRuntimeException("Unsupported opcode:
"+item.getOpcode());
}
+ @Deprecated
+ @SuppressWarnings("unused")
+ private static void rConstructHops(LineageItem item, Map<Long, Hop>
operands, Map<String, Hop> partDagRoots, Program prog)
+ {
+ if (item.isVisited())
+ return;
+
+ //recursively process children (ordering by data dependencies)
+ if (!item.isLeaf())
+ for (LineageItem c : item.getInputs())
+ rConstructHops(c, operands, partDagRoots, prog);
+
+ //process current lineage item
+ //NOTE: we generate instructions from hops (but without
rewrites) to automatically
+ //handle execution types, rmvar instructions, and rewiring of
inputs/outputs
+ switch (item.getType()) {
+ case Creation: {
+ if (item.getData().startsWith(LPLACEHOLDER)) {
+ long phId =
Long.parseLong(item.getData().substring(3));
+ Hop input = operands.get(phId);
+ operands.remove(phId);
+ // Replace the placeholders with TReads
+ operands.put(item.getId(), input); //
order preserving
+ break;
+ }
+ Instruction inst =
InstructionParser.parseSingleInstruction(item.getData());
+
+ if (inst instanceof DataGenCPInstruction) {
+ DataGenCPInstruction rand =
(DataGenCPInstruction) inst;
+ HashMap<String, Hop> params = new
HashMap<>();
+ if( rand.getOpcode().equals("rand") ) {
+ if( rand.output.getDataType()
== DataType.TENSOR)
+
params.put(DataExpression.RAND_DIMS, new LiteralOp(rand.getDims()));
+ else {
+
params.put(DataExpression.RAND_ROWS, new LiteralOp(rand.getRows()));
+
params.put(DataExpression.RAND_COLS, new LiteralOp(rand.getCols()));
+ }
+
params.put(DataExpression.RAND_MIN, new LiteralOp(rand.getMinValue()));
+
params.put(DataExpression.RAND_MAX, new LiteralOp(rand.getMaxValue()));
+
params.put(DataExpression.RAND_PDF, new LiteralOp(rand.getPdf()));
+
params.put(DataExpression.RAND_LAMBDA, new LiteralOp(rand.getPdfParams()));
+
params.put(DataExpression.RAND_SPARSITY, new LiteralOp(rand.getSparsity()));
+
params.put(DataExpression.RAND_SEED, new LiteralOp(rand.getSeed()));
+ }
+ else if( rand.getOpcode().equals("seq")
) {
+ params.put(Statement.SEQ_FROM,
new LiteralOp(rand.getFrom()));
+ params.put(Statement.SEQ_TO,
new LiteralOp(rand.getTo()));
+ params.put(Statement.SEQ_INCR,
new LiteralOp(rand.getIncr()));
+ }
+ Hop datagen = new
DataGenOp(OpOpDG.valueOf(rand.getOpcode().toUpperCase()),
+ new DataIdentifier("tmp"),
params);
+
datagen.setBlocksize(rand.getBlocksize());
+ operands.put(item.getId(), datagen);
+ } else if (inst instanceof VariableCPInstruction
+ && ((VariableCPInstruction)
inst).isCreateVariable()) {
+ String parts[] =
InstructionUtils.getInstructionPartsWithValueType(inst.toString());
+ DataType dt =
DataType.valueOf(parts[4]);
+ ValueType vt = dt == DataType.MATRIX ?
ValueType.FP64 : ValueType.STRING;
+ HashMap<String, Hop> params = new
HashMap<>();
+ params.put(DataExpression.IO_FILENAME,
new LiteralOp(parts[2]));
+ params.put(DataExpression.READROWPARAM,
new LiteralOp(Long.parseLong(parts[6])));
+ params.put(DataExpression.READCOLPARAM,
new LiteralOp(Long.parseLong(parts[7])));
+ params.put(DataExpression.READNNZPARAM,
new LiteralOp(Long.parseLong(parts[8])));
+ params.put(DataExpression.FORMAT_TYPE,
new LiteralOp(parts[5]));
+ DataOp pread = new
DataOp(parts[1].substring(5), dt, vt, OpOpData.PERSISTENTREAD, params);
+ pread.setFileName(parts[2]);
+ operands.put(item.getId(), pread);
+ }
+ else if (inst instanceof RandSPInstruction) {
+ RandSPInstruction rand =
(RandSPInstruction) inst;
+ HashMap<String, Hop> params = new
HashMap<>();
+ if (rand.output.getDataType() ==
DataType.TENSOR)
+
params.put(DataExpression.RAND_DIMS, new LiteralOp(rand.getDims()));
+ else {
+
params.put(DataExpression.RAND_ROWS, new LiteralOp(rand.getRows()));
+
params.put(DataExpression.RAND_COLS, new LiteralOp(rand.getCols()));
+ }
+ params.put(DataExpression.RAND_MIN, new
LiteralOp(rand.getMinValue()));
+ params.put(DataExpression.RAND_MAX, new
LiteralOp(rand.getMaxValue()));
+ params.put(DataExpression.RAND_PDF, new
LiteralOp(rand.getPdf()));
+ params.put(DataExpression.RAND_LAMBDA,
new LiteralOp(rand.getPdfParams()));
+
params.put(DataExpression.RAND_SPARSITY, new LiteralOp(rand.getSparsity()));
+ params.put(DataExpression.RAND_SEED,
new LiteralOp(rand.getSeed()));
+ Hop datagen = new
DataGenOp(OpOpDG.RAND, new DataIdentifier("tmp"), params);
+
datagen.setBlocksize(rand.getBlocksize());
+ operands.put(item.getId(), datagen);
+ }
+ break;
+ }
+ case Dedup: {
+ // Create function call for each dedup entry
+ String[] parts =
item.getOpcode().split(LineageDedupUtils.DEDUP_DELIM); //e.g. dedup_R_SB13_0
+ String name = parts[2] + parts[1] + parts[3];
//loopId + outVar + pathId
+ List<Hop> finputs =
Arrays.stream(item.getInputs())
+ .map(inp ->
operands.get(inp.getId())).collect(Collectors.toList());
+ String[] inputNames = new
String[item.getInputs().length];
+ for (int i=0; i<item.getInputs().length; i++)
+ inputNames[i] = LPLACEHOLDER + i;
//e.g. IN#0, IN#1
+ Hop funcOp = new FunctionOp(FunctionType.DML,
DMLProgram.DEFAULT_NAMESPACE,
+ name, inputNames, finputs, new
String[] {parts[1]}, false);
+
+ // Cut the Hop dag after function calls
+ partDagRoots.put(parts[1], funcOp);
+ // Compile the dag and save
+ constructBasicBlock(partDagRoots, parts[1],
prog);
+
+ // Construct a Hop dag for the function body
from the dedup patch, and compile
+ Hop output = constructHopsDedupPatch(parts,
inputNames, finputs, prog);
+ // Create a TRead on the function o/p as a leaf
for the next Hop dag
+ // Use the function body root/return hop to
propagate right data type
+ operands.put(item.getId(),
HopRewriteUtils.createTransientRead(parts[1], output));
+ break;
+ }
+ case Instruction: {
+ CPType ctype =
InstructionUtils.getCPTypeByOpcode(item.getOpcode());
+ SPType stype =
InstructionUtils.getSPTypeByOpcode(item.getOpcode());
+
+ if (ctype != null) {
+ switch (ctype) {
+ case AggregateUnary: {
+ Hop input =
operands.get(item.getInputs()[0].getId());
+ Hop aggunary =
InstructionUtils.isUnaryMetadata(item.getOpcode()) ?
+
HopRewriteUtils.createUnary(input, OpOp1.valueOfByOpcode(item.getOpcode())) :
+
HopRewriteUtils.createAggUnaryOp(input, item.getOpcode());
+
operands.put(item.getId(), aggunary);
+ break;
+ }
+ case AggregateBinary: {
+ Hop input1 =
operands.get(item.getInputs()[0].getId());
+ Hop input2 =
operands.get(item.getInputs()[1].getId());
+ Hop aggbinary =
HopRewriteUtils.createMatrixMultiply(input1, input2);
+
operands.put(item.getId(), aggbinary);
+ break;
+ }
+ case AggregateTernary: {
+ Hop input1 =
operands.get(item.getInputs()[0].getId());
+ Hop input2 =
operands.get(item.getInputs()[1].getId());
+ Hop input3 =
operands.get(item.getInputs()[2].getId());
+ Hop aggternary =
HopRewriteUtils.createSum(
+
HopRewriteUtils.createBinary(
+
HopRewriteUtils.createBinary(input1, input2, OpOp2.MULT),
+ input3,
OpOp2.MULT));
+
operands.put(item.getId(), aggternary);
+ break;
+ }
+ case Unary:
+ case Builtin: {
+ Hop input =
operands.get(item.getInputs()[0].getId());
+ Hop unary =
HopRewriteUtils.createUnary(input, item.getOpcode());
+
operands.put(item.getId(), unary);
+ break;
+ }
+ case Reorg: {
+
operands.put(item.getId(), HopRewriteUtils.createReorg(
+
operands.get(item.getInputs()[0].getId()), item.getOpcode()));
+ break;
+ }
+ case Reshape: {
+ ArrayList<Hop> inputs =
new ArrayList<>();
+ for(int i=0; i<5; i++)
+
inputs.add(operands.get(item.getInputs()[i].getId()));
+
operands.put(item.getId(), HopRewriteUtils.createReorg(inputs,
ReOrgOp.RESHAPE));
+ break;
+ }
+ case Binary: {
+ //handle special cases
of binary operations
+ String opcode =
("^2".equals(item.getOpcode())
+ ||
"*2".equals(item.getOpcode())) ?
+
item.getOpcode().substring(0, 1) : item.getOpcode();
+ Hop input1 =
operands.get(item.getInputs()[0].getId());
+ Hop input2 =
operands.get(item.getInputs()[1].getId());
+ Hop binary =
HopRewriteUtils.createBinary(input1, input2, opcode);
+
operands.put(item.getId(), binary);
+ break;
+ }
+ case Ternary: {
+
operands.put(item.getId(), HopRewriteUtils.createTernary(
+
operands.get(item.getInputs()[0].getId()),
+
operands.get(item.getInputs()[1].getId()),
+
operands.get(item.getInputs()[2].getId()), item.getOpcode()));
+ break;
+ }
+ case Ctable: { //e.g., ctable
+ if(
item.getInputs().length==3 )
+
operands.put(item.getId(), HopRewriteUtils.createTernary(
+
operands.get(item.getInputs()[0].getId()),
+
operands.get(item.getInputs()[1].getId()),
+
operands.get(item.getInputs()[2].getId()), OpOp3.CTABLE));
+ else if(
item.getInputs().length==5 )
+
operands.put(item.getId(), HopRewriteUtils.createTernary(
+
operands.get(item.getInputs()[0].getId()),
+
operands.get(item.getInputs()[1].getId()),
+
operands.get(item.getInputs()[2].getId()),
+
operands.get(item.getInputs()[3].getId()),
+
operands.get(item.getInputs()[4].getId()), OpOp3.CTABLE));
+ break;
+ }
+ case BuiltinNary: {
+ String opcode =
item.getOpcode().equals("n+") ? "plus" : item.getOpcode();
+
operands.put(item.getId(), HopRewriteUtils.createNary(
+
OpOpN.valueOf(opcode.toUpperCase()), createNaryInputs(item, operands)));
+ break;
+ }
+ case ParameterizedBuiltin: {
+
operands.put(item.getId(), constructParameterizedBuiltinOp(item, operands));
+ break;
+ }
+ case MatrixIndexing: {
+
operands.put(item.getId(), constructIndexingOp(item, operands));
+ break;
+ }
+ case MMTSJ: {
+ //TODO handling of tsmm
type left and right -> placement transpose
+ Hop input =
operands.get(item.getInputs()[0].getId());
+ Hop aggunary =
HopRewriteUtils.createMatrixMultiply(
+
HopRewriteUtils.createTranspose(input), input);
+
operands.put(item.getId(), aggunary);
+ break;
+ }
+ case Variable: {
+ if(
item.getOpcode().startsWith("cast") )
+
operands.put(item.getId(), HopRewriteUtils.createUnary(
+
operands.get(item.getInputs()[0].getId()),
+
OpOp1.valueOfByOpcode(item.getOpcode())));
+ else //cpvar, write
+
operands.put(item.getId(), operands.get(item.getInputs()[0].getId()));
+ break;
+ }
+ default:
+ throw new
DMLRuntimeException("Unsupported instruction "
+ + "type: " +
ctype.name() + " (" + item.getOpcode() + ").");
+ }
+ }
+ else if( stype != null ) {
+ switch(stype) {
+ case Reblock: {
+ Hop input =
operands.get(item.getInputs()[0].getId());
+
input.setBlocksize(ConfigurationManager.getBlocksize());
+
input.setRequiresReblock(true);
+
operands.put(item.getId(), input);
+ break;
+ }
+ case Checkpoint: {
+ Hop input =
operands.get(item.getInputs()[0].getId());
+
operands.put(item.getId(), input);
+ break;
+ }
+ case MatrixIndexing: {
+
operands.put(item.getId(), constructIndexingOp(item, operands));
+ break;
+ }
+ case GAppend: {
+
operands.put(item.getId(), HopRewriteUtils.createBinary(
+
operands.get(item.getInputs()[0].getId()),
+
operands.get(item.getInputs()[1].getId()), OpOp2.CBIND));
+ break;
+ }
+ default:
+ throw new
DMLRuntimeException("Unsupported instruction "
+ + "type: " +
stype.name() + " (" + item.getOpcode() + ").");
+ }
+ }
+ else
+ throw new
DMLRuntimeException("Unsupported instruction: " + item.getOpcode());
+ break;
+ }
+ case Literal: {
+ CPOperand op = new CPOperand(item.getData());
+ operands.put(item.getId(), ScalarObjectFactory
+ .createLiteralOp(op.getValueType(),
op.getName()));
+ break;
+ }
+ }
+
+ item.setVisited();
+ }
// Below class represents a single loop and contains related data
// that are needed for recomputation.
diff --git
a/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceDedupTest.java
b/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceDedupTest.java
index 18da399..3b1ae65 100644
---
a/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceDedupTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceDedupTest.java
@@ -95,12 +95,11 @@ public class LineageTraceDedupTest extends AutomatedTestBase
testLineageTrace(TEST_NAME5);
}
- /*@Test
+ @Test
public void testLineageTrace6() {
testLineageTrace(TEST_NAME6);
- }*/
- //FIXME: stack overflow only when ran the full package
-
+ }
+
@Test
public void testLineageTrace7() {
testLineageTrace(TEST_NAME7);