This is an automated email from the ASF dual-hosted git repository.
arnabp20 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 71385e1 [SYSTEMDS-2799] Lineage support for Federated UDFs
71385e1 is described below
commit 71385e1556a97e268acdb238709fe86bed5fe8af
Author: arnabp <[email protected]>
AuthorDate: Thu Jan 21 14:05:44 2021 +0100
[SYSTEMDS-2799] Lineage support for Federated UDFs
This patch extends the foundation to support all the UDFs
that are used by Federated instructions. A few multi-phased
instructions use UDFs to generate phased results which
are not saved in the symbol table. We do not trace those
UDFs as they are not part of the lineage DAGs.
---
.../controlprogram/federated/FederatedRequest.java | 4 +-
.../federated/FederatedWorkerHandler.java | 1 +
...tiReturnParameterizedBuiltinFEDInstruction.java | 27 ++++++++-
.../fed/ParameterizedBuiltinFEDInstruction.java | 39 +++++++++++-
.../instructions/fed/ReorgFEDInstruction.java | 40 ++++++++++++-
.../TransformFederatedEncodeApplyTest.java | 69 ++++++++++++++--------
.../transform/TransformFederatedEncodeApply.dml | 3 +-
7 files changed, 153 insertions(+), 30 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
index 05d7e57..00d0ac5 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
@@ -75,7 +75,7 @@ public class FederatedRequest implements Serializable {
_id = id;
_data = data;
setCheckPrivacy();
- if (DMLScript.LINEAGE)
+ if (DMLScript.LINEAGE && method == RequestType.PUT_VAR)
setChecksum();
}
@@ -161,7 +161,7 @@ public class FederatedRequest implements Serializable {
if (ob instanceof CacheBlock) {
try {
CacheBlock cb = (CacheBlock)ob;
- long cbsize =
LazyWriteBuffer.getCacheBlockSize((CacheBlock)ob);
+ long cbsize =
LazyWriteBuffer.getCacheBlockSize(cb);
DataOutput dout = new
CacheDataOutput(new byte[(int)cbsize]);
cb.write(dout);
byte bytes[] = ((CacheDataOutput)
dout).getBytes();
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index 5b574c1..d536830 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -352,6 +352,7 @@ public class FederatedWorkerHandler extends
ChannelInboundHandlerAdapter {
// reuse UDF outputs if available in lineage cache
if (LineageCache.reuse(udf, ec))
return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS_EMPTY);
+ //FIXME: few UDFs (e.g. Rdiag, DiagMatrix)
return additional data with response
// else execute the UDF
long t0 = !ReuseCacheType.isNone() ? System.nanoTime()
: 0;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
index 834d11d..ca0b066 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
@@ -21,10 +21,17 @@ package org.apache.sysds.runtime.instructions.fed;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.List;
import java.util.concurrent.Future;
+import java.util.stream.Stream;
+import java.util.zip.Adler32;
+import java.util.zip.Checksum;
+import org.apache.commons.lang3.SerializationUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
@@ -40,6 +47,7 @@ import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.Operator;
@@ -251,8 +259,25 @@ public class MultiReturnParameterizedBuiltinFEDInstruction
extends ComputationFE
}
@Override
+ public List<Long> getOutputIds() {
+ return new ArrayList<>(Arrays.asList(_outputID));
+ }
+
+ @Override
public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
- return null;
+ LineageItem[] liUdfInputs = Arrays.stream(getInputIDs())
+ .mapToObj(id ->
ec.getLineage().get(String.valueOf(id))).toArray(LineageItem[]::new);
+ // calculate checksum for the encoder
+ Checksum checksum = new Adler32();
+ byte bytes[] = SerializationUtils.serialize(_encoder);
+ checksum.update(bytes, 0, bytes.length);
+ CPOperand encoder = new
CPOperand(String.valueOf(checksum.getValue()),
+ ValueType.INT64, DataType.SCALAR, true);
+ LineageItem[] otherInputs =
LineageItemUtils.getLineage(ec, encoder);
+ LineageItem[] liInputs =
Stream.concat(Arrays.stream(liUdfInputs), Arrays.stream(otherInputs))
+ .toArray(LineageItem[]::new);
+ return Pair.of(String.valueOf(_outputID),
+ new
LineageItem(getClass().getSimpleName(), liInputs));
}
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
index cff8bb0..f870041 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
@@ -19,6 +19,8 @@
package org.apache.sysds.runtime.instructions.fed;
+import java.io.DataOutput;
+import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
@@ -26,7 +28,10 @@ import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;
+import java.util.zip.Adler32;
+import java.util.zip.Checksum;
+import org.apache.commons.lang3.SerializationUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.DataType;
@@ -34,8 +39,10 @@ import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.CacheDataOutput;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
+import org.apache.sysds.runtime.controlprogram.caching.LazyWriteBuffer;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
@@ -588,8 +595,38 @@ public class ParameterizedBuiltinFEDInstruction extends
ComputationFEDInstructio
}
@Override
+ public List<Long> getOutputIds() {
+ return new ArrayList<>(Arrays.asList(_outputID));
+ }
+
+ @Override
public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
- return null;
+ LineageItem[] liUdfInputs = Arrays.stream(getInputIDs())
+ .mapToObj(id ->
ec.getLineage().get(String.valueOf(id))).toArray(LineageItem[]::new);
+ // calculate checksums for meta and decoder
+ Checksum checksum = new Adler32();
+ try {
+ long cbsize =
LazyWriteBuffer.getCacheBlockSize(_meta);
+ DataOutput fout = new CacheDataOutput(new
byte[(int)cbsize]);
+ _meta.write(fout);
+ byte bytes[] = ((CacheDataOutput)
fout).getBytes();
+ checksum.update(bytes, 0, bytes.length);
+ }
+ catch (IOException e) {
+ throw new DMLRuntimeException("Failed to
serialize cache block.");
+ }
+ CPOperand meta = new
CPOperand(String.valueOf(checksum.getValue()),
+ ValueType.INT64, DataType.SCALAR, true);
+ checksum.reset();
+ byte bytes[] = SerializationUtils.serialize(_decoder);
+ checksum.update(bytes, 0, bytes.length);
+ CPOperand decoder = new
CPOperand(String.valueOf(checksum.getValue()),
+ ValueType.INT64, DataType.SCALAR, true);
+ LineageItem[] otherInputs =
LineageItemUtils.getLineage(ec, meta, decoder);
+ LineageItem[] liInputs =
Stream.concat(Arrays.stream(liUdfInputs), Arrays.stream(otherInputs))
+ .toArray(LineageItem[]::new);
+ return Pair.of(String.valueOf(_outputID),
+ new
LineageItem(getClass().getSimpleName(), liInputs));
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
index 65e6e97..d0b06e0 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
@@ -19,11 +19,17 @@
package org.apache.sysds.runtime.instructions.fed;
+import java.util.ArrayList;
+import java.util.Arrays;
import java.util.HashMap;
+import java.util.List;
import java.util.Map;
+import java.util.stream.Stream;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -40,6 +46,7 @@ import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
@@ -267,8 +274,22 @@ public class ReorgFEDInstruction extends
UnaryFEDInstruction {
}
@Override
+ public List<Long> getOutputIds() {
+ return new ArrayList<>(Arrays.asList(_outputID));
+ }
+
+ @Override
public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
- return null;
+ LineageItem[] liUdfInputs = Arrays.stream(getInputIDs())
+ .mapToObj(id ->
ec.getLineage().get(String.valueOf(id))).toArray(LineageItem[]::new);
+ CPOperand r_op = new
CPOperand(_r_op.fn.getClass().getSimpleName(), ValueType.STRING,
DataType.SCALAR, true);
+ CPOperand slice = new
CPOperand(Arrays.toString(_slice), ValueType.STRING, DataType.SCALAR, true);
+ CPOperand rowFed = new
CPOperand(String.valueOf(_rowFed), ValueType.BOOLEAN, DataType.SCALAR, true);
+ LineageItem[] otherInputs =
LineageItemUtils.getLineage(ec, r_op, slice, rowFed);
+ LineageItem[] liInputs =
Stream.concat(Arrays.stream(liUdfInputs), Arrays.stream(otherInputs))
+ .toArray(LineageItem[]::new);
+ return Pair.of(String.valueOf(_outputID),
+ new
LineageItem(getClass().getSimpleName(), liInputs));
}
}
@@ -311,8 +332,23 @@ public class ReorgFEDInstruction extends
UnaryFEDInstruction {
}
@Override
+ public List<Long> getOutputIds() {
+ return new ArrayList<>(Arrays.asList(_outputID));
+ }
+
+ @Override
public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
- return null;
+ LineageItem[] liUdfInputs = Arrays.stream(getInputIDs())
+ .mapToObj(id ->
ec.getLineage().get(String.valueOf(id))).toArray(LineageItem[]::new);
+ CPOperand r_op = new
CPOperand(_r_op.fn.getClass().getSimpleName(), ValueType.STRING,
DataType.SCALAR, true);
+ CPOperand len = new CPOperand(String.valueOf(_len),
ValueType.INT32, DataType.SCALAR, true);
+ CPOperand slice = new
CPOperand(Arrays.toString(_slice), ValueType.STRING, DataType.SCALAR, true);
+ CPOperand rowFed = new
CPOperand(String.valueOf(_rowFed), ValueType.BOOLEAN, DataType.SCALAR, true);
+ LineageItem[] otherInputs =
LineageItemUtils.getLineage(ec, r_op, len, slice, rowFed);
+ LineageItem[] liInputs =
Stream.concat(Arrays.stream(liUdfInputs), Arrays.stream(otherInputs))
+ .toArray(LineageItem[]::new);
+ return Pair.of(String.valueOf(_outputID),
+ new
LineageItem(getClass().getSimpleName(), liInputs));
}
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java
index 6c9b034..73e0532 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java
@@ -20,6 +20,8 @@
package org.apache.sysds.test.functions.federated.transform;
import java.io.IOException;
+
+import org.apache.commons.lang.ArrayUtils;
import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.common.Types.FileFormat;
@@ -29,6 +31,8 @@ import org.apache.sysds.runtime.io.FrameReaderFactory;
import org.apache.sysds.runtime.io.FrameWriter;
import org.apache.sysds.runtime.io.FrameWriterFactory;
import org.apache.sysds.runtime.io.MatrixReaderFactory;
+import org.apache.sysds.runtime.lineage.Lineage;
+import org.apache.sysds.runtime.lineage.LineageCacheStatistics;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.util.DataConverter;
@@ -85,95 +89,105 @@ public class TransformFederatedEncodeApplyTest extends
AutomatedTestBase {
@Test
public void testHomesRecodeIDsCSV() {
- runTransformTest(TransformType.RECODE, false);
+ runTransformTest(TransformType.RECODE, false, false);
}
@Test
public void testHomesDummycodeIDsCSV() {
- runTransformTest(TransformType.DUMMY, false);
+ runTransformTest(TransformType.DUMMY, false, false);
}
@Test
public void testHomesRecodeDummycodeIDsCSV() {
- runTransformTest(TransformType.RECODE_DUMMY, false);
+ runTransformTest(TransformType.RECODE_DUMMY, false, false);
}
@Test
public void testHomesBinningIDsCSV() {
- runTransformTest(TransformType.BIN, false);
+ runTransformTest(TransformType.BIN, false, false);
}
@Test
public void testHomesBinningDummyIDsCSV() {
- runTransformTest(TransformType.BIN_DUMMY, false);
+ runTransformTest(TransformType.BIN_DUMMY, false, false);
}
@Test
public void testHomesOmitIDsCSV() {
- runTransformTest(TransformType.OMIT, false);
+ runTransformTest(TransformType.OMIT, false, false);
}
@Test
public void testHomesImputeIDsCSV() {
- runTransformTest(TransformType.IMPUTE, false);
+ runTransformTest(TransformType.IMPUTE, false, false);
}
@Test
public void testHomesRecodeColnamesCSV() {
- runTransformTest(TransformType.RECODE, true);
+ runTransformTest(TransformType.RECODE, true, false);
}
@Test
public void testHomesDummycodeColnamesCSV() {
- runTransformTest(TransformType.DUMMY, true);
+ runTransformTest(TransformType.DUMMY, true, false);
}
@Test
public void testHomesRecodeDummycodeColnamesCSV() {
- runTransformTest(TransformType.RECODE_DUMMY, true);
+ runTransformTest(TransformType.RECODE_DUMMY, true, false);
}
@Test
public void testHomesBinningColnamesCSV() {
- runTransformTest(TransformType.BIN, true);
+ runTransformTest(TransformType.BIN, true, false);
}
@Test
public void testHomesBinningDummyColnamesCSV() {
- runTransformTest(TransformType.BIN_DUMMY, true);
+ runTransformTest(TransformType.BIN_DUMMY, true, false);
}
@Test
public void testHomesOmitColnamesCSV() {
- runTransformTest(TransformType.OMIT, true);
+ runTransformTest(TransformType.OMIT, true, false);
}
@Test
public void testHomesImputeColnamesCSV() {
- runTransformTest(TransformType.IMPUTE, true);
+ runTransformTest(TransformType.IMPUTE, true, false);
}
@Test
public void testHomesHashColnamesCSV() {
- runTransformTest(TransformType.HASH, true);
+ runTransformTest(TransformType.HASH, true, false);
}
@Test
public void testHomesHashIDsCSV() {
- runTransformTest(TransformType.HASH, false);
+ runTransformTest(TransformType.HASH, false, false);
}
@Test
public void testHomesHashRecodeColnamesCSV() {
- runTransformTest(TransformType.HASH_RECODE, true);
+ runTransformTest(TransformType.HASH_RECODE, true, false);
}
@Test
public void testHomesHashRecodeIDsCSV() {
- runTransformTest(TransformType.HASH_RECODE, false);
+ runTransformTest(TransformType.HASH_RECODE, false, false);
+ }
+
+ @Test
+ public void testHomesDummycodeIDsCSVLineage() {
+ runTransformTest(TransformType.DUMMY, false, true);
+ }
+
+ @Test
+ public void testHomesRecodeDummycodeIDsCSVLineage() {
+ runTransformTest(TransformType.RECODE_DUMMY, false, true);
}
- private void runTransformTest(TransformType type, boolean colnames) {
+ private void runTransformTest(TransformType type, boolean colnames,
boolean lineage) {
ExecMode rtold = setExecMode(ExecMode.SINGLE_NODE);
// set transform specification
@@ -199,10 +213,12 @@ public class TransformFederatedEncodeApplyTest extends
AutomatedTestBase {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- t1 = startLocalFedWorkerThread(port1,
FED_WORKER_WAIT_S);
- t2 = startLocalFedWorkerThread(port2,
FED_WORKER_WAIT_S);
- t3 = startLocalFedWorkerThread(port3,
FED_WORKER_WAIT_S);
- t4 = startLocalFedWorkerThread(port4);
+ String[] otherargs = lineage ? new String[]
{"-lineage", "reuse_full"} : null;
+ Lineage.resetInternalState();
+ t1 = startLocalFedWorkerThread(port1, otherargs,
FED_WORKER_WAIT_S);
+ t2 = startLocalFedWorkerThread(port2, otherargs,
FED_WORKER_WAIT_S);
+ t3 = startLocalFedWorkerThread(port3, otherargs,
FED_WORKER_WAIT_S);
+ t4 = startLocalFedWorkerThread(port4, otherargs);
FileFormatPropertiesCSV ffpCSV = new
FileFormatPropertiesCSV(true, DataExpression.DEFAULT_DELIM_DELIMITER,
DataExpression.DEFAULT_DELIM_FILL,
DataExpression.DEFAULT_DELIM_FILL_VALUE, DATASET.equals(DATASET1) ?
@@ -241,12 +257,16 @@ public class TransformFederatedEncodeApplyTest extends
AutomatedTestBase {
dataset.getNumColumns() - 1);
fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
+ String[] lineageArgs = new String[] {"-lineage",
"reuse_full", "-stats"};
programArgs = new String[] {"-nvargs", "in_AH=" +
TestUtils.federatedAddress(port1, input("AH")),
"in_AL=" + TestUtils.federatedAddress(port2,
input("AL")),
"in_BH=" + TestUtils.federatedAddress(port3,
input("BH")),
"in_BL=" + TestUtils.federatedAddress(port4,
input("BL")), "rows=" + dataset.getNumRows(),
"cols=" + dataset.getNumColumns(), "TFSPEC=" +
HOME + "input/" + SPEC, "TFDATA1=" + output("tfout1"),
"TFDATA2=" + output("tfout2"), "OFMT=csv"};
+
+ if (lineage)
+ programArgs = (String[])
ArrayUtils.addAll(lineageArgs, programArgs);
runTest(true, false, null, -1);
@@ -275,6 +295,9 @@ public class TransformFederatedEncodeApplyTest extends
AutomatedTestBase {
}
}
}
+ // assert reuse count
+ if (lineage)
+
Assert.assertTrue(LineageCacheStatistics.getInstHits() > 0);
}
catch(Exception ex) {
throw new RuntimeException(ex);
diff --git
a/src/test/scripts/functions/transform/TransformFederatedEncodeApply.dml
b/src/test/scripts/functions/transform/TransformFederatedEncodeApply.dml
index 28cdcda..4291993 100644
--- a/src/test/scripts/functions/transform/TransformFederatedEncodeApply.dml
+++ b/src/test/scripts/functions/transform/TransformFederatedEncodeApply.dml
@@ -27,7 +27,8 @@ F1 = federated(type="frame", addresses=list($in_AH, $in_AL,
$in_BH, $in_BL), ran
jspec = read($TFSPEC, data_type="scalar", value_type="string");
-[X, M] = transformencode(target=F1, spec=jspec);
+for (i in 1:2)
+ [X, M] = transformencode(target=F1, spec=jspec);
while(FALSE){}