This is an automated email from the ASF dual-hosted git repository.
arnabp20 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 6dd66b8 [SYSTEMDS-2784] Enable lineage-based reuse in federated
workers
6dd66b8 is described below
commit 6dd66b8257b43146fe3cd31bab61f3595184928a
Author: arnabp <[email protected]>
AuthorDate: Sat Jan 2 22:55:37 2021 +0100
[SYSTEMDS-2784] Enable lineage-based reuse in federated workers
This patch builds the initial infrastructure for lineage based
reuse in federated workers. Changes include:
- Lineage tracing InitFEDInstruction
- Lineage trace READ and PUT requests. For PUT, lineageitem hash
is sent with the request, which will be replaced by Adler32
in future commits.
- Disable compiler assisted optimizations for lineage-based reuse
(e.g. mark for caching) for the workers.
- Testing infrastructure.
---
.../controlprogram/federated/FederatedRequest.java | 13 +++
.../federated/FederatedWorkerHandler.java | 18 +++
.../fed/AggregateBinaryFEDInstruction.java | 6 +
.../instructions/fed/InitFEDInstruction.java | 34 +++++-
.../org/apache/sysds/test/AutomatedTestBase.java | 13 ++-
.../test/functions/lineage/FedFullReuseTest.java | 128 +++++++++++++++++++++
.../scripts/functions/lineage/FedFullReuse1.dml | 30 +++++
.../functions/lineage/FedFullReuse1Reference.dml | 28 +++++
8 files changed, 268 insertions(+), 2 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
index 6c9be16..33dad44 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
@@ -23,8 +23,10 @@ import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
+import java.util.stream.Collectors;
import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.utils.Statistics;
public class FederatedRequest implements Serializable {
@@ -45,6 +47,7 @@ public class FederatedRequest implements Serializable {
private long _tid;
private List<Object> _data;
private boolean _checkPrivacy;
+ private List<Integer> _lineageHash;
public FederatedRequest(RequestType method) {
@@ -117,6 +120,16 @@ public class FederatedRequest implements Serializable {
return _checkPrivacy;
}
+ public void setLineageHash(LineageItem[] liItems) {
+ // copy the hash of the corresponding lineage DAG
+ // TODO: copy both Adler32 checksum (on data) and hash (on
lineage DAG)
+ _lineageHash = Arrays.stream(liItems).map(li ->
li.hashCode()).collect(Collectors.toList());
+ }
+
+ public int getLineageHash(int i) {
+ return _lineageHash.get(i);
+ }
+
@Override
public String toString() {
StringBuilder sb = new StringBuilder("FederatedRequest[");
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index e3ec403..5c0a0bc 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -48,6 +48,8 @@ import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.io.FileFormatPropertiesCSV;
import org.apache.sysds.runtime.io.IOUtilFunctions;
+import org.apache.sysds.runtime.lineage.LineageCacheConfig;
+import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.meta.MetaDataFormat;
import org.apache.sysds.runtime.privacy.DMLPrivacyException;
@@ -232,6 +234,10 @@ public class FederatedWorkerHandler extends
ChannelInboundHandlerAdapter {
cd.enableCleanup(false); // guard against deletion
_ecm.get(tid).setVariable(String.valueOf(id), cd);
+ if (DMLScript.LINEAGE)
+ // create a literal type lineage item with the file name
+ _ecm.get(tid).getLineage().set(String.valueOf(id), new
LineageItem(filename));
+
if(dataType == Types.DataType.FRAME) {
FrameObject frameObject = (FrameObject) cd;
frameObject.acquireRead();
@@ -264,6 +270,10 @@ public class FederatedWorkerHandler extends
ChannelInboundHandlerAdapter {
// set variable and construct empty response
ec.setVariable(varname, data);
+ if (DMLScript.LINEAGE)
+ // TODO: Identify MO uniquely. Use Adler32 checksum.
+ ec.getLineage().set(varname, new
LineageItem(String.valueOf(request.getLineageHash(0))));
+
return new FederatedResponse(ResponseType.SUCCESS_EMPTY);
}
@@ -299,6 +309,14 @@ public class FederatedWorkerHandler extends
ChannelInboundHandlerAdapter {
pb.getInstructions().clear();
Instruction receivedInstruction =
InstructionParser.parseSingleInstruction((String) request.getParam(0));
pb.getInstructions().add(receivedInstruction);
+
+ if (DMLScript.LINEAGE)
+ // Compiler assisted optimizations are not applicable
for Fed workers.
+ // e.g. isMarkedForCaching fails as output operands are
saved in the
+ // symbol table only after the instruction execution
finishes.
+ // NOTE: In shared JVM, this will disable compiler
assistance even for the coordinator
+ LineageCacheConfig.setCompAssRW(false);
+
try {
pb.execute(ec); // execute single instruction
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
index 6ed642e..4a8194b 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
@@ -21,6 +21,7 @@ package org.apache.sysds.runtime.instructions.fed;
import java.util.concurrent.Future;
+import org.apache.sysds.api.DMLScript;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -31,6 +32,7 @@ import
org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.Operator;
@@ -78,6 +80,10 @@ public class AggregateBinaryFEDInstruction extends
BinaryFEDInstruction {
else if(mo1.isFederated(FType.ROW)) { // MV + MM
//construct commands: broadcast rhs, fed mv, retrieve
results
FederatedRequest fr1 =
mo1.getFedMapping().broadcast(mo2);
+ if (DMLScript.LINEAGE)
+ //also copy the hash of the lineage DAG
+
fr1.setLineageHash(LineageItemUtils.getLineage(ec, input1));
+ //TODO: calculate Adler32 checksum on data, and
move this code inside FederationMap.
FederatedRequest fr2 =
FederationUtils.callInstruction(instString, output,
new CPOperand[]{input1, input2}, new
long[]{mo1.getFedMapping().getID(), fr1.getID()});
if( mo2.getNumColumns() == 1 ) { //MV
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
index 17e2855..bc16149 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
@@ -56,9 +56,11 @@ import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.cp.StringObject;
+import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.runtime.lineage.LineageTraceable;
import org.apache.sysds.runtime.meta.DataCharacteristics;
-public class InitFEDInstruction extends FEDInstruction {
+public class InitFEDInstruction extends FEDInstruction implements
LineageTraceable {
private static final Log LOG =
LogFactory.getLog(InitFEDInstruction.class.getName());
@@ -342,4 +344,34 @@ public class InitFEDInstruction extends FEDInstruction {
throw new DMLRuntimeException("Exception in frame
response from federated worker.", e);
}
}
+
+ @Override
+ public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
+ String type = ec.getScalarInput(_type).getStringValue();
+ ListObject addresses = ec.getListObject(_addresses.getName());
+ ListObject ranges = ec.getListObject(_ranges.getName());
+ LineageItem[] liInputs = new LineageItem[addresses.getLength()];
+
+ for(int i = 0; i < addresses.getLength(); i++) {
+ Data addressData = addresses.getData().get(i);
+ if(addressData instanceof StringObject) {
+ String address =
((StringObject)addressData).getStringValue();
+ // get beginning and end of data ranges
+ List<Data> rangesData = ranges.getData();
+ List<Data> beginDimsData = ((ListObject)
rangesData.get(i*2)).getData();
+ List<Data> endDimsData = ((ListObject)
rangesData.get(i*2+1)).getData();
+ String rl =
((ScalarObject)beginDimsData.get(0)).getStringValue();
+ String cl =
((ScalarObject)beginDimsData.get(1)).getStringValue();
+ String ru =
((ScalarObject)endDimsData.get(0)).getStringValue();
+ String cu =
((ScalarObject)endDimsData.get(1)).getStringValue();
+ // form a string with all the information and
create a lineage item
+ String data =
InstructionUtils.concatOperands(type, address, rl, cl, ru, cu);
+ liInputs[i] = new LineageItem(data);
+ }
+ else {
+ throw new DMLRuntimeException("federated
instruction only takes strings as addresses");
+ }
+ }
+ return Pair.of(_output.getName(), new LineageItem(getOpcode(),
liInputs));
+ }
}
diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index 0143fed..d51f05b 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -36,6 +36,7 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;
+import java.util.stream.Collectors;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
@@ -1530,7 +1531,11 @@ public abstract class AutomatedTestBase {
* @return the thread associated with the worker.
*/
protected Thread startLocalFedWorkerThread(int port) {
- return startLocalFedWorkerThread(port, FED_WORKER_WAIT);
+ return startLocalFedWorkerThread(port, null, FED_WORKER_WAIT);
+ }
+
+ protected Thread startLocalFedWorkerThread(int port, String[]
otherArgs) {
+ return startLocalFedWorkerThread(port, otherArgs,
FED_WORKER_WAIT);
}
/**
@@ -1543,11 +1548,17 @@ public abstract class AutomatedTestBase {
* @return the thread associated with the worker.
*/
protected Thread startLocalFedWorkerThread(int port, int sleep) {
+ return startLocalFedWorkerThread(port, null, sleep);
+ }
+ protected Thread startLocalFedWorkerThread(int port, String[]
otherArgs, int sleep) {
Thread t = null;
String[] fedWorkArgs = {"-w", Integer.toString(port)};
ArrayList<String> args = new ArrayList<>();
addProgramIndependentArguments(args);
+
+ if (otherArgs != null)
+
args.addAll(Arrays.stream(otherArgs).collect(Collectors.toList()));
for(int i = 0; i < fedWorkArgs.length; i++)
args.add(fedWorkArgs[i]);
diff --git
a/src/test/java/org/apache/sysds/test/functions/lineage/FedFullReuseTest.java
b/src/test/java/org/apache/sysds/test/functions/lineage/FedFullReuseTest.java
new file mode 100644
index 0000000..00c6d6f
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/lineage/FedFullReuseTest.java
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.lineage;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+import org.apache.sysds.runtime.lineage.Lineage;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class FedFullReuseTest extends AutomatedTestBase {
+
+ private final static String TEST_DIR = "functions/lineage/";
+ private final static String TEST_NAME = "FedFullReuse1";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
FedFullReuseTest.class.getSimpleName() + "/";
+
+ private final static int blocksize = 1024;
+ @Parameterized.Parameter()
+ public int rows;
+ @Parameterized.Parameter(1)
+ public int cols;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+ }
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ // rows have to be even and > 1
+ return Arrays.asList(new Object[][] {
+ // {2, 1000}, {10, 100},
+ {100, 10},
+ //{1000, 1},
+ // {10, 2000}, {2000, 10}
+ });
+ }
+
+ @Test
+ public void federatedReuseMM() { //reuse inside federated workers
+ federatedReuse();
+ }
+
+ public void federatedReuse() {
+ getAndLoadTestConfiguration(TEST_NAME);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ // write input matrices
+ int halfRows = rows / 2;
+ // Share two matrices between two federated worker
+ double[][] X1 = getRandomMatrix(halfRows, cols, 0, 1, 1, 42);
+ double[][] X2 = getRandomMatrix(halfRows, cols, 0, 1, 1, 1340);
+ double[][] Y1 = getRandomMatrix(cols, halfRows, 0, 1, 1, 44);
+ double[][] Y2 = getRandomMatrix(cols, halfRows, 0, 1, 1, 21);
+
+ writeInputMatrixWithMTD("X1", X1, false, new
MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+ writeInputMatrixWithMTD("X2", X2, false, new
MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+ writeInputMatrixWithMTD("Y1", Y1, false, new
MatrixCharacteristics(cols, halfRows, blocksize, halfRows * cols));
+ writeInputMatrixWithMTD("Y2", Y2, false, new
MatrixCharacteristics(cols, halfRows, blocksize, halfRows * cols));
+
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ String[] otherargs = new String[] {"-lineage", "reuse_full"};
+ Lineage.resetInternalState();
+ Thread t1 = startLocalFedWorkerThread(port1, otherargs,
FED_WORKER_WAIT_S);
+ Thread t2 = startLocalFedWorkerThread(port2, otherargs);
+
+ TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
+ loadTestConfiguration(config);
+
+ // Run reference dml script with normal matrix. Reuse of ba+*.
+ fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+ programArgs = new String[] {"-stats", "-lineage", "reuse_full",
+ "-nvargs", "X1=" + input("X1"), "X2=" + input("X2"),
"Y1=" + input("Y1"),
+ "Y2=" + input("Y2"), "Z=" + expected("Z")};
+ runTest(true, false, null, -1);
+ long mmCount = Statistics.getCPHeavyHitterCount("ba+*");
+
+ // Run actual dml script with federated matrix
+ // The fed workers reuse ba+*
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-stats","-lineage", "reuse_full",
+ "-nvargs", "X1=" + TestUtils.federatedAddress(port1,
input("X1")),
+ "X2=" + TestUtils.federatedAddress(port2, input("X2")),
+ "Y1=" + TestUtils.federatedAddress(port1, input("Y1")),
+ "Y2=" + TestUtils.federatedAddress(port2, input("Y2")),
"r=" + rows, "c=" + cols, "Z=" + output("Z")};
+ runTest(true, false, null, -1);
+ long mmCount_fed = Statistics.getCPHeavyHitterCount("ba+*");
+
+ // compare results
+ compareResults(1e-9);
+ // compare matrix multiplication count
+ // #federated execution of ba+* = #threads times #non-federated
execution of ba+* (after reuse)
+ Assert.assertTrue("Violated reuse count: "+mmCount_fed+" ==
"+mmCount*2,
+ mmCount_fed == mmCount * 2); // #threads = 2
+
+ TestUtils.shutdownThreads(t1, t2);
+ }
+
+}
diff --git a/src/test/scripts/functions/lineage/FedFullReuse1.dml
b/src/test/scripts/functions/lineage/FedFullReuse1.dml
new file mode 100644
index 0000000..4597332
--- /dev/null
+++ b/src/test/scripts/functions/lineage/FedFullReuse1.dml
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($X1, $X2),
+ ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), list($r, $c)));
+Y = federated(addresses=list($Y1, $Y2),
+ ranges=list(list(0, 0), list($c, $r / 2), list(0, $r / 2), list($c, $r)));
+
+for(i in 1:10)
+ Z = X %*% Y;
+
+write(Z, $Z);
diff --git a/src/test/scripts/functions/lineage/FedFullReuse1Reference.dml
b/src/test/scripts/functions/lineage/FedFullReuse1Reference.dml
new file mode 100644
index 0000000..6049f5d
--- /dev/null
+++ b/src/test/scripts/functions/lineage/FedFullReuse1Reference.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($X1), read($X2));
+Y = cbind(read($Y1), read($Y2));
+
+for(i in 1:10)
+ Z = X %*% Y;
+
+write(Z, $Z);