This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/systemml.git
The following commit(s) were added to refs/heads/master by this push: new f29ae42 [SYSTEMDS-339] Fix robustness lineage tracing/parsing, part II f29ae42 is described below commit f29ae426be1722fba9468609976709068e6e5d7d Author: Matthias Boehm <mboe...@gmail.com> AuthorDate: Fri May 22 22:38:54 2020 +0200 [SYSTEMDS-339] Fix robustness lineage tracing/parsing, part II This patch fixes many additional issues in lineage tracing and parsing in order to support the round-trip for steplm and kmeans. 1) Lineage tracing with default arguments of function call parameters (so far missing arguments where traces as literal variable name) 2) Lineage Tracing: rshape with parameters, ctable w/ dimensions, rand/seq w/ variable rows/cols, from/to/incr inputs 3) Lineage Parsing: rshape, rdiag, nrow, ncol, all casts ops, ifelse with scalar/matrix inputs (so far block size wrong), ctable /w dimensions, gappend spark ops 4) New lineage parfor algorithm tests: steplm, kmeans --- scripts/builtin/kmeans.dml | 6 +-- src/main/java/org/apache/sysds/common/Types.java | 35 ++++++++++--- .../apache/sysds/hops/recompile/Recompiler.java | 8 +-- .../apache/sysds/hops/rewrite/HopRewriteUtils.java | 37 ++++++++------ .../RewriteAlgebraicSimplificationDynamic.java | 6 +-- .../runtime/instructions/InstructionUtils.java | 41 ++------------- .../instructions/cp/CtableCPInstruction.java | 27 ++++++---- .../instructions/cp/DataGenCPInstruction.java | 56 ++++++++++++++------ .../instructions/cp/FunctionCallCPInstruction.java | 7 ++- .../instructions/cp/ReshapeCPInstruction.java | 9 ++++ .../instructions/spark/RandSPInstruction.java | 10 +++- .../sysds/runtime/lineage/LineageItemUtils.java | 59 +++++++++++++++++----- .../functions/lineage/LineageTraceParforTest.java | 34 ++++++++----- ...aceParfor4.dml => LineageTraceParforKmeans.dml} | 3 +- ...aceParfor4.dml => LineageTraceParforSteplm.dml} | 0 15 files changed, 211 insertions(+), 127 deletions(-) diff --git a/scripts/builtin/kmeans.dml b/scripts/builtin/kmeans.dml index 23482da..96591c6 100644 --- a/scripts/builtin/kmeans.dml +++ b/scripts/builtin/kmeans.dml @@ -60,8 +60,8 @@ m_kmeans = function(Matrix[Double] X, Integer k = 0, Integer runs = 10, Integer print ("Taking data samples for initialization..."); - [sample_maps, samples_vs_runs_map, sample_block_size] = - get_sample_maps (num_records, num_runs, num_centroids * avg_sample_size_per_centroid); + [sample_maps, samples_vs_runs_map, sample_block_size] = get_sample_maps( + num_records, num_runs, num_centroids * avg_sample_size_per_centroid); is_row_in_samples = rowSums (sample_maps); X_samples = sample_maps %*% X; @@ -230,7 +230,7 @@ get_sample_maps = function (int num_records, int num_samples, int approx_sample_ # Replace all sample record ids over "num_records" (i.e. out of range) by "num_records + 1": is_sample_rec_id_within_range = (sample_rec_ids <= num_records); sample_rec_ids = sample_rec_ids * is_sample_rec_id_within_range - + (num_records + 1) * (1 - is_sample_rec_id_within_range); + + (num_records + 1) * (1 - is_sample_rec_id_within_range); # Rearrange all samples (and their out-of-range indicators) into one column-vector: sample_rec_ids = matrix (sample_rec_ids, rows = num_rows, cols = 1, byrow = FALSE); diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java index d693b7f..2d66e81 100644 --- a/src/main/java/org/apache/sysds/common/Types.java +++ b/src/main/java/org/apache/sysds/common/Types.java @@ -206,6 +206,15 @@ public class Types MULT2, MINUS1_MULT, MINUS_RIGHT, POW2, SUBTRACT_NZ; + + public boolean isScalarOutput() { + return this == CAST_AS_SCALAR + || this == NROW || this == NCOL + || this == LENGTH || this == EXISTS + || this == IQM || this == LINEAGE + || this == MEDIAN; + } + @Override public String toString() { switch(this) { @@ -244,7 +253,7 @@ public class Types case "ucumk+": return CUMSUM; case "ucumk+*": return CUMSUMPROD; case "*2": return MULT2; - case "!": return OpOp1.NOT; + case "!": return NOT; case "^2": return POW2; default: return valueOf(opcode.toUpperCase()); } @@ -354,12 +363,12 @@ public class Types } } - public static OpOp3 valueOfCode(String code) { - switch(code) { - case "cm": return OpOp3.MOMENT; - case "+*": return OpOp3.PLUS_MULT; - case "-*": return OpOp3.MINUS_MULT; - default: return OpOp3.valueOf(code.toUpperCase()); + public static OpOp3 valueOfByOpcode(String opcode) { + switch(opcode) { + case "cm": return MOMENT; + case "+*": return PLUS_MULT; + case "-*": return MINUS_MULT; + default: return valueOf(opcode.toUpperCase()); } } } @@ -394,11 +403,21 @@ public class Types @Override public String toString() { switch(this) { - case TRANS: return "t"; + case DIAG: return "rdiag"; + case TRANS: return "r'"; case RESHAPE: return "rshape"; default: return name().toLowerCase(); } } + + public static ReOrgOp valueOfByOpcode(String opcode) { + switch(opcode) { + case "rdiag": return DIAG; + case "r'": return TRANS; + case "rshape": return RESHAPE; + default: return valueOf(opcode.toUpperCase()); + } + } } public enum ParamBuiltinOp { diff --git a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java index 29e0c0e..4334384 100644 --- a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java +++ b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java @@ -1448,15 +1448,11 @@ public class Recompiler * @param scalarsOnly if true, replace only scalar variables but no matrix operations; * if false, apply full literal replacement */ - public static void rReplaceLiterals( Hop hop, ExecutionContext ec, boolean scalarsOnly ) - { - //public interface + public static void rReplaceLiterals( Hop hop, ExecutionContext ec, boolean scalarsOnly ) { LiteralReplacement.rReplaceLiterals(hop, ec, scalarsOnly); } - public static void rReplaceLiterals( Hop hop, LocalVariableMap vars, boolean scalarsOnly ) - { - //public interface + public static void rReplaceLiterals( Hop hop, LocalVariableMap vars, boolean scalarsOnly ) { LiteralReplacement.rReplaceLiterals(hop, new ExecutionContext(vars), scalarsOnly); } diff --git a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java index 9e73fcc..30f66f4 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java @@ -583,6 +583,10 @@ public class HopRewriteUtils return createReorg(input, ReOrgOp.TRANS); } + public static ReorgOp createReorg(Hop input, String rop) { + return createReorg(input, ReOrgOp.valueOfByOpcode(rop)); + } + public static ReorgOp createReorg(Hop input, ReOrgOp rop) { ReorgOp reorg = new ReorgOp(input.getName(), input.getDataType(), input.getValueType(), rop, input); reorg.setBlocksize(input.getBlocksize()); @@ -604,22 +608,19 @@ public class HopRewriteUtils return createUnary(input, OpOp1.valueOfByOpcode(type)); } - public static UnaryOp createUnary(Hop input, OpOp1 type) - { - DataType dt = (type==OpOp1.CAST_AS_SCALAR) ? DataType.SCALAR : + public static UnaryOp createUnary(Hop input, OpOp1 type) { + DataType dt = type.isScalarOutput() ? DataType.SCALAR : (type==OpOp1.CAST_AS_MATRIX) ? DataType.MATRIX : input.getDataType(); ValueType vt = (type==OpOp1.CAST_AS_MATRIX) ? ValueType.FP64 : input.getValueType(); UnaryOp unary = new UnaryOp(input.getName(), dt, vt, type, input); unary.setBlocksize(input.getBlocksize()); - if( type == OpOp1.CAST_AS_SCALAR || type == OpOp1.CAST_AS_MATRIX ) { - int dim = (type==OpOp1.CAST_AS_SCALAR) ? 0 : 1; + if( type.isScalarOutput() || type == OpOp1.CAST_AS_MATRIX ) { + int dim = type.isScalarOutput() ? 0 : 1; int blksz = (type==OpOp1.CAST_AS_SCALAR) ? 0 : ConfigurationManager.getBlocksize(); setOutputParameters(unary, dim, dim, blksz, -1); } - copyLineNumbers(input, unary); - unary.refreshSizeInformation(); - + unary.refreshSizeInformation(); return unary; } @@ -681,7 +682,6 @@ public class HopRewriteUtils mmult.setBlocksize(left.getBlocksize()); copyLineNumbers(left, mmult); mmult.refreshSizeInformation(); - return mmult; } @@ -690,7 +690,6 @@ public class HopRewriteUtils pbop.setBlocksize(input.getBlocksize()); copyLineNumbers(input, pbop); pbop.refreshSizeInformation(); - return pbop; } @@ -774,23 +773,29 @@ public class HopRewriteUtils return datagen; } - public static TernaryOp createTernaryOp(Hop mleft, Hop smid, Hop mright, String opcode) { - return createTernaryOp(mleft, smid, mright, OpOp3.valueOfCode(opcode)); + public static TernaryOp createTernary(Hop mleft, Hop smid, Hop mright, String opcode) { + return createTernary(mleft, smid, mright, OpOp3.valueOfByOpcode(opcode)); } - public static TernaryOp createTernaryOp(Hop mleft, Hop smid, Hop mright, OpOp3 op) { + public static TernaryOp createTernary(Hop mleft, Hop smid, Hop mright, OpOp3 op) { //NOTe: for ifelse it's sufficient to check mright as smid==mright - System.out.println(mleft.getDataType()+" "+smid.getDataType()+" "+mright.getDataType()); DataType dt = (op == OpOp3.IFELSE) ? mright.getDataType() : DataType.MATRIX; ValueType vt = (op == OpOp3.IFELSE) ? mright.getValueType() : ValueType.FP64; TernaryOp ternOp = new TernaryOp("tmp", dt, vt, op, mleft, smid, mright); - if( dt == DataType.MATRIX ) - ternOp.setBlocksize(mleft.getBlocksize()); + ternOp.setBlocksize(Math.max(mleft.getBlocksize(), mright.getBlocksize())); copyLineNumbers(mleft, ternOp); ternOp.refreshSizeInformation(); return ternOp; } + public static TernaryOp createTernary(Hop in1, Hop in2, Hop in3, Hop in4, Hop in5, OpOp3 op) { + TernaryOp ternOp = new TernaryOp("tmp", DataType.MATRIX, ValueType.FP64, op, in1, in2, in3, in4, in5); + ternOp.setBlocksize(Math.max(in1.getBlocksize(), in2.getBlocksize())); + copyLineNumbers(in1, ternOp); + ternOp.refreshSizeInformation(); + return ternOp; + } + public static Hop createComputeNnz(Hop input) { //nnz = sum(A != 0) -> later rewritten to meta-data operation return createSum(createBinary(input, new LiteralOp(0), OpOp2.NOTEQUAL)); diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java index e929e0a..1929315 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java @@ -2285,7 +2285,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule Hop smid = right.getInput().get( (right.getInput().get(0).getDataType()==DataType.SCALAR) ? 0 : 1); Hop mright = right.getInput().get( (right.getInput().get(0).getDataType()==DataType.SCALAR) ? 1 : 0); ternop = (smid instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)smid)==0) ? - left : HopRewriteUtils.createTernaryOp(left, smid, mright, OpOp3.PLUS_MULT); + left : HopRewriteUtils.createTernary(left, smid, mright, OpOp3.PLUS_MULT); LOG.debug("Applied fuseAxpyBinaryOperationChain1. (line " +hi.getBeginLine()+")"); } //pattern (b) s*Y + X -> X +* sY @@ -2297,7 +2297,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule Hop smid = left.getInput().get( (left.getInput().get(0).getDataType()==DataType.SCALAR) ? 0 : 1); Hop mright = left.getInput().get( (left.getInput().get(0).getDataType()==DataType.SCALAR) ? 1 : 0); ternop = (smid instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)smid)==0) ? - right : HopRewriteUtils.createTernaryOp(right, smid, mright, OpOp3.PLUS_MULT); + right : HopRewriteUtils.createTernary(right, smid, mright, OpOp3.PLUS_MULT); LOG.debug("Applied fuseAxpyBinaryOperationChain2. (line " +hi.getBeginLine()+")"); } //pattern (c) X - s*Y -> X -* sY @@ -2309,7 +2309,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule Hop smid = right.getInput().get( (right.getInput().get(0).getDataType()==DataType.SCALAR) ? 0 : 1); Hop mright = right.getInput().get( (right.getInput().get(0).getDataType()==DataType.SCALAR) ? 1 : 0); ternop = (smid instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)smid)==0) ? - left : HopRewriteUtils.createTernaryOp(left, smid, mright, OpOp3.MINUS_MULT); + left : HopRewriteUtils.createTernary(left, smid, mright, OpOp3.MINUS_MULT); LOG.debug("Applied fuseAxpyBinaryOperationChain3. (line " +hi.getBeginLine()+")"); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java index f1c8dc6..740d821 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java @@ -24,14 +24,7 @@ import java.util.StringTokenizer; import org.apache.sysds.common.Types.AggOp; import org.apache.sysds.common.Types.CorrectionLocationType; import org.apache.sysds.common.Types.Direction; -import org.apache.sysds.lops.AppendM; -import org.apache.sysds.lops.BinaryM; -import org.apache.sysds.lops.GroupedAggregateM; import org.apache.sysds.lops.Lop; -import org.apache.sysds.lops.MapMult; -import org.apache.sysds.lops.MapMultChain; -import org.apache.sysds.lops.PMMJ; -import org.apache.sysds.lops.UAggOuterChain; import org.apache.sysds.lops.WeightedCrossEntropy; import org.apache.sysds.lops.WeightedCrossEntropyR; import org.apache.sysds.lops.WeightedDivMM; @@ -239,36 +232,12 @@ public class InstructionUtils Builtin.BuiltinCode bfc = Builtin.String2BuiltinCode.get(opcode); return (bfc != null); } - - /** - * Evaluates if at least one instruction of the given instruction set - * used the distributed cache; this call can also be used for individual - * instructions. - * - * @param str instruction set - * @return true if at least one instruction uses distributed cache - */ - public static boolean isDistributedCacheUsed(String str) - { - String[] parts = str.split(Instruction.INSTRUCTION_DELIM); - for(String inst : parts) - { - String opcode = getOpCode(inst); - if( opcode.equalsIgnoreCase(AppendM.OPCODE) - || opcode.equalsIgnoreCase(MapMult.OPCODE) - || opcode.equalsIgnoreCase(MapMultChain.OPCODE) - || opcode.equalsIgnoreCase(PMMJ.OPCODE) - || opcode.equalsIgnoreCase(UAggOuterChain.OPCODE) - || opcode.equalsIgnoreCase(GroupedAggregateM.OPCODE) - || isDistQuaternaryOpcode( opcode ) //multiple quaternary opcodes - || BinaryM.isOpcode( opcode ) ) //multiple binary opcodes - { - return true; - } - } - return false; + + public static boolean isUnaryMetadata(String opcode) { + return opcode != null + && (opcode.equals("nrow") || opcode.equals("ncol")); } - + public static AggregateUnaryOperator parseBasicAggregateUnaryOperator(String opcode) { return parseBasicAggregateUnaryOperator(opcode, 1); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/CtableCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/CtableCPInstruction.java index 77625f4..4869e3e 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/CtableCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CtableCPInstruction.java @@ -20,22 +20,23 @@ package org.apache.sysds.runtime.instructions.cp; import org.apache.sysds.lops.Ctable; +import org.apache.commons.lang3.tuple.Pair; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.instructions.Instruction; import org.apache.sysds.runtime.instructions.InstructionUtils; +import org.apache.sysds.runtime.lineage.LineageItem; +import org.apache.sysds.runtime.lineage.LineageItemUtils; import org.apache.sysds.runtime.matrix.data.CTableMap; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.util.DataConverter; import org.apache.sysds.runtime.util.LongLongDoubleHashMap.EntryType; public class CtableCPInstruction extends ComputationCPInstruction { - private final String _outDim1; - private final String _outDim2; - private final boolean _dim1Literal; - private final boolean _dim2Literal; + private final CPOperand _outDim1; + private final CPOperand _outDim2; private final boolean _isExpand; private final boolean _ignoreZeros; @@ -43,10 +44,8 @@ public class CtableCPInstruction extends ComputationCPInstruction { String outputDim1, boolean dim1Literal, String outputDim2, boolean dim2Literal, boolean isExpand, boolean ignoreZeros, String opcode, String istr) { super(CPType.Ctable, null, in1, in2, in3, out, opcode, istr); - _outDim1 = outputDim1; - _dim1Literal = dim1Literal; - _outDim2 = outputDim2; - _dim2Literal = dim2Literal; + _outDim1 = new CPOperand(outputDim1, ValueType.FP64, DataType.SCALAR, dim1Literal); + _outDim2 = new CPOperand(outputDim2, ValueType.FP64, DataType.SCALAR, dim2Literal); _isExpand = isExpand; _ignoreZeros = ignoreZeros; } @@ -98,8 +97,8 @@ public class CtableCPInstruction extends ComputationCPInstruction { Ctable.OperationTypes ctableOp = findCtableOperation(); ctableOp = _isExpand ? Ctable.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT : ctableOp; - long outputDim1 = (_dim1Literal ? (long) Double.parseDouble(_outDim1) : (ec.getScalarInput(_outDim1, ValueType.FP64, false)).getLongValue()); - long outputDim2 = (_dim2Literal ? (long) Double.parseDouble(_outDim2) : (ec.getScalarInput(_outDim2, ValueType.FP64, false)).getLongValue()); + long outputDim1 = ec.getScalarInput(_outDim1).getLongValue(); + long outputDim2 = ec.getScalarInput(_outDim2).getLongValue(); boolean outputDimsKnown = (outputDim1 != -1 && outputDim2 != -1); if ( outputDimsKnown ) { @@ -178,4 +177,12 @@ public class CtableCPInstruction extends ComputationCPInstruction { ec.setMatrixOutput(output.getName(), resultBlock); } + + @Override + public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) { + LineageItem[] linputs = !(_outDim1.getName().equals("-1") && _outDim2.getName().equals("-1")) ? + LineageItemUtils.getLineage(ec, input1, input2, input3, _outDim1, _outDim2) : + LineageItemUtils.getLineage(ec, input1, input2, input3); + return Pair.of(output.getName(), new LineageItem(getOpcode(), linputs)); + } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java index 11f4e8e..4aa9660 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java @@ -130,14 +130,16 @@ public class DataGenCPInstruction extends UnaryCPInstruction { } public long getRows() { - return rows.isLiteral() ? Long.parseLong(rows.getName()) : -1; + return rows.isLiteral() ? UtilFunctions.parseToLong(rows.getName()) : -1; } public long getCols() { - return cols.isLiteral() ? Long.parseLong(cols.getName()) : -1; + return cols.isLiteral() ? UtilFunctions.parseToLong(cols.getName()) : -1; } - public String getDims() { return dims.getName(); } + public String getDims() { + return dims.getName(); + } public int getBlocksize() { return blocksize; @@ -172,15 +174,15 @@ public class DataGenCPInstruction extends UnaryCPInstruction { } public long getFrom() { - return seq_from.isLiteral() ? Long.parseLong(seq_from.getName()) : -1; + return seq_from.isLiteral() ? UtilFunctions.parseToLong(seq_from.getName()) : -1; } public long getTo() { - return seq_to.isLiteral() ? Long.parseLong(seq_to.getName()) : -1; + return seq_to.isLiteral() ? UtilFunctions.parseToLong(seq_to.getName()) : -1; } public long getIncr() { - return seq_incr.isLiteral() ? Long.parseLong(seq_incr.getName()) : -1; + return seq_incr.isLiteral() ? UtilFunctions.parseToLong(seq_incr.getName()) : -1; } public static DataGenCPInstruction parseInstruction(String str) @@ -385,16 +387,40 @@ public class DataGenCPInstruction extends UnaryCPInstruction { @Override public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) { String tmpInstStr = instString; - if (getSeed() == DataGenOp.UNSPECIFIED_SEED) { - //generate pseudo-random seed (because not specified) - if (runtimeSeed == null) - runtimeSeed = (minValue == maxValue && sparsity == 1) ? - DataGenOp.UNSPECIFIED_SEED : DataGenOp.generateRandomSeed(); - int position = (method == OpOpDG.RAND) ? SEED_POSITION_RAND : - (method == OpOpDG.SAMPLE) ? SEED_POSITION_SAMPLE : 0; - tmpInstStr = position != 0 ? InstructionUtils.replaceOperand( - tmpInstStr, position, String.valueOf(runtimeSeed)) : tmpInstStr; + + switch(method) { + case RAND: + case SAMPLE: { + if (getSeed() == DataGenOp.UNSPECIFIED_SEED) { + //generate pseudo-random seed (because not specified) + if (runtimeSeed == null) + runtimeSeed = (minValue == maxValue && sparsity == 1) ? + DataGenOp.UNSPECIFIED_SEED : DataGenOp.generateRandomSeed(); + int position = (method == OpOpDG.RAND) ? SEED_POSITION_RAND : + (method == OpOpDG.SAMPLE) ? SEED_POSITION_SAMPLE : 0; + tmpInstStr = position != 0 ? InstructionUtils.replaceOperand( + tmpInstStr, position, String.valueOf(runtimeSeed)) : tmpInstStr; + } + tmpInstStr = replaceNonLiteral(tmpInstStr, rows, 2, ec); + tmpInstStr = replaceNonLiteral(tmpInstStr, cols, 3, ec); + break; + } + case SEQ: { + tmpInstStr = replaceNonLiteral(tmpInstStr, seq_from, 5, ec); + tmpInstStr = replaceNonLiteral(tmpInstStr, seq_to, 6, ec); + tmpInstStr = replaceNonLiteral(tmpInstStr, seq_incr, 7, ec); + break; + } + default: + throw new DMLRuntimeException("Unsupported datagen op: "+method); } return Pair.of(output.getName(), new LineageItem(tmpInstStr, getOpcode())); } + + private static String replaceNonLiteral(String inst, CPOperand op, int pos, ExecutionContext ec) { + if( !op.isLiteral() ) + inst = InstructionUtils.replaceOperand(inst, pos, + new CPOperand(ec.getScalarInput(op)).getLineageLiteral()); + return inst; + } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java index 0f0951e..f00f42d 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java @@ -155,8 +155,11 @@ public class FunctionCallCPInstruction extends CPInstruction { functionVariables.put(currFormalParam.getName(), value); //map lineage to function arguments - if( lineage != null ) - lineage.set(currFormalParam.getName(), ec.getLineageItem(input)); + if( lineage != null ) { + LineageItem litem = ec.getLineageItem(input); + lineage.set(currFormalParam.getName(), (litem!=null) ? + litem : ec.getLineage().getOrCreate(input)); + } } // Pin the input variables so that they do not get deleted diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ReshapeCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ReshapeCPInstruction.java index 6262c50..8a3c001 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ReshapeCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ReshapeCPInstruction.java @@ -19,6 +19,7 @@ package org.apache.sysds.runtime.instructions.cp; +import org.apache.commons.lang3.tuple.Pair; import org.apache.sysds.common.Types; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; @@ -26,6 +27,8 @@ import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.data.LibTensorReorg; import org.apache.sysds.runtime.data.TensorBlock; import org.apache.sysds.runtime.instructions.InstructionUtils; +import org.apache.sysds.runtime.lineage.LineageItem; +import org.apache.sysds.runtime.lineage.LineageItemUtils; import org.apache.sysds.runtime.matrix.data.LibMatrixReorg; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.Operator; @@ -104,4 +107,10 @@ public class ReshapeCPInstruction extends UnaryCPInstruction { ec.releaseMatrixInput(input1.getName()); } } + + @Override + public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) { + return Pair.of(output.getName(), new LineageItem(getOpcode(), + LineageItemUtils.getLineage(ec, input1, _opRows, _opCols, _opDims, _opByRow))); + } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/RandSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/RandSPInstruction.java index a2058b7..ef40773 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/spark/RandSPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/RandSPInstruction.java @@ -164,11 +164,11 @@ public class RandSPInstruction extends UnarySPInstruction { } public long getRows() { - return rows.isLiteral() ? Long.parseLong(rows.getName()) : -1; + return rows.isLiteral() ? UtilFunctions.parseToLong(rows.getName()) : -1; } public long getCols() { - return cols.isLiteral() ? Long.parseLong(cols.getName()) : -1; + return cols.isLiteral() ? UtilFunctions.parseToLong(cols.getName()) : -1; } public int getBlocksize() { @@ -1011,6 +1011,12 @@ public class RandSPInstruction extends UnarySPInstruction { (_method == OpOpDG.SAMPLE) ? SEED_POSITION_SAMPLE : 0; tmpInstStr = InstructionUtils.replaceOperand( tmpInstStr, position, String.valueOf(runtimeSeed)); + if( !rows.isLiteral() ) + tmpInstStr = InstructionUtils.replaceOperand(tmpInstStr, 2, + new CPOperand(ec.getScalarInput(rows)).getLineageLiteral()); + if( !cols.isLiteral() ) + tmpInstStr = InstructionUtils.replaceOperand(tmpInstStr, 3, + new CPOperand(ec.getScalarInput(cols)).getLineageLiteral()); } return Pair.of(output.getName(), new LineageItem(tmpInstStr, getOpcode())); } diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java index 39a8c2a..b75baee 100644 --- a/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java +++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java @@ -31,6 +31,9 @@ import org.apache.sysds.runtime.lineage.LineageItem.LineageItemType; import org.apache.sysds.common.Types.AggOp; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.Direction; +import org.apache.sysds.common.Types.OpOp1; +import org.apache.sysds.common.Types.OpOp2; +import org.apache.sysds.common.Types.OpOp3; import org.apache.sysds.common.Types.OpOpDG; import org.apache.sysds.common.Types.OpOpData; import org.apache.sysds.common.Types.OpOpN; @@ -267,7 +270,9 @@ public class LineageItemUtils { switch (ctype) { case AggregateUnary: { Hop input = operands.get(item.getInputs()[0].getId()); - Hop aggunary = HopRewriteUtils.createAggUnaryOp(input, item.getOpcode()); + Hop aggunary = InstructionUtils.isUnaryMetadata(item.getOpcode()) ? + HopRewriteUtils.createUnary(input, OpOp1.valueOfByOpcode(item.getOpcode())) : + HopRewriteUtils.createAggUnaryOp(input, item.getOpcode()); operands.put(item.getId(), aggunary); break; } @@ -279,9 +284,15 @@ public class LineageItemUtils { break; } case Reorg: { - Hop input = operands.get(item.getInputs()[0].getId()); - Hop reorg = HopRewriteUtils.createReorg(input, ReOrgOp.TRANS); - operands.put(item.getId(), reorg); + operands.put(item.getId(), HopRewriteUtils.createReorg( + operands.get(item.getInputs()[0].getId()), item.getOpcode())); + break; + } + case Reshape: { + ArrayList<Hop> inputs = new ArrayList<>(); + for(int i=0; i<5; i++) + inputs.add(operands.get(item.getInputs()[i].getId())); + operands.put(item.getId(), HopRewriteUtils.createReorg(inputs, ReOrgOp.RESHAPE)); break; } case Binary: { @@ -303,12 +314,27 @@ public class LineageItemUtils { break; } case Ternary: { - operands.put(item.getId(), HopRewriteUtils.createTernaryOp( + operands.put(item.getId(), HopRewriteUtils.createTernary( operands.get(item.getInputs()[0].getId()), operands.get(item.getInputs()[1].getId()), operands.get(item.getInputs()[2].getId()), item.getOpcode())); break; } + case Ctable: { //e.g., ctable + if( item.getInputs().length==3 ) + operands.put(item.getId(), HopRewriteUtils.createTernary( + operands.get(item.getInputs()[0].getId()), + operands.get(item.getInputs()[1].getId()), + operands.get(item.getInputs()[2].getId()), OpOp3.CTABLE)); + else if( item.getInputs().length==5 ) + operands.put(item.getId(), HopRewriteUtils.createTernary( + operands.get(item.getInputs()[0].getId()), + operands.get(item.getInputs()[1].getId()), + operands.get(item.getInputs()[2].getId()), + operands.get(item.getInputs()[3].getId()), + operands.get(item.getInputs()[4].getId()), OpOp3.CTABLE)); + break; + } case BuiltinNary: { String opcode = item.getOpcode().equals("n+") ? "plus" : item.getOpcode(); operands.put(item.getId(), HopRewriteUtils.createNary( @@ -331,8 +357,13 @@ public class LineageItemUtils { operands.put(item.getId(), aggunary); break; } - case Variable: { //cpvar, write - operands.put(item.getId(), operands.get(item.getInputs()[0].getId())); + case Variable: { + if( item.getOpcode().startsWith("cast") ) + operands.put(item.getId(), HopRewriteUtils.createUnary( + operands.get(item.getInputs()[0].getId()), + OpOp1.valueOfByOpcode(item.getOpcode()))); + else //cpvar, write + operands.put(item.getId(), operands.get(item.getInputs()[0].getId())); break; } default: @@ -358,6 +389,12 @@ public class LineageItemUtils { operands.put(item.getId(), constructIndexingOp(item, operands)); break; } + case GAppend: { + operands.put(item.getId(), HopRewriteUtils.createBinary( + operands.get(item.getInputs()[0].getId()), + operands.get(item.getInputs()[1].getId()), OpOp2.CBIND)); + break; + } default: throw new DMLRuntimeException("Unsupported instruction " + "type: " + stype.name() + " (" + item.getOpcode() + ")."); @@ -482,18 +519,16 @@ public class LineageItemUtils { } private static Hop constructIndexingOp(LineageItem item, Map<Long, Hop> operands) { - //TODO fix + Hop input = operands.get(item.getInputs()[0].getId()); if( "rightIndex".equals(item.getOpcode()) ) - return HopRewriteUtils.createIndexingOp( - operands.get(item.getInputs()[0].getId()), //input + return HopRewriteUtils.createIndexingOp(input, operands.get(item.getInputs()[1].getId()), //rl operands.get(item.getInputs()[2].getId()), //ru operands.get(item.getInputs()[3].getId()), //cl operands.get(item.getInputs()[4].getId())); //cu else if( "leftIndex".equals(item.getOpcode()) || "mapLeftIndex".equals(item.getOpcode()) ) - return HopRewriteUtils.createLeftIndexingOp( - operands.get(item.getInputs()[0].getId()), //input + return HopRewriteUtils.createLeftIndexingOp(input, operands.get(item.getInputs()[1].getId()), //rhs operands.get(item.getInputs()[2].getId()), //rl operands.get(item.getInputs()[3].getId()), //ru diff --git a/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceParforTest.java b/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceParforTest.java index 50443c1..d100a4d 100644 --- a/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceParforTest.java +++ b/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceParforTest.java @@ -44,7 +44,8 @@ public class LineageTraceParforTest extends AutomatedTestBase { protected static final String TEST_NAME1 = "LineageTraceParfor1"; //rand - matrix result - local parfor protected static final String TEST_NAME2 = "LineageTraceParfor2"; //rand - matrix result - remote spark parfor protected static final String TEST_NAME3 = "LineageTraceParfor3"; //rand - matrix result - remote spark parfor - protected static final String TEST_NAME4 = "LineageTraceParfor4"; //rand - steplm (stackoverflow error) + protected static final String TEST_NAME4 = "LineageTraceParforSteplm"; //rand - steplm + protected static final String TEST_NAME5 = "LineageTraceParforKmeans"; //rand - kmeans protected String TEST_CLASS_DIR = TEST_DIR + LineageTraceParforTest.class.getSimpleName() + "/"; @@ -61,6 +62,7 @@ public class LineageTraceParforTest extends AutomatedTestBase { addTestConfiguration( TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"R"}) ); addTestConfiguration( TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"R"}) ); addTestConfiguration( TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] {"R"}) ); + addTestConfiguration( TEST_NAME5, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] {"R"}) ); } @Test @@ -113,16 +115,25 @@ public class LineageTraceParforTest extends AutomatedTestBase { testLineageTraceParFor(32, TEST_NAME3); } -// TODO additional fixes needed for steplm -// @Test -// public void testLineageTraceParFor4_8() { -// testLineageTraceParFor(8, TEST_NAME4); -// } -// -// @Test -// public void testLineageTraceParFor4_32() { -// testLineageTraceParFor(32, TEST_NAME4); -// } + @Test + public void testLineageTraceSteplm_8() { + testLineageTraceParFor(8, TEST_NAME4); + } + + @Test + public void testLineageTraceSteplm_32() { + testLineageTraceParFor(32, TEST_NAME4); + } + + @Test + public void testLineageTraceKMeans_8() { + testLineageTraceParFor(8, TEST_NAME5); + } + + @Test + public void testLineageTraceKmeans_32() { + testLineageTraceParFor(32, TEST_NAME5); + } private void testLineageTraceParFor(int ncol, String testname) { try { @@ -146,7 +157,6 @@ public class LineageTraceParforTest extends AutomatedTestBase { //get lineage and generate program String Rtrace = readDMLLineageFromHDFS("R"); - System.out.println(Rtrace); LineageItem R = LineageParser.parseLineageTrace(Rtrace); Data ret = LineageItemUtils.computeByLineage(R); diff --git a/src/test/scripts/functions/lineage/LineageTraceParfor4.dml b/src/test/scripts/functions/lineage/LineageTraceParforKmeans.dml similarity index 94% copy from src/test/scripts/functions/lineage/LineageTraceParfor4.dml copy to src/test/scripts/functions/lineage/LineageTraceParforKmeans.dml index 576182b..215cb5b 100644 --- a/src/test/scripts/functions/lineage/LineageTraceParfor4.dml +++ b/src/test/scripts/functions/lineage/LineageTraceParforKmeans.dml @@ -20,7 +20,6 @@ #------------------------------------------------------------- X = rand(rows=$2, cols=$3, seed=7); -Y = rand(rows=nrow(X), cols=1, seed=2) -X = steplm(X=X, y=Y) +X = kmeans(X=X, k=4) write(X, $1); diff --git a/src/test/scripts/functions/lineage/LineageTraceParfor4.dml b/src/test/scripts/functions/lineage/LineageTraceParforSteplm.dml similarity index 100% rename from src/test/scripts/functions/lineage/LineageTraceParfor4.dml rename to src/test/scripts/functions/lineage/LineageTraceParforSteplm.dml