This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 854b4e9 [SYSTEMDS-2549,2624] Fix federated binary matrix-vector, var
cleanup
854b4e9 is described below
commit 854b4e94f0e8f4c8b8e0f2867558cb90e4e8e552
Author: Matthias Boehm <[email protected]>
AuthorDate: Sun Aug 16 17:30:03 2020 +0200
[SYSTEMDS-2549,2624] Fix federated binary matrix-vector, var cleanup
This patch fixes two correctness issues related to (1) cleanup of
federated matrices, and (2) federated binary matrix-row vector
operators. Furthermore, this also includes a new federated Kmeans test
and some minor fixes for row aggregates, and improvements of federated
matrix multiplications.
---
scripts/builtin/kmeans.dml | 2 +-
.../controlprogram/context/ExecutionContext.java | 9 +---
.../controlprogram/federated/FederatedRange.java | 8 +++-
.../federated/FederatedWorkerHandler.java | 2 +-
.../controlprogram/federated/FederationMap.java | 8 ++++
.../fed/AggregateBinaryFEDInstruction.java | 23 +++++++---
.../fed/BinaryMatrixMatrixFEDInstruction.java | 24 ++++++++---
.../sysds/runtime/meta/DataCharacteristics.java | 2 +-
...eratedPCATest.java => FederatedKmeansTest.java} | 50 ++++++++++++----------
.../test/functions/federated/FederatedPCATest.java | 5 +++
.../functions/federated/FederatedKmeansTest.dml | 25 +++++++++++
.../federated/FederatedKmeansTestReference.dml | 24 +++++++++++
12 files changed, 132 insertions(+), 50 deletions(-)
diff --git a/scripts/builtin/kmeans.dml b/scripts/builtin/kmeans.dml
index f18466d..90a7222 100644
--- a/scripts/builtin/kmeans.dml
+++ b/scripts/builtin/kmeans.dml
@@ -160,7 +160,7 @@ m_kmeans = function(Matrix[Double] X, Integer k = 10,
Integer runs = 10, Integer
C_old = C; C = C_new;
}
- if(is_verbose == TRUE)
+ if(is_verbose)
print ("Run " + run_index + ", Iteration " + iter_count + ": Terminated
with code = "
+ term_code + ", Centroid WCSS = " + wcss);
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
index 31a467f..fcb5db3 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
@@ -59,9 +59,7 @@ import org.apache.sysds.utils.Statistics;
import java.util.ArrayList;
import java.util.Arrays;
-import java.util.HashSet;
import java.util.List;
-import java.util.Set;
import java.util.stream.Collectors;
public class ExecutionContext {
@@ -73,7 +71,6 @@ public class ExecutionContext {
//symbol table
protected LocalVariableMap _variables;
protected boolean _autoCreateVars;
- protected Set<String> _guardedFiles = new HashSet<>();
//lineage map, cache, prepared dedup blocks
protected Lineage _lineage;
@@ -134,10 +131,6 @@ public class ExecutionContext {
public void setAutoCreateVars(boolean flag) {
_autoCreateVars = flag;
}
-
- public void addGuardedFilename(String fname) {
- _guardedFiles.add(fname);
- }
/**
* Get the i-th GPUContext
@@ -758,7 +751,7 @@ public class ExecutionContext {
//compute ref count only if matrix cleanup actually
necessary
if ( mo.isCleanupEnabled() &&
!getVariables().hasReferences(mo) ) {
mo.clearData(); //clean cached data
- if( fileExists &&
!_guardedFiles.contains(mo.getFileName()) ) {
+ if( fileExists ) {
HDFSTool.deleteFileIfExistOnHDFS(mo.getFileName());
HDFSTool.deleteFileIfExistOnHDFS(mo.getFileName()+".mtd");
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
index 6571666..46ebce2 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
@@ -41,8 +41,12 @@ public class FederatedRange implements
Comparable<FederatedRange> {
* @param other the <code>FederatedRange</code> to copy
*/
public FederatedRange(FederatedRange other) {
- _beginDims = other._beginDims.clone();
- _endDims = other._endDims.clone();
+ this(other._beginDims.clone(), other._endDims.clone());
+ }
+
+ public FederatedRange(FederatedRange other, long clen) {
+ this(other._beginDims.clone(), other._endDims.clone());
+ _endDims[1] = clen;
}
public void setBeginDim(int dim, long value) {
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index 47ca43c..1afbfb1 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -181,7 +181,7 @@ public class FederatedWorkerHandler extends
ChannelInboundHandlerAdapter {
//TODO spawn async load of data, otherwise on first access
_ec.setVariable(String.valueOf(id), cd);
- _ec.addGuardedFilename(filename);
+ cd.enableCleanup(false); //guard against deletion
if (dataType == Types.DataType.FRAME) {
FrameObject frameObject = (FrameObject) cd;
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index 04532fd..d323bad 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -149,6 +149,14 @@ public class FederationMap
map.put(new FederatedRange(e.getKey()), new
FederatedData(e.getValue(), id));
return new FederationMap(id, map);
}
+
+ public FederationMap copyWithNewID(long id, long clen) {
+ Map<FederatedRange, FederatedData> map = new TreeMap<>();
+ //TODO handling of file path, but no danger as never written
+ for( Entry<FederatedRange, FederatedData> e :
_fedMap.entrySet() )
+ map.put(new FederatedRange(e.getKey(), clen), new
FederatedData(e.getValue(), id));
+ return new FederationMap(id, map);
+ }
public FederationMap rbind(long offset, FederationMap that) {
for( Entry<FederatedRange, FederatedData> e :
that._fedMap.entrySet() ) {
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
index 3fe1004..14f81bf 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
@@ -66,13 +66,22 @@ public class AggregateBinaryFEDInstruction extends
BinaryFEDInstruction {
FederatedRequest fr1 =
mo1.getFedMapping().broadcast(mo2);
FederatedRequest fr2 =
FederationUtils.callInstruction(instString, output,
new CPOperand[]{input1, input2}, new
long[]{mo1.getFedMapping().getID(), fr1.getID()});
- FederatedRequest fr3 = new
FederatedRequest(RequestType.GET_VAR, fr2.getID());
- //execute federated operations and aggregate
- Future<FederatedResponse>[] tmp =
mo1.getFedMapping().execute(fr1, fr2, fr3);
- MatrixBlock ret = FederationUtils.rbind(tmp);
- mo1.getFedMapping().cleanup(fr1.getID(), fr2.getID());
- ec.setMatrixOutput(output.getName(), ret);
- //TODO should remain federated matrix (no need for agg)
+ if( mo2.getNumColumns() == 1 ) { //MV
+ FederatedRequest fr3 = new
FederatedRequest(RequestType.GET_VAR, fr2.getID());
+ //execute federated operations and aggregate
+ Future<FederatedResponse>[] tmp =
mo1.getFedMapping().execute(fr1, fr2, fr3);
+ MatrixBlock ret = FederationUtils.rbind(tmp);
+ mo1.getFedMapping().cleanup(fr1.getID(),
fr2.getID());
+ ec.setMatrixOutput(output.getName(), ret);
+ }
+ else { //MM
+ //execute federated operations and aggregate
+ mo1.getFedMapping().execute(fr1, fr2);
+ mo1.getFedMapping().cleanup(fr1.getID());
+ MatrixObject out = ec.getMatrixObject(output);
+
out.getDataCharacteristics().set(mo1.getNumRows(), mo2.getNumColumns(),
(int)mo1.getBlocksize());
+
out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr2.getID(),
mo2.getNumColumns()));
+ }
}
//#2 vector - federated matrix multiplication
else if (mo2.isFederated()) {// VM + MM
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
index d124c76..7813f6a 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
@@ -45,13 +45,23 @@ public class BinaryMatrixMatrixFEDInstruction extends
BinaryFEDInstruction
}
//matrix-matrix binary operations -> lhs fed input -> fed output
- FederatedRequest fr1 = mo1.getFedMapping().broadcast(mo2);
- FederatedRequest fr2 =
FederationUtils.callInstruction(instString, output,
- new CPOperand[]{input1, input2}, new
long[]{mo1.getFedMapping().getID(), fr1.getID()});
-
- //execute federated instruction and cleanup intermediates
- mo1.getFedMapping().execute(fr1, fr2);
- mo1.getFedMapping().cleanup(fr1.getID());
+ FederatedRequest fr2 = null;
+ if(mo2.getNumRows() > 1 && mo2.getNumColumns() == 1 ) { //MV
row vector
+ FederatedRequest[] fr1 =
mo1.getFedMapping().broadcastSliced(mo2, false);
+ fr2 = FederationUtils.callInstruction(instString,
output, new CPOperand[]{input1, input2},
+ new long[]{mo1.getFedMapping().getID(),
fr1[0].getID()});
+ //execute federated instruction and cleanup
intermediates
+ mo1.getFedMapping().execute(fr1, fr2);
+ mo1.getFedMapping().cleanup(fr1[0].getID());
+ }
+ else { //MM or MV col vector
+ FederatedRequest fr1 =
mo1.getFedMapping().broadcast(mo2);
+ fr2 = FederationUtils.callInstruction(instString,
output, new CPOperand[]{input1, input2},
+ new long[]{mo1.getFedMapping().getID(),
fr1.getID()});
+ //execute federated instruction and cleanup
intermediates
+ mo1.getFedMapping().execute(fr1, fr2);
+ mo1.getFedMapping().cleanup(fr1.getID());
+ }
//derive new fed mapping for output
MatrixObject out = ec.getMatrixObject(output);
diff --git
a/src/main/java/org/apache/sysds/runtime/meta/DataCharacteristics.java
b/src/main/java/org/apache/sysds/runtime/meta/DataCharacteristics.java
index 58bdcd0..d71ce9d 100644
--- a/src/main/java/org/apache/sysds/runtime/meta/DataCharacteristics.java
+++ b/src/main/java/org/apache/sysds/runtime/meta/DataCharacteristics.java
@@ -31,7 +31,7 @@ public abstract class DataCharacteristics implements
Serializable {
protected int _blocksize;
- public DataCharacteristics set(long nr, long nc, int len) {
+ public DataCharacteristics set(long nr, long nc, int blen) {
throw new DMLRuntimeException("DataCharacteristics.set(long,
long, int): should never get called in the base class");
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java
b/src/test/java/org/apache/sysds/test/functions/federated/FederatedKmeansTest.java
similarity index 73%
copy from
src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java
copy to
src/test/java/org/apache/sysds/test/functions/federated/FederatedKmeansTest.java
index bf674a8..1ef2384 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/FederatedKmeansTest.java
@@ -27,6 +27,7 @@ import org.junit.runners.Parameterized;
import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
@@ -36,11 +37,11 @@ import java.util.Collection;
@RunWith(value = Parameterized.class)
@net.jcip.annotations.NotThreadSafe
-public class FederatedPCATest extends AutomatedTestBase {
+public class FederatedKmeansTest extends AutomatedTestBase {
private final static String TEST_DIR = "functions/federated/";
- private final static String TEST_NAME = "FederatedPCATest";
- private final static String TEST_CLASS_DIR = TEST_DIR +
FederatedPCATest.class.getSimpleName() + "/";
+ private final static String TEST_NAME = "FederatedKMeansTest";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
FederatedKmeansTest.class.getSimpleName() + "/";
private final static int blocksize = 1024;
@Parameterized.Parameter()
@@ -48,7 +49,7 @@ public class FederatedPCATest extends AutomatedTestBase {
@Parameterized.Parameter(1)
public int cols;
@Parameterized.Parameter(2)
- public boolean scaleAndShift;
+ public int runs;
@Override
public void setUp() {
@@ -60,22 +61,23 @@ public class FederatedPCATest extends AutomatedTestBase {
public static Collection<Object[]> data() {
// rows have to be even and > 1
return Arrays.asList(new Object[][] {
- {10000, 10, false}, {2000, 50, false}, {1000, 100,
false},
- {10000, 10, true}, {2000, 50, true}, {1000, 100, true}
+ {10000, 10, 1}, {2000, 50, 1}, {1000, 100, 1},
+ //TODO support for multi-threaded federated interactions
+ //{10000, 10, 16}, {2000, 50, 16}, {1000, 100, 16},
//concurrent requests
});
}
@Test
- public void federatedPCASinglenode() {
- federatedL2SVM(Types.ExecMode.SINGLE_NODE);
+ public void federatedKmeansSinglenode() {
+ federatedKmeans(Types.ExecMode.SINGLE_NODE);
}
@Test
- public void federatedPCAHybrid() {
- federatedL2SVM(Types.ExecMode.HYBRID);
+ public void federatedKmeansHybrid() {
+ federatedKmeans(Types.ExecMode.HYBRID);
}
- public void federatedL2SVM(Types.ExecMode execMode) {
+ public void federatedKmeans(Types.ExecMode execMode) {
ExecMode platformOld = setExecMode(execMode);
getAndLoadTestConfiguration(TEST_NAME);
@@ -98,11 +100,12 @@ public class FederatedPCATest extends AutomatedTestBase {
TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
+ setOutputBuffering(false);
// Run reference dml script with normal matrix
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
programArgs = new String[] {"-args", input("X1"), input("X2"),
- String.valueOf(scaleAndShift).toUpperCase(),
expected("Z")};
+ String.valueOf(runs), expected("Z")};
runTest(true, false, null, -1);
// Run actual dml script with federated matrix
@@ -110,24 +113,25 @@ public class FederatedPCATest extends AutomatedTestBase {
programArgs = new String[] {"-stats",
"-nvargs", "in_X1=" + TestUtils.federatedAddress(port1,
input("X1")),
"in_X2=" + TestUtils.federatedAddress(port2,
input("X2")), "rows=" + rows, "cols=" + cols,
- "scaleAndShift=" +
String.valueOf(scaleAndShift).toUpperCase(), "out=" + output("Z")};
+ "runs=" + String.valueOf(runs), "out=" + output("Z")};
runTest(true, false, null, -1);
// compare via files
- compareResults(1e-9);
+ //compareResults(1e-9); --> randomized
TestUtils.shutdownThreads(t1, t2);
// check for federated operations
Assert.assertTrue(heavyHittersContainsString("fed_ba+*"));
- Assert.assertTrue(heavyHittersContainsString("fed_uack+"));
- Assert.assertTrue(heavyHittersContainsString("fed_tsmm"));
- if( scaleAndShift ) {
-
Assert.assertTrue(heavyHittersContainsString("fed_uacsqk+"));
-
Assert.assertTrue(heavyHittersContainsString("fed_uacmean"));
- Assert.assertTrue(heavyHittersContainsString("fed_-"));
- Assert.assertTrue(heavyHittersContainsString("fed_/"));
-
Assert.assertTrue(heavyHittersContainsString("fed_replace"));
- }
+ Assert.assertTrue(heavyHittersContainsString("fed_uasqk+"));
+ Assert.assertTrue(heavyHittersContainsString("fed_uarmin"));
+ Assert.assertTrue(heavyHittersContainsString("fed_*"));
+ Assert.assertTrue(heavyHittersContainsString("fed_+"));
+ Assert.assertTrue(heavyHittersContainsString("fed_<="));
+ Assert.assertTrue(heavyHittersContainsString("fed_/"));
+
+ //check that federated input files are still existing
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
resetExecMode(platformOld);
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java
b/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java
index bf674a8..53eac1e 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java
@@ -27,6 +27,7 @@ import org.junit.runners.Parameterized;
import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
@@ -129,6 +130,10 @@ public class FederatedPCATest extends AutomatedTestBase {
Assert.assertTrue(heavyHittersContainsString("fed_replace"));
}
+ //check that federated input files are still existing
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+
resetExecMode(platformOld);
}
}
diff --git a/src/test/scripts/functions/federated/FederatedKmeansTest.dml
b/src/test/scripts/functions/federated/FederatedKmeansTest.dml
new file mode 100644
index 0000000..95f136c
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedKmeansTest.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($in_X1, $in_X2),
+ ranges=list(list(0, 0), list($rows / 2, $cols), list($rows / 2, 0),
list($rows, $cols)))
+[C,Y] = kmeans(X=X, k=4, runs=$runs)
+write(C, $out)
diff --git
a/src/test/scripts/functions/federated/FederatedKmeansTestReference.dml
b/src/test/scripts/functions/federated/FederatedKmeansTestReference.dml
new file mode 100644
index 0000000..da32c8b
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedKmeansTestReference.dml
@@ -0,0 +1,24 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($1), read($2))
+[C,Y] = kmeans(X=X, k=4, runs=$3)
+write(C, $4)