This is an automated email from the ASF dual-hosted git repository.
arnabp20 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new c108640 [SYSTEMDS-2799] Lineage tracing and reuse of federated UDFs
c108640 is described below
commit c1086403f0f0d0133d51b90df9fb3b262eba64f5
Author: arnabp <[email protected]>
AuthorDate: Tue Jan 12 14:03:39 2021 +0100
[SYSTEMDS-2799] Lineage tracing and reuse of federated UDFs
This patch builds the foundation of lineage tracing and reuse
of federated UDFs, and adds a test which calls lower.tri in a
loop and reuse the outputs of Tri UDF in the workers.
The core idea is to treat a UDF as an instruction which always
produces the same results for same inputs.
---
.../controlprogram/federated/FederatedUDF.java | 8 +-
.../federated/FederatedWorkerHandler.java | 26 +-
.../paramserv/FederatedPSControlThread.java | 17 ++
.../paramserv/dp/BalanceToAvgFederatedScheme.java | 7 +
.../dp/ReplicateToMaxFederatedScheme.java | 7 +
.../paramserv/dp/ShuffleFederatedScheme.java | 7 +
.../dp/SubsampleToMinFederatedScheme.java | 7 +
.../cp/ParameterizedBuiltinCPInstruction.java | 8 +
.../fed/CentralMomentFEDInstruction.java | 297 +++++++++++----------
...tiReturnParameterizedBuiltinFEDInstruction.java | 12 +
.../fed/ParameterizedBuiltinFEDInstruction.java | 50 ++++
.../fed/QuantilePickFEDInstruction.java | 15 ++
.../fed/QuantileSortFEDInstruction.java | 6 +
.../instructions/fed/ReorgFEDInstruction.java | 12 +
.../apache/sysds/runtime/lineage/LineageCache.java | 111 ++++++++
.../sysds/runtime/lineage/LineageCacheConfig.java | 2 +-
.../sysds/runtime/lineage/LineageItemUtils.java | 21 ++
.../test/functions/lineage/FedUDFReuseTest.java | 155 +++++++++++
.../scripts/functions/lineage/FedUdfReuse1.dml | 34 +++
.../functions/lineage/FedUdfReuse1Reference.dml | 28 ++
20 files changed, 683 insertions(+), 147 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedUDF.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedUDF.java
index 5423ffa..f42b39b 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedUDF.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedUDF.java
@@ -20,11 +20,13 @@
package org.apache.sysds.runtime.controlprogram.federated;
import java.io.Serializable;
+import java.util.List;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.lineage.LineageTraceable;
-public abstract class FederatedUDF implements Serializable {
+public abstract class FederatedUDF implements Serializable, LineageTraceable {
private static final long serialVersionUID = 799416525191257308L;
private final long[] _inputIDs;
@@ -36,6 +38,10 @@ public abstract class FederatedUDF implements Serializable {
public final long[] getInputIDs() {
return _inputIDs;
}
+
+ public List<Long> getOutputIds() {
+ return null;
+ }
/**
* Execute the user-defined function on a set of data objects
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index 4a69a10..5b574c1 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -48,8 +48,11 @@ import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.io.FileFormatPropertiesCSV;
import org.apache.sysds.runtime.io.IOUtilFunctions;
+import org.apache.sysds.runtime.lineage.LineageCache;
import org.apache.sysds.runtime.lineage.LineageCacheConfig;
+import org.apache.sysds.runtime.lineage.LineageCacheConfig.ReuseCacheType;
import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.meta.MetaDataFormat;
import org.apache.sysds.runtime.privacy.DMLPrivacyException;
@@ -271,7 +274,6 @@ public class FederatedWorkerHandler extends
ChannelInboundHandlerAdapter {
// set variable and construct empty response
ec.setVariable(varname, data);
if (DMLScript.LINEAGE)
- // TODO: Identify MO uniquely. Use Adler32 checksum.
ec.getLineage().set(varname, new
LineageItem(String.valueOf(request.getChecksum(0))));
return new FederatedResponse(ResponseType.SUCCESS_EMPTY);
@@ -340,10 +342,24 @@ public class FederatedWorkerHandler extends
ChannelInboundHandlerAdapter {
FederatedUDF udf = (FederatedUDF) request.getParam(0);
Data[] inputs = Arrays.stream(udf.getInputIDs()).mapToObj(id ->
ec.getVariable(String.valueOf(id)))
.map(PrivacyMonitor::handlePrivacy).toArray(Data[]::new);
-
- // execute user-defined function
+
+ // trace lineage
+ if (DMLScript.LINEAGE)
+ LineageItemUtils.traceFedUDF(ec, udf);
+
+ // reuse or execute user-defined function
try {
- return udf.execute(ec, inputs);
+ // reuse UDF outputs if available in lineage cache
+ if (LineageCache.reuse(udf, ec))
+ return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS_EMPTY);
+
+ // else execute the UDF
+ long t0 = !ReuseCacheType.isNone() ? System.nanoTime()
: 0;
+ FederatedResponse res = udf.execute(ec, inputs);
+ long t1 = !ReuseCacheType.isNone() ? System.nanoTime()
: 0;
+ //cacheUDFOutputs(udf, inputs, t1-t0, ec);
+ LineageCache.putValue(udf, ec, t1-t0);
+ return res;
}
catch(DMLPrivacyException | FederatedWorkerHandlerException ex){
throw ex;
@@ -354,7 +370,7 @@ public class FederatedWorkerHandler extends
ChannelInboundHandlerAdapter {
return new FederatedResponse(ResponseType.ERROR, new
FederatedWorkerHandlerException(msg));
}
}
-
+
private FederatedResponse execClear() {
try {
_ecm.clear();
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
index 48249db..98bc91a 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
@@ -20,6 +20,7 @@
package org.apache.sysds.runtime.controlprogram.paramserv;
import org.apache.commons.lang.NotImplementedException;
+import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.parser.DataIdentifier;
@@ -50,6 +51,7 @@ import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.runtime.instructions.cp.StringObject;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.RightScalarOperator;
+import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.util.ProgramConverter;
import java.util.ArrayList;
@@ -224,6 +226,11 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
}
+
+ @Override
+ public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
+ return null;
+ }
}
/**
@@ -271,6 +278,11 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
}
+
+ @Override
+ public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
+ return null;
+ }
}
/**
@@ -531,6 +543,11 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, accGradients);
}
+
+ @Override
+ public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
+ return null;
+ }
}
// Statistics methods
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java
index 34e94f0..e3daf60 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java
@@ -19,6 +19,7 @@
package org.apache.sysds.runtime.controlprogram.paramserv.dp;
+import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -28,6 +29,7 @@ import
org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.meta.DataCharacteristics;
@@ -115,5 +117,10 @@ public class BalanceToAvgFederatedScheme extends
DataPartitionFederatedScheme {
return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
}
+
+ @Override
+ public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
+ return null;
+ }
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java
index a1b8f6c..e9c1b50 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java
@@ -19,6 +19,7 @@
package org.apache.sysds.runtime.controlprogram.paramserv.dp;
+import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -28,6 +29,7 @@ import
org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.meta.DataCharacteristics;
@@ -113,5 +115,10 @@ public class ReplicateToMaxFederatedScheme extends
DataPartitionFederatedScheme
return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
}
+
+ @Override
+ public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
+ return null;
+ }
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
index 1920593..365554d 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
@@ -19,6 +19,7 @@
package org.apache.sysds.runtime.controlprogram.paramserv.dp;
+import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -28,6 +29,7 @@ import
org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import java.util.List;
@@ -95,5 +97,10 @@ public class ShuffleFederatedScheme extends
DataPartitionFederatedScheme {
shuffle(labels, permutationMatrixBlock);
return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
}
+
+ @Override
+ public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
+ return null;
+ }
}
}
\ No newline at end of file
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java
index 937c37e..e55b92e 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java
@@ -19,6 +19,7 @@
package org.apache.sysds.runtime.controlprogram.paramserv.dp;
+import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -28,6 +29,7 @@ import
org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.meta.DataCharacteristics;
@@ -112,5 +114,10 @@ public class SubsampleToMinFederatedScheme extends
DataPartitionFederatedScheme
return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
}
+
+ @Override
+ public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
+ return null;
+ }
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
index e87ee1e..58bc7b1 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
@@ -444,6 +444,14 @@ public class ParameterizedBuiltinCPInstruction extends
ComputationCPInstruction
return Pair.of(output.getName(), new
LineageItem(getOpcode(),
LineageItemUtils.getLineage(ec, target, max,
dir, cast, ignore)));
}
+ else if (opcode.equalsIgnoreCase("lowertri") ||
opcode.equalsIgnoreCase("uppertri")) {
+ CPOperand target = getTargetOperand();
+ CPOperand lower = getBoolLiteral("lowertri");
+ CPOperand diag = getBoolLiteral("diag");
+ CPOperand values = getBoolLiteral("values");
+ return Pair.of(output.getName(), new
LineageItem(getOpcode(),
+ LineageItemUtils.getLineage(ec, target, lower,
diag, values)));
+ }
else if (opcode.equalsIgnoreCase("transformdecode") ||
opcode.equalsIgnoreCase("transformapply")) {
CPOperand target = new CPOperand(params.get("target"),
ValueType.FP64, DataType.FRAME);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CentralMomentFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CentralMomentFEDInstruction.java
index cc8e683..4bd522f 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CentralMomentFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CentralMomentFEDInstruction.java
@@ -23,6 +23,7 @@ import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
+import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
@@ -39,149 +40,165 @@ import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
+import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.CMOperator;
public class CentralMomentFEDInstruction extends AggregateUnaryFEDInstruction {
- private CentralMomentFEDInstruction(CMOperator cm, CPOperand in1,
CPOperand in2, CPOperand in3, CPOperand out,
- String opcode, String str) {
- super(cm, in1, in2, in3, out, opcode, str);
- }
-
- public static CentralMomentFEDInstruction parseInstruction(String str) {
- CPOperand in1 = new CPOperand("", Types.ValueType.UNKNOWN,
Types.DataType.UNKNOWN);
- CPOperand in2 = null;
- CPOperand in3 = null;
- CPOperand out = new CPOperand("", Types.ValueType.UNKNOWN,
Types.DataType.UNKNOWN);
-
- String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
- String opcode = parts[0];
-
- //check supported opcode
- if( !opcode.equalsIgnoreCase("cm") ) {
- throw new DMLRuntimeException("Unsupported opcode "+opcode);
- }
-
- if ( parts.length == 4 ) {
- // Example: CP.cm.mVar0.Var1.mVar2; (without weights)
- in2 = new CPOperand("", Types.ValueType.UNKNOWN,
Types.DataType.UNKNOWN);
- parseUnaryInstruction(str, in1, in2, out);
- }
- else if ( parts.length == 5) {
- // CP.cm.mVar0.mVar1.Var2.mVar3; (with weights)
- in2 = new CPOperand("", Types.ValueType.UNKNOWN,
Types.DataType.UNKNOWN);
- in3 = new CPOperand("", Types.ValueType.UNKNOWN,
Types.DataType.UNKNOWN);
- parseUnaryInstruction(str, in1, in2, in3, out);
- }
-
- /*
- * Exact order of the central moment MAY NOT be known at compilation
time.
- * We first try to parse the second argument as an integer, and if we
fail,
- * we simply pass -1 so that getCMAggOpType() picks up
AggregateOperationTypes.INVALID.
- * It must be updated at run time in processInstruction() method.
- */
-
- int cmOrder;
- try {
- if ( in3 == null ) {
- cmOrder = Integer.parseInt(in2.getName());
- }
- else {
- cmOrder = Integer.parseInt(in3.getName());
- }
- } catch(NumberFormatException e) {
- cmOrder = -1; // unknown at compilation time
- }
-
- CMOperator.AggregateOperationTypes opType =
CMOperator.getCMAggOpType(cmOrder);
- CMOperator cm = new CMOperator(CM.getCMFnObject(opType), opType);
- return new CentralMomentFEDInstruction(cm, in1, in2, in3, out, opcode,
str);
- }
-
- @Override
- public void processInstruction( ExecutionContext ec ) {
- MatrixObject mo = ec.getMatrixObject(input1.getName());
- ScalarObject order = ec.getScalarInput(input3==null ? input2 : input3);
-
- CMOperator cm_op = ((CMOperator) _optr);
- if(cm_op.getAggOpType() == CMOperator.AggregateOperationTypes.INVALID)
- cm_op = cm_op.setCMAggOp((int) order.getLongValue());
-
- FederationMap fedMapping = mo.getFedMapping();
- List<CM_COV_Object> globalCmobj = new ArrayList<>();
-
- long varID = FederationUtils.getNextFedDataID();
- CMOperator finalCm_op = cm_op;
- fedMapping.mapParallel(varID, (range, data) -> {
-
- FederatedResponse response;
- try {
- if (input3 == null ) {
- response = data.executeFederatedOperation(
- new
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
- new
CentralMomentFEDInstruction.CMFunction(data.getVarID(), finalCm_op))).get();
- } else {
- MatrixBlock wtBlock = ec.getMatrixInput(input2.getName());
-
- response = data.executeFederatedOperation(
- new
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
- new
CentralMomentFEDInstruction.CMWeightsFunction(data.getVarID(), finalCm_op,
wtBlock))).get();
- }
-
- if(!response.isSuccessful())
- response.throwExceptionFromResponse();
- synchronized(globalCmobj) {
- globalCmobj.add((CM_COV_Object) response.getData()[0]);
- }
- }
- catch(Exception e) {
- throw new DMLRuntimeException(e);
- }
- return null;
- });
-
- Optional<CM_COV_Object> res = globalCmobj.stream().reduce((arg0, arg1)
-> (CM_COV_Object) finalCm_op.fn.execute(arg0, arg1));
- try {
- ec.setScalarOutput(output.getName(), new
DoubleObject(res.get().getRequiredResult(finalCm_op)));
- }
- catch(Exception e) {
- throw new DMLRuntimeException(e);
- }
- }
-
- private static class CMFunction extends FederatedUDF {
- private static final long serialVersionUID = 7460149207607220994L;
- private final CMOperator _op;
-
- public CMFunction (long input, CMOperator op) {
- super(new long[] {input});
- _op = op;
- }
-
- @Override
- public FederatedResponse execute(ExecutionContext ec, Data... data) {
- MatrixBlock mb = ((MatrixObject) data[0]).acquireReadAndRelease();
- return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, mb.cmOperations(_op));
- }
- }
-
-
- private static class CMWeightsFunction extends FederatedUDF {
- private static final long serialVersionUID = -3685746246551622021L;
- private final CMOperator _op;
- private final MatrixBlock _weights;
-
- protected CMWeightsFunction(long input, CMOperator op, MatrixBlock
weights) {
- super(new long[] {input});
- _op = op;
- _weights = weights;
- }
-
- @Override
- public FederatedResponse execute(ExecutionContext ec, Data... data) {
- MatrixBlock mb = ((MatrixObject) data[0]).acquireReadAndRelease();
- return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, mb.cmOperations(_op,
_weights));
- }
- }
+ private CentralMomentFEDInstruction(CMOperator cm, CPOperand in1,
CPOperand in2, CPOperand in3, CPOperand out,
+ String opcode, String str) {
+ super(cm, in1, in2, in3, out, opcode, str);
+ }
+
+ public static CentralMomentFEDInstruction parseInstruction(String str) {
+ CPOperand in1 = new CPOperand("", Types.ValueType.UNKNOWN,
Types.DataType.UNKNOWN);
+ CPOperand in2 = null;
+ CPOperand in3 = null;
+ CPOperand out = new CPOperand("", Types.ValueType.UNKNOWN,
Types.DataType.UNKNOWN);
+
+ String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
+ String opcode = parts[0];
+
+ // check supported opcode
+ if (!opcode.equalsIgnoreCase("cm")) {
+ throw new DMLRuntimeException("Unsupported opcode " +
opcode);
+ }
+
+ if (parts.length == 4) {
+ // Example: CP.cm.mVar0.Var1.mVar2; (without weights)
+ in2 = new CPOperand("", Types.ValueType.UNKNOWN,
Types.DataType.UNKNOWN);
+ parseUnaryInstruction(str, in1, in2, out);
+ }
+ else if (parts.length == 5) {
+ // CP.cm.mVar0.mVar1.Var2.mVar3; (with weights)
+ in2 = new CPOperand("", Types.ValueType.UNKNOWN,
Types.DataType.UNKNOWN);
+ in3 = new CPOperand("", Types.ValueType.UNKNOWN,
Types.DataType.UNKNOWN);
+ parseUnaryInstruction(str, in1, in2, in3, out);
+ }
+
+ /*
+ * Exact order of the central moment MAY NOT be known at
compilation time. We
+ * first try to parse the second argument as an integer, and if
we fail, we
+ * simply pass -1 so that getCMAggOpType() picks up
+ * AggregateOperationTypes.INVALID. It must be updated at run
time in
+ * processInstruction() method.
+ */
+
+ int cmOrder;
+ try {
+ if (in3 == null) {
+ cmOrder = Integer.parseInt(in2.getName());
+ }
+ else {
+ cmOrder = Integer.parseInt(in3.getName());
+ }
+ }
+ catch (NumberFormatException e) {
+ cmOrder = -1; // unknown at compilation time
+ }
+
+ CMOperator.AggregateOperationTypes opType =
CMOperator.getCMAggOpType(cmOrder);
+ CMOperator cm = new CMOperator(CM.getCMFnObject(opType),
opType);
+ return new CentralMomentFEDInstruction(cm, in1, in2, in3, out,
opcode, str);
+ }
+
+ @Override
+ public void processInstruction(ExecutionContext ec) {
+ MatrixObject mo = ec.getMatrixObject(input1.getName());
+ ScalarObject order = ec.getScalarInput(input3 == null ? input2
: input3);
+
+ CMOperator cm_op = ((CMOperator) _optr);
+ if (cm_op.getAggOpType() ==
CMOperator.AggregateOperationTypes.INVALID)
+ cm_op = cm_op.setCMAggOp((int) order.getLongValue());
+
+ FederationMap fedMapping = mo.getFedMapping();
+ List<CM_COV_Object> globalCmobj = new ArrayList<>();
+
+ long varID = FederationUtils.getNextFedDataID();
+ CMOperator finalCm_op = cm_op;
+ fedMapping.mapParallel(varID, (range, data) -> {
+
+ FederatedResponse response;
+ try {
+ if (input3 == null) {
+ response = data
+
.executeFederatedOperation(new
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
+ new
CentralMomentFEDInstruction.CMFunction(data.getVarID(), finalCm_op)))
+ .get();
+ }
+ else {
+ MatrixBlock wtBlock =
ec.getMatrixInput(input2.getName());
+
+ response =
data.executeFederatedOperation(new FederatedRequest(
+
FederatedRequest.RequestType.EXEC_UDF, -1,
+ new
CentralMomentFEDInstruction.CMWeightsFunction(data.getVarID(), finalCm_op,
wtBlock)))
+ .get();
+ }
+
+ if (!response.isSuccessful())
+ response.throwExceptionFromResponse();
+ synchronized (globalCmobj) {
+ globalCmobj.add((CM_COV_Object)
response.getData()[0]);
+ }
+ }
+ catch (Exception e) {
+ throw new DMLRuntimeException(e);
+ }
+ return null;
+ });
+
+ Optional<CM_COV_Object> res = globalCmobj.stream()
+ .reduce((arg0, arg1) -> (CM_COV_Object)
finalCm_op.fn.execute(arg0, arg1));
+ try {
+ ec.setScalarOutput(output.getName(), new
DoubleObject(res.get().getRequiredResult(finalCm_op)));
+ }
+ catch (Exception e) {
+ throw new DMLRuntimeException(e);
+ }
+ }
+
+ private static class CMFunction extends FederatedUDF {
+ private static final long serialVersionUID =
7460149207607220994L;
+ private final CMOperator _op;
+
+ public CMFunction(long input, CMOperator op) {
+ super(new long[] {input});
+ _op = op;
+ }
+
+ @Override
+ public FederatedResponse execute(ExecutionContext ec, Data...
data) {
+ MatrixBlock mb = ((MatrixObject)
data[0]).acquireReadAndRelease();
+ return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, mb.cmOperations(_op));
+ }
+
+ @Override
+ public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
+ return null;
+ }
+ }
+
+ private static class CMWeightsFunction extends FederatedUDF {
+ private static final long serialVersionUID =
-3685746246551622021L;
+ private final CMOperator _op;
+ private final MatrixBlock _weights;
+
+ protected CMWeightsFunction(long input, CMOperator op,
MatrixBlock weights) {
+ super(new long[] {input});
+ _op = op;
+ _weights = weights;
+ }
+
+ @Override
+ public FederatedResponse execute(ExecutionContext ec, Data...
data) {
+ MatrixBlock mb = ((MatrixObject)
data[0]).acquireReadAndRelease();
+ return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, mb.cmOperations(_op,
_weights));
+ }
+
+ @Override
+ public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
+ return null;
+ }
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
index d02d0f5..834d11d 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
@@ -23,6 +23,7 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.concurrent.Future;
+import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
@@ -38,6 +39,7 @@ import
org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.Operator;
@@ -212,6 +214,11 @@ public class MultiReturnParameterizedBuiltinFEDInstruction
extends ComputationFE
// create federated response
return new FederatedResponse(ResponseType.SUCCESS, new
Object[] {encoder, fb.getColumnNames()});
}
+
+ @Override
+ public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
+ return null;
+ }
}
public static class ExecuteFrameEncoder extends FederatedUDF {
@@ -242,5 +249,10 @@ public class MultiReturnParameterizedBuiltinFEDInstruction
extends ComputationFE
// return id handle
return new
FederatedResponse(ResponseType.SUCCESS_EMPTY);
}
+
+ @Override
+ public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
+ return null;
+ }
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
index a2b63e9..cff8bb0 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
@@ -25,7 +25,9 @@ import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
+import java.util.stream.Stream;
+import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ValueType;
@@ -48,6 +50,8 @@ import org.apache.sysds.runtime.functionobjects.ValueFunction;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
@@ -242,6 +246,27 @@ public class ParameterizedBuiltinFEDInstruction extends
ComputationFEDInstructio
return new
FederatedResponse(ResponseType.SUCCESS_EMPTY);
}
+
+ @Override
+ public List<Long> getOutputIds() {
+ return new ArrayList<>(Arrays.asList(_outputID));
+ }
+
+ @Override
+ public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
+ LineageItem[] liUdfInputs = Arrays.stream(getInputIDs())
+ .mapToObj(id ->
ec.getLineage().get(String.valueOf(id))).toArray(LineageItem[]::new);
+ CPOperand slice = new
CPOperand(Arrays.toString(_slice), ValueType.STRING, DataType.SCALAR, true);
+ CPOperand rowFed = new
CPOperand(String.valueOf(_rowFed), ValueType.BOOLEAN, DataType.SCALAR, true);
+ CPOperand lower = new CPOperand(String.valueOf(_lower),
ValueType.BOOLEAN, DataType.SCALAR, true);
+ CPOperand diag = new CPOperand(String.valueOf(_diag),
ValueType.BOOLEAN, DataType.SCALAR, true);
+ CPOperand values = new
CPOperand(String.valueOf(_values), ValueType.BOOLEAN, DataType.SCALAR, true);
+ LineageItem[] otherInputs =
LineageItemUtils.getLineage(ec, slice, rowFed, lower, diag, values);
+ LineageItem[] liInputs =
Stream.concat(Arrays.stream(liUdfInputs), Arrays.stream(otherInputs))
+ .toArray(LineageItem[]::new);
+ return Pair.of(String.valueOf(_outputID),
+ new
LineageItem(getClass().getSimpleName(), liInputs));
+ }
}
private void rmempty(ExecutionContext ec) {
@@ -561,6 +586,11 @@ public class ParameterizedBuiltinFEDInstruction extends
ComputationFEDInstructio
// return schema
return new FederatedResponse(ResponseType.SUCCESS, new
Object[] {fo.getSchema()});
}
+
+ @Override
+ public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
+ return null;
+ }
}
private static class GetColumnNames extends FederatedUDF {
@@ -576,6 +606,11 @@ public class ParameterizedBuiltinFEDInstruction extends
ComputationFEDInstructio
// return column names
return new FederatedResponse(ResponseType.SUCCESS, new
Object[] {fb.getColumnNames()});
}
+
+ @Override
+ public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
+ return null;
+ }
}
private static class InitRowsToRemoveOmit extends FederatedUDF {
@@ -594,6 +629,11 @@ public class ParameterizedBuiltinFEDInstruction extends
ComputationFEDInstructio
_encoder.build(fb);
return new FederatedResponse(ResponseType.SUCCESS, new
Object[] {_encoder});
}
+
+ @Override
+ public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
+ return null;
+ }
}
private static class GetDataCharacteristics extends FederatedUDF {
@@ -611,6 +651,11 @@ public class ParameterizedBuiltinFEDInstruction extends
ComputationFEDInstructio
int c = mb.getDenseBlockValues() != null ?
mb.getNumColumns(): 0;
return new FederatedResponse(ResponseType.SUCCESS, new
int[] {r, c});
}
+
+ @Override
+ public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
+ return null;
+ }
}
private static class GetVector extends FederatedUDF {
@@ -640,5 +685,10 @@ public class ParameterizedBuiltinFEDInstruction extends
ComputationFEDInstructio
tmp1 = tmp1.binaryOperationsInPlace(greater, new
MatrixBlock(tmp1.getNumRows(), tmp1.getNumColumns(), 0.0));
return new FederatedResponse(ResponseType.SUCCESS,
tmp1);
}
+
+ @Override
+ public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
+ return null;
+ }
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
index f2052b5..1d9cbdd 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
@@ -22,6 +22,7 @@ package org.apache.sysds.runtime.instructions.fed;
import java.util.ArrayList;
import java.util.List;
+import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.lops.PickByCount.OperationTypes;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
@@ -36,6 +37,7 @@ import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
+import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.Operator;
@@ -176,6 +178,11 @@ public class QuantilePickFEDInstruction extends
BinaryFEDInstruction {
new Object[] {picked});
}
}
+
+ @Override
+ public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
+ return null;
+ }
}
private static class IQM extends FederatedUDF {
@@ -191,6 +198,10 @@ public class QuantilePickFEDInstruction extends
BinaryFEDInstruction {
return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS,
new Object[] {mb.interQuartileMean()});
}
+ @Override
+ public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
+ return null;
+ }
}
private static class Median extends FederatedUDF {
@@ -206,5 +217,9 @@ public class QuantilePickFEDInstruction extends
BinaryFEDInstruction {
return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS,
new Object[] {mb.median()});
}
+ @Override
+ public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
+ return null;
+ }
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
index 0e994d9..c1aa37d 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
@@ -19,6 +19,7 @@
package org.apache.sysds.runtime.instructions.fed;
+import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.common.Types;
import org.apache.sysds.lops.SortKeys;
import org.apache.sysds.runtime.DMLRuntimeException;
@@ -32,6 +33,7 @@ import
org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
public class QuantileSortFEDInstruction extends UnaryFEDInstruction{
@@ -159,5 +161,9 @@ public class QuantileSortFEDInstruction extends
UnaryFEDInstruction{
// return schema
return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS_EMPTY);
}
+ @Override
+ public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
+ return null;
+ }
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
index a033769..65e6e97 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
@@ -22,6 +22,7 @@ package org.apache.sysds.runtime.instructions.fed;
import java.util.HashMap;
import java.util.Map;
+import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
@@ -38,6 +39,7 @@ import org.apache.sysds.runtime.functionobjects.SwapIndex;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
@@ -263,6 +265,11 @@ public class ReorgFEDInstruction extends
UnaryFEDInstruction {
return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, new
int[]{res.getNumRows(), res.getNumColumns()});
}
+
+ @Override
+ public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
+ return null;
+ }
}
private static class DiagMatrix extends FederatedUDF {
@@ -302,5 +309,10 @@ public class ReorgFEDInstruction extends
UnaryFEDInstruction {
return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, new
int[]{res.getNumRows(), res.getNumColumns()});
}
+
+ @Override
+ public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
+ return null;
+ }
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
index 287f9aa..0f9b6ed 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
@@ -32,6 +32,7 @@ import org.apache.sysds.parser.Statement;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.instructions.CPInstructionParser;
import org.apache.sysds.runtime.instructions.Instruction;
@@ -234,6 +235,69 @@ public class LineageCache
return reuse;
}
+ //Reuse federated UDFs
+ public static boolean reuse(FederatedUDF udf, ExecutionContext ec)
+ {
+ if (ReuseCacheType.isNone() || udf.getOutputIds() == null)
+ return false;
+ //TODO: reuse only those UDFs which are part of reusable
instructions
+
+ boolean reuse = false;
+ List<Long> outIds = udf.getOutputIds();
+ HashMap<String, Data> udfOutputs = new HashMap<>();
+
+ //TODO: support multi-return UDFs
+ if (udf.getLineageItem(ec) == null)
+ //TODO: trace all UDFs
+ return false;
+ LineageItem li = udf.getLineageItem(ec).getValue();
+ li.setDistLeaf2Node(1); //to save from early eviction
+ LineageCacheEntry e = null;
+ synchronized(_cache) {
+ if (probe(li))
+ e = LineageCache.getIntern(li);
+ else
+ //for now allow only matrix blocks
+ putIntern(li, DataType.MATRIX, null, null, 0);
+ }
+
+ if (e != null) {
+ String outName = String.valueOf(outIds.get(0));
+ Data outValue = null;
+ //convert to matrix object
+ if (e.isMatrixValue()) {
+ MetaDataFormat md = new MetaDataFormat(
+
e.getMBValue().getDataCharacteristics(),FileFormat.BINARY);
+ outValue = new MatrixObject(ValueType.FP64,
outName, md);
+
((MatrixObject)outValue).acquireModify(e.getMBValue());
+ ((MatrixObject)outValue).release();
+ }
+ else {
+ outValue = e.getSOValue();
+ }
+ udfOutputs.put(outName, outValue);
+ reuse = true;
+ }
+ else
+ reuse = false;
+
+ if (reuse) {
+ udfOutputs.forEach((var, val) -> {
+ //cleanup existing data bound to output name
+ Data exdata = ec.removeVariable(var);
+ if (exdata != val)
+ ec.cleanupDataObject(exdata);
+ //add or replace data in the symbol table
+ ec.setVariable(var, val);
+ });
+
+ if (DMLScript.STATISTICS)
+ //TODO: dedicated stats for federated reuse
+ LineageCacheStatistics.incrementInstHits();
+ }
+ return reuse;
+ }
+
public static boolean probe(LineageItem key) {
//TODO problematic as after probe the matrix might be kicked
out of cache
boolean p = _cache.containsKey(key); // in cache or in disk
@@ -384,6 +448,53 @@ public class LineageCache
return;
}
+ public static void putValue(FederatedUDF udf, ExecutionContext ec, long
computetime)
+ {
+ if (ReuseCacheType.isNone() || udf.getOutputIds() == null)
+ return;
+
+ List<Long> outIds = udf.getOutputIds();
+ if (udf.getLineageItem(ec) == null)
+ //TODO: trace all UDFs
+ return;
+ LineageItem item = udf.getLineageItem(ec).getValue();
+ LineageCacheEntry entry = _cache.get(item);
+ Data data = ec.getVariable(String.valueOf(outIds.get(0)));
+ if (!(data instanceof MatrixObject) && !(data instanceof
ScalarObject)) {
+ // Don't cache if the udf outputs frames
+ _cache.remove(item);
+ return;
+ }
+
+ MatrixBlock mb = (data instanceof MatrixObject) ?
+ ((MatrixObject)data).acquireReadAndRelease() :
null;
+ long size = mb != null ? mb.getInMemorySize() :
((ScalarObject)data).getSize();
+
+ //remove the placeholder if the entry is bigger than the cache.
+ //FIXME: the resumed threads will enter into infinite wait as
the entry
+ //is removed. Need to add support for graceful remove
(placeholder) and resume.
+ if (size > LineageCacheEviction.getCacheLimit()) {
+ _cache.remove(item);
+ return;
+ }
+
+ //make space for the data
+ if (!LineageCacheEviction.isBelowThreshold(size))
+ LineageCacheEviction.makeSpace(_cache, size);
+ LineageCacheEviction.updateSize(size, true);
+
+ //place the data
+ if (data instanceof MatrixObject)
+ entry.setValue(mb, computetime);
+ else if (data instanceof ScalarObject)
+ entry.setValue((ScalarObject)data, computetime);
+
+ //TODO: maintain statistics, lineage estimate
+
+ //maintain order for eviction
+ LineageCacheEviction.addEntry(entry);
+ }
+
public static void resetCache() {
synchronized (_cache) {
_cache.clear();
diff --git
a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
index 61dfc8e..2532d14 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
@@ -43,7 +43,7 @@ public class LineageCacheConfig
"uamean", "max", "min", "ifelse", "-", "sqrt", ">", "uak+",
"<=",
"^", "uamax", "uark+", "uacmean", "eigen", "ctableexpand",
"replace",
"^2", "uack+", "tak+*", "uacsqk+", "uark+", "n+", "uarimax",
"qsort",
- "qpick", "transformapply", "uarmax", "n+", "-*", "castdtm"
+ "qpick", "transformapply", "uarmax", "n+", "-*", "castdtm",
"lowertri"
//TODO: Reuse everything.
};
private static String[] REUSE_OPCODES = new String[] {};
diff --git
a/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
b/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
index 977b12c..49ed459 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
@@ -20,6 +20,7 @@
package org.apache.sysds.runtime.lineage;
import org.apache.commons.lang3.ArrayUtils;
+import org.apache.commons.lang3.tuple.Pair;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.LocalFileSystem;
import org.apache.hadoop.fs.Path;
@@ -47,6 +48,7 @@ import org.apache.sysds.lops.compile.Dag;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.InstructionParser;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
@@ -138,6 +140,25 @@ public class LineageItemUtils {
.map(c ->
ec.getLineage().getOrCreate(c)).toArray(LineageItem[]::new);
}
+ public static void traceFedUDF(ExecutionContext ec, FederatedUDF udf) {
+ if (udf.getLineageItem(ec) == null)
+ //TODO: trace all UDFs
+ return;
+
+ if (!(udf instanceof LineageTraceable))
+ throw new DMLRuntimeException("Unknown Federated UDF ("
+ udf.getClass().getSimpleName() + ") traced.");
+ LineageTraceable ludf = (LineageTraceable) udf;
+ if (ludf.hasSingleLineage()) {
+ Pair<String, LineageItem> item = udf.getLineageItem(ec);
+ ec.getLineage().set(item.getKey(), item.getValue());
+ }
+ else {
+ Pair<String, LineageItem>[] items =
udf.getLineageItems(ec);
+ for (Pair<String, LineageItem> item : items)
+ ec.getLineage().set(item.getKey(),
item.getValue());
+ }
+ }
+
public static void constructLineageFromHops(Hop[] roots, String
claName, Hop[] inputs, HashMap<Long, Hop> spoofmap) {
//probe existence and only generate lineage if non-existing
//(a fused operator might be used in multiple places of a
program)
diff --git
a/src/test/java/org/apache/sysds/test/functions/lineage/FedUDFReuseTest.java
b/src/test/java/org/apache/sysds/test/functions/lineage/FedUDFReuseTest.java
new file mode 100644
index 0000000..f7f01f6
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/lineage/FedUDFReuseTest.java
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.lineage;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.lineage.Lineage;
+import org.apache.sysds.runtime.lineage.LineageCacheStatistics;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class FedUDFReuseTest extends AutomatedTestBase {
+ private final static String TEST_NAME1 = "FedUdfReuse1";
+
+ private final static String TEST_DIR = "functions/lineage/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
FedUDFReuseTest.class.getSimpleName() + "/";
+
+ private final static int blocksize = 1024;
+ @Parameterized.Parameter()
+ public int rows;
+ @Parameterized.Parameter(1)
+ public int cols;
+
+ @Parameterized.Parameter(2)
+ public boolean rowPartitioned;
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ return Arrays.asList(new Object[][] {
+ {20, 20, true}
+ });
+ }
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME1, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"S"}));
+ }
+
+ @Test
+ public void testTriUDFReuseCP() {
+ runTriUDFReuse(ExecMode.SINGLE_NODE);
+ }
+
+ private void runTriUDFReuse(ExecMode execMode) {
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ ExecMode platformOld = rtplatform;
+ String TEST_NAME = TEST_NAME1;
+
+ if(rtplatform == ExecMode.SPARK)
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+ getAndLoadTestConfiguration(TEST_NAME);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ // write input matrices
+ int r = rows;
+ int c = cols / 4;
+ if(rowPartitioned) {
+ r = rows / 4;
+ c = cols;
+ }
+
+ double[][] X1 = getRandomMatrix(r, c, 1, 5, 1, 3);
+ double[][] X2 = getRandomMatrix(r, c, 1, 5, 1, 7);
+ double[][] X3 = getRandomMatrix(r, c, 1, 5, 1, 8);
+ double[][] X4 = getRandomMatrix(r, c, 1, 5, 1, 9);
+
+ MatrixCharacteristics mc = new MatrixCharacteristics(r, c,
blocksize, r * c);
+ writeInputMatrixWithMTD("X1", X1, false, mc);
+ writeInputMatrixWithMTD("X2", X2, false, mc);
+ writeInputMatrixWithMTD("X3", X3, false, mc);
+ writeInputMatrixWithMTD("X4", X4, false, mc);
+
+ // empty script name because we don't execute any script, just
start the worker
+ fullDMLScriptName = "";
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ int port3 = getRandomAvailablePort();
+ int port4 = getRandomAvailablePort();
+ String[] otherargs = new String[] {"-lineage", "reuse_full"};
+ Lineage.resetInternalState();
+ Thread t1 = startLocalFedWorkerThread(port1, otherargs,
FED_WORKER_WAIT_S);
+ Thread t2 = startLocalFedWorkerThread(port2, otherargs,
FED_WORKER_WAIT_S);
+ Thread t3 = startLocalFedWorkerThread(port3, otherargs,
FED_WORKER_WAIT_S);
+ Thread t4 = startLocalFedWorkerThread(port4, otherargs);
+
+ rtplatform = execMode;
+ if(rtplatform == ExecMode.SPARK) {
+ System.out.println(7);
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+ }
+ TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
+ loadTestConfiguration(config);
+
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+ programArgs = new String[] {"-lineage", "reuse_full", "-stats",
"100", "-args",
+ input("X1"), input("X2"), input("X3"), input("X4"),
+ Boolean.toString(rowPartitioned).toUpperCase(),
expected("S")};
+ runTest(null);
+
+ // Run actual dml script with federated matrix
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-lineage", "reuse_full", "-stats",
"100", "-nvargs",
+ "in_X1=" + TestUtils.federatedAddress(port1,
input("X1")),
+ "in_X2=" + TestUtils.federatedAddress(port2,
input("X2")),
+ "in_X3=" + TestUtils.federatedAddress(port3,
input("X3")),
+ "in_X4=" + TestUtils.federatedAddress(port4,
input("X4")), "rows=" + rows, "cols=" + cols,
+ "rP=" + Boolean.toString(rowPartitioned).toUpperCase(),
"out_S=" + output("S")};
+
+ runTest(null);
+
+ // compare via files
+ compareResults(1e-9);
+ // check if lowertri is federated
+ Assert.assertTrue(heavyHittersContainsString("fed_lowertri"));
+ // assert reuse count
+ Assert.assertTrue(LineageCacheStatistics.getInstHits() > 0);
+
+ TestUtils.shutdownThreads(t1, t2, t3, t4);
+
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+}
+
diff --git a/src/test/scripts/functions/lineage/FedUdfReuse1.dml
b/src/test/scripts/functions/lineage/FedUdfReuse1.dml
new file mode 100644
index 0000000..1867db9
--- /dev/null
+++ b/src/test/scripts/functions/lineage/FedUdfReuse1.dml
@@ -0,0 +1,34 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+if ($rP) {
+ A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0),
list(2*$rows/4, $cols),
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0),
list($rows, $cols)));
+ } else {
+ A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4),
list($rows, $cols/2),
+ list(0,$cols/2), list($rows, 3*($cols/4)), list(0,
3*($cols/4)), list($rows, $cols)));
+ }
+
+for (i in 1:10)
+ s = lower.tri(target=A, diag=FALSE, values=TRUE);
+write(s, $out_S);
diff --git a/src/test/scripts/functions/lineage/FedUdfReuse1Reference.dml
b/src/test/scripts/functions/lineage/FedUdfReuse1Reference.dml
new file mode 100644
index 0000000..8175334
--- /dev/null
+++ b/src/test/scripts/functions/lineage/FedUdfReuse1Reference.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+if($5) { A = rbind(read($1), read($2), read($3), read($4)); }
+else { A = cbind(read($1), read($2), read($3), read($4));}
+
+for (i in 1:10)
+ s = lower.tri(target=A, diag=FALSE, values=TRUE);
+
+write(s, $6);