This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 2439f75  [SYSTEMDS-325] Bug fix in transformencode post-processing
2439f75 is described below

commit 2439f75be2d03fb747e6909b8c542d41c15b4616
Author: arnabp <[email protected]>
AuthorDate: Sat Feb 5 18:16:10 2022 +0100

    [SYSTEMDS-325] Bug fix in transformencode post-processing
    
    This patch fixes a bug in the multithreaded compaction logic.
---
 .../runtime/transform/encode/ColumnEncoder.java    | 10 +++++++--
 .../transform/encode/ColumnEncoderComposite.java   |  5 +++--
 .../transform/encode/MultiColumnEncoder.java       | 24 ++++++++++++++--------
 3 files changed, 26 insertions(+), 13 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
index d5aee51..a9f0c70 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
@@ -30,6 +30,8 @@ import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
+import java.util.HashSet;
+import java.util.Set;
 import java.util.concurrent.Callable;
 
 import org.apache.commons.logging.Log;
@@ -352,8 +354,12 @@ public abstract class ColumnEncoder implements Encoder, 
Comparable<ColumnEncoder
                return new ColumnApplyTask<>(this, in, out, outputCol, 
startRow, blk);
        }
 
-       public List<Integer> getSparseRowsWZeros(){
-               return _sparseRowsWZeros;
+       public Set<Integer> getSparseRowsWZeros(){
+               if (_sparseRowsWZeros != null) {
+                       return new HashSet<Integer>(_sparseRowsWZeros);
+               }
+               else
+                       return null;
        }
 
        protected void addSparseRowsWZeros(ArrayList<Integer> sparseRowsWZeros){
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
index 4f72f9c..4eb57ba 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
@@ -27,6 +27,7 @@ import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 import java.util.Objects;
 import java.util.concurrent.Callable;
 import java.util.stream.Collectors;
@@ -357,12 +358,12 @@ public class ColumnEncoderComposite extends ColumnEncoder 
{
        }
 
        @Override
-       public List<Integer> getSparseRowsWZeros(){
+       public Set<Integer> getSparseRowsWZeros(){
                return 
_columnEncoders.stream().map(ColumnEncoder::getSparseRowsWZeros).flatMap(l -> {
                                        if(l == null)
                                                return null;
                                        return l.stream();
-                               }).collect(Collectors.toList());
+                               }).collect(Collectors.toSet());
        }
 
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
 
b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
index 2e71680..aa7f408 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
@@ -427,22 +427,28 @@ public class MultiColumnEncoder implements Encoder {
                long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
                int k = OptimizerUtils.getTransformNumThreads();
                ForkJoinPool myPool = new ForkJoinPool(k);
-               List<Integer> indexSet = _columnEncoders.stream().parallel()
-                               
.map(ColumnEncoderComposite::getSparseRowsWZeros).flatMap(l -> {
-                                       if(l == null)
-                                               return null;
-                                       return l.stream();
-                               }).collect(Collectors.toList());
-
                if (k == 1) {
-                       
if(!indexSet.stream().parallel().allMatch(Objects::isNull)) {
+                       Set<Integer> indexSet = _columnEncoders.stream()
+                                       
.map(ColumnEncoderComposite::getSparseRowsWZeros).flatMap(l -> {
+                                               if(l == null)
+                                                       return null;
+                                               return l.stream();
+                                       }).collect(Collectors.toSet());
+
+                       if(!indexSet.stream().allMatch(Objects::isNull)) {
                                for(Integer row : indexSet)
                                        
output.getSparseBlock().get(row).compact();
                        }
                }
                else {
                        try {
-                               
if(!indexSet.stream().allMatch(Objects::isNull)) {
+                               Set<Integer> indexSet = 
_columnEncoders.stream().parallel()
+                                       
.map(ColumnEncoderComposite::getSparseRowsWZeros).flatMap(l -> {
+                                               if(l == null)
+                                                       return null;
+                                               return l.stream();
+                                       }).collect(Collectors.toSet());
+                               
if(!indexSet.stream().parallel().allMatch(Objects::isNull)) {
                                        myPool.submit(() -> {
                                                
indexSet.stream().parallel().forEach(row -> {
                                                        
output.getSparseBlock().get(row).compact();

Reply via email to