This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new c108640  [SYSTEMDS-2799] Lineage tracing and reuse of federated UDFs
c108640 is described below

commit c1086403f0f0d0133d51b90df9fb3b262eba64f5
Author: arnabp <[email protected]>
AuthorDate: Tue Jan 12 14:03:39 2021 +0100

    [SYSTEMDS-2799] Lineage tracing and reuse of federated UDFs
    
    This patch builds the foundation of lineage tracing and reuse
    of federated UDFs, and adds a test which calls lower.tri in a
    loop and reuse the outputs of Tri UDF in the workers.
    The core idea is to treat a UDF as an instruction which always
    produces the same results for same inputs.
---
 .../controlprogram/federated/FederatedUDF.java     |   8 +-
 .../federated/FederatedWorkerHandler.java          |  26 +-
 .../paramserv/FederatedPSControlThread.java        |  17 ++
 .../paramserv/dp/BalanceToAvgFederatedScheme.java  |   7 +
 .../dp/ReplicateToMaxFederatedScheme.java          |   7 +
 .../paramserv/dp/ShuffleFederatedScheme.java       |   7 +
 .../dp/SubsampleToMinFederatedScheme.java          |   7 +
 .../cp/ParameterizedBuiltinCPInstruction.java      |   8 +
 .../fed/CentralMomentFEDInstruction.java           | 297 +++++++++++----------
 ...tiReturnParameterizedBuiltinFEDInstruction.java |  12 +
 .../fed/ParameterizedBuiltinFEDInstruction.java    |  50 ++++
 .../fed/QuantilePickFEDInstruction.java            |  15 ++
 .../fed/QuantileSortFEDInstruction.java            |   6 +
 .../instructions/fed/ReorgFEDInstruction.java      |  12 +
 .../apache/sysds/runtime/lineage/LineageCache.java | 111 ++++++++
 .../sysds/runtime/lineage/LineageCacheConfig.java  |   2 +-
 .../sysds/runtime/lineage/LineageItemUtils.java    |  21 ++
 .../test/functions/lineage/FedUDFReuseTest.java    | 155 +++++++++++
 .../scripts/functions/lineage/FedUdfReuse1.dml     |  34 +++
 .../functions/lineage/FedUdfReuse1Reference.dml    |  28 ++
 20 files changed, 683 insertions(+), 147 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedUDF.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedUDF.java
index 5423ffa..f42b39b 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedUDF.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedUDF.java
@@ -20,11 +20,13 @@
 package org.apache.sysds.runtime.controlprogram.federated;
 
 import java.io.Serializable;
+import java.util.List;
 
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.lineage.LineageTraceable;
 
-public abstract class FederatedUDF implements Serializable {
+public abstract class FederatedUDF implements Serializable, LineageTraceable {
        private static final long serialVersionUID = 799416525191257308L;
        
        private final long[] _inputIDs;
@@ -36,6 +38,10 @@ public abstract class FederatedUDF implements Serializable {
        public final long[] getInputIDs() {
                return _inputIDs;
        }
+
+       public List<Long> getOutputIds() {
+               return null;
+       }
        
        /**
         * Execute the user-defined function on a set of data objects
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index 4a69a10..5b574c1 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -48,8 +48,11 @@ import org.apache.sysds.runtime.instructions.cp.ListObject;
 import org.apache.sysds.runtime.instructions.cp.ScalarObject;
 import org.apache.sysds.runtime.io.FileFormatPropertiesCSV;
 import org.apache.sysds.runtime.io.IOUtilFunctions;
+import org.apache.sysds.runtime.lineage.LineageCache;
 import org.apache.sysds.runtime.lineage.LineageCacheConfig;
+import org.apache.sysds.runtime.lineage.LineageCacheConfig.ReuseCacheType;
 import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.runtime.lineage.LineageItemUtils;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 import org.apache.sysds.runtime.meta.MetaDataFormat;
 import org.apache.sysds.runtime.privacy.DMLPrivacyException;
@@ -271,7 +274,6 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
                // set variable and construct empty response
                ec.setVariable(varname, data);
                if (DMLScript.LINEAGE)
-                       // TODO: Identify MO uniquely. Use Adler32 checksum.
                        ec.getLineage().set(varname, new 
LineageItem(String.valueOf(request.getChecksum(0))));
 
                return new FederatedResponse(ResponseType.SUCCESS_EMPTY);
@@ -340,10 +342,24 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
                FederatedUDF udf = (FederatedUDF) request.getParam(0);
                Data[] inputs = Arrays.stream(udf.getInputIDs()).mapToObj(id -> 
ec.getVariable(String.valueOf(id)))
                        
.map(PrivacyMonitor::handlePrivacy).toArray(Data[]::new);
-
-               // execute user-defined function
+               
+               // trace lineage
+               if (DMLScript.LINEAGE)
+                       LineageItemUtils.traceFedUDF(ec, udf);
+               
+               // reuse or execute user-defined function
                try {
-                       return udf.execute(ec, inputs);
+                       // reuse UDF outputs if available in lineage cache
+                       if (LineageCache.reuse(udf, ec))
+                               return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS_EMPTY);
+
+                       // else execute the UDF
+                       long t0 = !ReuseCacheType.isNone() ? System.nanoTime() 
: 0;
+                       FederatedResponse res = udf.execute(ec, inputs);
+                       long t1 = !ReuseCacheType.isNone() ? System.nanoTime() 
: 0;
+                       //cacheUDFOutputs(udf, inputs, t1-t0, ec);
+                       LineageCache.putValue(udf, ec, t1-t0);
+                       return res;
                }
                catch(DMLPrivacyException | FederatedWorkerHandlerException ex){
                        throw ex;
@@ -354,7 +370,7 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
                        return new FederatedResponse(ResponseType.ERROR, new 
FederatedWorkerHandlerException(msg));
                }
        }
-
+       
        private FederatedResponse execClear() {
                try {
                        _ecm.clear();
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
index 48249db..98bc91a 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
@@ -20,6 +20,7 @@
 package org.apache.sysds.runtime.controlprogram.paramserv;
 
 import org.apache.commons.lang.NotImplementedException;
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.parser.DataIdentifier;
@@ -50,6 +51,7 @@ import org.apache.sysds.runtime.instructions.cp.ListObject;
 import org.apache.sysds.runtime.instructions.cp.StringObject;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.operators.RightScalarOperator;
+import org.apache.sysds.runtime.lineage.LineageItem;
 import org.apache.sysds.runtime.util.ProgramConverter;
 
 import java.util.ArrayList;
@@ -224,6 +226,11 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
 
                        return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
                }
+
+               @Override
+               public Pair<String, LineageItem> 
getLineageItem(ExecutionContext ec) {
+                       return null;
+               }
        }
 
        /**
@@ -271,6 +278,11 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                        
                        return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
                }
+
+               @Override
+               public Pair<String, LineageItem> 
getLineageItem(ExecutionContext ec) {
+                       return null;
+               }
        }
 
        /**
@@ -531,6 +543,11 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
 
                        return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, accGradients);
                }
+
+               @Override
+               public Pair<String, LineageItem> 
getLineageItem(ExecutionContext ec) {
+                       return null;
+               }
        }
 
        // Statistics methods
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java
index 34e94f0..e3daf60 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java
@@ -19,6 +19,7 @@
 
 package org.apache.sysds.runtime.controlprogram.paramserv.dp;
 
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -28,6 +29,7 @@ import 
org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
 import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
 import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.lineage.LineageItem;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
 
@@ -115,5 +117,10 @@ public class BalanceToAvgFederatedScheme extends 
DataPartitionFederatedScheme {
 
                        return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
                }
+
+               @Override
+               public Pair<String, LineageItem> 
getLineageItem(ExecutionContext ec) {
+                       return null;
+               }
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java
index a1b8f6c..e9c1b50 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java
@@ -19,6 +19,7 @@
 
 package org.apache.sysds.runtime.controlprogram.paramserv.dp;
 
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -28,6 +29,7 @@ import 
org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
 import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
 import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.lineage.LineageItem;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
 
@@ -113,5 +115,10 @@ public class ReplicateToMaxFederatedScheme extends 
DataPartitionFederatedScheme
 
                        return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
                }
+
+               @Override
+               public Pair<String, LineageItem> 
getLineageItem(ExecutionContext ec) {
+                       return null;
+               }
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
index 1920593..365554d 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
@@ -19,6 +19,7 @@
 
 package org.apache.sysds.runtime.controlprogram.paramserv.dp;
 
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -28,6 +29,7 @@ import 
org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
 import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
 import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.lineage.LineageItem;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 
 import java.util.List;
@@ -95,5 +97,10 @@ public class ShuffleFederatedScheme extends 
DataPartitionFederatedScheme {
                        shuffle(labels, permutationMatrixBlock);
                        return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
                }
+
+               @Override
+               public Pair<String, LineageItem> 
getLineageItem(ExecutionContext ec) {
+                       return null;
+               }
        }
 }
\ No newline at end of file
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java
index 937c37e..e55b92e 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java
@@ -19,6 +19,7 @@
 
 package org.apache.sysds.runtime.controlprogram.paramserv.dp;
 
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -28,6 +29,7 @@ import 
org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
 import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
 import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.lineage.LineageItem;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
 
@@ -112,5 +114,10 @@ public class SubsampleToMinFederatedScheme extends 
DataPartitionFederatedScheme
 
                        return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
                }
+
+               @Override
+               public Pair<String, LineageItem> 
getLineageItem(ExecutionContext ec) {
+                       return null;
+               }
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
index e87ee1e..58bc7b1 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
@@ -444,6 +444,14 @@ public class ParameterizedBuiltinCPInstruction extends 
ComputationCPInstruction
                        return Pair.of(output.getName(), new 
LineageItem(getOpcode(),
                                LineageItemUtils.getLineage(ec, target, max, 
dir, cast, ignore)));
                }
+               else if (opcode.equalsIgnoreCase("lowertri") || 
opcode.equalsIgnoreCase("uppertri")) {
+                       CPOperand target = getTargetOperand();
+                       CPOperand lower = getBoolLiteral("lowertri");
+                       CPOperand diag = getBoolLiteral("diag");
+                       CPOperand values = getBoolLiteral("values");
+                       return Pair.of(output.getName(), new 
LineageItem(getOpcode(),
+                               LineageItemUtils.getLineage(ec, target, lower, 
diag, values)));
+               }
                else if (opcode.equalsIgnoreCase("transformdecode") ||
                                opcode.equalsIgnoreCase("transformapply")) {
                        CPOperand target = new CPOperand(params.get("target"), 
ValueType.FP64, DataType.FRAME);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CentralMomentFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CentralMomentFEDInstruction.java
index cc8e683..4bd522f 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CentralMomentFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CentralMomentFEDInstruction.java
@@ -23,6 +23,7 @@ import java.util.ArrayList;
 import java.util.List;
 import java.util.Optional;
 
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.sysds.common.Types;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
@@ -39,149 +40,165 @@ import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.instructions.cp.Data;
 import org.apache.sysds.runtime.instructions.cp.DoubleObject;
 import org.apache.sysds.runtime.instructions.cp.ScalarObject;
+import org.apache.sysds.runtime.lineage.LineageItem;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.operators.CMOperator;
 
 public class CentralMomentFEDInstruction extends AggregateUnaryFEDInstruction {
 
-    private CentralMomentFEDInstruction(CMOperator cm, CPOperand in1, 
CPOperand in2, CPOperand in3, CPOperand out,
-                                       String opcode, String str) {
-        super(cm, in1, in2, in3, out, opcode, str);
-    }
-
-    public static CentralMomentFEDInstruction parseInstruction(String str) {
-        CPOperand in1 = new CPOperand("", Types.ValueType.UNKNOWN, 
Types.DataType.UNKNOWN);
-        CPOperand in2 = null;
-        CPOperand in3 = null;
-        CPOperand out = new CPOperand("", Types.ValueType.UNKNOWN, 
Types.DataType.UNKNOWN);
-
-        String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
-        String opcode = parts[0];
-
-        //check supported opcode
-        if( !opcode.equalsIgnoreCase("cm") ) {
-            throw new DMLRuntimeException("Unsupported opcode "+opcode);
-        }
-
-        if ( parts.length == 4 ) {
-            // Example: CP.cm.mVar0.Var1.mVar2; (without weights)
-            in2 = new CPOperand("", Types.ValueType.UNKNOWN, 
Types.DataType.UNKNOWN);
-            parseUnaryInstruction(str, in1, in2, out);
-        }
-        else if ( parts.length == 5) {
-            // CP.cm.mVar0.mVar1.Var2.mVar3; (with weights)
-            in2 = new CPOperand("", Types.ValueType.UNKNOWN, 
Types.DataType.UNKNOWN);
-            in3 = new CPOperand("", Types.ValueType.UNKNOWN, 
Types.DataType.UNKNOWN);
-            parseUnaryInstruction(str, in1, in2, in3, out);
-        }
-
-        /*
-         * Exact order of the central moment MAY NOT be known at compilation 
time.
-         * We first try to parse the second argument as an integer, and if we 
fail,
-         * we simply pass -1 so that getCMAggOpType() picks up 
AggregateOperationTypes.INVALID.
-         * It must be updated at run time in processInstruction() method.
-         */
-
-        int cmOrder;
-        try {
-            if ( in3 == null ) {
-                cmOrder = Integer.parseInt(in2.getName());
-            }
-            else {
-                cmOrder = Integer.parseInt(in3.getName());
-            }
-        } catch(NumberFormatException e) {
-            cmOrder = -1; // unknown at compilation time
-        }
-
-        CMOperator.AggregateOperationTypes opType = 
CMOperator.getCMAggOpType(cmOrder);
-        CMOperator cm = new CMOperator(CM.getCMFnObject(opType), opType);
-        return new CentralMomentFEDInstruction(cm, in1, in2, in3, out, opcode, 
str);
-    }
-
-    @Override
-    public void processInstruction( ExecutionContext ec ) {
-        MatrixObject mo = ec.getMatrixObject(input1.getName());
-        ScalarObject order = ec.getScalarInput(input3==null ? input2 : input3);
-
-        CMOperator cm_op = ((CMOperator) _optr);
-        if(cm_op.getAggOpType() == CMOperator.AggregateOperationTypes.INVALID)
-            cm_op = cm_op.setCMAggOp((int) order.getLongValue());
-
-        FederationMap fedMapping = mo.getFedMapping();
-        List<CM_COV_Object> globalCmobj = new ArrayList<>();
-
-        long varID = FederationUtils.getNextFedDataID();
-        CMOperator finalCm_op = cm_op;
-        fedMapping.mapParallel(varID, (range, data) -> {
-
-            FederatedResponse response;
-            try {
-                if (input3 == null ) {
-                    response = data.executeFederatedOperation(
-                        new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
-                            new 
CentralMomentFEDInstruction.CMFunction(data.getVarID(), finalCm_op))).get();
-                } else {
-                    MatrixBlock wtBlock = ec.getMatrixInput(input2.getName());
-
-                    response = data.executeFederatedOperation(
-                        new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
-                            new 
CentralMomentFEDInstruction.CMWeightsFunction(data.getVarID(), finalCm_op, 
wtBlock))).get();
-                }
-
-                if(!response.isSuccessful())
-                    response.throwExceptionFromResponse();
-               synchronized(globalCmobj) {
-                   globalCmobj.add((CM_COV_Object) response.getData()[0]);
-               }
-            }
-            catch(Exception e) {
-                throw new DMLRuntimeException(e);
-            }
-            return null;
-        });
-
-        Optional<CM_COV_Object> res = globalCmobj.stream().reduce((arg0, arg1) 
-> (CM_COV_Object) finalCm_op.fn.execute(arg0, arg1));
-        try {
-            ec.setScalarOutput(output.getName(), new 
DoubleObject(res.get().getRequiredResult(finalCm_op)));
-        }
-        catch(Exception e) {
-            throw new DMLRuntimeException(e);
-        }
-    }
-
-    private static class CMFunction extends FederatedUDF {
-        private static final long serialVersionUID = 7460149207607220994L;
-        private final CMOperator _op;
-
-        public CMFunction (long input, CMOperator op) {
-            super(new long[] {input});
-            _op = op;
-        }
-
-        @Override
-        public FederatedResponse execute(ExecutionContext ec, Data... data) {
-            MatrixBlock mb = ((MatrixObject) data[0]).acquireReadAndRelease();
-            return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, mb.cmOperations(_op));
-        }
-    }
-
-
-    private static class CMWeightsFunction extends FederatedUDF {
-        private static final long serialVersionUID = -3685746246551622021L;
-        private final CMOperator _op;
-        private final MatrixBlock _weights;
-
-        protected CMWeightsFunction(long input, CMOperator op, MatrixBlock 
weights) {
-            super(new long[] {input});
-            _op = op;
-            _weights = weights;
-        }
-
-        @Override
-        public FederatedResponse execute(ExecutionContext ec, Data... data) {
-            MatrixBlock mb = ((MatrixObject) data[0]).acquireReadAndRelease();
-            return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, mb.cmOperations(_op, 
_weights));
-        }
-    }
+       private CentralMomentFEDInstruction(CMOperator cm, CPOperand in1, 
CPOperand in2, CPOperand in3, CPOperand out,
+                       String opcode, String str) {
+               super(cm, in1, in2, in3, out, opcode, str);
+       }
+
+       public static CentralMomentFEDInstruction parseInstruction(String str) {
+               CPOperand in1 = new CPOperand("", Types.ValueType.UNKNOWN, 
Types.DataType.UNKNOWN);
+               CPOperand in2 = null;
+               CPOperand in3 = null;
+               CPOperand out = new CPOperand("", Types.ValueType.UNKNOWN, 
Types.DataType.UNKNOWN);
+
+               String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
+               String opcode = parts[0];
+
+               // check supported opcode
+               if (!opcode.equalsIgnoreCase("cm")) {
+                       throw new DMLRuntimeException("Unsupported opcode " + 
opcode);
+               }
+
+               if (parts.length == 4) {
+                       // Example: CP.cm.mVar0.Var1.mVar2; (without weights)
+                       in2 = new CPOperand("", Types.ValueType.UNKNOWN, 
Types.DataType.UNKNOWN);
+                       parseUnaryInstruction(str, in1, in2, out);
+               }
+               else if (parts.length == 5) {
+                       // CP.cm.mVar0.mVar1.Var2.mVar3; (with weights)
+                       in2 = new CPOperand("", Types.ValueType.UNKNOWN, 
Types.DataType.UNKNOWN);
+                       in3 = new CPOperand("", Types.ValueType.UNKNOWN, 
Types.DataType.UNKNOWN);
+                       parseUnaryInstruction(str, in1, in2, in3, out);
+               }
+
+               /*
+                * Exact order of the central moment MAY NOT be known at 
compilation time. We
+                * first try to parse the second argument as an integer, and if 
we fail, we
+                * simply pass -1 so that getCMAggOpType() picks up
+                * AggregateOperationTypes.INVALID. It must be updated at run 
time in
+                * processInstruction() method.
+                */
+
+               int cmOrder;
+               try {
+                       if (in3 == null) {
+                               cmOrder = Integer.parseInt(in2.getName());
+                       }
+                       else {
+                               cmOrder = Integer.parseInt(in3.getName());
+                       }
+               }
+               catch (NumberFormatException e) {
+                       cmOrder = -1; // unknown at compilation time
+               }
+
+               CMOperator.AggregateOperationTypes opType = 
CMOperator.getCMAggOpType(cmOrder);
+               CMOperator cm = new CMOperator(CM.getCMFnObject(opType), 
opType);
+               return new CentralMomentFEDInstruction(cm, in1, in2, in3, out, 
opcode, str);
+       }
+
+       @Override
+       public void processInstruction(ExecutionContext ec) {
+               MatrixObject mo = ec.getMatrixObject(input1.getName());
+               ScalarObject order = ec.getScalarInput(input3 == null ? input2 
: input3);
+
+               CMOperator cm_op = ((CMOperator) _optr);
+               if (cm_op.getAggOpType() == 
CMOperator.AggregateOperationTypes.INVALID)
+                       cm_op = cm_op.setCMAggOp((int) order.getLongValue());
+
+               FederationMap fedMapping = mo.getFedMapping();
+               List<CM_COV_Object> globalCmobj = new ArrayList<>();
+
+               long varID = FederationUtils.getNextFedDataID();
+               CMOperator finalCm_op = cm_op;
+               fedMapping.mapParallel(varID, (range, data) -> {
+
+                       FederatedResponse response;
+                       try {
+                               if (input3 == null) {
+                                       response = data
+                                                       
.executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
+                                                                       new 
CentralMomentFEDInstruction.CMFunction(data.getVarID(), finalCm_op)))
+                                                       .get();
+                               }
+                               else {
+                                       MatrixBlock wtBlock = 
ec.getMatrixInput(input2.getName());
+
+                                       response = 
data.executeFederatedOperation(new FederatedRequest(
+                                                       
FederatedRequest.RequestType.EXEC_UDF, -1,
+                                                       new 
CentralMomentFEDInstruction.CMWeightsFunction(data.getVarID(), finalCm_op, 
wtBlock)))
+                                                       .get();
+                               }
+
+                               if (!response.isSuccessful())
+                                       response.throwExceptionFromResponse();
+                               synchronized (globalCmobj) {
+                                       globalCmobj.add((CM_COV_Object) 
response.getData()[0]);
+                               }
+                       }
+                       catch (Exception e) {
+                               throw new DMLRuntimeException(e);
+                       }
+                       return null;
+               });
+
+               Optional<CM_COV_Object> res = globalCmobj.stream()
+                               .reduce((arg0, arg1) -> (CM_COV_Object) 
finalCm_op.fn.execute(arg0, arg1));
+               try {
+                       ec.setScalarOutput(output.getName(), new 
DoubleObject(res.get().getRequiredResult(finalCm_op)));
+               }
+               catch (Exception e) {
+                       throw new DMLRuntimeException(e);
+               }
+       }
+
+       private static class CMFunction extends FederatedUDF {
+               private static final long serialVersionUID = 
7460149207607220994L;
+               private final CMOperator _op;
+
+               public CMFunction(long input, CMOperator op) {
+                       super(new long[] {input});
+                       _op = op;
+               }
+
+               @Override
+               public FederatedResponse execute(ExecutionContext ec, Data... 
data) {
+                       MatrixBlock mb = ((MatrixObject) 
data[0]).acquireReadAndRelease();
+                       return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, mb.cmOperations(_op));
+               }
+
+               @Override
+               public Pair<String, LineageItem> 
getLineageItem(ExecutionContext ec) {
+                       return null;
+               }
+       }
+
+       private static class CMWeightsFunction extends FederatedUDF {
+               private static final long serialVersionUID = 
-3685746246551622021L;
+               private final CMOperator _op;
+               private final MatrixBlock _weights;
+
+               protected CMWeightsFunction(long input, CMOperator op, 
MatrixBlock weights) {
+                       super(new long[] {input});
+                       _op = op;
+                       _weights = weights;
+               }
+
+               @Override
+               public FederatedResponse execute(ExecutionContext ec, Data... 
data) {
+                       MatrixBlock mb = ((MatrixObject) 
data[0]).acquireReadAndRelease();
+                       return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, mb.cmOperations(_op, 
_weights));
+               }
+
+               @Override
+               public Pair<String, LineageItem> 
getLineageItem(ExecutionContext ec) {
+                       return null;
+               }
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
index d02d0f5..834d11d 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
@@ -23,6 +23,7 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.concurrent.Future;
 
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.sysds.common.Types;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
@@ -38,6 +39,7 @@ import 
org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.lineage.LineageItem;
 import org.apache.sysds.runtime.matrix.data.FrameBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.operators.Operator;
@@ -212,6 +214,11 @@ public class MultiReturnParameterizedBuiltinFEDInstruction 
extends ComputationFE
                        // create federated response
                        return new FederatedResponse(ResponseType.SUCCESS, new 
Object[] {encoder, fb.getColumnNames()});
                }
+
+               @Override
+               public Pair<String, LineageItem> 
getLineageItem(ExecutionContext ec) {
+                       return null;
+               }
        }
 
        public static class ExecuteFrameEncoder extends FederatedUDF {
@@ -242,5 +249,10 @@ public class MultiReturnParameterizedBuiltinFEDInstruction 
extends ComputationFE
                        // return id handle
                        return new 
FederatedResponse(ResponseType.SUCCESS_EMPTY);
                }
+
+               @Override
+               public Pair<String, LineageItem> 
getLineageItem(ExecutionContext ec) {
+                       return null;
+               }
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
index a2b63e9..cff8bb0 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
@@ -25,7 +25,9 @@ import java.util.HashMap;
 import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.stream.Stream;
 
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.sysds.common.Types;
 import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.common.Types.ValueType;
@@ -48,6 +50,8 @@ import org.apache.sysds.runtime.functionobjects.ValueFunction;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.runtime.lineage.LineageItemUtils;
 import org.apache.sysds.runtime.matrix.data.FrameBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
@@ -242,6 +246,27 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
 
                        return new 
FederatedResponse(ResponseType.SUCCESS_EMPTY);
                }
+
+               @Override
+               public List<Long> getOutputIds() {
+                       return new ArrayList<>(Arrays.asList(_outputID));
+               }
+
+               @Override
+               public Pair<String, LineageItem> 
getLineageItem(ExecutionContext ec) {
+                       LineageItem[] liUdfInputs = Arrays.stream(getInputIDs())
+                                       .mapToObj(id -> 
ec.getLineage().get(String.valueOf(id))).toArray(LineageItem[]::new);
+                       CPOperand slice = new 
CPOperand(Arrays.toString(_slice), ValueType.STRING, DataType.SCALAR, true);
+                       CPOperand rowFed = new 
CPOperand(String.valueOf(_rowFed), ValueType.BOOLEAN, DataType.SCALAR, true);
+                       CPOperand lower = new CPOperand(String.valueOf(_lower), 
ValueType.BOOLEAN, DataType.SCALAR, true);
+                       CPOperand diag = new CPOperand(String.valueOf(_diag), 
ValueType.BOOLEAN, DataType.SCALAR, true);
+                       CPOperand values = new 
CPOperand(String.valueOf(_values), ValueType.BOOLEAN, DataType.SCALAR, true);
+                       LineageItem[] otherInputs = 
LineageItemUtils.getLineage(ec, slice, rowFed, lower, diag, values);
+                       LineageItem[] liInputs = 
Stream.concat(Arrays.stream(liUdfInputs), Arrays.stream(otherInputs))
+                                       .toArray(LineageItem[]::new);
+                       return Pair.of(String.valueOf(_outputID), 
+                                       new 
LineageItem(getClass().getSimpleName(), liInputs));
+               }
        }
 
        private void rmempty(ExecutionContext ec) {
@@ -561,6 +586,11 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
                        // return schema
                        return new FederatedResponse(ResponseType.SUCCESS, new 
Object[] {fo.getSchema()});
                }
+
+               @Override
+               public Pair<String, LineageItem> 
getLineageItem(ExecutionContext ec) {
+                       return null;
+               }
        }
 
        private static class GetColumnNames extends FederatedUDF {
@@ -576,6 +606,11 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
                        // return column names
                        return new FederatedResponse(ResponseType.SUCCESS, new 
Object[] {fb.getColumnNames()});
                }
+
+               @Override
+               public Pair<String, LineageItem> 
getLineageItem(ExecutionContext ec) {
+                       return null;
+               }
        }
 
        private static class InitRowsToRemoveOmit extends FederatedUDF {
@@ -594,6 +629,11 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
                        _encoder.build(fb);
                        return new FederatedResponse(ResponseType.SUCCESS, new 
Object[] {_encoder});
                }
+
+               @Override
+               public Pair<String, LineageItem> 
getLineageItem(ExecutionContext ec) {
+                       return null;
+               }
        }
 
        private static class GetDataCharacteristics extends FederatedUDF {
@@ -611,6 +651,11 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
                        int c = mb.getDenseBlockValues() != null ? 
mb.getNumColumns(): 0;
                        return new FederatedResponse(ResponseType.SUCCESS, new 
int[] {r, c});
                }
+
+               @Override
+               public Pair<String, LineageItem> 
getLineageItem(ExecutionContext ec) {
+                       return null;
+               }
        }
 
        private static class GetVector extends FederatedUDF {
@@ -640,5 +685,10 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
                        tmp1 = tmp1.binaryOperationsInPlace(greater, new 
MatrixBlock(tmp1.getNumRows(), tmp1.getNumColumns(), 0.0));
                        return new FederatedResponse(ResponseType.SUCCESS, 
tmp1);
                }
+
+               @Override
+               public Pair<String, LineageItem> 
getLineageItem(ExecutionContext ec) {
+                       return null;
+               }
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
index f2052b5..1d9cbdd 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
@@ -22,6 +22,7 @@ package org.apache.sysds.runtime.instructions.fed;
 import java.util.ArrayList;
 import java.util.List;
 
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.sysds.lops.PickByCount.OperationTypes;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
@@ -36,6 +37,7 @@ import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.instructions.cp.Data;
 import org.apache.sysds.runtime.instructions.cp.DoubleObject;
 import org.apache.sysds.runtime.instructions.cp.ScalarObject;
+import org.apache.sysds.runtime.lineage.LineageItem;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 
@@ -176,6 +178,11 @@ public class QuantilePickFEDInstruction extends 
BinaryFEDInstruction {
                                        new Object[] {picked});
                        }
                }
+
+               @Override
+               public Pair<String, LineageItem> 
getLineageItem(ExecutionContext ec) {
+                       return null;
+               }
        }
 
        private static class IQM extends FederatedUDF {
@@ -191,6 +198,10 @@ public class QuantilePickFEDInstruction extends 
BinaryFEDInstruction {
                        return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS,
                                new Object[] {mb.interQuartileMean()});
                }
+               @Override
+               public Pair<String, LineageItem> 
getLineageItem(ExecutionContext ec) {
+                       return null;
+               }
        }
 
        private static class Median extends FederatedUDF {
@@ -206,5 +217,9 @@ public class QuantilePickFEDInstruction extends 
BinaryFEDInstruction {
                        return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS,
                                new Object[] {mb.median()});
                }
+               @Override
+               public Pair<String, LineageItem> 
getLineageItem(ExecutionContext ec) {
+                       return null;
+               }
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
index 0e994d9..c1aa37d 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
@@ -19,6 +19,7 @@
 
 package org.apache.sysds.runtime.instructions.fed;
 
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.sysds.common.Types;
 import org.apache.sysds.lops.SortKeys;
 import org.apache.sysds.runtime.DMLRuntimeException;
@@ -32,6 +33,7 @@ import 
org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.lineage.LineageItem;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 
 public class QuantileSortFEDInstruction extends UnaryFEDInstruction{
@@ -159,5 +161,9 @@ public class QuantileSortFEDInstruction extends 
UnaryFEDInstruction{
                        // return schema
                        return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS_EMPTY);
                }
+               @Override
+               public Pair<String, LineageItem> 
getLineageItem(ExecutionContext ec) {
+                       return null;
+               }
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
index a033769..65e6e97 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
@@ -22,6 +22,7 @@ package org.apache.sysds.runtime.instructions.fed;
 import java.util.HashMap;
 import java.util.Map;
 
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.sysds.common.Types;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
@@ -38,6 +39,7 @@ import org.apache.sysds.runtime.functionobjects.SwapIndex;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.lineage.LineageItem;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
@@ -263,6 +265,11 @@ public class ReorgFEDInstruction extends 
UnaryFEDInstruction {
 
                        return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, new 
int[]{res.getNumRows(), res.getNumColumns()});
                }
+
+               @Override
+               public Pair<String, LineageItem> 
getLineageItem(ExecutionContext ec) {
+                       return null;
+               }
        }
 
        private static class DiagMatrix extends FederatedUDF {
@@ -302,5 +309,10 @@ public class ReorgFEDInstruction extends 
UnaryFEDInstruction {
 
                        return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, new 
int[]{res.getNumRows(), res.getNumColumns()});
                }
+
+               @Override
+               public Pair<String, LineageItem> 
getLineageItem(ExecutionContext ec) {
+                       return null;
+               }
        }
 }
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
index 287f9aa..0f9b6ed 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
@@ -32,6 +32,7 @@ import org.apache.sysds.parser.Statement;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
 import 
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
 import org.apache.sysds.runtime.instructions.CPInstructionParser;
 import org.apache.sysds.runtime.instructions.Instruction;
@@ -234,6 +235,69 @@ public class LineageCache
                return reuse;
        }
        
+       //Reuse federated UDFs
+       public static boolean reuse(FederatedUDF udf, ExecutionContext ec) 
+       {
+               if (ReuseCacheType.isNone() || udf.getOutputIds() == null)
+                       return false;
+               //TODO: reuse only those UDFs which are part of reusable 
instructions
+               
+               boolean reuse = false;
+               List<Long> outIds = udf.getOutputIds();
+               HashMap<String, Data> udfOutputs = new HashMap<>();
+
+               //TODO: support multi-return UDFs
+               if (udf.getLineageItem(ec) == null)
+                       //TODO: trace all UDFs
+                       return false;
+               LineageItem li = udf.getLineageItem(ec).getValue();
+               li.setDistLeaf2Node(1); //to save from early eviction
+               LineageCacheEntry e = null;
+               synchronized(_cache) {
+                       if (probe(li))
+                               e = LineageCache.getIntern(li);
+                       else
+                               //for now allow only matrix blocks
+                               putIntern(li, DataType.MATRIX, null, null, 0);
+               }
+               
+               if (e != null) {
+                       String outName = String.valueOf(outIds.get(0));
+                       Data outValue = null;
+                       //convert to matrix object
+                       if (e.isMatrixValue()) {
+                               MetaDataFormat md = new MetaDataFormat(
+                                       
e.getMBValue().getDataCharacteristics(),FileFormat.BINARY);
+                               outValue = new MatrixObject(ValueType.FP64, 
outName, md);
+                               
((MatrixObject)outValue).acquireModify(e.getMBValue());
+                               ((MatrixObject)outValue).release();
+                       }
+                       else {
+                               outValue = e.getSOValue();
+                       }
+                       udfOutputs.put(outName, outValue);
+                       reuse = true;
+               }
+               else
+                       reuse = false;
+               
+               if (reuse) {
+                       udfOutputs.forEach((var, val) -> {
+                               //cleanup existing data bound to output name
+                               Data exdata = ec.removeVariable(var);
+                               if (exdata != val)
+                                       ec.cleanupDataObject(exdata);
+                               //add or replace data in the symbol table
+                               ec.setVariable(var, val);
+                       });
+
+                       if (DMLScript.STATISTICS)
+                               //TODO: dedicated stats for federated reuse
+                               LineageCacheStatistics.incrementInstHits();
+               }
+               return reuse;
+       }
+       
        public static boolean probe(LineageItem key) {
                //TODO problematic as after probe the matrix might be kicked 
out of cache
                boolean p = _cache.containsKey(key);  // in cache or in disk
@@ -384,6 +448,53 @@ public class LineageCache
                return;
        }
        
+       public static void putValue(FederatedUDF udf, ExecutionContext ec, long 
computetime) 
+       {
+               if (ReuseCacheType.isNone() || udf.getOutputIds() == null)
+                       return;
+
+               List<Long> outIds = udf.getOutputIds();
+               if (udf.getLineageItem(ec) == null)
+                       //TODO: trace all UDFs
+                       return;
+               LineageItem item = udf.getLineageItem(ec).getValue();
+               LineageCacheEntry entry = _cache.get(item);
+               Data data = ec.getVariable(String.valueOf(outIds.get(0)));
+               if (!(data instanceof MatrixObject) && !(data instanceof 
ScalarObject)) {
+                       // Don't cache if the udf outputs frames
+                       _cache.remove(item);
+                       return;
+               }
+               
+               MatrixBlock mb = (data instanceof MatrixObject) ? 
+                               ((MatrixObject)data).acquireReadAndRelease() : 
null;
+               long size = mb != null ? mb.getInMemorySize() : 
((ScalarObject)data).getSize();
+
+               //remove the placeholder if the entry is bigger than the cache.
+               //FIXME: the resumed threads will enter into infinite wait as 
the entry
+               //is removed. Need to add support for graceful remove 
(placeholder) and resume.
+               if (size > LineageCacheEviction.getCacheLimit()) {
+                       _cache.remove(item);
+                       return;
+               }
+
+               //make space for the data
+               if (!LineageCacheEviction.isBelowThreshold(size))
+                       LineageCacheEviction.makeSpace(_cache, size);
+               LineageCacheEviction.updateSize(size, true);
+
+               //place the data
+               if (data instanceof MatrixObject)
+                       entry.setValue(mb, computetime);
+               else if (data instanceof ScalarObject)
+                       entry.setValue((ScalarObject)data, computetime);
+
+               //TODO: maintain statistics, lineage estimate
+
+               //maintain order for eviction
+               LineageCacheEviction.addEntry(entry);
+       }
+       
        public static void resetCache() {
                synchronized (_cache) {
                        _cache.clear();
diff --git 
a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
index 61dfc8e..2532d14 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
@@ -43,7 +43,7 @@ public class LineageCacheConfig
                "uamean", "max", "min", "ifelse", "-", "sqrt", ">", "uak+", 
"<=",
                "^", "uamax", "uark+", "uacmean", "eigen", "ctableexpand", 
"replace",
                "^2", "uack+", "tak+*", "uacsqk+", "uark+", "n+", "uarimax", 
"qsort", 
-               "qpick", "transformapply", "uarmax", "n+", "-*", "castdtm"
+               "qpick", "transformapply", "uarmax", "n+", "-*", "castdtm", 
"lowertri"
                //TODO: Reuse everything. 
        };
        private static String[] REUSE_OPCODES  = new String[] {};
diff --git 
a/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
index 977b12c..49ed459 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
@@ -20,6 +20,7 @@
 package org.apache.sysds.runtime.lineage;
 
 import org.apache.commons.lang3.ArrayUtils;
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.LocalFileSystem;
 import org.apache.hadoop.fs.Path;
@@ -47,6 +48,7 @@ import org.apache.sysds.lops.compile.Dag;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
 import org.apache.sysds.runtime.instructions.Instruction;
 import org.apache.sysds.runtime.instructions.InstructionParser;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
@@ -138,6 +140,25 @@ public class LineageItemUtils {
                        .map(c -> 
ec.getLineage().getOrCreate(c)).toArray(LineageItem[]::new);
        }
        
+       public static void traceFedUDF(ExecutionContext ec, FederatedUDF udf) {
+               if (udf.getLineageItem(ec) == null)
+                       //TODO: trace all UDFs
+                       return;
+
+               if (!(udf instanceof LineageTraceable))
+                       throw new DMLRuntimeException("Unknown Federated UDF (" 
+ udf.getClass().getSimpleName() + ") traced.");
+               LineageTraceable ludf = (LineageTraceable) udf;
+               if (ludf.hasSingleLineage()) {
+                       Pair<String, LineageItem> item = udf.getLineageItem(ec);
+                       ec.getLineage().set(item.getKey(), item.getValue());
+               }
+               else {
+                       Pair<String, LineageItem>[] items = 
udf.getLineageItems(ec);
+                       for (Pair<String, LineageItem> item : items)
+                               ec.getLineage().set(item.getKey(), 
item.getValue());
+               }
+       }
+       
        public static void constructLineageFromHops(Hop[] roots, String 
claName, Hop[] inputs, HashMap<Long, Hop> spoofmap) {
                //probe existence and only generate lineage if non-existing
                //(a fused operator might be used in multiple places of a 
program)
diff --git 
a/src/test/java/org/apache/sysds/test/functions/lineage/FedUDFReuseTest.java 
b/src/test/java/org/apache/sysds/test/functions/lineage/FedUDFReuseTest.java
new file mode 100644
index 0000000..f7f01f6
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/lineage/FedUDFReuseTest.java
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.lineage;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.lineage.Lineage;
+import org.apache.sysds.runtime.lineage.LineageCacheStatistics;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class FedUDFReuseTest extends AutomatedTestBase {
+       private final static String TEST_NAME1 = "FedUdfReuse1";
+
+       private final static String TEST_DIR = "functions/lineage/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
FedUDFReuseTest.class.getSimpleName() + "/";
+
+       private final static int blocksize = 1024;
+       @Parameterized.Parameter()
+       public int rows;
+       @Parameterized.Parameter(1)
+       public int cols;
+
+       @Parameterized.Parameter(2)
+       public boolean rowPartitioned;
+
+       @Parameterized.Parameters
+       public static Collection<Object[]> data() {
+               return Arrays.asList(new Object[][] {
+                       {20, 20, true}
+               });
+       }
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"S"}));
+       }
+
+       @Test
+       public void testTriUDFReuseCP() {
+               runTriUDFReuse(ExecMode.SINGLE_NODE);
+       }
+       
+       private void runTriUDFReuse(ExecMode execMode) {
+               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+               ExecMode platformOld = rtplatform;
+               String TEST_NAME = TEST_NAME1;
+
+               if(rtplatform == ExecMode.SPARK)
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+               getAndLoadTestConfiguration(TEST_NAME);
+               String HOME = SCRIPT_DIR + TEST_DIR;
+
+               // write input matrices
+               int r = rows;
+               int c = cols / 4;
+               if(rowPartitioned) {
+                       r = rows / 4;
+                       c = cols;
+               }
+
+               double[][] X1 = getRandomMatrix(r, c, 1, 5, 1, 3);
+               double[][] X2 = getRandomMatrix(r, c, 1, 5, 1, 7);
+               double[][] X3 = getRandomMatrix(r, c, 1, 5, 1, 8);
+               double[][] X4 = getRandomMatrix(r, c, 1, 5, 1, 9);
+
+               MatrixCharacteristics mc = new MatrixCharacteristics(r, c, 
blocksize, r * c);
+               writeInputMatrixWithMTD("X1", X1, false, mc);
+               writeInputMatrixWithMTD("X2", X2, false, mc);
+               writeInputMatrixWithMTD("X3", X3, false, mc);
+               writeInputMatrixWithMTD("X4", X4, false, mc);
+
+               // empty script name because we don't execute any script, just 
start the worker
+               fullDMLScriptName = "";
+               int port1 = getRandomAvailablePort();
+               int port2 = getRandomAvailablePort();
+               int port3 = getRandomAvailablePort();
+               int port4 = getRandomAvailablePort();
+               String[] otherargs = new String[] {"-lineage", "reuse_full"};
+               Lineage.resetInternalState();
+               Thread t1 = startLocalFedWorkerThread(port1, otherargs, 
FED_WORKER_WAIT_S);
+               Thread t2 = startLocalFedWorkerThread(port2, otherargs, 
FED_WORKER_WAIT_S);
+               Thread t3 = startLocalFedWorkerThread(port3, otherargs, 
FED_WORKER_WAIT_S);
+               Thread t4 = startLocalFedWorkerThread(port4, otherargs);
+
+               rtplatform = execMode;
+               if(rtplatform == ExecMode.SPARK) {
+                       System.out.println(7);
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+               }
+               TestConfiguration config = 
availableTestConfigurations.get(TEST_NAME);
+               loadTestConfiguration(config);
+
+               // Run reference dml script with normal matrix
+               fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+               programArgs = new String[] {"-lineage", "reuse_full", "-stats", 
"100", "-args", 
+                       input("X1"), input("X2"), input("X3"), input("X4"),
+                       Boolean.toString(rowPartitioned).toUpperCase(), 
expected("S")};
+               runTest(null);
+
+               // Run actual dml script with federated matrix
+               fullDMLScriptName = HOME + TEST_NAME + ".dml";
+               programArgs = new String[] {"-lineage", "reuse_full", "-stats", 
"100", "-nvargs",
+                       "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
+                       "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
+                       "in_X3=" + TestUtils.federatedAddress(port3, 
input("X3")),
+                       "in_X4=" + TestUtils.federatedAddress(port4, 
input("X4")), "rows=" + rows, "cols=" + cols,
+                       "rP=" + Boolean.toString(rowPartitioned).toUpperCase(), 
"out_S=" + output("S")};
+
+               runTest(null);
+
+               // compare via files
+               compareResults(1e-9);
+               // check if lowertri is federated
+               Assert.assertTrue(heavyHittersContainsString("fed_lowertri"));
+               // assert reuse count
+               Assert.assertTrue(LineageCacheStatistics.getInstHits() > 0);
+
+               TestUtils.shutdownThreads(t1, t2, t3, t4);
+
+               rtplatform = platformOld;
+               DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+       }
+}
+
diff --git a/src/test/scripts/functions/lineage/FedUdfReuse1.dml 
b/src/test/scripts/functions/lineage/FedUdfReuse1.dml
new file mode 100644
index 0000000..1867db9
--- /dev/null
+++ b/src/test/scripts/functions/lineage/FedUdfReuse1.dml
@@ -0,0 +1,34 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+if ($rP) {
+    A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+        ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
+               list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), 
list($rows, $cols)));
+  } else {
+    A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+            ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), 
list($rows, $cols/2),
+               list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 
3*($cols/4)), list($rows, $cols)));
+    }
+
+for (i in 1:10)
+  s = lower.tri(target=A, diag=FALSE, values=TRUE);
+write(s, $out_S);
diff --git a/src/test/scripts/functions/lineage/FedUdfReuse1Reference.dml 
b/src/test/scripts/functions/lineage/FedUdfReuse1Reference.dml
new file mode 100644
index 0000000..8175334
--- /dev/null
+++ b/src/test/scripts/functions/lineage/FedUdfReuse1Reference.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+if($5) { A = rbind(read($1), read($2), read($3), read($4)); }
+else { A = cbind(read($1), read($2), read($3), read($4));}
+
+for (i in 1:10)
+  s = lower.tri(target=A, diag=FALSE, values=TRUE);
+
+write(s, $6);

Reply via email to