This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new e62034d [SYSTEMDS-3185] Read cache/event loop for multi-tenant
federated learning
e62034d is described below
commit e62034d5f017d24e2d616fb7abd44b5d187bb58b
Author: ywcb00 <[email protected]>
AuthorDate: Sun Jan 23 19:29:51 2022 +0100
[SYSTEMDS-3185] Read cache/event loop for multi-tenant federated learning
* Parallel event loops for multiple tenants
* Read cache for reuse across tenants
* Fix several federated instructions (missing wait)
Closes #1521.
---
.../perftest/fed/data/splitAndMakeFederated.dml | 4 +-
scripts/perftest/fed/runALS_CG_Fed.sh | 4 +-
scripts/perftest/fed/runAllFed.sh | 4 +-
scripts/perftest/fed/utils/killFedWorkers.sh | 4 +-
scripts/perftest/fed/utils/startFedWorkers.sh | 4 +-
scripts/perftest/runAllBinomial.sh | 4 +-
scripts/perftest/todo/runAllTrees.sh | 4 +-
.../apache/sysds/runtime/codegen/CodegenUtils.java | 2 +-
.../federated/FederatedLocalData.java | 3 +-
.../federated/FederatedLookupTable.java | 6 +-
.../federated/FederatedReadCache.java | 125 +++++++++++++
.../federated/FederatedStatistics.java | 97 +++++++++-
.../controlprogram/federated/FederatedWorker.java | 16 +-
.../federated/FederatedWorkerHandler.java | 117 +++++++-----
.../fed/AggregateBinaryFEDInstruction.java | 2 +-
.../fed/AggregateUnaryFEDInstruction.java | 2 +-
.../java/org/apache/sysds/utils/Statistics.java | 1 +
.../multitenant/FederatedMultiTenantTest.java | 140 ++++----------
.../multitenant/FederatedReadCacheTest.java | 208 +++++++++++++++++++++
.../federated/multitenant/MultiTenantTestBase.java | 143 ++++++++++++++
.../functions/binary/frame/frameComparisonTest.dml | 2 +-
...FullMatrixVectorRowCellwiseOperation_Addition.R | 6 +-
.../multitenant/FederatedMultiTenantTest.dml | 9 +-
...tiTenantTest.dml => FederatedReadCacheTest.dml} | 48 ++---
24 files changed, 734 insertions(+), 221 deletions(-)
diff --git a/scripts/perftest/fed/data/splitAndMakeFederated.dml
b/scripts/perftest/fed/data/splitAndMakeFederated.dml
index c301bd0..281fedc 100644
--- a/scripts/perftest/fed/data/splitAndMakeFederated.dml
+++ b/scripts/perftest/fed/data/splitAndMakeFederated.dml
@@ -31,6 +31,7 @@ nSplit = $nSplit;
transposed = ifdef($transposed, FALSE);
target = ifdef($target, data + "_fed.json");
fmt = ifdef($fmt, "text");
+hostOffset = ifdef($hostOffset, 0);
if(transposed) # for column partitions we simply transpose the data before
splitting
X = t(X);
@@ -50,7 +51,8 @@ for (counter in 1:nSplit) {
X_part = X[beginDim:endDim, ]; # select the partition from the dataset
write(X_part, data + counter, format=fmt); # write the partition to disk
# collect the addresses and ranges for creating a federated object
- addresses = append(addresses, as.scalar(hosts[counter]) + "/" + data +
counter);
+ hostIX = counter + hostOffset;
+ addresses = append(addresses, as.scalar(hosts[hostIX]) + "/" + data +
counter);
if(transposed) {
ranges = append(ranges, list(0, beginDim - 1));
ranges = append(ranges, list(M, endDim));
diff --git a/scripts/perftest/fed/runALS_CG_Fed.sh
b/scripts/perftest/fed/runALS_CG_Fed.sh
index 99a101a..d2cd388 100755
--- a/scripts/perftest/fed/runALS_CG_Fed.sh
+++ b/scripts/perftest/fed/runALS_CG_Fed.sh
@@ -8,9 +8,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
diff --git a/scripts/perftest/fed/runAllFed.sh
b/scripts/perftest/fed/runAllFed.sh
index 5c5f46e..546ac6f 100755
--- a/scripts/perftest/fed/runAllFed.sh
+++ b/scripts/perftest/fed/runAllFed.sh
@@ -8,9 +8,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
diff --git a/scripts/perftest/fed/utils/killFedWorkers.sh
b/scripts/perftest/fed/utils/killFedWorkers.sh
index 5c71aa8..6634b35 100755
--- a/scripts/perftest/fed/utils/killFedWorkers.sh
+++ b/scripts/perftest/fed/utils/killFedWorkers.sh
@@ -8,9 +8,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
diff --git a/scripts/perftest/fed/utils/startFedWorkers.sh
b/scripts/perftest/fed/utils/startFedWorkers.sh
index 691ce73..b39251d 100755
--- a/scripts/perftest/fed/utils/startFedWorkers.sh
+++ b/scripts/perftest/fed/utils/startFedWorkers.sh
@@ -8,9 +8,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
diff --git a/scripts/perftest/runAllBinomial.sh
b/scripts/perftest/runAllBinomial.sh
index a40c6a7..3e52fcd 100755
--- a/scripts/perftest/runAllBinomial.sh
+++ b/scripts/perftest/runAllBinomial.sh
@@ -8,9 +8,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
diff --git a/scripts/perftest/todo/runAllTrees.sh
b/scripts/perftest/todo/runAllTrees.sh
index 1671d26..01e47ae 100755
--- a/scripts/perftest/todo/runAllTrees.sh
+++ b/scripts/perftest/todo/runAllTrees.sh
@@ -8,9 +8,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
diff --git a/src/main/java/org/apache/sysds/runtime/codegen/CodegenUtils.java
b/src/main/java/org/apache/sysds/runtime/codegen/CodegenUtils.java
index d5bbdfa..9a10315 100644
--- a/src/main/java/org/apache/sysds/runtime/codegen/CodegenUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/codegen/CodegenUtils.java
@@ -25,9 +25,9 @@ import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.hops.codegen.SpoofCompiler;
import org.apache.sysds.hops.codegen.SpoofCompiler.CompilerType;
-import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.codegen.SpoofOperator.SideInput;
import org.apache.sysds.runtime.codegen.SpoofOperator.SideInputSparseCell;
+import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.io.IOUtilFunctions;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.LocalFileUtils;
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLocalData.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLocalData.java
index b1a4c6d..8f1a52d 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLocalData.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLocalData.java
@@ -30,7 +30,8 @@ public class FederatedLocalData extends FederatedData {
protected final static Logger log =
Logger.getLogger(FederatedWorkerHandler.class);
private static final FederatedLookupTable _flt = new
FederatedLookupTable();
- private static final FederatedWorkerHandler _fwh = new
FederatedWorkerHandler(_flt);
+ private static final FederatedReadCache _frc = new FederatedReadCache();
+ private static final FederatedWorkerHandler _fwh = new
FederatedWorkerHandler(_flt, _frc);
private final CacheableData<?> _data;
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLookupTable.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLookupTable.java
index c3fc1f5..63defe4 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLookupTable.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLookupTable.java
@@ -38,7 +38,7 @@ public class FederatedLookupTable {
// is no actual network connection (and hence no host either)
public static final String NOHOST = "nohost";
- protected static Logger log =
Logger.getLogger(FederatedLookupTable.class);
+ private static final Logger LOG =
Logger.getLogger(FederatedLookupTable.class);
// stores the mapping between the funCID and the corresponding
ExecutionContextMap
private final Map<FedUniqueCoordID, ExecutionContextMap> _lookup_table;
@@ -57,13 +57,13 @@ public class FederatedLookupTable {
* @return ExecutionContextMap the ECM corresponding to the requesting
coordinator
*/
public ExecutionContextMap getECM(String host, long pid) {
- log.trace("Getting the ExecutionContextMap for coordinator " +
pid + "@" + host);
+ LOG.trace("Getting the ExecutionContextMap for coordinator " +
pid + "@" + host);
long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
FedUniqueCoordID funCID = new FedUniqueCoordID(host, pid);
ExecutionContextMap ecm = _lookup_table.computeIfAbsent(funCID,
k -> createNewECM());
if(ecm == null) {
- log.error("Computing federated execution context map
failed. "
+ LOG.error("Computing federated execution context map
failed. "
+ "No valid resolution for " +
funCID.toString() + " found.");
throw new FederatedWorkerHandlerException("Computing
federated execution context map failed. "
+ "No valid resolution for " +
funCID.toString() + " found.");
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedReadCache.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedReadCache.java
new file mode 100644
index 0000000..a7180d7
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedReadCache.java
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.federated;
+
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.Map;
+
+import org.apache.log4j.Logger;
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
+
+public class FederatedReadCache {
+ private static final Logger LOG =
Logger.getLogger(FederatedReadCache.class);
+
+ private Map<String, ReadCacheEntry> _rmap = new ConcurrentHashMap<>();
+
+ /**
+ * Get the data from the ReadCacheEntry corresponding to the specified
+ * filename, if the data from this filename has already been read.
+ * Otherwise, create a new ReadCacheEntry for the filename and return
null
+ * to indicate that the data is not cached yet.
+ *
+ * @param fname the filename of the read data
+ * @return the CacheableData object if it is cached, otherwise null
+ */
+ public CacheableData<?> get(String fname) {
+ ReadCacheEntry tmp = _rmap.putIfAbsent(fname, new
ReadCacheEntry());
+ return (tmp != null) ? tmp.get() : null;
+ }
+
+ /**
+ * Set the data for the ReadCacheEntry with specified filename.
+ *
+ * @param fname the filename of the read data
+ * @param data the CacheableData object for setting the ReadCacheEntry
+ */
+ public void setData(String fname, CacheableData<?> data) {
+ LOG.trace("Setting the data for the ReadCacheEntry of file " +
fname);
+ ReadCacheEntry rce = _rmap.get(fname);
+ if(rce == null)
+ throw new DMLRuntimeException("Tried to set the data
for an unregistered ReadCacheEntry.");
+ rce.setValue(data);
+ }
+
+ /**
+ * Set the ReadCacheEntry of a given filename to invalid. Usually done
after a
+ * failing read attempt so that the threads waiting for the data can
continue.
+ *
+ * @param fname the filename of the read data
+ */
+ public void setInvalid(String fname) {
+ LOG.debug("Read of file " + fname + " failed. Setting the
corresponding ReadCacheEntry to invalid.");
+ ReadCacheEntry rce = _rmap.get(fname);
+ if(rce == null)
+ throw new DMLRuntimeException("Tried to set an
unexisting ReadCacheEntry to invalid.");
+ rce.setInvalid();
+ }
+
+ /**
+ * Class representing an entry of the federated read cache.
+ */
+ public static class ReadCacheEntry {
+ protected CacheableData<?> _data = null;
+ private boolean _is_valid = true;
+
+ public synchronized CacheableData<?> get() {
+ try {
+ //wait until other thread completes operation
+ //in order to avoid redundant computation
+ while(_data == null && _is_valid) {
+ wait();
+ }
+ if(!_is_valid) { // previous thread failed when
trying to read the data
+ _is_valid = true;
+ return null; // trying to read the data
with the current thread
+ }
+ }
+ catch( InterruptedException ex ) {
+ throw new DMLRuntimeException(ex);
+ }
+
+ if(DMLScript.STATISTICS) {
+ FederatedStatistics.incFedReadCacheHitCount();
+
FederatedStatistics.incFedReadCacheBytesCount(_data);
+ }
+
+ //comes here if data is placed or the entry is removed
by the running thread
+ return _data;
+ }
+
+ public synchronized void setValue(CacheableData<?> val) {
+ if(_data != null)
+ throw new DMLRuntimeException("Tried to set the
value of a ReadCacheEntry twice. "
+ + "Should only be performed once.");
+
+ _data = val;
+ //resume all threads waiting for _data
+ notifyAll();
+ }
+
+ public synchronized void setInvalid() {
+ _is_valid = false;
+ notify(); // resume one waiting thread so it can try
reading the data
+ }
+ }
+}
+
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java
index b5e69b0..9ef0518 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java
@@ -38,10 +38,12 @@ import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.runtime.controlprogram.caching.CacheStatistics;
+import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import
org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics.FedStatsCollection.CacheStatsCollection;
import
org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics.FedStatsCollection.GCStatsCollection;
+import
org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics.FedStatsCollection.MultiTenantStatsCollection;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.ListObject;
@@ -72,6 +74,8 @@ public class FederatedStatistics {
private static final LongAdder fedLookupTableGetCount = new LongAdder();
private static final LongAdder fedLookupTableGetTime = new LongAdder();
// in milli sec
private static final LongAdder fedLookupTableEntryCount = new
LongAdder();
+ private static final LongAdder fedReadCacheHitCount = new LongAdder();
+ private static final LongAdder fedReadCacheBytesCount = new LongAdder();
public static synchronized void incFederated(RequestType rqt,
List<Object> data){
switch (rqt) {
@@ -134,6 +138,8 @@ public class FederatedStatistics {
fedLookupTableGetCount.reset();
fedLookupTableGetTime.reset();
fedLookupTableEntryCount.reset();
+ fedReadCacheHitCount.reset();
+ fedReadCacheBytesCount.reset();
}
public static String displayFedIOExecStatistics() {
@@ -186,6 +192,7 @@ public class FederatedStatistics {
sb.append(displayCacheStats(fedStats.cacheStats));
sb.append(String.format("Total JIT compile time:\t\t%.3f
sec.\n", fedStats.jitCompileTime));
sb.append(displayGCStats(fedStats.gcStats));
+ sb.append(displayMultiTenantStats(fedStats.mtStats));
sb.append(displayHeavyHitters(fedStats.heavyHitters,
numHeavyHitters));
return sb.toString();
}
@@ -208,6 +215,13 @@ public class FederatedStatistics {
return sb.toString();
}
+ private static String
displayMultiTenantStats(MultiTenantStatsCollection mtsc) {
+ StringBuilder sb = new StringBuilder();
+ sb.append(displayFedLookupTableStats(mtsc.fLTGetCount,
mtsc.fLTEntryCount, mtsc.fLTGetTime));
+ sb.append(displayFedReadCacheStats(mtsc.readCacheHits,
mtsc.readCacheBytes));
+ return sb.toString();
+ }
+
@SuppressWarnings("unused")
private static String displayHeavyHitters(HashMap<String, Pair<Long,
Double>> heavyHitters) {
return displayHeavyHitters(heavyHitters, 10);
@@ -314,6 +328,25 @@ public class FederatedStatistics {
return retArr;
}
+ public static long getFedLookupTableGetCount() {
+ return fedLookupTableGetCount.longValue();
+ }
+
+ public static long getFedLookupTableGetTime() {
+ return fedLookupTableGetTime.longValue();
+ }
+
+ public static long getFedLookupTableEntryCount() {
+ return fedLookupTableEntryCount.longValue();
+ }
+
+ public static long getFedReadCacheHitCount() {
+ return fedReadCacheHitCount.longValue();
+ }
+
+ public static long getFedReadCacheBytesCount() {
+ return fedReadCacheBytesCount.longValue();
+ }
public static void incFedLookupTableGetCount() {
fedLookupTableGetCount.increment();
@@ -327,14 +360,39 @@ public class FederatedStatistics {
fedLookupTableEntryCount.increment();
}
+ public static void incFedReadCacheHitCount() {
+ fedReadCacheHitCount.increment();
+ }
+
+ public static void incFedReadCacheBytesCount(CacheableData<?> data) {
+ fedReadCacheBytesCount.add(data.getDataSize());
+ }
+
public static String displayFedLookupTableStats() {
- if(fedLookupTableGetCount.longValue() > 0) {
+ return
displayFedLookupTableStats(fedLookupTableGetCount.longValue(),
+ fedLookupTableEntryCount.longValue(),
fedLookupTableGetTime.doubleValue() / 1000000000);
+ }
+
+ public static String displayFedLookupTableStats(long fltGetCount, long
fltEntryCount, double fltGetTime) {
+ if(fltGetCount > 0) {
StringBuilder sb = new StringBuilder();
sb.append("Fed LookupTable (Get, Entries):\t" +
- fedLookupTableGetCount.longValue() + "/" +
- fedLookupTableEntryCount.longValue() + ".\n");
- // sb.append(String.format("Fed LookupTable Get
Time:\t%.3f sec.\n",
- // fedLookupTableGetTime.doubleValue() /
1000000000));
+ fltGetCount + "/" + fltEntryCount + ".\n");
+ return sb.toString();
+ }
+ return "";
+ }
+
+ public static String displayFedReadCacheStats() {
+ return
displayFedReadCacheStats(fedReadCacheHitCount.longValue(),
+ fedReadCacheBytesCount.longValue());
+ }
+
+ public static String displayFedReadCacheStats(long rcHits, long
rcBytes) {
+ if(rcHits > 0) {
+ StringBuilder sb = new StringBuilder();
+ sb.append("Fed ReadCache (Hits, Bytes):\t" +
+ rcHits + "/" + rcBytes + ".\n");
return sb.toString();
}
return "";
@@ -368,6 +426,7 @@ public class FederatedStatistics {
cacheStats.collectStats();
jitCompileTime =
((double)Statistics.getJITCompileTime()) / 1000; // in sec
gcStats.collectStats();
+ mtStats.collectStats();
heavyHitters = Statistics.getHeavyHittersHashMap();
}
@@ -375,6 +434,7 @@ public class FederatedStatistics {
cacheStats.aggregate(that.cacheStats);
jitCompileTime += that.jitCompileTime;
gcStats.aggregate(that.gcStats);
+ mtStats.aggregate(that.mtStats);
that.heavyHitters.forEach(
(key, value) -> heavyHitters.merge(key, value,
(v1, v2) ->
new ImmutablePair<>(v1.getLeft() +
v2.getLeft(), v1.getRight() + v2.getRight()))
@@ -448,9 +508,36 @@ public class FederatedStatistics {
private double gcTime = 0;
}
+ protected static class MultiTenantStatsCollection implements
Serializable {
+ private static final long serialVersionUID = 1L;
+
+ private void collectStats() {
+ fLTGetCount = getFedLookupTableGetCount();
+ fLTGetTime =
((double)getFedLookupTableGetTime()) / 1000000000; // in sec
+ fLTEntryCount = getFedLookupTableEntryCount();
+ readCacheHits = getFedReadCacheHitCount();
+ readCacheBytes = getFedReadCacheBytesCount();
+ }
+
+ private void aggregate(MultiTenantStatsCollection that)
{
+ fLTGetCount += that.fLTGetCount;
+ fLTGetTime += that.fLTGetTime;
+ fLTEntryCount += that.fLTEntryCount;
+ readCacheHits += that.readCacheHits;
+ readCacheBytes += that.readCacheBytes;
+ }
+
+ private long fLTGetCount = 0;
+ private double fLTGetTime = 0;
+ private long fLTEntryCount = 0;
+ private long readCacheHits = 0;
+ private long readCacheBytes = 0;
+ }
+
private CacheStatsCollection cacheStats = new
CacheStatsCollection();
private double jitCompileTime = 0;
private GCStatsCollection gcStats = new GCStatsCollection();
+ private MultiTenantStatsCollection mtStats = new
MultiTenantStatsCollection();
private HashMap<String, Pair<Long, Double>> heavyHitters = new
HashMap<>();
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
index d324ee6..de0e1d8 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
@@ -20,6 +20,9 @@
package org.apache.sysds.runtime.controlprogram.federated;
import java.security.cert.CertificateException;
+import java.util.concurrent.SynchronousQueue;
+import java.util.concurrent.ThreadPoolExecutor;
+import java.util.concurrent.TimeUnit;
import javax.net.ssl.SSLException;
@@ -28,7 +31,6 @@ import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
-import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
@@ -47,16 +49,22 @@ public class FederatedWorker {
private int _port;
private final FederatedLookupTable _flt;
+ private final FederatedReadCache _frc;
public FederatedWorker(int port) {
_flt = new FederatedLookupTable();
+ _frc = new FederatedReadCache();
_port = (port == -1) ? DMLConfig.DEFAULT_FEDERATED_PORT : port;
}
public void run() throws CertificateException, SSLException {
log.info("Setting up Federated Worker");
- EventLoopGroup bossGroup = new NioEventLoopGroup(1);
- EventLoopGroup workerGroup = new NioEventLoopGroup(1);
+ final int EVENT_LOOP_THREADS = Math.max(4,
Runtime.getRuntime().availableProcessors() * 4);
+ NioEventLoopGroup bossGroup = new NioEventLoopGroup(1);
+ ThreadPoolExecutor workerTPE = new ThreadPoolExecutor(1,
Integer.MAX_VALUE,
+ 10, TimeUnit.SECONDS, new
SynchronousQueue<Runnable>(true));
+ NioEventLoopGroup workerGroup = new
NioEventLoopGroup(EVENT_LOOP_THREADS, workerTPE);
+
ServerBootstrap b = new ServerBootstrap();
// TODO add ability to use real ssl files, not self signed
certificates.
SelfSignedCertificate cert = new SelfSignedCertificate();
@@ -77,7 +85,7 @@ public class FederatedWorker {
new
ObjectDecoder(Integer.MAX_VALUE,
ClassResolvers.weakCachingResolver(ClassLoader.getSystemClassLoader())));
cp.addLast("ObjectEncoder", new
ObjectEncoder());
-
cp.addLast("FederatedWorkerHandler", new FederatedWorkerHandler(_flt));
+
cp.addLast("FederatedWorkerHandler", new FederatedWorkerHandler(_flt, _frc));
}
}).option(ChannelOption.SO_BACKLOG,
128).childOption(ChannelOption.SO_KEEPALIVE, true);
log.info("Starting Federated Worker server at port: " +
_port);
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index 7461765..9f7de1f 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -76,6 +76,7 @@ public class FederatedWorkerHandler extends
ChannelInboundHandlerAdapter {
private static final Logger LOG =
Logger.getLogger(FederatedWorkerHandler.class);
private final FederatedLookupTable _flt;
+ private final FederatedReadCache _frc;
/**
* Create a Federated Worker Handler.
@@ -84,9 +85,11 @@ public class FederatedWorkerHandler extends
ChannelInboundHandlerAdapter {
* separate execution contexts at the federated sites too
*
* @param flt The Federated Lookup Table of the current Federated
Worker.
+ * @param frc read cache shared by all worker handlers
*/
- public FederatedWorkerHandler(FederatedLookupTable flt) {
+ public FederatedWorkerHandler(FederatedLookupTable flt,
FederatedReadCache frc) {
_flt = flt;
+ _frc = frc;
}
@Override
@@ -232,62 +235,78 @@ public class FederatedWorkerHandler extends
ChannelInboundHandlerAdapter {
return readData(filename, dt, request.getID(),
request.getTID(), ecm);
}
- private FederatedResponse readData(String filename, Types.DataType
dataType,
+ private FederatedResponse readData(String filename, DataType dataType,
long id, long tid, ExecutionContextMap ecm) {
MatrixCharacteristics mc = new MatrixCharacteristics();
mc.setBlocksize(ConfigurationManager.getBlocksize());
- CacheableData<?> cd;
- switch(dataType) {
- case MATRIX:
- cd = new MatrixObject(Types.ValueType.FP64,
filename);
- break;
- case FRAME:
- cd = new FrameObject(filename);
- break;
- default:
- throw new
FederatedWorkerHandlerException("Could not recognize datatype");
- }
- FileFormat fmt = null;
- boolean header = false;
- String delim = null;
- FileSystem fs = null;
- MetaDataAll mtd;
+ if(dataType != DataType.MATRIX && dataType != DataType.FRAME)
+ // early throwing of exception to avoid infinitely
waiting threads for data
+ throw new FederatedWorkerHandlerException("Could not
recognize datatype");
+
+ CacheableData<?> cd = _frc.get(filename);
+ if(cd == null) {
+ try {
+ switch(dataType) {
+ case MATRIX:
+ cd = new
MatrixObject(Types.ValueType.FP64, filename);
+ break;
+ case FRAME:
+ cd = new FrameObject(filename);
+ break;
+ default:
+ throw new
FederatedWorkerHandlerException("Could not recognize datatype");
+ }
- try {
- final String mtdName =
DataExpression.getMTDFileName(filename);
- Path path = new Path(mtdName);
- fs = IOUtilFunctions.getFileSystem(mtdName);
- try(BufferedReader br = new BufferedReader(new
InputStreamReader(fs.open(path)))) {
- mtd = new MetaDataAll(br);
- if(!mtd.mtdExists())
- throw new
FederatedWorkerHandlerException("Could not parse metadata file");
- mc.setRows(mtd.getDim1());
- mc.setCols(mtd.getDim2());
- mc.setNonZeros(mtd.getNnz());
- header = mtd.getHasHeader();
- cd = mtd.parseAndSetPrivacyConstraint(cd);
- fmt = mtd.getFileFormat();
- delim = mtd.getDelim();
+ FileFormat fmt = null;
+ boolean header = false;
+ String delim = null;
+ FileSystem fs = null;
+ MetaDataAll mtd;
+
+ try {
+ final String mtdName =
DataExpression.getMTDFileName(filename);
+ Path path = new Path(mtdName);
+ fs =
IOUtilFunctions.getFileSystem(mtdName);
+ try(BufferedReader br = new
BufferedReader(new InputStreamReader(fs.open(path)))) {
+ mtd = new MetaDataAll(br);
+ if(!mtd.mtdExists())
+ throw new
FederatedWorkerHandlerException("Could not parse metadata file");
+ mc.setRows(mtd.getDim1());
+ mc.setCols(mtd.getDim2());
+ mc.setNonZeros(mtd.getNnz());
+ header = mtd.getHasHeader();
+ cd =
mtd.parseAndSetPrivacyConstraint(cd);
+ fmt = mtd.getFileFormat();
+ delim = mtd.getDelim();
+ }
+ }
+ catch(DMLPrivacyException |
FederatedWorkerHandlerException ex) {
+ throw ex;
+ }
+ catch(Exception ex) {
+ String msg = "Exception of type " +
ex.getClass() + " thrown when processing READ request";
+ LOG.error(msg, ex);
+ throw new DMLRuntimeException(msg);
+ }
+ finally {
+ IOUtilFunctions.closeSilently(fs);
+ }
+
+ // put meta data object in symbol table, read
on first operation
+ cd.setMetaData(new MetaDataFormat(mc, fmt));
+ if(fmt == FileFormat.CSV)
+ cd.setFileFormatProperties(new
FileFormatPropertiesCSV(header, delim,
+
DataExpression.DEFAULT_DELIM_SPARSE));
+ cd.enableCleanup(false); // guard against
deletion
+
+ _frc.setData(filename, cd);
+ } catch(Exception ex) {
+ _frc.setInvalid(filename);
+ throw ex;
}
}
- catch(DMLPrivacyException | FederatedWorkerHandlerException ex)
{
- throw ex;
- }
- catch(Exception ex) {
- String msg = "Exception of type " + ex.getClass() + "
thrown when processing READ request";
- LOG.error(msg, ex);
- throw new DMLRuntimeException(msg);
- }
- finally {
- IOUtilFunctions.closeSilently(fs);
- }
- // put meta data object in symbol table, read on first operation
- cd.setMetaData(new MetaDataFormat(mc, fmt));
- if(fmt == FileFormat.CSV)
- cd.setFileFormatProperties(new
FileFormatPropertiesCSV(header, delim, DataExpression.DEFAULT_DELIM_SPARSE));
- cd.enableCleanup(false); // guard against deletion
ecm.get(tid).setVariable(String.valueOf(id), cd);
if(DMLScript.LINEAGE)
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
index 48b31a1..8ea77c4 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
@@ -80,7 +80,7 @@ public class AggregateBinaryFEDInstruction extends
BinaryFEDInstruction {
new long[]{mo1.getFedMapping().getID(),
mo2.getFedMapping().getID()}, true);
if ( _fedOut.isForcedFederated() ){
- mo1.getFedMapping().execute(getTID(), fr1);
+ mo1.getFedMapping().execute(getTID(), true,
fr1);
setPartialOutput(mo1.getFedMapping(), mo1, mo2,
fr1.getID(), ec);
}
else {
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
index 3183194..d329f44 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
@@ -132,7 +132,7 @@ public class AggregateUnaryFEDInstruction extends
UnaryFEDInstruction {
+ ", is a scalar and the output is set to be
federated. Scalars cannot be federated. ");
FederatedRequest fr1 =
FederationUtils.callInstruction(instString, output,
new CPOperand[]{input1}, new
long[]{in.getFedMapping().getID()}, true);
- map.execute(getTID(), fr1);
+ map.execute(getTID(), true, fr1);
MatrixObject out = ec.getMatrixObject(output);
deriveNewOutputFedMapping(in, out, fr1);
diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java
b/src/main/java/org/apache/sysds/utils/Statistics.java
index 12748a0..4382138 100644
--- a/src/main/java/org/apache/sysds/utils/Statistics.java
+++ b/src/main/java/org/apache/sysds/utils/Statistics.java
@@ -655,6 +655,7 @@ public class Statistics
sb.append(FederatedStatistics.displayFedIOExecStatistics());
sb.append(FederatedStatistics.displayFedLookupTableStats());
+
sb.append(FederatedStatistics.displayFedReadCacheStats());
sb.append(TransformStatistics.displayStatistics());
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedMultiTenantTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedMultiTenantTest.java
index 8852931..ca0d1ce 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedMultiTenantTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedMultiTenantTest.java
@@ -19,26 +19,22 @@
package org.apache.sysds.test.functions.federated.multitenant;
-import static org.junit.Assert.fail;
-
-import java.io.IOException;
-import java.nio.charset.Charset;
-import java.util.ArrayList;
+import java.lang.Math;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
-import org.apache.commons.io.IOUtils;
+import static org.junit.Assert.fail;
+
import org.apache.commons.lang3.ArrayUtils;
+import org.apache.commons.lang3.StringUtils;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.util.HDFSTool;
-import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
-import org.junit.After;
import org.junit.Assert;
import org.junit.Ignore;
import org.junit.Test;
@@ -47,7 +43,7 @@ import org.junit.runners.Parameterized;
@RunWith(value = Parameterized.class)
@net.jcip.annotations.NotThreadSafe
-public class FederatedMultiTenantTest extends AutomatedTestBase {
+public class FederatedMultiTenantTest extends MultiTenantTestBase {
private final static String TEST_NAME = "FederatedMultiTenantTest";
private final static String TEST_DIR =
"functions/federated/multitenant/";
@@ -72,9 +68,6 @@ public class FederatedMultiTenantTest extends
AutomatedTestBase {
});
}
- private ArrayList<Process> workerProcesses = new ArrayList<>();
- private ArrayList<Process> coordinatorProcesses = new ArrayList<>();
-
private enum OpType {
SUM,
PARFOR_SUM,
@@ -105,6 +98,7 @@ public class FederatedMultiTenantTest extends
AutomatedTestBase {
}
@Test
+ @Ignore
public void testSumSharedWorkersSP() {
runMultiTenantSharedWorkerTest(OpType.SUM, 3, 9,
ExecMode.SPARK);
}
@@ -116,6 +110,7 @@ public class FederatedMultiTenantTest extends
AutomatedTestBase {
}
@Test
+ @Ignore
public void testParforSumSharedWorkersCP() {
runMultiTenantSharedWorkerTest(OpType.PARFOR_SUM, 3, 9,
ExecMode.SINGLE_NODE);
}
@@ -149,19 +144,11 @@ public class FederatedMultiTenantTest extends
AutomatedTestBase {
}
@Test
+ @Ignore
public void testWSigmoidSharedWorkersSP() {
runMultiTenantSharedWorkerTest(OpType.WSIGMOID, 3, 9,
ExecMode.SPARK);
}
- // ensure that the processes are killed - even if the test throws an
exception
- @After
- public void stopAllProcesses() {
- for(Process p : coordinatorProcesses)
- p.destroyForcibly();
- for(Process p : workerProcesses)
- p.destroyForcibly();
- }
-
private void runMultiTenantSameWorkerTest(OpType opType, int
numCoordinators, ExecMode execMode) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
ExecMode platformOld = rtplatform;
@@ -206,17 +193,20 @@ public class FederatedMultiTenantTest extends
AutomatedTestBase {
// start the coordinator processes
String scriptName = HOME + TEST_NAME + ".dml";
programArgs = new String[] {"-stats", "100", "-fedStats",
"100", "-nvargs",
- "in_X1=" + TestUtils.federatedAddress(workerPorts[0],
input("X1")),
- "in_X2=" + TestUtils.federatedAddress(workerPorts[1],
input("X2")),
- "in_X3=" + TestUtils.federatedAddress(workerPorts[2],
input("X3")),
- "in_X4=" + TestUtils.federatedAddress(workerPorts[3],
input("X4")),
+ "in_X1=" + TestUtils.federatedAddress(workerPorts[0],
""),
+ "in_X2=" + TestUtils.federatedAddress(workerPorts[1],
""),
+ "in_X3=" + TestUtils.federatedAddress(workerPorts[2],
""),
+ "in_X4=" + TestUtils.federatedAddress(workerPorts[3],
""),
+ "in=" + (baseDirectory+INPUT_DIR),
"rows=" + rows, "cols=" + cols, "testnum=" +
Integer.toString(opType.ordinal()),
"rP=" + Boolean.toString(rowPartitioned).toUpperCase()};
for(int counter = 0; counter < numCoordinators; counter++)
- coordinatorProcesses.add(startCoordinator(execMode,
scriptName,
- ArrayUtils.addAll(programArgs, "out_S=" +
output("S" + counter))));
+ startCoordinator(execMode, scriptName,
+ ArrayUtils.addAll(programArgs, "out_S=" +
output("S" + counter)));
- joinCoordinatorsAndVerify(opType, execMode);
+ // wait for the coordinator processes to end and verify the
results
+ String coordinatorOutput = waitForCoordinators();
+ verifyResults(opType, coordinatorOutput, execMode);
// check that federated input files are still existing
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
@@ -287,10 +277,13 @@ public class FederatedMultiTenantTest extends
AutomatedTestBase {
"in_X4=" +
TestUtils.federatedAddress(workerPorts[workerIndexOffset + 3], input("X4")),
"rows=" + rows, "cols=" + cols, "testnum=" +
Integer.toString(opType.ordinal()),
"rP=" +
Boolean.toString(rowPartitioned).toUpperCase(), "out_S=" + output("S" +
counter)};
- coordinatorProcesses.add(startCoordinator(execMode,
scriptName, programArgs));
+ startCoordinator(execMode, scriptName, programArgs);
}
- joinCoordinatorsAndVerify(opType, execMode);
+ // wait for the coordinator processes to end and verify the
results
+ String coordinatorOutput = waitForCoordinators();
+ System.out.println(coordinatorOutput);
+ verifyResults(opType, coordinatorOutput, execMode);
// check that federated input files are still existing
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
@@ -304,78 +297,8 @@ public class FederatedMultiTenantTest extends
AutomatedTestBase {
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
- private int[] startFedWorkers(int numFedWorkers) {
- int[] ports = new int[numFedWorkers];
- for(int counter = 0; counter < numFedWorkers; counter++) {
- ports[counter] = getRandomAvailablePort();
- @SuppressWarnings("deprecation")
- Process tmpProcess =
startLocalFedWorker(ports[counter]);
- workerProcesses.add(tmpProcess);
- }
- return ports;
- }
-
- private Process startCoordinator(ExecMode execMode, String scriptPath,
String[] args) {
- String separator = System.getProperty("file.separator");
- String classpath = System.getProperty("java.class.path");
- String path = System.getProperty("java.home") + separator +
"bin" + separator + "java";
-
- String em = null;
- switch(execMode) {
- case SINGLE_NODE:
- em = "singlenode";
- break;
- case HYBRID:
- em = "hybrid";
- break;
- case SPARK:
- em = "spark";
- break;
- }
-
- ArrayList<String> argsList = new ArrayList<>();
- argsList.add("-f");
- argsList.add(scriptPath);
- argsList.add("-exec");
- argsList.add(em);
- argsList.addAll(Arrays.asList(args));
-
- ProcessBuilder processBuilder = new
ProcessBuilder(ArrayUtils.addAll(new String[]{
- path, "-cp", classpath, DMLScript.class.getName()},
argsList.toArray(new String[0])));
-
- Process process = null;
- try {
- process = processBuilder.start();
- } catch(IOException ioe) {
- ioe.printStackTrace();
- }
-
- return process;
- }
-
- private void joinCoordinatorsAndVerify(OpType opType, ExecMode
execMode) {
- // join the coordinator processes
- for(int counter = 0; counter < coordinatorProcesses.size();
counter++) {
- Process coord = coordinatorProcesses.get(counter);
-
- //wait for process, but obtain logs before to avoid
blocking
- String outputLog = null, errorLog = null;
- try {
- outputLog =
IOUtils.toString(coord.getInputStream(), Charset.defaultCharset());
- errorLog =
IOUtils.toString(coord.getErrorStream(), Charset.defaultCharset());
-
- coord.waitFor();
- }
- catch(Exception ex) {
- ex.printStackTrace();
- }
-
- // get and print the output
- System.out.println("Output of coordinator #" +
Integer.toString(counter + 1) + ":\n");
- System.out.println(outputLog);
- System.out.println(errorLog);
- Assert.assertTrue(checkForHeavyHitter(opType,
outputLog, execMode));
- }
+ private void verifyResults(OpType opType, String outputLog, ExecMode
execMode) {
+ Assert.assertTrue(checkForHeavyHitter(opType, outputLog,
execMode));
// compare the results via files
HashMap<CellIndex, Double> refResults =
readDMLMatrixFromOutputDir("S" + 0);
@@ -387,16 +310,21 @@ public class FederatedMultiTenantTest extends
AutomatedTestBase {
}
}
- private static boolean checkForHeavyHitter(OpType opType, String
outputLog, ExecMode execMode) {
+ private boolean checkForHeavyHitter(OpType opType, String outputLog,
ExecMode execMode) {
switch(opType) {
case SUM:
- return outputLog.contains("fed_uak+");
+ return checkForHeavyHitter(outputLog,
"fed_uak+");
case PARFOR_SUM:
- return outputLog.contains(execMode ==
ExecMode.SPARK ? "fed_rblk" : "fed_uak+");
+ return checkForHeavyHitter(outputLog, execMode
== ExecMode.SPARK ? "fed_rblk" : "fed_uak+");
case WSIGMOID:
- return outputLog.contains("fed_wsigmoid");
+ return checkForHeavyHitter(outputLog,
"fed_wsigmoid");
default:
return false;
}
}
+
+ private boolean checkForHeavyHitter(String outputLog, String hhString) {
+ int occurrences = StringUtils.countMatches(outputLog, hhString);
+ return (occurrences == coordinatorProcesses.size());
+ }
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedReadCacheTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedReadCacheTest.java
new file mode 100644
index 0000000..67e4c1a
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedReadCacheTest.java
@@ -0,0 +1,208 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.multitenant;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashMap;
+
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.commons.lang3.StringUtils;
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Ignore;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class FederatedReadCacheTest extends MultiTenantTestBase {
+ private final static String TEST_NAME = "FederatedReadCacheTest";
+
+ private final static String TEST_DIR =
"functions/federated/multitenant/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
FederatedReadCacheTest.class.getSimpleName() + "/";
+
+ private final static double TOLERANCE = 0;
+
+ private final static int blocksize = 1024;
+ @Parameterized.Parameter()
+ public int rows;
+ @Parameterized.Parameter(1)
+ public int cols;
+ @Parameterized.Parameter(2)
+ public double sparsity;
+ @Parameterized.Parameter(3)
+ public boolean rowPartitioned;
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ return Arrays.asList(
+ new Object[][] {
+ {100, 1000, 0.9, false},
+ // {1000, 100, 0.9, true},
+ // {100, 1000, 0.01, false},
+ // {1000, 100, 0.01, true},
+ });
+ }
+
+ private enum OpType {
+ PLUS_SCALAR,
+ MODIFIED_VAL,
+ }
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"S"}));
+ }
+
+ @Test
+ public void testPlusScalarCP() {
+ runReadCacheTest(OpType.PLUS_SCALAR, 3, ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ @Ignore
+ public void testPlusScalarSP() {
+ runReadCacheTest(OpType.PLUS_SCALAR, 3, ExecMode.SPARK);
+ }
+
+ @Test
+ public void testModifiedValCP() {
+ //TODO with 4 runs sporadically into non-terminating state
+ runReadCacheTest(OpType.MODIFIED_VAL, 3, ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ @Ignore
+ public void testModifiedValSP() {
+ runReadCacheTest(OpType.MODIFIED_VAL, 4, ExecMode.SPARK);
+ }
+
+ private void runReadCacheTest(OpType opType, int numCoordinators,
ExecMode execMode) {
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ ExecMode platformOld = rtplatform;
+
+ if(rtplatform == ExecMode.SPARK)
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+ getAndLoadTestConfiguration(TEST_NAME);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ // write input matrices
+ int r = rows;
+ int c = cols / 4;
+ if(rowPartitioned) {
+ r = rows / 4;
+ c = cols;
+ }
+
+ double[][] X1 = getRandomMatrix(r, c, 0, 3, sparsity, 3);
+ double[][] X2 = getRandomMatrix(r, c, 0, 3, sparsity, 7);
+ double[][] X3 = getRandomMatrix(r, c, 0, 3, sparsity, 8);
+ double[][] X4 = getRandomMatrix(r, c, 0, 3, sparsity, 9);
+
+ MatrixCharacteristics mc = new MatrixCharacteristics(r, c,
blocksize, r * c);
+ writeInputMatrixWithMTD("X1", X1, false, mc);
+ writeInputMatrixWithMTD("X2", X2, false, mc);
+ writeInputMatrixWithMTD("X3", X3, false, mc);
+ writeInputMatrixWithMTD("X4", X4, false, mc);
+
+ // empty script name because we don't execute any script, just
start the worker
+ fullDMLScriptName = "";
+
+ int[] workerPorts = startFedWorkers(4);
+
+ rtplatform = execMode;
+ if(rtplatform == ExecMode.SPARK) {
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+ }
+ TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
+ loadTestConfiguration(config);
+
+ // start the coordinator processes
+ String scriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-stats", "100", "-fedStats",
"100", "-nvargs",
+ "in_X1=" + TestUtils.federatedAddress(workerPorts[0],
""),
+ "in_X2=" + TestUtils.federatedAddress(workerPorts[1],
""),
+ "in_X3=" + TestUtils.federatedAddress(workerPorts[2],
""),
+ "in_X4=" + TestUtils.federatedAddress(workerPorts[3],
""),
+ "in=" + (baseDirectory+INPUT_DIR),
+ "rows=" + rows, "cols=" + cols, "testnum=" +
Integer.toString(opType.ordinal()),
+ "rP=" + Boolean.toString(rowPartitioned).toUpperCase()};
+ for(int counter = 0; counter < numCoordinators; counter++)
+ startCoordinator(execMode, scriptName,
+ ArrayUtils.addAll(programArgs, "out_S=" +
output("S" + counter)));
+
+ // wait for the coordinator processes to end and verify the
results
+ String coordinatorOutput = waitForCoordinators();
+ System.out.println(coordinatorOutput);
+ verifyResults(opType, coordinatorOutput, execMode);
+
+ // check that federated input files are still existing
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
+
+ TestUtils.shutdownThreads(workerProcesses.toArray(new
Process[0]));
+
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+
+ private void verifyResults(OpType opType, String outputLog, ExecMode
execMode) {
+ Assert.assertTrue(checkForHeavyHitter(opType, outputLog,
execMode));
+ // verify that the matrix object has been taken from cache
+ Assert.assertTrue(outputLog.contains("Fed ReadCache (Hits,
Bytes):\t"
+ + Integer.toString((coordinatorProcesses.size()-1) *
workerProcesses.size()) + "/"));
+
+ // compare the results via files
+ HashMap<CellIndex, Double> refResults =
readDMLMatrixFromOutputDir("S" + 0);
+ Assert.assertFalse("The result of the first coordinator, which
is taken as reference, is empty.",
+ refResults.isEmpty());
+ for(int counter = 1; counter < coordinatorProcesses.size();
counter++) {
+ HashMap<CellIndex, Double> fedResults =
readDMLMatrixFromOutputDir("S" + counter);
+ TestUtils.compareMatrices(fedResults, refResults,
TOLERANCE, "Fed" + counter, "FedRef");
+ }
+ }
+
+ private boolean checkForHeavyHitter(OpType opType, String outputLog,
ExecMode execMode) {
+ switch(opType) {
+ case PLUS_SCALAR:
+ return checkForHeavyHitter(outputLog, "fed_+");
+ case MODIFIED_VAL:
+ return checkForHeavyHitter(outputLog, "fed_*")
&& checkForHeavyHitter(outputLog, "fed_+");
+ }
+ return false;
+ }
+
+ private boolean checkForHeavyHitter(String outputLog, String hhString) {
+ int occurrences = StringUtils.countMatches(outputLog, hhString);
+ return (occurrences == coordinatorProcesses.size());
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/multitenant/MultiTenantTestBase.java
b/src/test/java/org/apache/sysds/test/functions/federated/multitenant/MultiTenantTestBase.java
new file mode 100644
index 0000000..90f50d4
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/federated/multitenant/MultiTenantTestBase.java
@@ -0,0 +1,143 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.multitenant;
+
+import java.io.IOException;
+import java.nio.charset.Charset;
+import java.util.ArrayList;
+import java.util.Arrays;
+
+import static org.junit.Assert.fail;
+
+import org.apache.commons.io.IOUtils;
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.junit.After;
+
+public abstract class MultiTenantTestBase extends AutomatedTestBase {
+ protected ArrayList<Process> workerProcesses = new ArrayList<>();
+ protected ArrayList<Process> coordinatorProcesses = new ArrayList<>();
+
+ @Override
+ public abstract void setUp();
+
+ // ensure that the processes are killed - even if the test throws an
exception
+ @After
+ public void stopAllProcesses() {
+ for(Process p : coordinatorProcesses)
+ p.destroyForcibly();
+ for(Process p : workerProcesses)
+ p.destroyForcibly();
+ }
+
+ /**
+ * Start numFedWorkers federated worker processes on available ports
and add
+ * them to the workerProcesses
+ *
+ * @param numFedWorkers the number of federated workers to start
+ * @return int[] the ports of the created federated workers
+ */
+ protected int[] startFedWorkers(int numFedWorkers) {
+ int[] ports = new int[numFedWorkers];
+ for(int counter = 0; counter < numFedWorkers; counter++) {
+ ports[counter] = getRandomAvailablePort();
+ @SuppressWarnings("deprecation")
+ Process tmpProcess =
startLocalFedWorker(ports[counter]);
+ workerProcesses.add(tmpProcess);
+ }
+ return ports;
+ }
+
+ /**
+ * Start a coordinator process running the specified script with given
arguments
+ * and add it to the coordinatorProcesses
+ *
+ * @param execMode the execution mode of the coordinator
+ * @param scriptPath the path to the dml script
+ * @param args the program arguments for running the dml script
+ */
+ protected void startCoordinator(ExecMode execMode, String scriptPath,
String[] args) {
+ String separator = System.getProperty("file.separator");
+ String classpath = System.getProperty("java.class.path");
+ String path = System.getProperty("java.home") + separator +
"bin" + separator + "java";
+
+ String em = null;
+ switch(execMode) {
+ case SINGLE_NODE:
+ em = "singlenode";
+ break;
+ case HYBRID:
+ em = "hybrid";
+ break;
+ case SPARK:
+ em = "spark";
+ break;
+ }
+
+ ArrayList<String> argsList = new ArrayList<>();
+ argsList.add("-f");
+ argsList.add(scriptPath);
+ argsList.add("-exec");
+ argsList.add(em);
+ argsList.addAll(Arrays.asList(args));
+
+ // create the processBuilder and redirect the stderr to its
stdout
+ ProcessBuilder processBuilder = new
ProcessBuilder(ArrayUtils.addAll(new String[]{
+ path, "-cp", classpath, DMLScript.class.getName()},
argsList.toArray(new String[0])));
+
+ Process process = null;
+ try {
+ process = processBuilder.start();
+ } catch(IOException ioe) {
+ ioe.printStackTrace();
+ fail("Can't start the coordinator process.");
+ }
+ coordinatorProcesses.add(process);
+ }
+
+ /**
+ * Wait for all processes of coordinatorProcesses to terminate and
collect
+ * their output
+ *
+ * @return String the collected output of the coordinator processes
+ */
+ protected String waitForCoordinators() {
+ // wait for the coordinator processes to finish and collect
their output
+ StringBuilder outputLog = new StringBuilder();
+ for(int counter = 0; counter < coordinatorProcesses.size();
counter++) {
+ Process coord = coordinatorProcesses.get(counter);
+ try {
+ outputLog.append("\n");
+ outputLog.append("Output of coordinator #" +
Integer.toString(counter + 1) + ":\n");
+
outputLog.append(IOUtils.toString(coord.getInputStream(),
Charset.defaultCharset()));
+
outputLog.append(IOUtils.toString(coord.getErrorStream(),
Charset.defaultCharset()));
+
+ coord.waitFor();
+ } catch(Exception ex) {
+ fail(ex.getClass().getSimpleName() + " thrown
while collecting log output of coordinator #"
+ + Integer.toString(counter+1) + ".\n");
+ ex.printStackTrace();
+ }
+ }
+ return outputLog.toString();
+ }
+}
diff --git a/src/test/scripts/functions/binary/frame/frameComparisonTest.dml
b/src/test/scripts/functions/binary/frame/frameComparisonTest.dml
index c43a614..44b23e8 100644
--- a/src/test/scripts/functions/binary/frame/frameComparisonTest.dml
+++ b/src/test/scripts/functions/binary/frame/frameComparisonTest.dml
@@ -40,4 +40,4 @@ else if (test == "LESS_EQUALS")
C = as.matrix(C)
# print("this is C "+toString(C))
-write(C, $C);
\ No newline at end of file
+write(C, $C);
diff --git
a/src/test/scripts/functions/binary/matrix_full_cellwise/FullMatrixVectorRowCellwiseOperation_Addition.R
b/src/test/scripts/functions/binary/matrix_full_cellwise/FullMatrixVectorRowCellwiseOperation_Addition.R
index a582836..49b5436 100644
---
a/src/test/scripts/functions/binary/matrix_full_cellwise/FullMatrixVectorRowCellwiseOperation_Addition.R
+++
b/src/test/scripts/functions/binary/matrix_full_cellwise/FullMatrixVectorRowCellwiseOperation_Addition.R
@@ -7,9 +7,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -35,5 +35,3 @@ B <- as.matrix(B1);
C <- t(t(A)+as.vector(t(B)))
writeMM(as(C, "CsparseMatrix"), paste(args[2], "C", sep=""));
-
-
diff --git
a/src/test/scripts/functions/federated/multitenant/FederatedMultiTenantTest.dml
b/src/test/scripts/functions/federated/multitenant/FederatedMultiTenantTest.dml
index e0bef2e..963d5ca 100644
---
a/src/test/scripts/functions/federated/multitenant/FederatedMultiTenantTest.dml
+++
b/src/test/scripts/functions/federated/multitenant/FederatedMultiTenantTest.dml
@@ -19,12 +19,17 @@
#
#-------------------------------------------------------------
+in_X1 = $in_X1 + $in + "/X1";
+in_X2 = $in_X2 + $in + "/X2";
+in_X3 = $in_X3 + $in + "/X3";
+in_X4 = $in_X4 + $in + "/X4";
+
if ($rP) {
- X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ X = federated(addresses=list(in_X1, in_X2, in_X3, in_X4),
ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0),
list(2*$rows/4, $cols),
list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0),
list($rows, $cols)));
} else {
- X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ X = federated(addresses=list(in_X1, in_X2, in_X3, in_X4),
ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4),
list($rows, $cols/2),
list(0,$cols/2), list($rows, 3*($cols/4)), list(0,
3*($cols/4)), list($rows, $cols)));
}
diff --git
a/src/test/scripts/functions/federated/multitenant/FederatedMultiTenantTest.dml
b/src/test/scripts/functions/federated/multitenant/FederatedReadCacheTest.dml
similarity index 54%
copy from
src/test/scripts/functions/federated/multitenant/FederatedMultiTenantTest.dml
copy to
src/test/scripts/functions/federated/multitenant/FederatedReadCacheTest.dml
index e0bef2e..2814dbc 100644
---
a/src/test/scripts/functions/federated/multitenant/FederatedMultiTenantTest.dml
+++
b/src/test/scripts/functions/federated/multitenant/FederatedReadCacheTest.dml
@@ -19,45 +19,33 @@
#
#-------------------------------------------------------------
+
+in_X1 = $in_X1 + $in + "/X1";
+in_X2 = $in_X2 + $in + "/X2";
+in_X3 = $in_X3 + $in + "/X3";
+in_X4 = $in_X4 + $in + "/X4";
+
if ($rP) {
- X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ X = federated(addresses=list(in_X1, in_X2, in_X3, in_X4),
ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0),
list(2*$rows/4, $cols),
- list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0),
list($rows, $cols)));
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4,
0), list($rows, $cols)));
} else {
- X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
- ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4),
list($rows, $cols/2),
- list(0,$cols/2), list($rows, 3*($cols/4)), list(0,
3*($cols/4)), list($rows, $cols)));
+ X = federated(addresses=list(in_X1, in_X2, in_X3, in_X4),
+ ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4),
list($rows, $cols/2),
+ list(0,$cols/2), list($rows, 3*($cols/4)), list(0,
3*($cols/4)), list($rows, $cols)));
}
testnum = $testnum;
-if(testnum == 0) { # SUM
+if(testnum == 0) { # PLUS_SCALAR
+ X = X + 1;
+ while(FALSE) { }
S = as.matrix(sum(X));
}
-else if(testnum == 1) { # PARFOR_SUM
- numiter = 5;
- Z = matrix(0, rows=numiter, cols=1);
- parfor( i in 1:numiter ) {
- while(FALSE) { }
- Y = X + i;
- while(FALSE) { }
- Z[i, 1] = sum(Y);
- }
- S = as.matrix(0);
- for( i in 1:numiter ) {
- while(FALSE) { }
- S = S + Z[i, 1];
- }
-}
-else if(testnum == 2) { # WSIGMOID
- N = nrow(X);
- M = ncol(X);
-
- U = rand(rows=N, cols=15, seed=123);
- V = rand(rows=M, cols=15, seed=456);
-
- UV = U %*% t(V);
- S = X * log(1 / (1 + exp(-UV)));
+else if(testnum == 1) { # MODIFIED_VAL
+ X[nrow(X)/2, ncol(X)/2] = (X[nrow(X)/2, ncol(X)/2] + 1) * 10;
+ while(FALSE) { }
+ S = as.matrix(sum(X));
}
write(S, $out_S);