This is an automated email from the ASF dual-hosted git repository.
arnabp20 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 7f61827 [SYSTEMDS-2795] Reuse of FED instruction results in
Coordinator
7f61827 is described below
commit 7f6182715414e9685d13d7a9f02dbe95a8f599c0
Author: arnabp <[email protected]>
AuthorDate: Fri Jan 15 16:00:10 2021 +0100
[SYSTEMDS-2795] Reuse of FED instruction results in Coordinator
This patch introduces lineage-based caching of FED instructions
in the coordinator. This enables skipping federated execution entirely
if the output is available in the cache. To avoid unnecessary pulling,
we only cache if the output object is not federated.
---
.../runtime/instructions/cp/CPInstruction.java | 2 +
.../apache/sysds/runtime/lineage/LineageCache.java | 43 ++++++++++++++++-----
.../sysds/runtime/lineage/LineageCacheConfig.java | 17 ++++++++-
.../test/functions/lineage/FedFullReuseTest.java | 44 +++++++++++++++++-----
.../scripts/functions/lineage/FedFullReuse2.dml | 30 +++++++++++++++
.../functions/lineage/FedFullReuse2Reference.dml | 28 ++++++++++++++
6 files changed, 145 insertions(+), 19 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
index 565ac8f..b2a14ce 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
@@ -95,6 +95,8 @@ public abstract class CPInstruction extends Instruction
//robustness federated instructions (runtime assignment)
tmp = FEDInstructionUtils.checkAndReplaceCP(tmp, ec);
+ //NOTE: Retracing of lineage is not needed as the lineage trace
+ //is same for an instruction and its FED version.
tmp = PrivacyPropagator.preprocessInstruction(tmp, ec);
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
index d6e74e7..287f9aa 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
@@ -36,12 +36,14 @@ import
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyze
import org.apache.sysds.runtime.instructions.CPInstructionParser;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.cp.CPInstruction.CPType;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.ComputationCPInstruction;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.MMTSJCPInstruction;
import
org.apache.sysds.runtime.instructions.cp.MultiReturnBuiltinCPInstruction;
import
org.apache.sysds.runtime.instructions.cp.ParameterizedBuiltinCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
+import org.apache.sysds.runtime.instructions.fed.ComputationFEDInstruction;
import org.apache.sysds.runtime.lineage.LineageCacheConfig.LineageCacheStatus;
import org.apache.sysds.runtime.lineage.LineageCacheConfig.ReuseCacheType;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -84,8 +86,10 @@ public class LineageCache
//NOTE: the check for computation CP instructions ensures that
the output
// will always fit in memory and hence can be pinned
unconditionally
if (LineageCacheConfig.isReusable(inst, ec)) {
- ComputationCPInstruction cinst =
(ComputationCPInstruction) inst;
- LineageItem instLI =
cinst.getLineageItem(ec).getValue();
+ ComputationCPInstruction cinst = inst instanceof
ComputationCPInstruction ? (ComputationCPInstruction)inst : null;
+ ComputationFEDInstruction cfinst = inst instanceof
ComputationFEDInstruction ? (ComputationFEDInstruction)inst : null;
+
+ LineageItem instLI = (cinst != null) ?
cinst.getLineageItem(ec).getValue():cfinst.getLineageItem(ec).getValue();
List<MutablePair<LineageItem, LineageCacheEntry>>
liList = null;
if (inst instanceof MultiReturnBuiltinCPInstruction) {
liList = new ArrayList<>();
@@ -119,7 +123,10 @@ public class LineageCache
//create a placeholder if no reuse to
avoid redundancy
//(e.g., concurrent threads that try to
start the computation)
if(e == null &&
isMarkedForCaching(inst, ec)) {
- putIntern(item.getKey(),
cinst.output.getDataType(), null, null, 0);
+ if (cinst != null)
+
putIntern(item.getKey(), cinst.output.getDataType(), null, null, 0);
+ else
+
putIntern(item.getKey(), cfinst.output.getDataType(), null, null, 0);
//FIXME: different o/p
datatypes for MultiReturnBuiltins.
}
}
@@ -134,9 +141,11 @@ public class LineageCache
if (inst instanceof
MultiReturnBuiltinCPInstruction)
outName =
((MultiReturnBuiltinCPInstruction)inst).
getOutput(entry.getKey().getOpcode().charAt(entry.getKey().getOpcode().length()-1)-'0').getName();
- else
+ else if (inst instanceof
ComputationCPInstruction)
outName =
cinst.output.getName();
-
+ else
+ outName =
cfinst.output.getName();
+
if (e.isMatrixValue())
ec.setMatrixOutput(outName,
e.getMBValue());
else
@@ -248,7 +257,9 @@ public class LineageCache
if (LineageCacheConfig.isReusable(inst, ec) ) {
LineageItem item = ((LineageTraceable)
inst).getLineageItem(ec).getValue();
//This method is called only to put matrix value
- MatrixObject mo =
ec.getMatrixObject(((ComputationCPInstruction) inst).output);
+ MatrixObject mo = inst instanceof
ComputationCPInstruction ?
+
ec.getMatrixObject(((ComputationCPInstruction) inst).output) :
+
ec.getMatrixObject(((ComputationFEDInstruction) inst).output);
synchronized( _cache ) {
putIntern(item, DataType.MATRIX,
mo.acquireReadAndRelease(), null, computetime);
}
@@ -278,7 +289,9 @@ public class LineageCache
}
}
else
- liData = Arrays.asList(Pair.of(instLI,
ec.getVariable(((ComputationCPInstruction) inst).output)));
+ liData = inst instanceof
ComputationCPInstruction ?
+ Arrays.asList(Pair.of(instLI,
ec.getVariable(((ComputationCPInstruction) inst).output))) :
+ Arrays.asList(Pair.of(instLI,
ec.getVariable(((ComputationFEDInstruction) inst).output)));
synchronized( _cache ) {
for (Pair<LineageItem, Data> entry : liData) {
LineageItem item = entry.getKey();
@@ -290,6 +303,13 @@ public class LineageCache
_cache.remove(item);
continue;
}
+
+ if
(LineageCacheConfig.isOutputFederated(inst, data)) {
+ // Do not cache federated
outputs (in the coordinator)
+ // Cannot skip putting the
placeholder as the above is only known after execution
+ _cache.remove(item);
+ continue;
+ }
MatrixBlock mb = (data instanceof
MatrixObject) ?
((MatrixObject)data).acquireReadAndRelease() : null;
@@ -456,8 +476,13 @@ public class LineageCache
if (!LineageCacheConfig.getCompAssRW())
return true;
- if (((ComputationCPInstruction)inst).output.isMatrix()) {
- MatrixObject mo =
ec.getMatrixObject(((ComputationCPInstruction)inst).output);
+ CPOperand output = inst instanceof ComputationCPInstruction ?
+ ((ComputationCPInstruction)inst).output :
+ ((ComputationFEDInstruction)inst).output;
+ if (output.isMatrix()) {
+ MatrixObject mo = inst instanceof
ComputationCPInstruction ?
+
ec.getMatrixObject(((ComputationCPInstruction)inst).output) :
+
ec.getMatrixObject(((ComputationFEDInstruction)inst).output);
//limit this to full reuse as partial reuse is
applicable even for loop dependent operation
return !(LineageCacheConfig.getCacheType() ==
ReuseCacheType.REUSE_FULL
&& !mo.isMarked());
diff --git
a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
index 0b22b28..61dfc8e 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
@@ -21,12 +21,15 @@ package org.apache.sysds.runtime.lineage;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.cp.ComputationCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.DataGenCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ListIndexingCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MatrixIndexingCPInstruction;
+import org.apache.sysds.runtime.instructions.fed.ComputationFEDInstruction;
import java.util.Comparator;
@@ -186,6 +189,7 @@ public class LineageCacheConfig
public static boolean isReusable (Instruction inst, ExecutionContext
ec) {
boolean insttype = inst instanceof ComputationCPInstruction
+ || inst instanceof ComputationFEDInstruction
&& !(inst instanceof ListIndexingCPInstruction);
boolean rightop = (ArrayUtils.contains(REUSE_OPCODES,
inst.getOpcode())
|| (inst.getOpcode().equals("append") &&
isVectorAppend(inst, ec))
@@ -193,7 +197,8 @@ public class LineageCacheConfig
|| (inst instanceof DataGenCPInstruction) &&
((DataGenCPInstruction) inst).isMatrixCall());
boolean updateInplace = (inst instanceof
MatrixIndexingCPInstruction)
&&
ec.getMatrixObject(((ComputationCPInstruction)inst).input1).getUpdateType().isInPlace();
- return insttype && rightop && !updateInplace;
+ boolean federatedOutput = false;
+ return insttype && rightop && !updateInplace &&
!federatedOutput;
}
private static boolean isVectorAppend(Instruction inst,
ExecutionContext ec) {
@@ -205,6 +210,16 @@ public class LineageCacheConfig
return(c1 == 1 || c2 == 1);
}
+ public static boolean isOutputFederated(Instruction inst, Data data) {
+ if (!(inst instanceof ComputationFEDInstruction))
+ return false;
+ // return true if the output matrixobject is federated
+ if (inst instanceof ComputationFEDInstruction)
+ if (data instanceof MatrixObject && ((MatrixObject)
data).isFederated())
+ return true;
+ return false;
+ }
+
public static void setConfigTsmmCbind(ReuseCacheType ct) {
_cacheType = ct;
_itemH = CachedItemHead.TSMM;
diff --git
a/src/test/java/org/apache/sysds/test/functions/lineage/FedFullReuseTest.java
b/src/test/java/org/apache/sysds/test/functions/lineage/FedFullReuseTest.java
index 00c6d6f..1051f5c 100644
---
a/src/test/java/org/apache/sysds/test/functions/lineage/FedFullReuseTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/lineage/FedFullReuseTest.java
@@ -19,6 +19,8 @@
package org.apache.sysds.test.functions.lineage;
+import static org.junit.Assert.assertTrue;
+
import java.util.Arrays;
import java.util.Collection;
@@ -38,7 +40,8 @@ import org.junit.runners.Parameterized;
public class FedFullReuseTest extends AutomatedTestBase {
private final static String TEST_DIR = "functions/lineage/";
- private final static String TEST_NAME = "FedFullReuse1";
+ private final static String TEST_NAME1 = "FedFullReuse1";
+ private final static String TEST_NAME2 = "FedFullReuse2";
private final static String TEST_CLASS_DIR = TEST_DIR +
FedFullReuseTest.class.getSimpleName() + "/";
private final static int blocksize = 1024;
@@ -50,7 +53,8 @@ public class FedFullReuseTest extends AutomatedTestBase {
@Override
public void setUp() {
TestUtils.clearAssertionInformation();
- addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+ addTestConfiguration(TEST_NAME1, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"Z"}));
+ addTestConfiguration(TEST_NAME2, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"Z"}));
}
@Parameterized.Parameters
@@ -65,12 +69,20 @@ public class FedFullReuseTest extends AutomatedTestBase {
}
@Test
- public void federatedReuseMM() { //reuse inside federated workers
- federatedReuse();
+ public void federatedOutputReuse() {
+ //don't cache federated outputs in the coordinator
+ //reuse inside federated workers
+ federatedReuse(TEST_NAME1);
+ }
+
+ @Test
+ public void nonfederatedOutputReuse() {
+ //cache non-federated outputs in the coordinator
+ federatedReuse(TEST_NAME2);
}
- public void federatedReuse() {
- getAndLoadTestConfiguration(TEST_NAME);
+ public void federatedReuse(String test) {
+ getAndLoadTestConfiguration(test);
String HOME = SCRIPT_DIR + TEST_DIR;
// write input matrices
@@ -93,11 +105,11 @@ public class FedFullReuseTest extends AutomatedTestBase {
Thread t1 = startLocalFedWorkerThread(port1, otherargs,
FED_WORKER_WAIT_S);
Thread t2 = startLocalFedWorkerThread(port2, otherargs);
- TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
+ TestConfiguration config =
availableTestConfigurations.get(test);
loadTestConfiguration(config);
// Run reference dml script with normal matrix. Reuse of ba+*.
- fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+ fullDMLScriptName = HOME + test + "Reference.dml";
programArgs = new String[] {"-stats", "-lineage", "reuse_full",
"-nvargs", "X1=" + input("X1"), "X2=" + input("X2"),
"Y1=" + input("Y1"),
"Y2=" + input("Y2"), "Z=" + expected("Z")};
@@ -106,7 +118,7 @@ public class FedFullReuseTest extends AutomatedTestBase {
// Run actual dml script with federated matrix
// The fed workers reuse ba+*
- fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ fullDMLScriptName = HOME + test + ".dml";
programArgs = new String[] {"-stats","-lineage", "reuse_full",
"-nvargs", "X1=" + TestUtils.federatedAddress(port1,
input("X1")),
"X2=" + TestUtils.federatedAddress(port2, input("X2")),
@@ -114,6 +126,7 @@ public class FedFullReuseTest extends AutomatedTestBase {
"Y2=" + TestUtils.federatedAddress(port2, input("Y2")),
"r=" + rows, "c=" + cols, "Z=" + output("Z")};
runTest(true, false, null, -1);
long mmCount_fed = Statistics.getCPHeavyHitterCount("ba+*");
+ long fedMMCount = Statistics.getCPHeavyHitterCount("fed_ba+*");
// compare results
compareResults(1e-9);
@@ -121,6 +134,19 @@ public class FedFullReuseTest extends AutomatedTestBase {
// #federated execution of ba+* = #threads times #non-federated
execution of ba+* (after reuse)
Assert.assertTrue("Violated reuse count: "+mmCount_fed+" ==
"+mmCount*2,
mmCount_fed == mmCount * 2); // #threads = 2
+ switch(test) {
+ case TEST_NAME1:
+ // If the o/p is federated, fed_ba+* will be
called everytime
+ // but the workers should be able to reuse ba+*
+ assertTrue(fedMMCount > mmCount_fed);
+ break;
+ case TEST_NAME2:
+ // If the o/p is non-federated, fed_ba+* will
be called once
+ // and each worker will call ba+* once.
+ assertTrue(fedMMCount < mmCount_fed);
+ break;
+ }
+
TestUtils.shutdownThreads(t1, t2);
}
diff --git a/src/test/scripts/functions/lineage/FedFullReuse2.dml
b/src/test/scripts/functions/lineage/FedFullReuse2.dml
new file mode 100644
index 0000000..f57863a
--- /dev/null
+++ b/src/test/scripts/functions/lineage/FedFullReuse2.dml
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($X1, $X2),
+ ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), list($r, $c)));
+
+vec = rand(rows=1, cols=100, seed=42);
+
+for(i in 1:10)
+ Z = vec %*% X;
+
+write(Z, $Z);
diff --git a/src/test/scripts/functions/lineage/FedFullReuse2Reference.dml
b/src/test/scripts/functions/lineage/FedFullReuse2Reference.dml
new file mode 100644
index 0000000..261dca5
--- /dev/null
+++ b/src/test/scripts/functions/lineage/FedFullReuse2Reference.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($X1), read($X2));
+vec = rand(rows=1, cols=100, seed=42);
+
+for(i in 1:10)
+ Z = vec %*% X;
+
+write(Z, $Z);