This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 4942601fec [MINOR] Parallel NNz count MatrixBlock
4942601fec is described below

commit 4942601fec9d7418605a4e9dcc55e266ffb4fc3d
Author: Sebastian Baunsgaard <[email protected]>
AuthorDate: Wed Sep 6 10:17:06 2023 +0200

    [MINOR] Parallel NNz count MatrixBlock
    
    This commit adds another method to count the nnz values in a MatrixBlock
    in parallel. This gives no improvement in performance for sparseBlocks
    and therefore defaults to our default nnz count in that case.
    For a dense: i see improvements of 7.9 ms to 3.4 ms on a 10k by 1k matrix.
    
    Closes #1894
---
 .../sysds/runtime/matrix/data/LibMatrixReorg.java  | 15 ++++--
 .../sysds/runtime/matrix/data/MatrixBlock.java     | 44 +++++++++++++++-
 .../org/apache/sysds/runtime/matrix/data/Pair.java | 58 ++++++++++------------
 .../org/apache/sysds/performance/simple/NNZ.java   | 55 ++++++++++++++++++++
 4 files changed, 133 insertions(+), 39 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
index 1917b564bc..2952c2b4c8 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
@@ -203,13 +203,18 @@ public class LibMatrixReorg {
        }
 
        public static MatrixBlock transpose(MatrixBlock in, int k) {
+               return transpose(in, k, false);
+       }
+
+       public static MatrixBlock transpose(MatrixBlock in, int k, boolean 
allowCSR) {
                final int clen = in.getNumColumns();
                final int rlen = in.getNumRows();
                final long nnz = in.getNonZeros();
-               final boolean sparseOut = 
MatrixBlock.evalSparseFormatInMemory(clen, rlen, nnz, true);
-               return transpose(in, new MatrixBlock(clen, rlen, sparseOut), k);
+               final boolean sparseOut = 
MatrixBlock.evalSparseFormatInMemory(clen, rlen, nnz, allowCSR);
+               return transpose(in, new MatrixBlock(clen, rlen, sparseOut), k, 
allowCSR);
        }
 
+
        public static MatrixBlock transpose( MatrixBlock in, MatrixBlock out, 
int k ) {
                return transpose(in, out, k, false);
        }
@@ -238,8 +243,8 @@ public class LibMatrixReorg {
                // Timing time = new Timing(true);
 
                // CSR is only allowed in the transposed output if the number 
of non zeros is counted in the columns
-               allowCSR = allowCSR && out.nonZeros < Integer.MAX_VALUE && 
in.clen <= 4096;
-
+               allowCSR = allowCSR && (in.clen <= 4096 || out.nonZeros < 
10000000);
+               
                if(out.sparse && allowCSR) {
                        int size = (int) out.nonZeros;
                        out.sparseBlock = new 
SparseBlockCSR(in.getNumColumns(), size, size);
@@ -256,7 +261,7 @@ public class LibMatrixReorg {
                        int[] cnt = null;
                        // filter matrices with many columns since the 
CountNnzTask would return
                        // null if the number of columns is larger than 
threshold
-                       if(in.sparse && out.sparse && in.clen <= 4096) {
+                       if(allowCSR) {
                                
                                cnt = countNNZColumns(in, k, pool);
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 0956950a3a..cb67fc3a68 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -32,6 +32,7 @@ import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.Iterator;
+import java.util.List;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Future;
 import java.util.stream.IntStream;
@@ -1412,6 +1413,48 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock<MatrixBlock>,
                }
                return nonZeros;
        }
+
+       /**
+        * Recompute the number of nonZero values in parallel
+        * 
+        * @param k the paralelization degree
+        * @return the number of non zeros
+        */
+       public long recomputeNonZeros(int k) {
+               if(isInSparseFormat()) {
+                       return recomputeNonZeros();
+               }
+               else {
+                       if((long) rlen * clen < 10000)
+                               return recomputeNonZeros();
+                       final ExecutorService pool = CommonThreadPool.get(k);
+                       try {
+                               List<Future<Long>> f = new ArrayList<>();
+                               final int bz = 1000;
+                               for(int i = 0; i < rlen; i += bz) {
+                                       for(int ii = 0; ii < clen; ii += bz) {
+                                               final int j = i;
+                                               final int jj = ii;
+                                               f.add(pool.submit(() -> //
+                                               recomputeNonZeros(j, Math.min(j 
+ bz, rlen) - 1, jj, Math.min(jj + bz, clen) - 1)));
+                                       }
+                               }
+                               long nnz = 0;
+                               for(Future<Long> e : f)
+                                       nnz += e.get();
+                               nonZeros = nnz;
+                               return nonZeros;
+
+                       }
+                       catch(Exception e) {
+                               LOG.warn("Failed Parallel non zero count 
fallback to singlethread");
+                               return recomputeNonZeros();
+                       }
+                       finally {
+                               pool.shutdown();
+                       }
+               }
+       }
        
        public long recomputeNonZeros(int rl, int ru) {
                return recomputeNonZeros(rl, ru, 0, clen-1);
@@ -2986,7 +3029,6 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock<MatrixBlock>,
        @Override
        public MatrixBlock unaryOperations(UnaryOperator op, MatrixValue 
result) {
                MatrixBlock ret = checkType(result);
-               
                // estimate the sparsity structure of result matrix
                // by default, we guess result.sparsity=input.sparsity, unless 
not sparse safe
                boolean sp = this.sparse && op.sparseSafe;
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/Pair.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/Pair.java
index c4762295cc..77172d1bed 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/Pair.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/Pair.java
@@ -17,54 +17,46 @@
  * under the License.
  */
 
-
 package org.apache.sysds.runtime.matrix.data;
 
-public class Pair<K, V> 
-{
-       
+public class Pair<K, V> {
+
        private K key;
        private V value;
-       
-       public Pair()
-       {
-               key=null;
-               value=null;
+
+       public Pair() {
+               key = null;
+               value = null;
        }
-       
-       public Pair(K k, V v)
-       {
-               set(k, v);
+
+       public Pair(K k, V v) {
+               key = k;
+               value = v;
        }
-       
-       public void setKey(K k)
-       {
-               key=k;
+
+       public final void setKey(K k) {
+               key = k;
        }
-       
-       public void setValue(V v)
-       {
-               value=v;
+
+       public final void setValue(V v) {
+               value = v;
        }
-       
-       public void set(K k, V v)
-       {
-               key=k;
-               value=v;
+
+       public final void set(K k, V v) {
+               key = k;
+               value = v;
        }
-       
-       public K getKey()
-       {
+
+       public final K getKey() {
                return key;
        }
-       
-       public V getValue()
-       {
+
+       public final V getValue() {
                return value;
        }
 
        @Override
-       public String toString(){
+       public String toString() {
                return key + ":" + value;
        }
 }
diff --git a/src/test/java/org/apache/sysds/performance/simple/NNZ.java 
b/src/test/java/org/apache/sysds/performance/simple/NNZ.java
new file mode 100644
index 0000000000..8ed77aea97
--- /dev/null
+++ b/src/test/java/org/apache/sysds/performance/simple/NNZ.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.performance.simple;
+
+import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.test.TestUtils;
+
+public class NNZ {
+    public static void main(String[] args) {
+        MatrixBlock mb = TestUtils.generateTestMatrixBlock(10000, 1000, 0, 
103, 0.7, 421);
+        Timing t = new Timing();
+        t.start();
+        for(int i = 0; i < 1000; i++) {
+            mb.recomputeNonZeros();
+        }
+        System.out.println("single:   " + t.stop()/ 1000);
+        t.start();
+        for(int i = 0; i < 1000; i++) {
+
+            mb.recomputeNonZeros(16);
+        }
+
+        System.out.println("par:      " + t.stop()/ 1000);
+        t.start();
+        for(int i = 0; i < 1000; i++) {
+            mb.recomputeNonZeros();
+        }
+        System.out.println("single:   " + t.stop()/ 1000);
+        t.start();
+        for(int i = 0; i < 1000; i++) {
+
+            mb.recomputeNonZeros(16);
+        }
+
+        System.out.println("par:      " + t.stop()/ 1000);
+    }
+}

Reply via email to