This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git


The following commit(s) were added to refs/heads/master by this push:
     new f29ae42  [SYSTEMDS-339] Fix robustness lineage tracing/parsing, part II
f29ae42 is described below

commit f29ae426be1722fba9468609976709068e6e5d7d
Author: Matthias Boehm <mboe...@gmail.com>
AuthorDate: Fri May 22 22:38:54 2020 +0200

    [SYSTEMDS-339] Fix robustness lineage tracing/parsing, part II
    
    This patch fixes many additional issues in lineage tracing and parsing
    in order to support the round-trip for steplm and kmeans.
    
    1) Lineage tracing with default arguments of function call parameters
    (so far missing arguments where traces as literal variable name)
    
    2) Lineage Tracing: rshape with parameters, ctable w/ dimensions,
    rand/seq w/ variable rows/cols, from/to/incr inputs
    
    3) Lineage Parsing: rshape, rdiag, nrow, ncol, all casts ops, ifelse
    with scalar/matrix inputs (so far block size wrong), ctable /w
    dimensions, gappend spark ops
    
    4) New lineage parfor algorithm tests: steplm, kmeans
---
 scripts/builtin/kmeans.dml                         |  6 +--
 src/main/java/org/apache/sysds/common/Types.java   | 35 ++++++++++---
 .../apache/sysds/hops/recompile/Recompiler.java    |  8 +--
 .../apache/sysds/hops/rewrite/HopRewriteUtils.java | 37 ++++++++------
 .../RewriteAlgebraicSimplificationDynamic.java     |  6 +--
 .../runtime/instructions/InstructionUtils.java     | 41 ++-------------
 .../instructions/cp/CtableCPInstruction.java       | 27 ++++++----
 .../instructions/cp/DataGenCPInstruction.java      | 56 ++++++++++++++------
 .../instructions/cp/FunctionCallCPInstruction.java |  7 ++-
 .../instructions/cp/ReshapeCPInstruction.java      |  9 ++++
 .../instructions/spark/RandSPInstruction.java      | 10 +++-
 .../sysds/runtime/lineage/LineageItemUtils.java    | 59 +++++++++++++++++-----
 .../functions/lineage/LineageTraceParforTest.java  | 34 ++++++++-----
 ...aceParfor4.dml => LineageTraceParforKmeans.dml} |  3 +-
 ...aceParfor4.dml => LineageTraceParforSteplm.dml} |  0
 15 files changed, 211 insertions(+), 127 deletions(-)

diff --git a/scripts/builtin/kmeans.dml b/scripts/builtin/kmeans.dml
index 23482da..96591c6 100644
--- a/scripts/builtin/kmeans.dml
+++ b/scripts/builtin/kmeans.dml
@@ -60,8 +60,8 @@ m_kmeans = function(Matrix[Double] X, Integer k = 0, Integer 
runs = 10, Integer
 
   print ("Taking data samples for initialization...");
 
-  [sample_maps, samples_vs_runs_map, sample_block_size] =
-  get_sample_maps (num_records, num_runs, num_centroids * 
avg_sample_size_per_centroid);
+  [sample_maps, samples_vs_runs_map, sample_block_size] = get_sample_maps(
+    num_records, num_runs, num_centroids * avg_sample_size_per_centroid);
 
   is_row_in_samples = rowSums (sample_maps);
   X_samples = sample_maps %*% X;
@@ -230,7 +230,7 @@ get_sample_maps = function (int num_records, int 
num_samples, int approx_sample_
     # Replace all sample record ids over "num_records" (i.e. out of range) by 
"num_records + 1":
     is_sample_rec_id_within_range = (sample_rec_ids <= num_records);
     sample_rec_ids = sample_rec_ids * is_sample_rec_id_within_range
-    + (num_records + 1) * (1 - is_sample_rec_id_within_range);
+      + (num_records + 1) * (1 - is_sample_rec_id_within_range);
 
     # Rearrange all samples (and their out-of-range indicators) into one 
column-vector:
     sample_rec_ids = matrix (sample_rec_ids, rows = num_rows, cols = 1, byrow 
= FALSE);
diff --git a/src/main/java/org/apache/sysds/common/Types.java 
b/src/main/java/org/apache/sysds/common/Types.java
index d693b7f..2d66e81 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -206,6 +206,15 @@ public class Types
                MULT2, MINUS1_MULT, MINUS_RIGHT, 
                POW2, SUBTRACT_NZ;
                
+
+               public boolean isScalarOutput() {
+                       return this == CAST_AS_SCALAR
+                               || this == NROW || this == NCOL
+                               || this == LENGTH || this == EXISTS
+                               || this == IQM || this == LINEAGE
+                               || this == MEDIAN;
+               }
+               
                @Override
                public String toString() {
                        switch(this) {
@@ -244,7 +253,7 @@ public class Types
                                case "ucumk+":  return CUMSUM;
                                case "ucumk+*": return CUMSUMPROD;
                                case "*2":      return MULT2;
-                               case "!":       return OpOp1.NOT;
+                               case "!":       return NOT;
                                case "^2":      return POW2;
                                default:        return 
valueOf(opcode.toUpperCase());
                        }
@@ -354,12 +363,12 @@ public class Types
                        }
                }
                
-               public static OpOp3 valueOfCode(String code) {
-                       switch(code) {
-                               case "cm": return OpOp3.MOMENT;
-                               case "+*": return OpOp3.PLUS_MULT;
-                               case "-*": return OpOp3.MINUS_MULT;
-                               default:   return 
OpOp3.valueOf(code.toUpperCase());
+               public static OpOp3 valueOfByOpcode(String opcode) {
+                       switch(opcode) {
+                               case "cm": return MOMENT;
+                               case "+*": return PLUS_MULT;
+                               case "-*": return MINUS_MULT;
+                               default:   return valueOf(opcode.toUpperCase());
                        }
                }
        }
@@ -394,11 +403,21 @@ public class Types
                @Override
                public String toString() {
                        switch(this) {
-                               case TRANS:   return "t";
+                               case DIAG:    return "rdiag";
+                               case TRANS:   return "r'";
                                case RESHAPE: return "rshape";
                                default:      return name().toLowerCase();
                        }
                }
+               
+               public static ReOrgOp valueOfByOpcode(String opcode) {
+                       switch(opcode) {
+                               case "rdiag":  return DIAG;
+                               case "r'":     return TRANS;
+                               case "rshape": return RESHAPE;
+                               default:       return 
valueOf(opcode.toUpperCase());
+                       }
+               }
        }
        
        public enum ParamBuiltinOp {
diff --git a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java 
b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
index 29e0c0e..4334384 100644
--- a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
+++ b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
@@ -1448,15 +1448,11 @@ public class Recompiler
         * @param scalarsOnly if true, replace only scalar variables but no 
matrix operations;
         *            if false, apply full literal replacement
         */
-       public static void rReplaceLiterals( Hop hop, ExecutionContext ec, 
boolean scalarsOnly )
-       {
-               //public interface 
+       public static void rReplaceLiterals( Hop hop, ExecutionContext ec, 
boolean scalarsOnly ) {
                LiteralReplacement.rReplaceLiterals(hop, ec, scalarsOnly);
        }
 
-       public static void rReplaceLiterals( Hop hop, LocalVariableMap vars, 
boolean scalarsOnly )
-       {
-               //public interface 
+       public static void rReplaceLiterals( Hop hop, LocalVariableMap vars, 
boolean scalarsOnly ) {
                LiteralReplacement.rReplaceLiterals(hop, new 
ExecutionContext(vars), scalarsOnly);
        }
        
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java 
b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
index 9e73fcc..30f66f4 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
@@ -583,6 +583,10 @@ public class HopRewriteUtils
                return createReorg(input, ReOrgOp.TRANS);
        }
        
+       public static ReorgOp createReorg(Hop input, String rop) {
+               return createReorg(input, ReOrgOp.valueOfByOpcode(rop));
+       }
+       
        public static ReorgOp createReorg(Hop input, ReOrgOp rop) {
                ReorgOp reorg = new ReorgOp(input.getName(), 
input.getDataType(), input.getValueType(), rop, input);
                reorg.setBlocksize(input.getBlocksize());
@@ -604,22 +608,19 @@ public class HopRewriteUtils
                return createUnary(input, OpOp1.valueOfByOpcode(type));
        }
        
-       public static UnaryOp createUnary(Hop input, OpOp1 type) 
-       {
-               DataType dt = (type==OpOp1.CAST_AS_SCALAR) ? DataType.SCALAR : 
+       public static UnaryOp createUnary(Hop input, OpOp1 type)  {
+               DataType dt = type.isScalarOutput() ? DataType.SCALAR :
                        (type==OpOp1.CAST_AS_MATRIX) ? DataType.MATRIX : 
input.getDataType();
                ValueType vt = (type==OpOp1.CAST_AS_MATRIX) ? ValueType.FP64 : 
input.getValueType();
                UnaryOp unary = new UnaryOp(input.getName(), dt, vt, type, 
input);
                unary.setBlocksize(input.getBlocksize());
-               if( type == OpOp1.CAST_AS_SCALAR || type == 
OpOp1.CAST_AS_MATRIX ) {
-                       int dim = (type==OpOp1.CAST_AS_SCALAR) ? 0 : 1;
+               if( type.isScalarOutput() || type == OpOp1.CAST_AS_MATRIX ) {
+                       int dim = type.isScalarOutput() ? 0 : 1;
                        int blksz = (type==OpOp1.CAST_AS_SCALAR) ? 0 : 
ConfigurationManager.getBlocksize();
                        setOutputParameters(unary, dim, dim, blksz, -1);
                }
-               
                copyLineNumbers(input, unary);
-               unary.refreshSizeInformation(); 
-               
+               unary.refreshSizeInformation();
                return unary;
        }
        
@@ -681,7 +682,6 @@ public class HopRewriteUtils
                mmult.setBlocksize(left.getBlocksize());
                copyLineNumbers(left, mmult);
                mmult.refreshSizeInformation();
-               
                return mmult;
        }
        
@@ -690,7 +690,6 @@ public class HopRewriteUtils
                pbop.setBlocksize(input.getBlocksize());
                copyLineNumbers(input, pbop);
                pbop.refreshSizeInformation();
-               
                return pbop;
        }
        
@@ -774,23 +773,29 @@ public class HopRewriteUtils
                return datagen;
        }
        
-       public static TernaryOp createTernaryOp(Hop mleft, Hop smid, Hop 
mright, String opcode) {
-               return createTernaryOp(mleft, smid, mright, 
OpOp3.valueOfCode(opcode));
+       public static TernaryOp createTernary(Hop mleft, Hop smid, Hop mright, 
String opcode) {
+               return createTernary(mleft, smid, mright, 
OpOp3.valueOfByOpcode(opcode));
        }
        
-       public static TernaryOp createTernaryOp(Hop mleft, Hop smid, Hop 
mright, OpOp3 op) {
+       public static TernaryOp createTernary(Hop mleft, Hop smid, Hop mright, 
OpOp3 op) {
                //NOTe: for ifelse it's sufficient to check mright as 
smid==mright
-               System.out.println(mleft.getDataType()+" "+smid.getDataType()+" 
"+mright.getDataType());
                DataType dt = (op == OpOp3.IFELSE) ? mright.getDataType() : 
DataType.MATRIX;
                ValueType vt = (op == OpOp3.IFELSE) ? mright.getValueType() : 
ValueType.FP64;
                TernaryOp ternOp = new TernaryOp("tmp", dt, vt, op, mleft, 
smid, mright);
-               if( dt == DataType.MATRIX )
-                       ternOp.setBlocksize(mleft.getBlocksize());
+               ternOp.setBlocksize(Math.max(mleft.getBlocksize(), 
mright.getBlocksize()));
                copyLineNumbers(mleft, ternOp);
                ternOp.refreshSizeInformation();
                return ternOp;
        }
        
+       public static TernaryOp createTernary(Hop in1, Hop in2, Hop in3, Hop 
in4, Hop in5, OpOp3 op) {
+               TernaryOp ternOp = new TernaryOp("tmp", DataType.MATRIX, 
ValueType.FP64, op, in1, in2, in3, in4, in5);
+               ternOp.setBlocksize(Math.max(in1.getBlocksize(), 
in2.getBlocksize()));
+               copyLineNumbers(in1, ternOp);
+               ternOp.refreshSizeInformation();
+               return ternOp;
+       }
+       
        public static Hop createComputeNnz(Hop input) {
                //nnz = sum(A != 0) -> later rewritten to meta-data operation
                return createSum(createBinary(input, new LiteralOp(0), 
OpOp2.NOTEQUAL));
diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index e929e0a..1929315 100644
--- 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -2285,7 +2285,7 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
                                Hop smid = right.getInput().get( 
(right.getInput().get(0).getDataType()==DataType.SCALAR) ? 0 : 1); 
                                Hop mright = right.getInput().get( 
(right.getInput().get(0).getDataType()==DataType.SCALAR) ? 1 : 0);
                                ternop = (smid instanceof LiteralOp && 
HopRewriteUtils.getDoubleValueSafe((LiteralOp)smid)==0) ? 
-                                               left : 
HopRewriteUtils.createTernaryOp(left, smid, mright, OpOp3.PLUS_MULT);
+                                               left : 
HopRewriteUtils.createTernary(left, smid, mright, OpOp3.PLUS_MULT);
                                LOG.debug("Applied 
fuseAxpyBinaryOperationChain1. (line " +hi.getBeginLine()+")");
                        }
                        //pattern (b) s*Y + X -> X +* sY
@@ -2297,7 +2297,7 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
                                Hop smid = left.getInput().get( 
(left.getInput().get(0).getDataType()==DataType.SCALAR) ? 0 : 1); 
                                Hop mright = left.getInput().get( 
(left.getInput().get(0).getDataType()==DataType.SCALAR) ? 1 : 0);
                                ternop = (smid instanceof LiteralOp && 
HopRewriteUtils.getDoubleValueSafe((LiteralOp)smid)==0) ? 
-                                               right : 
HopRewriteUtils.createTernaryOp(right, smid, mright, OpOp3.PLUS_MULT);
+                                               right : 
HopRewriteUtils.createTernary(right, smid, mright, OpOp3.PLUS_MULT);
                                LOG.debug("Applied 
fuseAxpyBinaryOperationChain2. (line " +hi.getBeginLine()+")");
                        }
                        //pattern (c) X - s*Y -> X -* sY
@@ -2309,7 +2309,7 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
                                Hop smid = right.getInput().get( 
(right.getInput().get(0).getDataType()==DataType.SCALAR) ? 0 : 1); 
                                Hop mright = right.getInput().get( 
(right.getInput().get(0).getDataType()==DataType.SCALAR) ? 1 : 0);
                                ternop = (smid instanceof LiteralOp && 
HopRewriteUtils.getDoubleValueSafe((LiteralOp)smid)==0) ? 
-                                               left : 
HopRewriteUtils.createTernaryOp(left, smid, mright, OpOp3.MINUS_MULT);
+                                               left : 
HopRewriteUtils.createTernary(left, smid, mright, OpOp3.MINUS_MULT);
                                LOG.debug("Applied 
fuseAxpyBinaryOperationChain3. (line " +hi.getBeginLine()+")");
                        }
                        
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java 
b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
index f1c8dc6..740d821 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
@@ -24,14 +24,7 @@ import java.util.StringTokenizer;
 import org.apache.sysds.common.Types.AggOp;
 import org.apache.sysds.common.Types.CorrectionLocationType;
 import org.apache.sysds.common.Types.Direction;
-import org.apache.sysds.lops.AppendM;
-import org.apache.sysds.lops.BinaryM;
-import org.apache.sysds.lops.GroupedAggregateM;
 import org.apache.sysds.lops.Lop;
-import org.apache.sysds.lops.MapMult;
-import org.apache.sysds.lops.MapMultChain;
-import org.apache.sysds.lops.PMMJ;
-import org.apache.sysds.lops.UAggOuterChain;
 import org.apache.sysds.lops.WeightedCrossEntropy;
 import org.apache.sysds.lops.WeightedCrossEntropyR;
 import org.apache.sysds.lops.WeightedDivMM;
@@ -239,36 +232,12 @@ public class InstructionUtils
                Builtin.BuiltinCode bfc = 
Builtin.String2BuiltinCode.get(opcode);
                return (bfc != null);
        }
-
-       /**
-        * Evaluates if at least one instruction of the given instruction set
-        * used the distributed cache; this call can also be used for individual
-        * instructions. 
-        * 
-        * @param str instruction set
-        * @return true if at least one instruction uses distributed cache
-        */
-       public static boolean isDistributedCacheUsed(String str) 
-       {       
-               String[] parts = str.split(Instruction.INSTRUCTION_DELIM);
-               for(String inst : parts) 
-               {
-                       String opcode = getOpCode(inst);
-                       if(  opcode.equalsIgnoreCase(AppendM.OPCODE)  
-                          || opcode.equalsIgnoreCase(MapMult.OPCODE)
-                          || opcode.equalsIgnoreCase(MapMultChain.OPCODE)
-                          || opcode.equalsIgnoreCase(PMMJ.OPCODE)
-                          || opcode.equalsIgnoreCase(UAggOuterChain.OPCODE)
-                          || opcode.equalsIgnoreCase(GroupedAggregateM.OPCODE)
-                          || isDistQuaternaryOpcode( opcode ) //multiple 
quaternary opcodes
-                          || BinaryM.isOpcode( opcode ) ) //multiple binary 
opcodes    
-                       {
-                               return true;
-                       }
-               }
-               return false;
+       
+       public static boolean isUnaryMetadata(String opcode) {
+               return opcode != null 
+                       && (opcode.equals("nrow") || opcode.equals("ncol"));
        }
-
+       
        public static AggregateUnaryOperator 
parseBasicAggregateUnaryOperator(String opcode) {
                return parseBasicAggregateUnaryOperator(opcode, 1);
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CtableCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CtableCPInstruction.java
index 77625f4..4869e3e 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CtableCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CtableCPInstruction.java
@@ -20,22 +20,23 @@
 package org.apache.sysds.runtime.instructions.cp;
 
 import org.apache.sysds.lops.Ctable;
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.instructions.Instruction;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.runtime.lineage.LineageItemUtils;
 import org.apache.sysds.runtime.matrix.data.CTableMap;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.util.DataConverter;
 import org.apache.sysds.runtime.util.LongLongDoubleHashMap.EntryType;
 
 public class CtableCPInstruction extends ComputationCPInstruction {
-       private final String _outDim1;
-       private final String _outDim2;
-       private final boolean _dim1Literal;
-       private final boolean _dim2Literal;
+       private final CPOperand _outDim1;
+       private final CPOperand _outDim2;
        private final boolean _isExpand;
        private final boolean _ignoreZeros;
 
@@ -43,10 +44,8 @@ public class CtableCPInstruction extends 
ComputationCPInstruction {
                        String outputDim1, boolean dim1Literal, String 
outputDim2, boolean dim2Literal, boolean isExpand,
                        boolean ignoreZeros, String opcode, String istr) {
                super(CPType.Ctable, null, in1, in2, in3, out, opcode, istr);
-               _outDim1 = outputDim1;
-               _dim1Literal = dim1Literal;
-               _outDim2 = outputDim2;
-               _dim2Literal = dim2Literal;
+               _outDim1 = new CPOperand(outputDim1, ValueType.FP64, 
DataType.SCALAR, dim1Literal);
+               _outDim2 = new CPOperand(outputDim2, ValueType.FP64, 
DataType.SCALAR, dim2Literal);
                _isExpand = isExpand;
                _ignoreZeros = ignoreZeros;
        }
@@ -98,8 +97,8 @@ public class CtableCPInstruction extends 
ComputationCPInstruction {
                Ctable.OperationTypes ctableOp = findCtableOperation();
                ctableOp = _isExpand ? 
Ctable.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT : ctableOp;
                
-               long outputDim1 = (_dim1Literal ? (long) 
Double.parseDouble(_outDim1) : (ec.getScalarInput(_outDim1, ValueType.FP64, 
false)).getLongValue());
-               long outputDim2 = (_dim2Literal ? (long) 
Double.parseDouble(_outDim2) : (ec.getScalarInput(_outDim2, ValueType.FP64, 
false)).getLongValue());
+               long outputDim1 = ec.getScalarInput(_outDim1).getLongValue();
+               long outputDim2 = ec.getScalarInput(_outDim2).getLongValue();
                
                boolean outputDimsKnown = (outputDim1 != -1 && outputDim2 != 
-1);
                if ( outputDimsKnown ) {
@@ -178,4 +177,12 @@ public class CtableCPInstruction extends 
ComputationCPInstruction {
                
                ec.setMatrixOutput(output.getName(), resultBlock);
        }
+       
+       @Override
+       public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
+               LineageItem[] linputs = !(_outDim1.getName().equals("-1") && 
_outDim2.getName().equals("-1")) ?
+                       LineageItemUtils.getLineage(ec, input1, input2, input3, 
_outDim1, _outDim2) :
+                       LineageItemUtils.getLineage(ec, input1, input2, input3);
+               return Pair.of(output.getName(), new LineageItem(getOpcode(), 
linputs));
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java
index 11f4e8e..4aa9660 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java
@@ -130,14 +130,16 @@ public class DataGenCPInstruction extends 
UnaryCPInstruction {
        }
 
        public long getRows() {
-               return rows.isLiteral() ? Long.parseLong(rows.getName()) : -1;
+               return rows.isLiteral() ? 
UtilFunctions.parseToLong(rows.getName()) : -1;
        }
 
        public long getCols() {
-               return cols.isLiteral() ? Long.parseLong(cols.getName()) : -1;
+               return cols.isLiteral() ? 
UtilFunctions.parseToLong(cols.getName()) : -1;
        }
 
-       public String getDims() { return dims.getName(); }
+       public String getDims() { 
+               return dims.getName();
+       }
        
        public int getBlocksize() {
                return blocksize;
@@ -172,15 +174,15 @@ public class DataGenCPInstruction extends 
UnaryCPInstruction {
        }
        
        public long getFrom() {
-               return seq_from.isLiteral() ? 
Long.parseLong(seq_from.getName()) : -1;
+               return seq_from.isLiteral() ? 
UtilFunctions.parseToLong(seq_from.getName()) : -1;
        }
        
        public long getTo() {
-               return seq_to.isLiteral() ? Long.parseLong(seq_to.getName()) : 
-1;
+               return seq_to.isLiteral() ? 
UtilFunctions.parseToLong(seq_to.getName()) : -1;
        }
        
        public long getIncr() {
-               return seq_incr.isLiteral() ? 
Long.parseLong(seq_incr.getName()) : -1;
+               return seq_incr.isLiteral() ? 
UtilFunctions.parseToLong(seq_incr.getName()) : -1;
        }
 
        public static DataGenCPInstruction parseInstruction(String str)
@@ -385,16 +387,40 @@ public class DataGenCPInstruction extends 
UnaryCPInstruction {
        @Override
        public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
                String tmpInstStr = instString;
-               if (getSeed() == DataGenOp.UNSPECIFIED_SEED) {
-                       //generate pseudo-random seed (because not specified)
-                       if (runtimeSeed == null)
-                               runtimeSeed = (minValue == maxValue && sparsity 
== 1) ? 
-                                       DataGenOp.UNSPECIFIED_SEED : 
DataGenOp.generateRandomSeed();
-                       int position = (method == OpOpDG.RAND) ? 
SEED_POSITION_RAND :
-                                       (method == OpOpDG.SAMPLE) ? 
SEED_POSITION_SAMPLE : 0;
-                       tmpInstStr = position != 0 ? 
InstructionUtils.replaceOperand(
-                                                       tmpInstStr, position, 
String.valueOf(runtimeSeed)) : tmpInstStr;
+               
+               switch(method) {
+                       case RAND:
+                       case SAMPLE: {
+                               if (getSeed() == DataGenOp.UNSPECIFIED_SEED) {
+                                       //generate pseudo-random seed (because 
not specified)
+                                       if (runtimeSeed == null)
+                                               runtimeSeed = (minValue == 
maxValue && sparsity == 1) ? 
+                                                       
DataGenOp.UNSPECIFIED_SEED : DataGenOp.generateRandomSeed();
+                                       int position = (method == OpOpDG.RAND) 
? SEED_POSITION_RAND :
+                                               (method == OpOpDG.SAMPLE) ? 
SEED_POSITION_SAMPLE : 0;
+                                       tmpInstStr = position != 0 ? 
InstructionUtils.replaceOperand(
+                                               tmpInstStr, position, 
String.valueOf(runtimeSeed)) : tmpInstStr;
+                               }
+                               tmpInstStr = replaceNonLiteral(tmpInstStr, 
rows, 2, ec);
+                               tmpInstStr = replaceNonLiteral(tmpInstStr, 
cols, 3, ec);
+                               break;
+                       }
+                       case SEQ: {
+                               tmpInstStr = replaceNonLiteral(tmpInstStr, 
seq_from, 5, ec);
+                               tmpInstStr = replaceNonLiteral(tmpInstStr, 
seq_to, 6, ec);
+                               tmpInstStr = replaceNonLiteral(tmpInstStr, 
seq_incr, 7, ec);
+                               break;
+                       }
+                       default:
+                               throw new DMLRuntimeException("Unsupported 
datagen op: "+method);
                }
                return Pair.of(output.getName(), new LineageItem(tmpInstStr, 
getOpcode()));
        }
+       
+       private static String replaceNonLiteral(String inst, CPOperand op, int 
pos, ExecutionContext ec) {
+               if( !op.isLiteral() )
+                       inst = InstructionUtils.replaceOperand(inst, pos,
+                               new 
CPOperand(ec.getScalarInput(op)).getLineageLiteral());
+               return inst;
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
index 0f0951e..f00f42d 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
@@ -155,8 +155,11 @@ public class FunctionCallCPInstruction extends 
CPInstruction {
                        functionVariables.put(currFormalParam.getName(), value);
                        
                        //map lineage to function arguments
-                       if( lineage != null )
-                               lineage.set(currFormalParam.getName(), 
ec.getLineageItem(input));
+                       if( lineage != null ) {
+                               LineageItem litem = ec.getLineageItem(input);
+                               lineage.set(currFormalParam.getName(), 
(litem!=null) ? 
+                                       litem : 
ec.getLineage().getOrCreate(input));
+                       }
                }
                
                // Pin the input variables so that they do not get deleted 
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ReshapeCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ReshapeCPInstruction.java
index 6262c50..8a3c001 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ReshapeCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ReshapeCPInstruction.java
@@ -19,6 +19,7 @@
 
 package org.apache.sysds.runtime.instructions.cp;
 
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.sysds.common.Types;
 import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.runtime.DMLRuntimeException;
@@ -26,6 +27,8 @@ import 
org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.data.LibTensorReorg;
 import org.apache.sysds.runtime.data.TensorBlock;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.runtime.lineage.LineageItemUtils;
 import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.operators.Operator;
@@ -104,4 +107,10 @@ public class ReshapeCPInstruction extends 
UnaryCPInstruction {
                        ec.releaseMatrixInput(input1.getName());
                }
        }
+       
+       @Override
+       public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
+               return Pair.of(output.getName(), new LineageItem(getOpcode(),
+                       LineageItemUtils.getLineage(ec, input1, _opRows, 
_opCols, _opDims, _opByRow)));
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/RandSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/RandSPInstruction.java
index a2058b7..ef40773 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/RandSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/RandSPInstruction.java
@@ -164,11 +164,11 @@ public class RandSPInstruction extends UnarySPInstruction 
{
        }
 
        public long getRows() {
-               return rows.isLiteral() ? Long.parseLong(rows.getName()) : -1;
+               return rows.isLiteral() ? 
UtilFunctions.parseToLong(rows.getName()) : -1;
        }
 
        public long getCols() {
-               return cols.isLiteral() ? Long.parseLong(cols.getName()) : -1;
+               return cols.isLiteral() ? 
UtilFunctions.parseToLong(cols.getName()) : -1;
        }
 
        public int getBlocksize() {
@@ -1011,6 +1011,12 @@ public class RandSPInstruction extends 
UnarySPInstruction {
                                (_method == OpOpDG.SAMPLE) ? 
SEED_POSITION_SAMPLE : 0;
                        tmpInstStr = InstructionUtils.replaceOperand(
                                tmpInstStr, position, 
String.valueOf(runtimeSeed));
+                       if( !rows.isLiteral() )
+                               tmpInstStr = 
InstructionUtils.replaceOperand(tmpInstStr, 2,
+                                       new 
CPOperand(ec.getScalarInput(rows)).getLineageLiteral());
+                       if( !cols.isLiteral() )
+                               tmpInstStr = 
InstructionUtils.replaceOperand(tmpInstStr, 3,
+                                       new 
CPOperand(ec.getScalarInput(cols)).getLineageLiteral());
                }
                return Pair.of(output.getName(), new LineageItem(tmpInstStr, 
getOpcode()));
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
index 39a8c2a..b75baee 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
@@ -31,6 +31,9 @@ import 
org.apache.sysds.runtime.lineage.LineageItem.LineageItemType;
 import org.apache.sysds.common.Types.AggOp;
 import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.common.Types.Direction;
+import org.apache.sysds.common.Types.OpOp1;
+import org.apache.sysds.common.Types.OpOp2;
+import org.apache.sysds.common.Types.OpOp3;
 import org.apache.sysds.common.Types.OpOpDG;
 import org.apache.sysds.common.Types.OpOpData;
 import org.apache.sysds.common.Types.OpOpN;
@@ -267,7 +270,9 @@ public class LineageItemUtils {
                                        switch (ctype) {
                                                case AggregateUnary: {
                                                        Hop input = 
operands.get(item.getInputs()[0].getId());
-                                                       Hop aggunary = 
HopRewriteUtils.createAggUnaryOp(input, item.getOpcode());
+                                                       Hop aggunary = 
InstructionUtils.isUnaryMetadata(item.getOpcode()) ?
+                                                               
HopRewriteUtils.createUnary(input, OpOp1.valueOfByOpcode(item.getOpcode())) :
+                                                               
HopRewriteUtils.createAggUnaryOp(input, item.getOpcode());
                                                        
operands.put(item.getId(), aggunary);
                                                        break;
                                                }
@@ -279,9 +284,15 @@ public class LineageItemUtils {
                                                        break;
                                                }
                                                case Reorg: {
-                                                       Hop input = 
operands.get(item.getInputs()[0].getId());
-                                                       Hop reorg = 
HopRewriteUtils.createReorg(input, ReOrgOp.TRANS);
-                                                       
operands.put(item.getId(), reorg);
+                                                       
operands.put(item.getId(), HopRewriteUtils.createReorg(
+                                                               
operands.get(item.getInputs()[0].getId()), item.getOpcode()));
+                                                       break;
+                                               }
+                                               case Reshape: {
+                                                       ArrayList<Hop> inputs = 
new ArrayList<>();
+                                                       for(int i=0; i<5; i++)
+                                                               
inputs.add(operands.get(item.getInputs()[i].getId()));
+                                                       
operands.put(item.getId(), HopRewriteUtils.createReorg(inputs, 
ReOrgOp.RESHAPE));
                                                        break;
                                                }
                                                case Binary: {
@@ -303,12 +314,27 @@ public class LineageItemUtils {
                                                        break;
                                                }
                                                case Ternary: {
-                                                       
operands.put(item.getId(), HopRewriteUtils.createTernaryOp(
+                                                       
operands.put(item.getId(), HopRewriteUtils.createTernary(
                                                                
operands.get(item.getInputs()[0].getId()), 
                                                                
operands.get(item.getInputs()[1].getId()), 
                                                                
operands.get(item.getInputs()[2].getId()), item.getOpcode()));
                                                        break;
                                                }
+                                               case Ctable: { //e.g., ctable 
+                                                       if( 
item.getInputs().length==3 )
+                                                               
operands.put(item.getId(), HopRewriteUtils.createTernary(
+                                                                       
operands.get(item.getInputs()[0].getId()),
+                                                                       
operands.get(item.getInputs()[1].getId()),
+                                                                       
operands.get(item.getInputs()[2].getId()), OpOp3.CTABLE));
+                                                       else if( 
item.getInputs().length==5 )
+                                                               
operands.put(item.getId(), HopRewriteUtils.createTernary(
+                                                                       
operands.get(item.getInputs()[0].getId()),
+                                                                       
operands.get(item.getInputs()[1].getId()),
+                                                                       
operands.get(item.getInputs()[2].getId()),
+                                                                       
operands.get(item.getInputs()[3].getId()),
+                                                                       
operands.get(item.getInputs()[4].getId()), OpOp3.CTABLE));
+                                                       break;
+                                               }
                                                case BuiltinNary: {
                                                        String opcode = 
item.getOpcode().equals("n+") ? "plus" : item.getOpcode();
                                                        
operands.put(item.getId(), HopRewriteUtils.createNary(
@@ -331,8 +357,13 @@ public class LineageItemUtils {
                                                        
operands.put(item.getId(), aggunary);
                                                        break;
                                                }
-                                               case Variable: { //cpvar, write
-                                                       
operands.put(item.getId(), operands.get(item.getInputs()[0].getId()));
+                                               case Variable: {
+                                                       if( 
item.getOpcode().startsWith("cast") )
+                                                               
operands.put(item.getId(), HopRewriteUtils.createUnary(
+                                                                       
operands.get(item.getInputs()[0].getId()),
+                                                                       
OpOp1.valueOfByOpcode(item.getOpcode())));
+                                                       else //cpvar, write
+                                                               
operands.put(item.getId(), operands.get(item.getInputs()[0].getId()));
                                                        break;
                                                }
                                                default:
@@ -358,6 +389,12 @@ public class LineageItemUtils {
                                                        
operands.put(item.getId(), constructIndexingOp(item, operands));
                                                        break;
                                                }
+                                               case GAppend: {
+                                                       
operands.put(item.getId(), HopRewriteUtils.createBinary(
+                                                               
operands.get(item.getInputs()[0].getId()),
+                                                               
operands.get(item.getInputs()[1].getId()), OpOp2.CBIND));
+                                                       break;
+                                               }
                                                default:
                                                        throw new 
DMLRuntimeException("Unsupported instruction "
                                                                + "type: " + 
stype.name() + " (" + item.getOpcode() + ").");
@@ -482,18 +519,16 @@ public class LineageItemUtils {
        }
        
        private static Hop constructIndexingOp(LineageItem item, Map<Long, Hop> 
operands) {
-               //TODO fix 
+               Hop input = operands.get(item.getInputs()[0].getId());
                if( "rightIndex".equals(item.getOpcode()) )
-                       return HopRewriteUtils.createIndexingOp(
-                               operands.get(item.getInputs()[0].getId()), 
//input
+                       return HopRewriteUtils.createIndexingOp(input,
                                operands.get(item.getInputs()[1].getId()), //rl
                                operands.get(item.getInputs()[2].getId()), //ru
                                operands.get(item.getInputs()[3].getId()), //cl
                                operands.get(item.getInputs()[4].getId())); //cu
                else if( "leftIndex".equals(item.getOpcode()) 
                                || "mapLeftIndex".equals(item.getOpcode()) )
-                       return HopRewriteUtils.createLeftIndexingOp(
-                               operands.get(item.getInputs()[0].getId()), 
//input
+                       return HopRewriteUtils.createLeftIndexingOp(input,
                                operands.get(item.getInputs()[1].getId()), //rhs
                                operands.get(item.getInputs()[2].getId()), //rl
                                operands.get(item.getInputs()[3].getId()), //ru
diff --git 
a/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceParforTest.java
 
b/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceParforTest.java
index 50443c1..d100a4d 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceParforTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceParforTest.java
@@ -44,7 +44,8 @@ public class LineageTraceParforTest extends AutomatedTestBase 
{
        protected static final String TEST_NAME1 = "LineageTraceParfor1"; 
//rand - matrix result - local parfor
        protected static final String TEST_NAME2 = "LineageTraceParfor2"; 
//rand - matrix result - remote spark parfor
        protected static final String TEST_NAME3 = "LineageTraceParfor3"; 
//rand - matrix result - remote spark parfor
-       protected static final String TEST_NAME4 = "LineageTraceParfor4"; 
//rand - steplm (stackoverflow error)
+       protected static final String TEST_NAME4 = "LineageTraceParforSteplm"; 
//rand - steplm
+       protected static final String TEST_NAME5 = "LineageTraceParforKmeans"; 
//rand - kmeans
        
        protected String TEST_CLASS_DIR = TEST_DIR + 
LineageTraceParforTest.class.getSimpleName() + "/";
        
@@ -61,6 +62,7 @@ public class LineageTraceParforTest extends AutomatedTestBase 
{
                addTestConfiguration( TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"R"}) );
                addTestConfiguration( TEST_NAME3, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"R"}) );
                addTestConfiguration( TEST_NAME4, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] {"R"}) );
+               addTestConfiguration( TEST_NAME5, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] {"R"}) );
        }
        
        @Test
@@ -113,16 +115,25 @@ public class LineageTraceParforTest extends 
AutomatedTestBase {
                testLineageTraceParFor(32, TEST_NAME3);
        }
        
-// TODO additional fixes needed for steplm
-//     @Test
-//     public void testLineageTraceParFor4_8() {
-//             testLineageTraceParFor(8, TEST_NAME4);
-//     }
-//     
-//     @Test
-//     public void testLineageTraceParFor4_32() {
-//             testLineageTraceParFor(32, TEST_NAME4);
-//     }
+       @Test
+       public void testLineageTraceSteplm_8() {
+               testLineageTraceParFor(8, TEST_NAME4);
+       }
+       
+       @Test
+       public void testLineageTraceSteplm_32() {
+               testLineageTraceParFor(32, TEST_NAME4);
+       }
+       
+       @Test
+       public void testLineageTraceKMeans_8() {
+               testLineageTraceParFor(8, TEST_NAME5);
+       }
+       
+       @Test
+       public void testLineageTraceKmeans_32() {
+               testLineageTraceParFor(32, TEST_NAME5);
+       }
        
        private void testLineageTraceParFor(int ncol, String testname) {
                try {
@@ -146,7 +157,6 @@ public class LineageTraceParforTest extends 
AutomatedTestBase {
                        
                        //get lineage and generate program
                        String Rtrace = readDMLLineageFromHDFS("R");
-                       System.out.println(Rtrace);
                        LineageItem R = LineageParser.parseLineageTrace(Rtrace);
                        Data ret = LineageItemUtils.computeByLineage(R);
 
diff --git a/src/test/scripts/functions/lineage/LineageTraceParfor4.dml 
b/src/test/scripts/functions/lineage/LineageTraceParforKmeans.dml
similarity index 94%
copy from src/test/scripts/functions/lineage/LineageTraceParfor4.dml
copy to src/test/scripts/functions/lineage/LineageTraceParforKmeans.dml
index 576182b..215cb5b 100644
--- a/src/test/scripts/functions/lineage/LineageTraceParfor4.dml
+++ b/src/test/scripts/functions/lineage/LineageTraceParforKmeans.dml
@@ -20,7 +20,6 @@
 #-------------------------------------------------------------
 
 X = rand(rows=$2, cols=$3, seed=7);
-Y = rand(rows=nrow(X), cols=1, seed=2)
-X = steplm(X=X, y=Y)
+X = kmeans(X=X, k=4)
 
 write(X, $1);
diff --git a/src/test/scripts/functions/lineage/LineageTraceParfor4.dml 
b/src/test/scripts/functions/lineage/LineageTraceParforSteplm.dml
similarity index 100%
rename from src/test/scripts/functions/lineage/LineageTraceParfor4.dml
rename to src/test/scripts/functions/lineage/LineageTraceParforSteplm.dml

Reply via email to