This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 3537383147 [MINOR] Add LibReplace for MatrixBlock
3537383147 is described below

commit 3537383147d2bb9e716ea8ec5b80738ecf4e8b14
Author: Sebastian Baunsgaard <baunsga...@apache.org>
AuthorDate: Thu Jul 4 23:18:37 2024 +0200

    [MINOR] Add LibReplace for MatrixBlock
    
    Adds a Lib class for replace.
    
    The main benefit in this commit is when we perform
    replacement on a sparse matrix, and the output is dense.
    
    after:
    
    ```
    java -jar target/systemds-3.3.0-SNAPSHOT-perf.jar 17 10000 10000 0.1 16
    Profiling started
                            replaceZero,  132.883+-  9.335 ms,
                             replaceOne,   59.612+-  3.295 ms,
                             replaceNaN,   10.651+-  0.470 ms,
    ```
    
    before:
    
    ```
    java -jar target/systemds-3.3.0-SNAPSHOT-perf.jar 17 10000 10000 0.1 16
    Profiling started
                            replaceZero,  228.727+- 11.965 ms,
                             replaceOne,  163.212+-  4.993 ms,
                             replaceNaN,   10.602+-  0.437 ms,
    ```
    
    Closes #2043
---
 .../org/apache/sysds/runtime/data/DenseBlock.java  |   7 +
 .../apache/sysds/runtime/data/DenseBlockBool.java  |   7 +
 .../apache/sysds/runtime/data/DenseBlockFP32.java  |   7 +
 .../apache/sysds/runtime/data/DenseBlockFP64.java  |   8 +
 .../sysds/runtime/data/DenseBlockFP64DEDUP.java    |   6 +
 .../apache/sysds/runtime/data/DenseBlockInt32.java |   7 +
 .../apache/sysds/runtime/data/DenseBlockInt64.java |   8 +
 .../apache/sysds/runtime/data/DenseBlockLBool.java |   7 +
 .../apache/sysds/runtime/data/DenseBlockLFP32.java |   7 +
 .../apache/sysds/runtime/data/DenseBlockLFP64.java |   7 +
 .../sysds/runtime/data/DenseBlockLFP64DEDUP.java   |   5 +
 .../sysds/runtime/data/DenseBlockLInt32.java       |   7 +
 .../sysds/runtime/data/DenseBlockLInt64.java       |   7 +
 .../sysds/runtime/data/DenseBlockLString.java      |   7 +
 .../sysds/runtime/data/DenseBlockString.java       |   8 +
 .../runtime/matrix/data/LibMatrixReplace.java      | 211 +++++++++++++++++++++
 .../sysds/runtime/matrix/data/MatrixBlock.java     |  88 +--------
 .../java/org/apache/sysds/performance/Main.java    |  14 ++
 .../performance/matrix/MatrixReplacePerf.java      |  69 +++++++
 19 files changed, 400 insertions(+), 87 deletions(-)

diff --git a/src/main/java/org/apache/sysds/runtime/data/DenseBlock.java 
b/src/main/java/org/apache/sysds/runtime/data/DenseBlock.java
index 0baf881936..8796717422 100644
--- a/src/main/java/org/apache/sysds/runtime/data/DenseBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlock.java
@@ -418,6 +418,13 @@ public abstract class DenseBlock implements Serializable, 
Block
         */
        public abstract void fillBlock(int bix, int fromIndex, int toIndex, 
double v);
 
+       /**
+        * Fill the DenseBlock row index with the value specified.
+        * @param r The row to fill
+        * @param v The value to fill it with.
+        */
+       public abstract void fillRow(int r, double v);
+
        /**
         * Set a value at a position given by block index and index in that 
block.
         * @param bix   block index
diff --git a/src/main/java/org/apache/sysds/runtime/data/DenseBlockBool.java 
b/src/main/java/org/apache/sysds/runtime/data/DenseBlockBool.java
index 3d94fcf8af..3244fb75b1 100644
--- a/src/main/java/org/apache/sysds/runtime/data/DenseBlockBool.java
+++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockBool.java
@@ -146,6 +146,13 @@ public class DenseBlockBool extends DenseBlockDRB
                _data.set(fromIndex, toIndex, v != 0);
        }
 
+       @Override 
+       public void fillRow(int r, double v){
+               int start = pos(r);
+               int end = start + getDim(1);
+               _data.set(start, end, v != 0);
+       }
+
        @Override
        protected void setInternal(int bix, int ix, double v) {
                _data.set(ix, v != 0);
diff --git a/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP32.java 
b/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP32.java
index 519c17d83d..8d2b9e61d9 100644
--- a/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP32.java
+++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP32.java
@@ -132,6 +132,13 @@ public class DenseBlockFP32 extends DenseBlockDRB
                Arrays.fill(_data, fromIndex, toIndex, (float)v);
        }
 
+       @Override 
+       public void fillRow(int r, double v){
+               int start = pos(r);
+               int end = start + getDim(1);
+               Arrays.fill(_data, start, end, (float)v);
+       }
+
        @Override
        protected void setInternal(int bix, int ix, double v) {
                _data[ix] = (float)v;
diff --git a/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64.java 
b/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64.java
index eb93777fa4..9490944419 100644
--- a/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64.java
+++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64.java
@@ -139,6 +139,14 @@ public class DenseBlockFP64 extends DenseBlockDRB
                Arrays.fill(_data, fromIndex, toIndex, v);
        }
 
+       @Override 
+       public void fillRow(int r, double v){
+               int start = pos(r);
+               int end = start + getDim(1);
+               Arrays.fill(_data, start, end, v);
+       }
+
+
        @Override
        protected void setInternal(int bix, int ix, double v) {
                _data[ix] = v;
diff --git 
a/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64DEDUP.java 
b/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64DEDUP.java
index 2f3008c727..49d591a01e 100644
--- a/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64DEDUP.java
+++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64DEDUP.java
@@ -317,6 +317,12 @@ public class DenseBlockFP64DEDUP extends DenseBlockDRB
                }
        }
 
+       @Override 
+       public void fillRow(int r, double v){
+               throw new NotImplementedException();
+       }
+
+
        @Override
        protected void setInternal(int bix, int ix, double v) {
                set(bix, ix, v);
diff --git a/src/main/java/org/apache/sysds/runtime/data/DenseBlockInt32.java 
b/src/main/java/org/apache/sysds/runtime/data/DenseBlockInt32.java
index 6f3c2a6622..c10072ec17 100644
--- a/src/main/java/org/apache/sysds/runtime/data/DenseBlockInt32.java
+++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockInt32.java
@@ -132,6 +132,13 @@ public class DenseBlockInt32 extends DenseBlockDRB
                Arrays.fill(_data, fromIndex, toIndex, UtilFunctions.toInt(v));
        }
 
+       @Override 
+       public void fillRow(int r, double v){
+               int start = pos(r);
+               int end = start + getDim(1);
+               Arrays.fill(_data, start, end, UtilFunctions.toInt(v));
+       }
+
        @Override
        protected void setInternal(int bix, int ix, double v) {
                _data[ix] = UtilFunctions.toInt(v);
diff --git a/src/main/java/org/apache/sysds/runtime/data/DenseBlockInt64.java 
b/src/main/java/org/apache/sysds/runtime/data/DenseBlockInt64.java
index ffe790e81a..23930926a9 100644
--- a/src/main/java/org/apache/sysds/runtime/data/DenseBlockInt64.java
+++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockInt64.java
@@ -133,6 +133,14 @@ public class DenseBlockInt64 extends DenseBlockDRB
                Arrays.fill(_data, fromIndex, toIndex, UtilFunctions.toLong(v));
        }
 
+       @Override 
+       public void fillRow(int r, double v){
+               int start = pos(r);
+               int end = start + getDim(1);
+               Arrays.fill(_data, start, end, UtilFunctions.toLong(v));
+       }
+
+
        @Override
        protected void setInternal(int bix, int ix, double v) {
                _data[ix] = UtilFunctions.toLong(v);
diff --git a/src/main/java/org/apache/sysds/runtime/data/DenseBlockLBool.java 
b/src/main/java/org/apache/sysds/runtime/data/DenseBlockLBool.java
index ab3bd98b9c..73a93d9a87 100644
--- a/src/main/java/org/apache/sysds/runtime/data/DenseBlockLBool.java
+++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockLBool.java
@@ -150,6 +150,13 @@ public class DenseBlockLBool extends DenseBlockLDRB
                _blocks[bix].set(fromIndex, toIndex, v != 0);
        }
 
+       @Override 
+       public void fillRow(int r, double v){
+               int start = pos(r);
+               int end = start + getDim(1);
+               _blocks[index(r)].set(start, end, v != 0);
+       }
+
        @Override
        public DenseBlock set(String s) {
                boolean b = Boolean.parseBoolean(s);
diff --git a/src/main/java/org/apache/sysds/runtime/data/DenseBlockLFP32.java 
b/src/main/java/org/apache/sysds/runtime/data/DenseBlockLFP32.java
index 4122db2fae..b9e9d602b5 100644
--- a/src/main/java/org/apache/sysds/runtime/data/DenseBlockLFP32.java
+++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockLFP32.java
@@ -113,6 +113,13 @@ public class DenseBlockLFP32 extends DenseBlockLDRB
                Arrays.fill(_blocks[bix], fromIndex, toIndex, (float)v);
        }
 
+       @Override 
+       public void fillRow(int r, double v){
+               int start = pos(r);
+               int end = start + getDim(1);
+               Arrays.fill(_blocks[index(r)],start, end, (float)v);
+       }
+
        @Override
        public DenseBlock set(int r, int c, double v) {
                _blocks[index(r)][pos(r, c)] = (float)v;
diff --git a/src/main/java/org/apache/sysds/runtime/data/DenseBlockLFP64.java 
b/src/main/java/org/apache/sysds/runtime/data/DenseBlockLFP64.java
index 1d0bc3ccfb..3a8091938b 100644
--- a/src/main/java/org/apache/sysds/runtime/data/DenseBlockLFP64.java
+++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockLFP64.java
@@ -105,6 +105,13 @@ public class DenseBlockLFP64 extends DenseBlockLDRB
                Arrays.fill(_blocks[bix], fromIndex,toIndex, v);
        }
 
+       @Override 
+       public void fillRow(int r, double v){
+               int start = pos(r);
+               int end = start + getDim(1);
+               Arrays.fill(_blocks[index(r)],start, end, v);
+       }
+
        @Override
        public DenseBlock set(int r, int c, double v) {
                _blocks[index(r)][pos(r, c)] = v;
diff --git 
a/src/main/java/org/apache/sysds/runtime/data/DenseBlockLFP64DEDUP.java 
b/src/main/java/org/apache/sysds/runtime/data/DenseBlockLFP64DEDUP.java
index 79c02b7ac6..782c87f786 100644
--- a/src/main/java/org/apache/sysds/runtime/data/DenseBlockLFP64DEDUP.java
+++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockLFP64DEDUP.java
@@ -178,6 +178,11 @@ public class DenseBlockLFP64DEDUP extends DenseBlockLDRB{
                throw new NotImplementedException();
        }
 
+       @Override 
+       public void fillRow(int r, double v){
+               throw new NotImplementedException();
+       }
+
        @Override
        protected void setInternal(int bix, int ix, double v) {
                throw new NotImplementedException();
diff --git a/src/main/java/org/apache/sysds/runtime/data/DenseBlockLInt32.java 
b/src/main/java/org/apache/sysds/runtime/data/DenseBlockLInt32.java
index 0880440e46..6c63f31d20 100644
--- a/src/main/java/org/apache/sysds/runtime/data/DenseBlockLInt32.java
+++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockLInt32.java
@@ -109,6 +109,13 @@ public class DenseBlockLInt32 extends DenseBlockLDRB
                Arrays.fill(_blocks[bix], fromIndex, toIndex, 
UtilFunctions.toInt(v));
        }
 
+       @Override 
+       public void fillRow(int r, double v){
+               int start = pos(r);
+               int end = start + getDim(1);
+               Arrays.fill(_blocks[index(r)], start, end, 
UtilFunctions.toInt(v));
+       }
+
        @Override
        public DenseBlock set(int r, int c, double v) {
                _blocks[index(r)][pos(r, c)] = UtilFunctions.toInt(v);
diff --git a/src/main/java/org/apache/sysds/runtime/data/DenseBlockLInt64.java 
b/src/main/java/org/apache/sysds/runtime/data/DenseBlockLInt64.java
index d9ffc26176..79b79e1760 100644
--- a/src/main/java/org/apache/sysds/runtime/data/DenseBlockLInt64.java
+++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockLInt64.java
@@ -109,6 +109,13 @@ public class DenseBlockLInt64 extends DenseBlockLDRB
                Arrays.fill(_blocks[bix], fromIndex, toIndex, 
UtilFunctions.toLong(v));
        }
 
+       @Override 
+       public void fillRow(int r, double v){
+               int start = pos(r);
+               int end = start + getDim(1);
+               Arrays.fill(_blocks[index(r)], start, end, 
UtilFunctions.toLong(v));
+       }
+
        @Override
        public DenseBlock set(int r, int c, double v) {
                _blocks[index(r)][pos(r, c)] = UtilFunctions.toLong(v);
diff --git a/src/main/java/org/apache/sysds/runtime/data/DenseBlockLString.java 
b/src/main/java/org/apache/sysds/runtime/data/DenseBlockLString.java
index 0ab267abec..013ae366bb 100644
--- a/src/main/java/org/apache/sysds/runtime/data/DenseBlockLString.java
+++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockLString.java
@@ -109,6 +109,13 @@ public class DenseBlockLString extends DenseBlockLDRB
                Arrays.fill(_blocks[bix], fromIndex, toIndex, 
String.valueOf(v));
        }
 
+       @Override 
+       public void fillRow(int r, double v){
+               int start = pos(r);
+               int end = start + getDim(1);
+               Arrays.fill(_blocks[index(r)], start, end, String.valueOf(v));
+       }
+
        @Override
        public DenseBlock set(String s) {
                for (int i = 0; i < numBlocks() - 1; i++) {
diff --git a/src/main/java/org/apache/sysds/runtime/data/DenseBlockString.java 
b/src/main/java/org/apache/sysds/runtime/data/DenseBlockString.java
index 657a0f6559..396c3496be 100644
--- a/src/main/java/org/apache/sysds/runtime/data/DenseBlockString.java
+++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockString.java
@@ -118,6 +118,14 @@ public class DenseBlockString extends DenseBlockDRB {
                Arrays.fill(_data, fromIndex, toIndex, String.valueOf(v));
        }
 
+       @Override 
+       public void fillRow(int r, double v){
+               int start = pos(r);
+               int end = start + getDim(1);
+               Arrays.fill(_data, start, end, String.valueOf(v));
+       }
+
+
        @Override
        protected void setInternal(int bix, int ix, double v) {
                _data[ix] = String.valueOf(v);
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReplace.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReplace.java
new file mode 100644
index 0000000000..b40cea9c4b
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReplace.java
@@ -0,0 +1,211 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.matrix.data;
+
+import 
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
+import org.apache.sysds.runtime.data.DenseBlock;
+import org.apache.sysds.runtime.data.SparseBlock;
+
+public class LibMatrixReplace {
+
+       private LibMatrixReplace() {
+
+       }
+
+       public static MatrixBlock replaceOperations(MatrixBlock in, MatrixBlock 
ret, double pattern, double replacement) {
+               return replaceOperations(in, ret, pattern, replacement, 
InfrastructureAnalyzer.getLocalParallelism());
+       }
+
+       public static MatrixBlock replaceOperations(MatrixBlock in, MatrixBlock 
ret, double pattern, double replacement,
+               int k) {
+
+               // ensure input its in the right format
+               in.examSparsity(k);
+
+               final int rlen = in.getNumRows();
+               final int clen = in.getNumColumns();
+               final long nonZeros = in.getNonZeros();
+               final boolean sparse = in.isInSparseFormat();
+
+               if(ret != null)
+                       ret.reset(rlen, clen, sparse);
+               else
+                       ret = new MatrixBlock(rlen, clen, sparse);
+
+               // probe early abort conditions
+               if(nonZeros == 0 && pattern != 0)
+                       return ret;
+               if(!in.containsValue(pattern))
+                       return in; // avoid allocation + copy
+               if(in.isEmpty() && pattern == 0) {
+                       ret.reset(rlen, clen, replacement);
+                       return ret;
+               }
+
+               final boolean replaceNaN = Double.isNaN(pattern);
+
+               final long nnz;
+               if(sparse) // SPARSE
+                       nnz = replaceSparse(in, ret, pattern, replacement, 
replaceNaN);
+               else if(replaceNaN)
+                       nnz = replaceDenseNaN(in, ret, replacement);
+               else
+                       nnz = replaceDense(in, ret, pattern, replacement);
+
+               ret.setNonZeros(nnz);
+               ret.examSparsity(k);
+               return ret;
+       }
+
+       private static long replaceSparse(MatrixBlock in, MatrixBlock ret, 
double pattern, double replacement,
+               boolean replaceNaN) {
+               if(replaceNaN)
+                       return replaceSparseInSparseOutReplaceNaN(in, ret, 
replacement);
+               else if(pattern != 0d) // sparse safe.
+                       return replaceSparseInSparseOut(in, ret, pattern, 
replacement);
+               else // sparse unsafe
+                       return replace0InSparse(in, ret, replacement);
+
+       }
+
+       private static long replaceSparseInSparseOutReplaceNaN(MatrixBlock in, 
MatrixBlock ret, double replacement) {
+               ret.allocateSparseRowsBlock();
+               SparseBlock a = in.sparseBlock;
+               SparseBlock c = ret.sparseBlock;
+               long nnz = 0;
+               for(int i = 0; i < in.rlen; i++) {
+                       if(!a.isEmpty(i)) {
+                               int apos = a.pos(i);
+                               int alen = a.size(i);
+                               c.allocate(i, alen);
+                               int[] aix = a.indexes(i);
+                               double[] avals = a.values(i);
+                               for(int j = apos; j < apos + alen; j++) {
+                                       double val = avals[j];
+                                       if(Double.isNaN(val))
+                                               c.append(i, aix[j], 
replacement);
+                                       else
+                                               c.append(i, aix[j], val);
+                               }
+                               c.compact(i);
+                               nnz += c.size(i);
+                       }
+               }
+               return nnz;
+       }
+
+       private static long replaceSparseInSparseOut(MatrixBlock in, 
MatrixBlock ret, double pattern, double replacement) {
+               ret.allocateSparseRowsBlock();
+               final SparseBlock a = in.sparseBlock;
+               final SparseBlock c = ret.sparseBlock;
+
+               return replaceSparseInSparseOut(a, c, pattern, replacement, 0, 
in.rlen);
+
+       }
+
+       private static long replaceSparseInSparseOut(SparseBlock a, SparseBlock 
c, double pattern, double replacement, int s,
+               int e) {
+               long nnz = 0;
+               for(int i = s; i < e; i++) {
+                       if(!a.isEmpty(i)) {
+                               final int apos = a.pos(i);
+                               final int alen = a.size(i);
+                               final int[] aix = a.indexes(i);
+                               final double[] avals = a.values(i);
+                               c.allocate(i, alen);
+                               for(int j = apos; j < apos + alen; j++) {
+                                       double val = avals[j];
+                                       if(val == pattern)
+                                               c.append(i, aix[j], 
replacement);
+                                       else
+                                               c.append(i, aix[j], val);
+                               }
+                               c.compact(i);
+                               nnz += c.size(i);
+                       }
+               }
+               return nnz;
+       }
+
+       private static long replace0InSparse(MatrixBlock in, MatrixBlock ret, 
double replacement) {
+               ret.sparse = false;
+               ret.allocateDenseBlock();
+               SparseBlock a = in.sparseBlock;
+               DenseBlock c = ret.getDenseBlock();
+
+               // initialize with replacement (since all 0 values, see 
SPARSITY_TURN_POINT)
+               // c.reset(in.rlen, in.clen, replacement);
+
+               if(a == null)// check for empty matrix
+                       return ((long) in.rlen) * in.clen;
+
+               // overwrite with existing values (via scatter)
+               for(int i = 0; i < in.rlen; i++) {
+                       c.fillRow(i, replacement);
+                       if(!a.isEmpty(i)) {
+                               int apos = a.pos(i);
+                               int cpos = c.pos(i);
+                               int alen = a.size(i);
+                               int[] aix = a.indexes(i);
+                               double[] avals = a.values(i);
+                               double[] cvals = c.values(i);
+                               for(int j = apos; j < apos + alen; j++)
+                                       if(avals[j] != 0)
+                                               cvals[cpos + aix[j]] = avals[j];
+                       }
+               }
+               return ((long) in.rlen) * in.clen;
+
+       }
+
+       private static long replaceDense(MatrixBlock in, MatrixBlock ret, 
double pattern, double replacement) {
+               DenseBlock a = in.getDenseBlock();
+               DenseBlock c = ret.allocateDenseBlock().getDenseBlock();
+               long nnz = 0;
+               for(int bi = 0; bi < a.numBlocks(); bi++) {
+                       int len = a.size(bi);
+                       double[] avals = a.valuesAt(bi);
+                       double[] cvals = c.valuesAt(bi);
+                       for(int i = 0; i < len; i++) {
+                               cvals[i] = avals[i] == pattern ? replacement : 
avals[i];
+                               nnz += cvals[i] != 0 ? 1 : 0;
+                       }
+               }
+               return nnz;
+       }
+
+       private static long replaceDenseNaN(MatrixBlock in, MatrixBlock ret, 
double replacement) {
+               DenseBlock a = in.getDenseBlock();
+               DenseBlock c = ret.allocateDenseBlock().getDenseBlock();
+               long nnz = 0;
+               for(int bi = 0; bi < a.numBlocks(); bi++) {
+                       int len = a.size(bi);
+                       double[] avals = a.valuesAt(bi);
+                       double[] cvals = c.valuesAt(bi);
+                       for(int i = 0; i < len; i++) {
+                               cvals[i] = Double.isNaN(avals[i]) ? replacement 
: avals[i];
+                               nnz += cvals[i] != 0 ? 1 : 0;
+                       }
+               }
+               return nnz;
+
+       }
+
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 668906ee2e..054edf06a2 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -5207,93 +5207,7 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock<MatrixBlock>,
        @Override
        public MatrixBlock replaceOperations(MatrixValue result, double 
pattern, double replacement) {
                MatrixBlock ret = checkType(result);
-               examSparsity(); //ensure its in the right format
-               if(ret != null)
-                       ret.reset(rlen, clen, sparse);
-               else
-                       ret = new MatrixBlock(rlen, clen, sparse);
-               
-               //probe early abort conditions
-               if( nonZeros == 0 && pattern != 0  )
-                       return ret;
-               if( !containsValue(pattern) )
-                       return this; //avoid allocation + copy
-               if( isEmpty() && pattern==0 ) {
-                       ret.reset(rlen, clen, replacement);
-                       return ret;
-               }
-
-               boolean NaNpattern = Double.isNaN(pattern);
-               if( sparse ) //SPARSE
-               {
-                       if( pattern != 0d ) //SPARSE <- SPARSE (sparse-safe)
-                       {
-                               ret.allocateSparseRowsBlock();
-                               SparseBlock a = sparseBlock;
-                               SparseBlock c = ret.sparseBlock;
-                               
-                               for( int i=0; i<rlen; i++ ) {
-                                       if( !a.isEmpty(i) ) {
-                                               c.allocate(i);
-                                               int apos = a.pos(i);
-                                               int alen = a.size(i);
-                                               int[] aix = a.indexes(i);
-                                               double[] avals = a.values(i);
-                                               for( int j=apos; j<apos+alen; 
j++ ) {
-                                                       double val = avals[j];
-                                                       if( val== pattern || 
(NaNpattern && Double.isNaN(val)) )
-                                                               c.append(i, 
aix[j], replacement);
-                                                       else
-                                                               c.append(i, 
aix[j], val);
-                                               }
-                                       }
-                               }
-                       }
-                       else //DENSE <- SPARSE
-                       {
-                               ret.sparse = false;
-                               ret.allocateDenseBlock();
-                               SparseBlock a = sparseBlock;
-                               DenseBlock c = ret.getDenseBlock();
-                               
-                               //initialize with replacement (since all 0 
values, see SPARSITY_TURN_POINT)
-                               c.reset(rlen, clen, replacement);
-                               
-                               //overwrite with existing values (via scatter)
-                               if( a != null  ) //check for empty matrix
-                                       for( int i=0; i<rlen; i++ ) {
-                                               if( !a.isEmpty(i) ) {
-                                                       int apos = a.pos(i);
-                                                       int cpos = c.pos(i);
-                                                       int alen = a.size(i);
-                                                       int[] aix = 
a.indexes(i);
-                                                       double[] avals = 
a.values(i);
-                                                       double[] cvals = 
c.values(i);
-                                                       for( int j=apos; 
j<apos+alen; j++ )
-                                                               if( avals[ j ] 
!= 0 )
-                                                                       cvals[ 
cpos+aix[j] ] = avals[ j ];
-                                               }
-                                       }
-                       }
-               }
-               else { //DENSE <- DENSE
-                       DenseBlock a = getDenseBlock();
-                       DenseBlock c = ret.allocateDenseBlock().getDenseBlock();
-                       for( int bi=0; bi<a.numBlocks(); bi++ ) {
-                               int len = a.size(bi);
-                               double[] avals = a.valuesAt(bi);
-                               double[] cvals = c.valuesAt(bi);
-                               for( int i=0; i<len; i++ ) {
-                                       cvals[i] = (avals[i]== pattern 
-                                               || (NaNpattern && 
Double.isNaN(avals[i]))) ?
-                                               replacement : avals[i];
-                               }
-                       }
-               }
-               
-               ret.recomputeNonZeros();
-               ret.examSparsity();
-               return ret;
+               return LibMatrixReplace.replaceOperations(this, ret, pattern, 
replacement);
        }
        
        public MatrixBlock extractTriangular(MatrixBlock ret, boolean lower, 
boolean diag, boolean values) {
diff --git a/src/test/java/org/apache/sysds/performance/Main.java 
b/src/test/java/org/apache/sysds/performance/Main.java
index 9959192188..a281dd2cf0 100644
--- a/src/test/java/org/apache/sysds/performance/Main.java
+++ b/src/test/java/org/apache/sysds/performance/Main.java
@@ -31,6 +31,7 @@ import org.apache.sysds.performance.generators.GenMatrices;
 import org.apache.sysds.performance.generators.IGenerate;
 import org.apache.sysds.performance.generators.MatrixFile;
 import org.apache.sysds.performance.matrix.MatrixMulPerformance;
+import org.apache.sysds.performance.matrix.MatrixReplacePerf;
 import org.apache.sysds.performance.matrix.MatrixStorage;
 import org.apache.sysds.performance.matrix.SparseAppend;
 import org.apache.sysds.runtime.data.SparseBlock;
@@ -104,6 +105,9 @@ public class Main {
                        case 16:
                                run16(args);
                                break;
+                       case 17: 
+                               run17(args);
+                               break;
                        case 1000:
                                run1000(args);
                                break;
@@ -204,6 +208,16 @@ public class Main {
                System.out.println(mb);
        }
 
+       private static void run17(String[] args) throws Exception {
+               int rows = Integer.parseInt(args[1]);
+               int cols = Integer.parseInt(args[2]);
+               double spar = Double.parseDouble(args[3]);
+               int k = Integer.parseInt(args[4]);
+               MatrixBlock mb = 
TestUtils.ceil(TestUtils.generateTestMatrixBlock(rows, cols, 0, 100, spar, rows 
+ 1));
+               IGenerate<MatrixBlock> g = new ConstMatrix(mb);
+               new MatrixReplacePerf(100, g, k).run();
+       }
+
        private static void run1000(String[] args) {
                MatrixMulPerformance perf;
                if (args.length < 3) {
diff --git 
a/src/test/java/org/apache/sysds/performance/matrix/MatrixReplacePerf.java 
b/src/test/java/org/apache/sysds/performance/matrix/MatrixReplacePerf.java
new file mode 100644
index 0000000000..17de8d53fe
--- /dev/null
+++ b/src/test/java/org/apache/sysds/performance/matrix/MatrixReplacePerf.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.performance.matrix;
+
+import org.apache.sysds.performance.compression.APerfTest;
+import org.apache.sysds.performance.generators.IGenerate;
+import org.apache.sysds.runtime.matrix.data.LibMatrixReplace;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+
+public class MatrixReplacePerf extends APerfTest<Object, MatrixBlock> {
+
+       private final int k;
+
+       public MatrixReplacePerf(int N, IGenerate<MatrixBlock> gen, int k) {
+               super(N, gen);
+               this.k = k;
+       }
+
+       public void run() throws Exception {
+
+               warmup(() -> replaceZeroTask(k), 10);
+               execute(() -> replaceZeroTask(k), "replaceZero");
+               execute(() -> replaceOneTask(k), "replaceOne");
+               execute(() -> replaceNaNTask(k), "replaceNaN");
+       }
+
+       private void replaceZeroTask(int k) {
+               MatrixBlock mb = gen.take();
+               LibMatrixReplace.replaceOperations(mb, null, 0, 1, k);
+               ret.add(null);
+       }
+
+
+       private void replaceOneTask(int k) {
+               MatrixBlock mb = gen.take();
+               LibMatrixReplace.replaceOperations(mb, null, 1, 2, k);
+               ret.add(null);
+       }
+
+
+       private void replaceNaNTask(int k) {
+               MatrixBlock mb = gen.take();
+               LibMatrixReplace.replaceOperations(mb, null, Double.NaN, 2, k);
+               ret.add(null);
+       }
+
+       @Override
+       protected String makeResString() {
+               return "";
+       }
+
+}

Reply via email to