This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 4942601fec [MINOR] Parallel NNz count MatrixBlock
4942601fec is described below
commit 4942601fec9d7418605a4e9dcc55e266ffb4fc3d
Author: Sebastian Baunsgaard <[email protected]>
AuthorDate: Wed Sep 6 10:17:06 2023 +0200
[MINOR] Parallel NNz count MatrixBlock
This commit adds another method to count the nnz values in a MatrixBlock
in parallel. This gives no improvement in performance for sparseBlocks
and therefore defaults to our default nnz count in that case.
For a dense: i see improvements of 7.9 ms to 3.4 ms on a 10k by 1k matrix.
Closes #1894
---
.../sysds/runtime/matrix/data/LibMatrixReorg.java | 15 ++++--
.../sysds/runtime/matrix/data/MatrixBlock.java | 44 +++++++++++++++-
.../org/apache/sysds/runtime/matrix/data/Pair.java | 58 ++++++++++------------
.../org/apache/sysds/performance/simple/NNZ.java | 55 ++++++++++++++++++++
4 files changed, 133 insertions(+), 39 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
index 1917b564bc..2952c2b4c8 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
@@ -203,13 +203,18 @@ public class LibMatrixReorg {
}
public static MatrixBlock transpose(MatrixBlock in, int k) {
+ return transpose(in, k, false);
+ }
+
+ public static MatrixBlock transpose(MatrixBlock in, int k, boolean
allowCSR) {
final int clen = in.getNumColumns();
final int rlen = in.getNumRows();
final long nnz = in.getNonZeros();
- final boolean sparseOut =
MatrixBlock.evalSparseFormatInMemory(clen, rlen, nnz, true);
- return transpose(in, new MatrixBlock(clen, rlen, sparseOut), k);
+ final boolean sparseOut =
MatrixBlock.evalSparseFormatInMemory(clen, rlen, nnz, allowCSR);
+ return transpose(in, new MatrixBlock(clen, rlen, sparseOut), k,
allowCSR);
}
+
public static MatrixBlock transpose( MatrixBlock in, MatrixBlock out,
int k ) {
return transpose(in, out, k, false);
}
@@ -238,8 +243,8 @@ public class LibMatrixReorg {
// Timing time = new Timing(true);
// CSR is only allowed in the transposed output if the number
of non zeros is counted in the columns
- allowCSR = allowCSR && out.nonZeros < Integer.MAX_VALUE &&
in.clen <= 4096;
-
+ allowCSR = allowCSR && (in.clen <= 4096 || out.nonZeros <
10000000);
+
if(out.sparse && allowCSR) {
int size = (int) out.nonZeros;
out.sparseBlock = new
SparseBlockCSR(in.getNumColumns(), size, size);
@@ -256,7 +261,7 @@ public class LibMatrixReorg {
int[] cnt = null;
// filter matrices with many columns since the
CountNnzTask would return
// null if the number of columns is larger than
threshold
- if(in.sparse && out.sparse && in.clen <= 4096) {
+ if(allowCSR) {
cnt = countNNZColumns(in, k, pool);
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 0956950a3a..cb67fc3a68 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -32,6 +32,7 @@ import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
+import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.stream.IntStream;
@@ -1412,6 +1413,48 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock<MatrixBlock>,
}
return nonZeros;
}
+
+ /**
+ * Recompute the number of nonZero values in parallel
+ *
+ * @param k the paralelization degree
+ * @return the number of non zeros
+ */
+ public long recomputeNonZeros(int k) {
+ if(isInSparseFormat()) {
+ return recomputeNonZeros();
+ }
+ else {
+ if((long) rlen * clen < 10000)
+ return recomputeNonZeros();
+ final ExecutorService pool = CommonThreadPool.get(k);
+ try {
+ List<Future<Long>> f = new ArrayList<>();
+ final int bz = 1000;
+ for(int i = 0; i < rlen; i += bz) {
+ for(int ii = 0; ii < clen; ii += bz) {
+ final int j = i;
+ final int jj = ii;
+ f.add(pool.submit(() -> //
+ recomputeNonZeros(j, Math.min(j
+ bz, rlen) - 1, jj, Math.min(jj + bz, clen) - 1)));
+ }
+ }
+ long nnz = 0;
+ for(Future<Long> e : f)
+ nnz += e.get();
+ nonZeros = nnz;
+ return nonZeros;
+
+ }
+ catch(Exception e) {
+ LOG.warn("Failed Parallel non zero count
fallback to singlethread");
+ return recomputeNonZeros();
+ }
+ finally {
+ pool.shutdown();
+ }
+ }
+ }
public long recomputeNonZeros(int rl, int ru) {
return recomputeNonZeros(rl, ru, 0, clen-1);
@@ -2986,7 +3029,6 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock<MatrixBlock>,
@Override
public MatrixBlock unaryOperations(UnaryOperator op, MatrixValue
result) {
MatrixBlock ret = checkType(result);
-
// estimate the sparsity structure of result matrix
// by default, we guess result.sparsity=input.sparsity, unless
not sparse safe
boolean sp = this.sparse && op.sparseSafe;
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/Pair.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/Pair.java
index c4762295cc..77172d1bed 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/Pair.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/Pair.java
@@ -17,54 +17,46 @@
* under the License.
*/
-
package org.apache.sysds.runtime.matrix.data;
-public class Pair<K, V>
-{
-
+public class Pair<K, V> {
+
private K key;
private V value;
-
- public Pair()
- {
- key=null;
- value=null;
+
+ public Pair() {
+ key = null;
+ value = null;
}
-
- public Pair(K k, V v)
- {
- set(k, v);
+
+ public Pair(K k, V v) {
+ key = k;
+ value = v;
}
-
- public void setKey(K k)
- {
- key=k;
+
+ public final void setKey(K k) {
+ key = k;
}
-
- public void setValue(V v)
- {
- value=v;
+
+ public final void setValue(V v) {
+ value = v;
}
-
- public void set(K k, V v)
- {
- key=k;
- value=v;
+
+ public final void set(K k, V v) {
+ key = k;
+ value = v;
}
-
- public K getKey()
- {
+
+ public final K getKey() {
return key;
}
-
- public V getValue()
- {
+
+ public final V getValue() {
return value;
}
@Override
- public String toString(){
+ public String toString() {
return key + ":" + value;
}
}
diff --git a/src/test/java/org/apache/sysds/performance/simple/NNZ.java
b/src/test/java/org/apache/sysds/performance/simple/NNZ.java
new file mode 100644
index 0000000000..8ed77aea97
--- /dev/null
+++ b/src/test/java/org/apache/sysds/performance/simple/NNZ.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.performance.simple;
+
+import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.test.TestUtils;
+
+public class NNZ {
+ public static void main(String[] args) {
+ MatrixBlock mb = TestUtils.generateTestMatrixBlock(10000, 1000, 0,
103, 0.7, 421);
+ Timing t = new Timing();
+ t.start();
+ for(int i = 0; i < 1000; i++) {
+ mb.recomputeNonZeros();
+ }
+ System.out.println("single: " + t.stop()/ 1000);
+ t.start();
+ for(int i = 0; i < 1000; i++) {
+
+ mb.recomputeNonZeros(16);
+ }
+
+ System.out.println("par: " + t.stop()/ 1000);
+ t.start();
+ for(int i = 0; i < 1000; i++) {
+ mb.recomputeNonZeros();
+ }
+ System.out.println("single: " + t.stop()/ 1000);
+ t.start();
+ for(int i = 0; i < 1000; i++) {
+
+ mb.recomputeNonZeros(16);
+ }
+
+ System.out.println("par: " + t.stop()/ 1000);
+ }
+}