This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 36e84d5  [SYSTEMDS-3185] Transfer lineage traces to federated workers
36e84d5 is described below

commit 36e84d5fc020a150dc7d9952b24c652b453e4b13
Author: ywcb00 <[email protected]>
AuthorDate: Fri Oct 22 18:02:04 2021 +0200

    [SYSTEMDS-3185] Transfer lineage traces to federated workers
    
    This patch introduces the mechanics for tranferring the lineage trace
    of a data object to the federated worker.
    For now, we are including only the lineage trace of matrices which come
    from the datagen operation (e.g. rand()), as they have the respective
    lineage item set in their CacheableData objects.
    
    Closes #1544
---
 .../runtime/controlprogram/LocalVariableMap.java   |   7 +-
 .../controlprogram/federated/FederatedData.java    |   3 +-
 .../controlprogram/federated/FederatedRequest.java |  12 +-
 .../federated/FederatedStatistics.java             |  93 ++++++++
 .../federated/FederatedWorkerHandler.java          |  17 +-
 .../controlprogram/federated/FederationMap.java    |  20 +-
 .../sysds/runtime/instructions/cp/CPOperand.java   |   8 +-
 .../org/apache/sysds/runtime/lineage/Lineage.java  |  13 +-
 .../runtime/lineage/LineageCacheStatistics.java    |  32 ++-
 .../apache/sysds/runtime/lineage/LineageItem.java  |  12 +-
 .../FederatedLineageTraceReuseTest.java            | 252 +++++++++++++++++++++
 .../multitenant/FederatedLineageTraceReuseTest.dml |  69 ++++++
 12 files changed, 514 insertions(+), 24 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/LocalVariableMap.java 
b/src/main/java/org/apache/sysds/runtime/controlprogram/LocalVariableMap.java
index bac6759..3f92c9e 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/LocalVariableMap.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/LocalVariableMap.java
@@ -19,6 +19,7 @@
 
 package org.apache.sysds.runtime.controlprogram;
 
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.Map;
@@ -44,19 +45,19 @@ public class LocalVariableMap implements Cloneable
        private static final IDSequence _seq = new IDSequence();
        
        //variable map data and id
-       private final HashMap<String, Data> localMap;
+       private final ConcurrentHashMap<String, Data> localMap;
        private final long localID;
        
        //optional set of registered outputs
        private HashSet<String> outputs = null;
        
        public LocalVariableMap() {
-               localMap = new HashMap<>();
+               localMap = new ConcurrentHashMap<>();
                localID = _seq.getNextID();
        }
        
        public LocalVariableMap(LocalVariableMap vars) {
-               localMap = new HashMap<>(vars.localMap);
+               localMap = new ConcurrentHashMap<>(vars.localMap);
                localID = _seq.getNextID();
        }
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
index 07079b7..3f4173d 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
@@ -33,9 +33,8 @@ import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.common.Types;
 import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.conf.DMLConfig;
-
-import org.apache.sysds.runtime.DMLRuntimeException;
 import 
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
+import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.meta.MetaData;
 
 import io.netty.bootstrap.Bootstrap;
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
index 6e9b388..1566a65 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
@@ -57,6 +57,7 @@ public class FederatedRequest implements Serializable {
        private boolean _checkPrivacy;
        private List<Long> _checksums;
        private long _pid;
+       private String _lineageTrace; // the serialized lineage trace of a put 
object
 
        public FederatedRequest(RequestType method) {
                this(method, FederationUtils.getNextFedDataID(), new 
ArrayList<>());
@@ -70,6 +71,11 @@ public class FederatedRequest implements Serializable {
                this(method, id, Arrays.asList(data));
        }
 
+       public FederatedRequest(RequestType method, String linTrace, long id, 
Object ... data) {
+               this(method, id, Arrays.asList(data));
+               _lineageTrace = linTrace;
+       }
+
        public FederatedRequest(RequestType method, long id, List<Object> data) 
{
                if(DMLScript.STATISTICS)
                        FederatedStatistics.incFederated(method, data);
@@ -78,8 +84,6 @@ public class FederatedRequest implements Serializable {
                _data = data;
                _pid = Long.valueOf(IDHandler.obtainProcessID());
                setCheckPrivacy();
-               if (DMLScript.LINEAGE && method == RequestType.PUT_VAR)
-                       setChecksum();
        }
 
        public RequestType getType() {
@@ -185,6 +189,10 @@ public class FederatedRequest implements Serializable {
                }
        }
 
+       public String getLineageTrace() {
+               return _lineageTrace;
+       }
+
        @Override
        public String toString() {
                StringBuilder sb = new StringBuilder("FederatedRequest[");
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java
index 5597d18..9a9cfa7 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java
@@ -44,11 +44,13 @@ import 
org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import 
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
 import 
org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics.FedStatsCollection.CacheStatsCollection;
 import 
org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics.FedStatsCollection.GCStatsCollection;
+import 
org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics.FedStatsCollection.LineageCacheStatsCollection;
 import 
org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics.FedStatsCollection.MultiTenantStatsCollection;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.instructions.cp.Data;
 import org.apache.sysds.runtime.instructions.cp.ListObject;
 import org.apache.sysds.runtime.instructions.cp.ScalarObject;
+import org.apache.sysds.runtime.lineage.LineageCacheStatistics;
 import org.apache.sysds.runtime.lineage.LineageItem;
 import org.apache.sysds.runtime.matrix.data.FrameBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -77,6 +79,8 @@ public class FederatedStatistics {
        private static final LongAdder fedLookupTableEntryCount = new 
LongAdder();
        private static final LongAdder fedReuseReadHitCount = new LongAdder();
        private static final LongAdder fedReuseReadBytesCount = new LongAdder();
+       private static final LongAdder fedPutLineageCount = new LongAdder();
+       private static final LongAdder fedPutLineageItems = new LongAdder();
 
        public static synchronized void incFederated(RequestType rqt, 
List<Object> data){
                switch (rqt) {
@@ -141,6 +145,8 @@ public class FederatedStatistics {
                fedLookupTableEntryCount.reset();
                fedReuseReadHitCount.reset();
                fedReuseReadBytesCount.reset();
+               fedPutLineageCount.reset();
+               fedPutLineageItems.reset();
        }
 
        public static String displayFedIOExecStatistics() {
@@ -193,6 +199,7 @@ public class FederatedStatistics {
                sb.append(displayCacheStats(fedStats.cacheStats));
                sb.append(String.format("Total JIT compile time:\t\t%.3f 
sec.\n", fedStats.jitCompileTime));
                sb.append(displayGCStats(fedStats.gcStats));
+               sb.append(displayLinCacheStats(fedStats.linCacheStats));
                sb.append(displayMultiTenantStats(fedStats.mtStats));
                sb.append(displayHeavyHitters(fedStats.heavyHitters, 
numHeavyHitters));
                return sb.toString();
@@ -216,10 +223,22 @@ public class FederatedStatistics {
                return sb.toString();
        }
 
+       private static String displayLinCacheStats(LineageCacheStatsCollection 
lcsc) {
+               StringBuilder sb = new StringBuilder();
+               sb.append(String.format("LinCache hits 
(Mem/FS/Del):\t%d/%d/%d.\n",
+                       lcsc.numHitsMem, lcsc.numHitsFS, lcsc.numHitsDel));
+               sb.append(String.format("LinCache MultiLvl 
(Ins/SB/Fn):\t%d/%d/%d.\n",
+                       lcsc.numHitsInst, lcsc.numHitsSB, lcsc.numHitsFunc));
+               sb.append(String.format("LinCache writes 
(Mem/FS/Del):\t%d/%d/%d.\n",
+                       lcsc.numWritesMem, lcsc.numWritesFS, lcsc.numMemDel));
+               return sb.toString();
+       }
+
        private static String 
displayMultiTenantStats(MultiTenantStatsCollection mtsc) {
                StringBuilder sb = new StringBuilder();
                sb.append(displayFedLookupTableStats(mtsc.fLTGetCount, 
mtsc.fLTEntryCount, mtsc.fLTGetTime));
                sb.append(displayFedReuseReadStats(mtsc.reuseReadHits, 
mtsc.reuseReadBytes));
+               sb.append(displayFedPutLineageStats(mtsc.putLineageCount, 
mtsc.putLineageItems));
                return sb.toString();
        }
 
@@ -349,6 +368,14 @@ public class FederatedStatistics {
                return fedReuseReadBytesCount.longValue();
        }
 
+       public static long getFedPutLineageCount() {
+               return fedPutLineageCount.longValue();
+       }
+
+       public static long getFedPutLineageItems() {
+               return fedPutLineageItems.longValue();
+       }
+
        public static void incFedLookupTableGetCount() {
                fedLookupTableGetCount.increment();
        }
@@ -373,6 +400,11 @@ public class FederatedStatistics {
                fedReuseReadBytesCount.add(cb.getInMemorySize());
        }
 
+       public static void aggFedPutLineage(String serializedLineage) {
+               fedPutLineageCount.increment();
+               fedPutLineageItems.add(serializedLineage.lines().count());
+       }
+
        public static String displayFedLookupTableStats() {
                return 
displayFedLookupTableStats(fedLookupTableGetCount.longValue(),
                        fedLookupTableEntryCount.longValue(), 
fedLookupTableGetTime.doubleValue() / 1000000000);
@@ -403,6 +435,20 @@ public class FederatedStatistics {
                return "";
        }
 
+       public static String displayFedPutLineageStats() {
+               return displayFedPutLineageStats(fedPutLineageCount.longValue(),
+                       fedPutLineageItems.longValue());
+       }
+
+       public static String displayFedPutLineageStats(long plCount, long 
plItems) {
+               if(plCount > 0) {
+                       StringBuilder sb = new StringBuilder();
+                       sb.append("Fed PutLineage (Count, Items):\t" +
+                               plCount + "/" + plItems + ".\n");
+                       return sb.toString();
+               }
+               return "";
+       }
 
        private static class FedStatsCollectFunction extends FederatedUDF {
                private static final long serialVersionUID = 1L;
@@ -431,6 +477,7 @@ public class FederatedStatistics {
                        cacheStats.collectStats();
                        jitCompileTime = 
((double)Statistics.getJITCompileTime()) / 1000; // in sec
                        gcStats.collectStats();
+                       linCacheStats.collectStats();
                        mtStats.collectStats();
                        heavyHitters = Statistics.getHeavyHittersHashMap();
                }
@@ -439,6 +486,7 @@ public class FederatedStatistics {
                        cacheStats.aggregate(that.cacheStats);
                        jitCompileTime += that.jitCompileTime;
                        gcStats.aggregate(that.gcStats);
+                       linCacheStats.aggregate(that.linCacheStats);
                        mtStats.aggregate(that.mtStats);
                        that.heavyHitters.forEach(
                                (key, value) -> heavyHitters.merge(key, value, 
(v1, v2) ->
@@ -513,6 +561,44 @@ public class FederatedStatistics {
                        private double gcTime = 0;
                }
 
+               protected static class LineageCacheStatsCollection implements 
Serializable {
+                       private static final long serialVersionUID = 1L;
+
+                       private void collectStats() {
+                               numHitsMem = 
LineageCacheStatistics.getMemHits();
+                               numHitsFS = LineageCacheStatistics.getFSHits();
+                               numHitsDel = 
LineageCacheStatistics.getDelHits();
+                               numHitsInst = 
LineageCacheStatistics.getInstHits();
+                               numHitsSB = LineageCacheStatistics.getSBHits();
+                               numHitsFunc = 
LineageCacheStatistics.getFuncHits();
+                               numWritesMem = 
LineageCacheStatistics.getMemWrites();
+                               numWritesFS = 
LineageCacheStatistics.getFSWrites();
+                               numMemDel = 
LineageCacheStatistics.getMemDeletes();
+                       }
+
+                       private void aggregate(LineageCacheStatsCollection 
that) {
+                               numHitsMem += that.numHitsMem;
+                               numHitsFS += that.numHitsFS;
+                               numHitsDel += that.numHitsDel;
+                               numHitsInst += that.numHitsInst;
+                               numHitsSB += that.numHitsSB;
+                               numHitsFunc += that.numHitsFunc;
+                               numWritesMem += that.numWritesMem;
+                               numWritesFS += that.numWritesFS;
+                               numMemDel += that.numMemDel;
+                       }
+
+                       private long numHitsMem = 0;
+                       private long numHitsFS = 0;
+                       private long numHitsDel = 0;
+                       private long numHitsInst = 0;
+                       private long numHitsSB = 0;
+                       private long numHitsFunc = 0;
+                       private long numWritesMem = 0;
+                       private long numWritesFS = 0;
+                       private long numMemDel = 0;
+               }
+
                protected static class MultiTenantStatsCollection implements 
Serializable {
                        private static final long serialVersionUID = 1L;
 
@@ -522,6 +608,8 @@ public class FederatedStatistics {
                                fLTEntryCount = getFedLookupTableEntryCount();
                                reuseReadHits = getFedReuseReadHitCount();
                                reuseReadBytes = getFedReuseReadBytesCount();
+                               putLineageCount = getFedPutLineageCount();
+                               putLineageItems = getFedPutLineageItems();
                        }
 
                        private void aggregate(MultiTenantStatsCollection that) 
{
@@ -530,6 +618,8 @@ public class FederatedStatistics {
                                fLTEntryCount += that.fLTEntryCount;
                                reuseReadHits += that.reuseReadHits;
                                reuseReadBytes += that.reuseReadBytes;
+                               putLineageCount += that.putLineageCount;
+                               putLineageItems += that.putLineageItems;
                        }
 
                        private long fLTGetCount = 0;
@@ -537,11 +627,14 @@ public class FederatedStatistics {
                        private long fLTEntryCount = 0;
                        private long reuseReadHits = 0;
                        private long reuseReadBytes = 0;
+                       private long putLineageCount = 0;
+                       private long putLineageItems = 0;
                }
 
                private CacheStatsCollection cacheStats = new 
CacheStatsCollection();
                private double jitCompileTime = 0;
                private GCStatsCollection gcStats = new GCStatsCollection();
+               private LineageCacheStatsCollection linCacheStats = new 
LineageCacheStatsCollection();
                private MultiTenantStatsCollection mtStats = new 
MultiTenantStatsCollection();
                private HashMap<String, Pair<Long, Double>> heavyHitters = new 
HashMap<>();
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index 9a758ea..1cbfb5a 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -51,11 +51,13 @@ import 
org.apache.sysds.runtime.controlprogram.federated.FederatedResponse.Respo
 import org.apache.sysds.runtime.instructions.Instruction;
 import org.apache.sysds.runtime.instructions.Instruction.IType;
 import org.apache.sysds.runtime.instructions.InstructionParser;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.instructions.cp.Data;
 import org.apache.sysds.runtime.instructions.cp.ListObject;
 import org.apache.sysds.runtime.instructions.cp.ScalarObject;
 import org.apache.sysds.runtime.io.FileFormatPropertiesCSV;
 import org.apache.sysds.runtime.io.IOUtilFunctions;
+import org.apache.sysds.runtime.lineage.Lineage;
 import org.apache.sysds.runtime.lineage.LineageCache;
 import org.apache.sysds.runtime.lineage.LineageCacheConfig;
 import org.apache.sysds.runtime.lineage.LineageCacheConfig.ReuseCacheType;
@@ -380,9 +382,18 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
 
                // set variable and construct empty response
                ec.setVariable(varName, data);
-               if(DMLScript.LINEAGE && request.getNumParams()==1)
-                       // don't trace if the data contains only metadata
-                       ec.getLineage().set(varName, new 
LineageItem(String.valueOf(request.getChecksum(0))));
+
+               if(DMLScript.LINEAGE) {
+                       if(request.getParam(0) instanceof CacheBlock && 
request.getLineageTrace() != null) {
+                               ec.getLineage().set(varName, 
Lineage.deserializeSingleTrace(request.getLineageTrace()));
+                               if(DMLScript.STATISTICS)
+                                       
FederatedStatistics.aggFedPutLineage(request.getLineageTrace());
+                       }
+                       else if(request.getParam(0) instanceof ScalarObject)
+                               ec.getLineage().set(varName, new 
LineageItem(CPOperand.getLineageLiteral((ScalarObject)request.getParam(0), 
true)));
+                       else if(request.getNumParams()==1) // don't trace if 
the data contains only metadata
+                               ec.getLineage().set(varName, new 
LineageItem(String.valueOf(request.getChecksum(0))));
+               }
 
                return new FederatedResponse(ResponseType.SUCCESS_EMPTY);
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index 7e3c101..1676078 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -32,12 +32,14 @@ import java.util.stream.IntStream;
 import java.util.stream.Stream;
 
 import org.apache.commons.lang3.tuple.Pair;
+import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
 import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
 import 
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
 import org.apache.sysds.runtime.instructions.cp.ScalarObject;
 import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
+import org.apache.sysds.runtime.lineage.Lineage;
 import org.apache.sysds.runtime.matrix.data.FrameBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.util.CommonThreadPool;
@@ -196,11 +198,15 @@ public class FederationMap {
                // prepare single request for all federated data
                long id = FederationUtils.getNextFedDataID();
                CacheBlock cb = data.acquireReadAndRelease();
+
+               final String lineageTrace = (DMLScript.LINEAGE && 
data.getCacheLineage() != null) ?
+                       Lineage.serializeSingleTrace(data.getCacheLineage()) : 
null;
+
                // create new fed mapping for broadcast (a potential overwrite
                // is fine, because with broadcast all data on all workers)
                data.setFedMapping(copyWithNewIDAndRange(
                        cb.getNumRows(), cb.getNumColumns(), id, 
FType.BROADCAST));
-               return new FederatedRequest(RequestType.PUT_VAR, id, cb);
+               return new FederatedRequest(RequestType.PUT_VAR, lineageTrace, 
id, cb);
        }
 
        public FederatedRequest broadcast(ScalarObject scalar) {
@@ -210,7 +216,7 @@ public class FederationMap {
        }
 
        /**
-        * Creates separate slices of an input data object according to the 
index ranges of federated data. Theses slices
+        * Creates separate slices of an input data object according to the 
index ranges of federated data. These slices
         * are then wrapped in separate federated requests for broadcasting.
         *
         * @param data       input data object (matrix, tensor, frame)
@@ -253,8 +259,11 @@ public class FederationMap {
                }
                // multi-threaded block slicing and federation request creation
                else {
+                       final String lineageTrace = (DMLScript.LINEAGE && 
data.getCacheLineage() != null) ?
+                               
Lineage.serializeSingleTrace(data.getCacheLineage()) : null;
+
                        Arrays.parallelSetAll(ret,
-                               i -> new FederatedRequest(RequestType.PUT_VAR, 
id,
+                               i -> new FederatedRequest(RequestType.PUT_VAR, 
lineageTrace, id,
                                cb.slice(ix[i][0], ix[i][1], ix[i][2], 
ix[i][3], new MatrixBlock())));
                }
                return ret;
@@ -268,10 +277,13 @@ public class FederationMap {
                long id = FederationUtils.getNextFedDataID();
                CacheBlock cb = data.acquireReadAndRelease();
 
+               final String lineageTrace = (DMLScript.LINEAGE && 
data.getCacheLineage() != null) ?
+                       Lineage.serializeSingleTrace(data.getCacheLineage()) : 
null;
+
                // multi-threaded block slicing and federation request creation
                FederatedRequest[] ret = new FederatedRequest[ix.length];
                Arrays.setAll(ret,
-                       i -> new FederatedRequest(RequestType.PUT_VAR, id,
+                       i -> new FederatedRequest(RequestType.PUT_VAR, 
lineageTrace, id,
                                cb.slice(ix[i][0], ix[i][1], ix[i][2], 
ix[i][3], isFrame ? new FrameBlock() : new MatrixBlock())));
                return ret;
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPOperand.java 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPOperand.java
index 30ded87..6570259 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPOperand.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPOperand.java
@@ -181,8 +181,12 @@ public class CPOperand
        }
        
        public String getLineageLiteral(ScalarObject so) {
+               return getLineageLiteral(so, isLiteral());
+       }
+
+       public static String getLineageLiteral(ScalarObject so, boolean 
isLiteral) {
                return InstructionUtils.concatOperandParts(
-                       so.toString(), getDataType().name(),
-                       getValueType().name(), String.valueOf(isLiteral()));
+                       so.toString(), so.getDataType().name(),
+                       so.getValueType().name(), String.valueOf(isLiteral));
        }
 }
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/Lineage.java 
b/src/main/java/org/apache/sysds/runtime/lineage/Lineage.java
index af7b6e0..f7387db 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/Lineage.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/Lineage.java
@@ -24,8 +24,8 @@ import 
org.apache.sysds.runtime.controlprogram.ForProgramBlock;
 import org.apache.sysds.runtime.controlprogram.ProgramBlock;
 import org.apache.sysds.runtime.controlprogram.WhileProgramBlock;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysds.runtime.instructions.Instruction;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.Instruction;
 import org.apache.sysds.runtime.lineage.LineageCacheConfig.ReuseCacheType;
 
 import java.util.ArrayList;
@@ -169,6 +169,17 @@ public class Lineage {
                }
                return ret;
        }
+
+       public static String serializeSingleTrace(LineageItem linItem) {
+               if(linItem == null)
+                       throw new DMLRuntimeException("Cannot serialize null 
lineage object.");
+
+               return explain(linItem);
+       }
+
+       public static LineageItem deserializeSingleTrace(String serialLinTrace) 
{
+               return LineageParser.parseLineageTrace(serialLinTrace);
+       }
        
        public static void resetInternalState() {
                LineageItem.resetIDSequence();
diff --git 
a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java
index 3382365..fc34f73 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java
@@ -71,16 +71,28 @@ public class LineageCacheStatistics {
                _numHitsMem.increment();
        }
 
+       public static long getMemHits() {
+               return _numHitsMem.longValue();
+       }
+
        public static void incrementFSHits() {
                // Number of times found in local FS.
                _numHitsFS.increment();
        }
 
+       public static long getFSHits() {
+               return _numHitsFS.longValue();
+       }
+
        public static void incrementDelHits() {
                // Number of times entry is removed from cache but sought again 
later.
                _numHitsDel.increment();
        }
 
+       public static long getDelHits() {
+               return _numHitsDel.longValue();
+       }
+
        public static void incrementInstHits() {
                // Number of times single instruction results are reused (full 
and partial).
                _numHitsInst.increment();
@@ -95,16 +107,28 @@ public class LineageCacheStatistics {
                _numHitsSB.increment();
        }
 
+       public static long getSBHits() {
+               return _numHitsSB.longValue();
+       }
+
        public static void incrementFuncHits() {
                // Number of times function results are reused.
                _numHitsFunc.increment();
        }
 
+       public static long getFuncHits() {
+               return _numHitsFunc.longValue();
+       }
+
        public static void incrementMemWrites() {
                // Number of times written in cache.
                _numWritesMem.increment();
        }
 
+       public static long getMemWrites() {
+               return _numWritesMem.longValue();
+       }
+
        public static void incrementPRewrites() {
                // Number of partial rewrites.
                _numRewrites.increment();
@@ -114,12 +138,16 @@ public class LineageCacheStatistics {
                // Number of times written in local FS.
                _numWritesFS.increment();
        }
-       
+
+       public static long getFSWrites() {
+               return _numWritesFS.longValue();
+       }
+
        public static void incrementMemDeletes() {
                // Number of deletions from cache (including spilling).
                _numMemDel.increment();
        }
-       
+
        public static long getMemDeletes() {
                return _numMemDel.longValue();
        }
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageItem.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageItem.java
index 14b9894..31284f7 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageItem.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageItem.java
@@ -19,6 +19,7 @@
 
 package org.apache.sysds.runtime.lineage;
 
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.Stack;
@@ -38,9 +39,8 @@ public class LineageItem {
        private LineageItem _dedupPatch;
        private long _distLeaf2Node;
        private final BooleanArray32 _specialValueBits;  // TODO: Move this to 
a new subclass
-       // init visited to true to ensure visited items are
-       // not hidden when used as inputs to new items
-       private boolean _visited = true;
+       // map from thread id to visited flag to allow concurrent checks 
through the lineage trace
+       private Map<Long, Boolean> _visited = new ConcurrentHashMap<>();
        
        public enum LineageItemType {Literal, Creation, Instruction, Dedup}
        public static final String dedupItemOpcode = "dedup";
@@ -133,7 +133,9 @@ public class LineageItem {
        }
 
        public boolean isVisited() {
-               return _visited;
+               // default value (e.g., not set value) is true to ensure 
visited items are
+               // not hidden when used as inputs to new items
+               return _visited.getOrDefault(Thread.currentThread().getId(), 
true);
        }
        
        public void setVisited() {
@@ -141,7 +143,7 @@ public class LineageItem {
        }
        
        public void setVisited(boolean flag) {
-               _visited = flag;
+               _visited.put(Thread.currentThread().getId(), flag);
        }
        
        public void setSpecialValueBit(int pos, boolean flag) {
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedLineageTraceReuseTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedLineageTraceReuseTest.java
new file mode 100644
index 0000000..717f53b
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedLineageTraceReuseTest.java
@@ -0,0 +1,252 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.multitenant;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashMap;
+
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.commons.lang3.StringUtils;
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Ignore;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class FederatedLineageTraceReuseTest extends MultiTenantTestBase {
+       private final static String TEST_NAME = 
"FederatedLineageTraceReuseTest";
+
+       private final static String TEST_DIR = 
"functions/federated/multitenant/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
FederatedLineageTraceReuseTest.class.getSimpleName() + "/";
+
+       private final static double TOLERANCE = 0;
+
+       private final static int blocksize = 1024;
+       @Parameterized.Parameter()
+       public int rows;
+       @Parameterized.Parameter(1)
+       public int cols;
+       @Parameterized.Parameter(2)
+       public double sparsity;
+       @Parameterized.Parameter(3)
+       public boolean rowPartitioned;
+
+       @Parameterized.Parameters
+       public static Collection<Object[]> data() {
+               return Arrays.asList(
+                       new Object[][] {
+                               // {100, 200, 0.9, false},
+                               {200, 100, 0.9, true},
+                               // {100, 1000, 0.01, false},
+                               // {1000, 100, 0.01, true},
+               });
+       }
+
+       private enum OpType {
+               EW_PLUS,
+               MM,
+               PARFOR_ADD,
+       }
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"S"}));
+       }
+
+       @Test
+       @Ignore
+       public void testElementWisePlusCP() {
+               runLineageTraceReuseTest(OpType.EW_PLUS, 4, 
ExecMode.SINGLE_NODE);
+       }
+
+       @Test
+       public void testElementWisePlusSP() {
+               runLineageTraceReuseTest(OpType.EW_PLUS, 4, ExecMode.SPARK);
+       }
+
+       @Test
+       public void testMatrixMultCP() {
+               runLineageTraceReuseTest(OpType.MM, 4, ExecMode.SINGLE_NODE);
+       }
+
+       @Test
+       @Ignore // TODO: allow for reuse of respective spark instructions
+       public void testMatrixMultSP() {
+               runLineageTraceReuseTest(OpType.MM, 4, ExecMode.SPARK);
+       }
+
+       @Test
+       @Ignore
+       public void testParforAddCP() {
+               runLineageTraceReuseTest(OpType.PARFOR_ADD, 3, 
ExecMode.SINGLE_NODE);
+       }
+
+       @Test
+       public void testParforAddSP() {
+               runLineageTraceReuseTest(OpType.PARFOR_ADD, 3, ExecMode.SPARK);
+       }
+
+       private void runLineageTraceReuseTest(OpType opType, int 
numCoordinators, ExecMode execMode) {
+               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+               ExecMode platformOld = rtplatform;
+
+               if(rtplatform == ExecMode.SPARK)
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+               getAndLoadTestConfiguration(TEST_NAME);
+               String HOME = SCRIPT_DIR + TEST_DIR;
+
+               // write input matrices
+               int r = rows;
+               int c = cols / 4;
+               if(rowPartitioned) {
+                       r = rows / 4;
+                       c = cols;
+               }
+
+               double[][] X1 = getRandomMatrix(r, c, 0, 3, sparsity, 3);
+               double[][] X2 = getRandomMatrix(r, c, 0, 3, sparsity, 7);
+               double[][] X3 = getRandomMatrix(r, c, 0, 3, sparsity, 8);
+               double[][] X4 = getRandomMatrix(r, c, 0, 3, sparsity, 9);
+
+               MatrixCharacteristics mc = new MatrixCharacteristics(r, c, 
blocksize, r * c);
+               writeInputMatrixWithMTD("X1", X1, false, mc);
+               writeInputMatrixWithMTD("X2", X2, false, mc);
+               writeInputMatrixWithMTD("X3", X3, false, mc);
+               writeInputMatrixWithMTD("X4", X4, false, mc);
+
+               // empty script name because we don't execute any script, just 
start the worker
+               fullDMLScriptName = "";
+
+               int[] workerPorts = startFedWorkers(4, new String[]{"-lineage", 
"reuse"});
+
+               rtplatform = execMode;
+               if(rtplatform == ExecMode.SPARK) {
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+               }
+               TestConfiguration config = 
availableTestConfigurations.get(TEST_NAME);
+               loadTestConfiguration(config);
+
+               // start the coordinator processes
+               String scriptName = HOME + TEST_NAME + ".dml";
+               programArgs = new String[] {"-config", CONFIG_DIR + 
"SystemDS-MultiTenant-config.xml",
+                       "-lineage", "reuse", "-stats", "100", "-fedStats", 
"100", "-nvargs",
+                       "in_X1=" + TestUtils.federatedAddress(workerPorts[0], 
input("X1")),
+                       "in_X2=" + TestUtils.federatedAddress(workerPorts[1], 
input("X2")),
+                       "in_X3=" + TestUtils.federatedAddress(workerPorts[2], 
input("X3")),
+                       "in_X4=" + TestUtils.federatedAddress(workerPorts[3], 
input("X4")),
+                       "rows=" + rows, "cols=" + cols, "testnum=" + 
Integer.toString(opType.ordinal()),
+                       "rP=" + Boolean.toString(rowPartitioned).toUpperCase()};
+               for(int counter = 0; counter < numCoordinators; counter++)
+                       startCoordinator(execMode, scriptName,
+                               ArrayUtils.addAll(programArgs, "out_S=" + 
output("S" + counter)));
+
+               // wait for the coordinator processes to end and verify the 
results
+               String coordinatorOutput = waitForCoordinators();
+               verifyResults(opType, coordinatorOutput, execMode);
+
+               // check that federated input files are still existing
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
+
+               TestUtils.shutdownThreads(workerProcesses.toArray(new 
Process[0]));
+
+               rtplatform = platformOld;
+               DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+       }
+
+       private void verifyResults(OpType opType, String outputLog, ExecMode 
execMode) {
+               Assert.assertTrue(checkForHeavyHitter(opType, outputLog, 
execMode));
+               // verify that the matrix object has been taken from cache
+               Assert.assertTrue(checkForReuses(opType, outputLog, execMode));
+
+               // compare the results via files
+               HashMap<CellIndex, Double> refResults   = 
readDMLMatrixFromOutputDir("S" + 0);
+               Assert.assertFalse("The result of the first coordinator, which 
is taken as reference, is empty.",
+                       refResults.isEmpty());
+               for(int counter = 1; counter < coordinatorProcesses.size(); 
counter++) {
+                       HashMap<CellIndex, Double> fedResults = 
readDMLMatrixFromOutputDir("S" + counter);
+                       TestUtils.compareMatrices(fedResults, refResults, 
TOLERANCE, "Fed" + counter, "FedRef");
+               }
+       }
+
+       private boolean checkForHeavyHitter(OpType opType, String outputLog, 
ExecMode execMode) {
+               boolean retVal = false;
+               switch(opType) {
+                       case EW_PLUS:
+                               retVal = checkForHeavyHitter(outputLog, 
"fed_+");
+                               if(execMode == ExecMode.SINGLE_NODE)
+                                       retVal &= 
checkForHeavyHitter(outputLog, "fed_uak+");
+                               break;
+                       case MM:
+                               retVal = checkForHeavyHitter(outputLog, 
(execMode == ExecMode.SPARK) ? "fed_mapmm" : "fed_ba+*");
+                               if(rowPartitioned)
+                                       retVal &= 
checkForHeavyHitter(outputLog, (execMode == ExecMode.SPARK) ? "fed_rblk" : 
"fed_uak+");
+                               break;
+                       case PARFOR_ADD:
+                               retVal = checkForHeavyHitter(outputLog, 
"fed_-");
+                               retVal &= checkForHeavyHitter(outputLog, 
"fed_+");
+                               retVal &= checkForHeavyHitter(outputLog, 
(execMode == ExecMode.SPARK) ? "fed_rblk" : "fed_uak+");
+                               break;
+               }
+               return retVal;
+       }
+
+       private boolean checkForHeavyHitter(String outputLog, String hhString) {
+               int occurrences = StringUtils.countMatches(outputLog, hhString);
+               return (occurrences == coordinatorProcesses.size());
+       }
+
+       private boolean checkForReuses(OpType opType, String outputLog, 
ExecMode execMode) {
+               final String LINCACHE_MULTILVL = "LinCache MultiLvl 
(Ins/SB/Fn):\t";
+               final String LINCACHE_WRITES = "LinCache writes 
(Mem/FS/Del):\t";
+               boolean retVal = false;
+               int numInst = -1;
+               switch(opType) {
+                       case EW_PLUS:
+                               numInst = (execMode == ExecMode.SPARK) ? 1 : 2;
+                               break;
+                       case MM:
+                               numInst = rowPartitioned ? 2 : 1;
+                               break;
+                       case PARFOR_ADD: // number of instructions times number 
of iterations of the parfor loop
+                               numInst = ((execMode == ExecMode.SPARK) ? 2 : 
3) * 3;
+                               break;
+               }
+               retVal = outputLog.contains(LINCACHE_MULTILVL
+                       + Integer.toString(numInst * 
(coordinatorProcesses.size()-1) * workerProcesses.size()) + "/");
+               retVal &= outputLog.contains(LINCACHE_WRITES
+                       + Integer.toString((1 + numInst) * 
workerProcesses.size()) + "/"); // read + instructions
+               return retVal;
+       }
+}
diff --git 
a/src/test/scripts/functions/federated/multitenant/FederatedLineageTraceReuseTest.dml
 
b/src/test/scripts/functions/federated/multitenant/FederatedLineageTraceReuseTest.dml
new file mode 100644
index 0000000..c23d674
--- /dev/null
+++ 
b/src/test/scripts/functions/federated/multitenant/FederatedLineageTraceReuseTest.dml
@@ -0,0 +1,69 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+rowPart = $rP;
+
+if (rowPart) {
+  X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+        ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
+                 list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 
0), list($rows, $cols)));
+} else {
+  X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+        ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), 
list($rows, $cols/2),
+               list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 
3*($cols/4)), list($rows, $cols)));
+}
+
+testnum = $testnum;
+
+Y = rand(rows=$rows, cols=$cols, seed=1234);
+
+for(counter in 1:50) {
+  while(FALSE) { }
+  Y = Y + 0.1;
+}
+
+while(FALSE) { }
+
+if(testnum == 0) { # EW_PLUS
+  Z = X + Y;
+  while(FALSE) { }
+  S = as.matrix(sum(Z));
+}
+else if(testnum == 1) { # MM
+  Z = X %*% t(Y);
+  while(FALSE) { }
+  S = as.matrix(sum(Z));
+}
+else if(testnum == 2) { # PARFOR_ADD
+  numiter = 3;
+  Z = matrix(0, rows=numiter, cols=1);
+  parfor(i in 1:numiter) {
+    X_tmp = X - i;
+    Y_vec = rowSums(Y + i);
+    while(FALSE) { }
+    Z_tmp = X_tmp + Y_vec;
+    while(FALSE) { }
+    Z[i, 1] = sum(Z_tmp);
+  }
+  S = as.matrix(sum(Z));
+}
+
+write(S, $out_S);

Reply via email to