This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 636a683  [SYSTEMDS-3267] Explain for transformencode task-graph
636a683 is described below

commit 636a683a07b0a377289f0c83922abcd44c37a7f8
Author: arnabp <[email protected]>
AuthorDate: Thu Jan 6 01:57:39 2022 +0100

    [SYSTEMDS-3267] Explain for transformencode task-graph
    
    This patch adds a method to print the task-graph of
    transformencode. Moreover, this commit integrates
    getMetadata tasks within the task-graph.
    
    Closes #1498
---
 .../runtime/transform/encode/ColumnEncoder.java    |  1 +
 .../transform/encode/ColumnEncoderDummycode.java   |  1 +
 .../transform/encode/ColumnEncoderPassThrough.java |  1 +
 .../transform/encode/MultiColumnEncoder.java       | 91 +++++++++++++++++-----
 .../sysds/runtime/util/DependencyThreadPool.java   | 67 ++++++++++++++--
 src/main/java/org/apache/sysds/utils/Explain.java  |  2 +-
 .../TransformFrameEncodeMultithreadedTest.java     |  9 +++
 7 files changed, 147 insertions(+), 25 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
index 9db3772..6c9ac6a 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
@@ -132,6 +132,7 @@ public abstract class ColumnEncoder implements Encoder, 
Comparable<ColumnEncoder
 
        protected void applySparse(CacheBlock in, MatrixBlock out, int 
outputCol, int rowStart, int blk){
                boolean mcsr = MatrixBlock.DEFAULT_SPARSEBLOCK == 
SparseBlock.Type.MCSR;
+               mcsr = false; //force CSR for transformencode
                int index = _colID - 1;
                // Apply loop tiling to exploit CPU caches
                double[] codes = getCodeCol(in, rowStart, blk);
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java
 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java
index 63cf86c..5c1cd11 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java
@@ -88,6 +88,7 @@ public class ColumnEncoderDummycode extends ColumnEncoder {
                                        " and not MatrixBlock");
                }
                boolean mcsr = MatrixBlock.DEFAULT_SPARSEBLOCK == 
SparseBlock.Type.MCSR;
+               mcsr = false; //force CSR for transformencode
                Set<Integer> sparseRowsWZeros = null;
                int index = _colID - 1;
                for(int r = rowStart; r < getEndIndex(in.getNumRows(), 
rowStart, blk); r++) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java
 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java
index f8b467d..36784ab 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java
@@ -80,6 +80,7 @@ public class ColumnEncoderPassThrough extends ColumnEncoder {
        protected void applySparse(CacheBlock in, MatrixBlock out, int 
outputCol, int rowStart, int blk){
                Set<Integer> sparseRowsWZeros = null;
                boolean mcsr = MatrixBlock.DEFAULT_SPARSEBLOCK == 
SparseBlock.Type.MCSR;
+               mcsr = false; //force CSR for transformencode
                int index = _colID - 1;
                // Apply loop tiling to exploit CPU caches
                double[] codes = getCodeCol(in, rowStart, blk);
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
 
b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
index 76301ce..52bad53 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
@@ -43,6 +43,7 @@ import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.runtime.DMLRuntimeException;
@@ -134,46 +135,67 @@ public class MultiColumnEncoder implements Encoder {
                return out;
        }
 
+       /* TASK DETAILS:
+        * InitOutputMatrixTask:        Allocate output matrix
+        * AllocMetaTask:               Allocate metadata frame
+        * BuildTask:                   Build an encoder
+        * ColumnCompositeUpdateDCTask: Update domain size of a DC encoder 
based on #distincts, #bins, K
+        * ColumnMetaDataTask:          Fill up metadata of an encoder
+        * ApplyTasksWrapperTask:       Wrapper task for an Apply
+        * UpdateOutputColTask:         Sets starting offsets of the DC columns
+        */
        private List<DependencyTask<?>> getEncodeTasks(CacheBlock in, 
MatrixBlock out, DependencyThreadPool pool) {
                List<DependencyTask<?>> tasks = new ArrayList<>();
                List<DependencyTask<?>> applyTAgg = null;
                Map<Integer[], Integer[]> depMap = new HashMap<>();
                boolean hasDC = 
getColumnEncoders(ColumnEncoderDummycode.class).size() > 0;
                boolean applyOffsetDep = false;
+               _meta = new FrameBlock(in.getNumColumns(), ValueType.STRING);
+               // Create the output and metadata allocation tasks
                tasks.add(DependencyThreadPool.createDependencyTask(new 
InitOutputMatrixTask(this, in, out)));
+               tasks.add(DependencyThreadPool.createDependencyTask(new 
AllocMetaTask(this, _meta)));
+
                for(ColumnEncoderComposite e : _columnEncoders) {
+                       // Create the build tasks
                        List<DependencyTask<?>> buildTasks = 
e.getBuildTasks(in);
-
                        tasks.addAll(buildTasks);
                        if(buildTasks.size() > 0) {
-                               // Apply Task dependency to build completion 
task
-                               depMap.put(new Integer[] {tasks.size(), 
tasks.size() + 1},
-                                       new Integer[] {tasks.size() - 1, 
tasks.size()});
+                               // Apply Task depends on build completion task
+                               depMap.put(new Integer[] {tasks.size(), 
tasks.size() + 1},      //ApplyTask
+                                       new Integer[] {tasks.size() - 1, 
tasks.size()});            //BuildTask
+                               // getMetaDataTask depends on build completion
+                               depMap.put(new Integer[] {tasks.size() + 1, 
tasks.size() + 2}, //MetaDataTask
+                                       new Integer[] {tasks.size() - 1, 
tasks.size()});           //BuildTask
+                               // getMetaDataTask depends on AllocMeta task
+                               depMap.put(new Integer[] {tasks.size() + 1, 
tasks.size() + 2}, //MetaDataTask
+                                       new Integer[] {1, 2});                  
                   //AllocMetaTask (2nd task)
+                               // AllocMetaTask depends on the build 
completion tasks
+                               depMap.put(new Integer[] {1, 2},                
               //AllocMetaTask (2nd task)
+                                       new Integer[] {tasks.size() - 1, 
tasks.size()});           //BuildTask
                        }
 
-                       // Apply Task dependency to InitOutputMatrixTask
-                       depMap.put(new Integer[] {tasks.size(), tasks.size() + 
1}, new Integer[] {0, 1});
+                       // Apply Task depends on InitOutputMatrixTask (output 
allocation)
+                       depMap.put(new Integer[] {tasks.size(), tasks.size() + 
1},         //ApplyTask
+                                       new Integer[] {0, 1});                  
                   //Allocation task (1st task)
                        ApplyTasksWrapperTask applyTaskWrapper = new 
ApplyTasksWrapperTask(e, in, out, pool);
 
                        if(e.hasEncoder(ColumnEncoderDummycode.class)) {
-                               // InitMatrix dependency to build of recode if 
a DC is present
-                               // Since they are the only ones that change the 
domain size which would influence the Matrix creation
-                               depMap.put(new Integer[] {0, 1}, // InitMatrix 
Task first in list
-                                       new Integer[] {tasks.size() - 1, 
tasks.size()});
-                               // output col update task dependent on Build 
completion only for Recode and binning since they can
-                               // change dummycode domain size
-                               // colUpdateTask can start when all domain 
sizes, because it can now calculate the offsets for
-                               // each column
-                               depMap.put(new Integer[] {-2, -1}, new 
Integer[] {tasks.size() - 1, tasks.size()});
+                               // Allocation depends on build if DC is in the 
list.
+                               // Note, DC is the only encoder that changes 
dimensionality
+                               depMap.put(new Integer[] {0, 1},                
               //Allocation task (1st task)
+                                       new Integer[] {tasks.size() - 1, 
tasks.size()});           //BuildTask
+                               // UpdateOutputColTask, that sets the starting 
offsets of the DC columns,
+                               // depends on the Build completion tasks
+                               depMap.put(new Integer[] {-2, -1},              
               //UpdateOutputColTask (last task) 
+                                               new Integer[] {tasks.size() - 
1, tasks.size()});       //BuildTask
                                buildTasks.forEach(t -> t.setPriority(5));
                                applyOffsetDep = true;
                        }
 
                        if(hasDC && applyOffsetDep) {
-                               // Apply Task dependency to output col update 
task (is last in list)
-                               // All ApplyTasks need to wait for this task, 
so they all have the correct offsets.
-                               // But only for the columns that come after the 
first DC coder since they don't have an offset
-                               depMap.put(new Integer[] {tasks.size(), 
tasks.size() + 1}, new Integer[] {-2, -1});
+                               // Apply tasks depend on UpdateOutputColTask
+                               depMap.put(new Integer[] {tasks.size(), 
tasks.size() + 1},     //ApplyTask 
+                                               new Integer[] {-2, -1});        
                       //UpdateOutputColTask (last task)
 
                                applyTAgg = applyTAgg == null ? new 
ArrayList<>() : applyTAgg;
                                applyTAgg.add(applyTaskWrapper);
@@ -181,9 +203,13 @@ public class MultiColumnEncoder implements Encoder {
                        else {
                                applyTaskWrapper.setOffset(0);
                        }
+                       // Create the ApplyTask (wrapper)
                        tasks.add(applyTaskWrapper);
+                       // Create the getMetadata task
+                       tasks.add(DependencyThreadPool.createDependencyTask(new 
ColumnMetaDataTask<ColumnEncoder>(e, _meta)));
                }
                if(hasDC)
+                       // Create the last task, UpdateOutputColTask
                        tasks.add(DependencyThreadPool.createDependencyTask(new 
UpdateOutputColTask(this, applyTAgg)));
 
                List<List<? extends Callable<?>>> deps = new 
ArrayList<>(Collections.nCopies(tasks.size(), null));
@@ -330,6 +356,7 @@ public class MultiColumnEncoder implements Encoder {
                                        && MatrixBlock.DEFAULT_SPARSEBLOCK != 
SparseBlock.Type.MCSR)
                                throw new RuntimeException("Transformapply is 
only supported for MCSR and CSR output matrix");
                        boolean mcsr = MatrixBlock.DEFAULT_SPARSEBLOCK == 
SparseBlock.Type.MCSR;
+                       mcsr = false; //force CSR for transformencode
                        if (mcsr) {
                                output.allocateBlock();
                                SparseBlock block = output.getSparseBlock();
@@ -933,6 +960,27 @@ public class MultiColumnEncoder implements Encoder {
                        return null;
                }
        }
+
+       private static class AllocMetaTask implements Callable<Object> {
+               private final MultiColumnEncoder _encoder;
+               private final FrameBlock _meta;
+               
+               private AllocMetaTask (MultiColumnEncoder encoder, FrameBlock 
meta) {
+                       _encoder = encoder;
+                       _meta = meta;
+               }
+
+               @Override
+               public Object call() throws Exception {
+                       _encoder.allocateMetaData(_meta);
+                       return null;
+               }
+
+               @Override
+               public String toString() {
+                       return getClass().getSimpleName();
+               }
+       }
        
        private static class ColumnMetaDataTask<T extends ColumnEncoder> 
implements Callable<Object> {
                private final T _colEncoder;
@@ -948,6 +996,11 @@ public class MultiColumnEncoder implements Encoder {
                        _colEncoder.getMetaData(_out);
                        return null;
                }
+
+               @Override
+               public String toString() {
+                       return getClass().getSimpleName() + "<ColId: " + 
_colEncoder._colID + ">";
+               }
        }
 
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/util/DependencyThreadPool.java 
b/src/main/java/org/apache/sysds/runtime/util/DependencyThreadPool.java
index 50675d6..90d1dfc 100644
--- a/src/main/java/org/apache/sysds/runtime/util/DependencyThreadPool.java
+++ b/src/main/java/org/apache/sysds/runtime/util/DependencyThreadPool.java
@@ -22,8 +22,10 @@ package org.apache.sysds.runtime.util;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.utils.Explain;
 
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
@@ -90,7 +92,10 @@ public class DependencyThreadPool {
        public List<Object> submitAllAndWait(List<DependencyTask<?>> dtasks)
                throws ExecutionException, InterruptedException {
                List<Object> res = new ArrayList<>();
-               // printDependencyGraph(dtasks);
+               if(DependencyTask.ENABLE_DEBUG_DATA) {
+                       if (dtasks != null && dtasks.size() > 0)
+                               explainTaskGraph(dtasks);
+               }
                List<Future<Future<?>>> futures = submitAll(dtasks);
                int i = 0;
                for(Future<Future<?>> ff : futures) {
@@ -112,10 +117,12 @@ public class DependencyThreadPool {
        }
 
        /*
-        * Creates the Dependency list from a map and the tasks. The map 
specifies which tasks should have a Dependency on
-        * which other task. e.g.
-        * ([0, 3], [4, 6])   means the first 3 tasks in the tasks list are 
dependent on tasks at index 4 and 5
-        * ([-2, -1], [0, 5]) means the last task has a Dependency on the first 
5 tasks.
+        * Creates the Dependency list from a map and the tasks. The map 
specifies which tasks 
+        * should have a Dependency on which other task. e.g.
+        * ([0, 3], [4, 6])   means the 1st 3 tasks in the list are dependent 
on tasks at index 4 and 5
+        * ([-2, -1], [0, 5]) means the last task depends on the first 5 tasks.
+        * ([dependent start index, dependent end index (excluding)], 
+        *  [parent start index, parent end index (excluding)])
         */
        public static List<List<? extends Callable<?>>> 
createDependencyList(List<? extends Callable<?>> tasks,
                Map<Integer[], Integer[]> depMap, List<List<? extends 
Callable<?>>> dep) {
@@ -175,4 +182,54 @@ public class DependencyThreadPool {
                }
                return ret;
        }
+
+       /*
+        * Prints the task-graph level-wise, however, the printed
+        * output doesn't specify which task of level l depends
+        * on which task of level (l-1).
+        */
+       public static void explainTaskGraph(List<DependencyTask<?>> tasks) {
+               Map<DependencyTask<?>, Integer> levelMap = new HashMap<>();
+               int depth = 1;
+               while (levelMap.size() < tasks.size()) {
+                       for (int i=0; i<tasks.size(); i++) {
+                               DependencyTask<?> dt = tasks.get(i);
+                               if (dt._dependencyTasks == null || 
dt._dependencyTasks.size() == 0)
+                                       levelMap.put(dt, 0);
+                               if (dt._dependencyTasks != null) {
+                                       List<DependencyTask<?>> parents = 
dt._dependencyTasks;
+                                       int[] parentLevels = new 
int[parents.size()];
+                                       boolean missing = false;
+                                       for (int p=0; p<parents.size(); p++) {
+                                               if 
(!levelMap.containsKey(parents.get(p)))
+                                                       missing = true;
+                                               else
+                                                       parentLevels[p] = 
levelMap.get(parents.get(p));
+                                       }
+                                       if (missing)
+                                               continue;
+                                       int maxParentLevel = 
Arrays.stream(parentLevels).max().getAsInt();
+                                       levelMap.put(dt, maxParentLevel+1);
+                                       if (maxParentLevel+1 == depth)
+                                               depth++;
+                               }
+                       }
+               }
+               StringBuilder sbs[] = new StringBuilder[depth];
+               String offsets[] = new String[depth];
+               for (Map.Entry<DependencyTask<?>, Integer> entry : 
levelMap.entrySet()) {
+                       int level = entry.getValue();
+                       if (sbs[level] == null) {
+                               sbs[level] = new StringBuilder();
+                               offsets[level] = Explain.createOffset(level);
+                       }
+                       sbs[level].append(offsets[level]);
+                       sbs[level].append(entry.getKey().toString()+"\n");
+               }
+               System.out.println("EXPlAIN (TASK-GRAPH):");
+               for (int i=0; i<sbs.length; i++) {
+                       System.out.println(sbs[i].toString());
+               }
+
+       }
 }
diff --git a/src/main/java/org/apache/sysds/utils/Explain.java 
b/src/main/java/org/apache/sysds/utils/Explain.java
index ae6a523..ba6fb71 100644
--- a/src/main/java/org/apache/sysds/utils/Explain.java
+++ b/src/main/java/org/apache/sysds/utils/Explain.java
@@ -830,7 +830,7 @@ public class Explain
                return OptimizerUtils.toMB(mem) + (units?"MB":"");
        }
 
-       private static String createOffset( int level )
+       public static String createOffset( int level )
        {
                StringBuilder sb = new StringBuilder();
                for( int i=0; i<level; i++ )
diff --git 
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeMultithreadedTest.java
 
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeMultithreadedTest.java
index fbf7111..560f934 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeMultithreadedTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeMultithreadedTest.java
@@ -24,6 +24,7 @@ import java.nio.file.Paths;
 
 import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.common.Types.FileFormat;
+import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.runtime.io.FileFormatPropertiesCSV;
 import org.apache.sysds.runtime.io.FrameReaderFactory;
 import org.apache.sysds.runtime.matrix.data.FrameBlock;
@@ -211,11 +212,19 @@ public class TransformFrameEncodeMultithreadedTest 
extends AutomatedTestBase {
                        MultiColumnEncoder.MULTI_THREADED_STAGES = staged;
 
                        MatrixBlock outputS = encoder.encode(input, 1);
+                       FrameBlock metaS = encoder.getMetaData(new 
FrameBlock(input.getNumColumns(), ValueType.STRING), 1);
                        MatrixBlock outputM = encoder.encode(input, 12);
+                       FrameBlock metaM = encoder.getMetaData(new 
FrameBlock(input.getNumColumns(), ValueType.STRING), 12);
 
+                       // Match encoded matrices
                        double[][] R1 = 
DataConverter.convertToDoubleMatrix(outputS);
                        double[][] R2 = 
DataConverter.convertToDoubleMatrix(outputM);
                        TestUtils.compareMatrices(R1, R2, R1.length, 
R1[0].length, 0);
+                       // Match the metadata frames
+                       String[][] M1 = 
DataConverter.convertToStringFrame(metaS);
+                       String[][] M2 = 
DataConverter.convertToStringFrame(metaM);
+                       TestUtils.compareFrames(M1, M2, M1.length, 
M1[0].length);
+
                        Assert.assertEquals(outputS.getNonZeros(), 
outputM.getNonZeros());
                        Assert.assertTrue(outputM.getNonZeros() > 0);
 

Reply via email to