This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 12367cb9f4 [SYSTEMDS-3828] Parallel Compressed Replace
12367cb9f4 is described below

commit 12367cb9f4ba54d174779945b739a6af2a4968da
Author: Sebastian Baunsgaard <[email protected]>
AuthorDate: Mon Feb 3 15:03:07 2025 +0100

    [SYSTEMDS-3828] Parallel Compressed Replace
    
    This commit adds the parallel kernel for compressed
    replace of values.
    
    Closes #2209
---
 .../runtime/compress/CompressedMatrixBlock.java    |  52 +++-------
 .../sysds/runtime/compress/lib/CLALibReplace.java  | 108 +++++++++++++++++++++
 .../cp/ParameterizedBuiltinCPInstruction.java      |   4 +-
 .../sysds/runtime/matrix/data/MatrixBlock.java     |   8 +-
 .../component/compress/CompressedCustomTests.java  |  15 ++-
 .../component/compress/CompressedMatrixTest.java   |  32 ------
 .../component/compress/CompressedTestBase.java     |  40 +++++++-
 7 files changed, 185 insertions(+), 74 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java 
b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
index a05c076b36..bee86addf2 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
@@ -58,6 +58,7 @@ import org.apache.sysds.runtime.compress.lib.CLALibDecompress;
 import org.apache.sysds.runtime.compress.lib.CLALibMMChain;
 import org.apache.sysds.runtime.compress.lib.CLALibMatrixMult;
 import org.apache.sysds.runtime.compress.lib.CLALibMerge;
+import org.apache.sysds.runtime.compress.lib.CLALibReplace;
 import org.apache.sysds.runtime.compress.lib.CLALibReshape;
 import org.apache.sysds.runtime.compress.lib.CLALibRexpand;
 import org.apache.sysds.runtime.compress.lib.CLALibScalar;
@@ -307,7 +308,7 @@ public class CompressedMatrixBlock extends MatrixBlock {
         * @return The cached decompressed matrix, if it does not exist return 
null
         */
        public MatrixBlock getCachedDecompressed() {
-               if( allowCachingUncompressed && decompressedVersion != null) {
+               if(allowCachingUncompressed && decompressedVersion != null) {
                        final MatrixBlock mb = decompressedVersion.get();
                        if(mb != null) {
                                
DMLCompressionStatistics.addDecompressCacheCount();
@@ -401,8 +402,8 @@ public class CompressedMatrixBlock extends MatrixBlock {
                        long total = baseSizeInMemory();
                        // take into consideration duplicate dictionaries
                        Set<IDictionary> dicts = new HashSet<>();
-                       for(AColGroup grp : _colGroups){
-                               if(grp instanceof ADictBasedColGroup){
+                       for(AColGroup grp : _colGroups) {
+                               if(grp instanceof ADictBasedColGroup) {
                                        IDictionary dg = ((ADictBasedColGroup) 
grp).getDictionary();
                                        if(dicts.contains(dg))
                                                total -= dg.getInMemorySize();
@@ -576,8 +577,7 @@ public class CompressedMatrixBlock extends MatrixBlock {
        }
 
        @Override
-       public MatrixBlock chainMatrixMultOperations(MatrixBlock v, MatrixBlock 
w, MatrixBlock out, ChainType ctype,
-               int k) {
+       public MatrixBlock chainMatrixMultOperations(MatrixBlock v, MatrixBlock 
w, MatrixBlock out, ChainType ctype, int k) {
 
                checkMMChain(ctype, v, w);
                // multi-threaded MMChain of single uncompressed ColGroup
@@ -629,27 +629,8 @@ public class CompressedMatrixBlock extends MatrixBlock {
        }
 
        @Override
-       public MatrixBlock replaceOperations(MatrixValue result, double 
pattern, double replacement) {
-               if(Double.isInfinite(pattern)) {
-                       LOG.info("Ignoring replace infinite in compression 
since it does not contain this value");
-                       return this;
-               }
-               else if(isOverlapping()) {
-                       final String message = "replaceOperations " + pattern + 
" -> " + replacement;
-                       return 
getUncompressed(message).replaceOperations(result, pattern, replacement);
-               }
-               else {
-
-                       CompressedMatrixBlock ret = new 
CompressedMatrixBlock(getNumRows(), getNumColumns());
-                       final List<AColGroup> prev = getColGroups();
-                       final int colGroupsLength = prev.size();
-                       final List<AColGroup> retList = new 
ArrayList<>(colGroupsLength);
-                       for(int i = 0; i < colGroupsLength; i++)
-                               retList.add(prev.get(i).replace(pattern, 
replacement));
-                       ret.allocateColGroupList(retList);
-                       ret.recomputeNonZeros();
-                       return ret;
-               }
+       public MatrixBlock replaceOperations(MatrixValue result, double 
pattern, double replacement, int k) {
+               return CLALibReplace.replace(this, (MatrixBlock) result, 
pattern, replacement, k);
        }
 
        @Override
@@ -710,10 +691,10 @@ public class CompressedMatrixBlock extends MatrixBlock {
                        return false;
                }
        }
-       
+
        @Override
        public boolean containsValue(double pattern, int k) {
-               //TODO parallel contains value
+               // TODO parallel contains value
                return containsValue(pattern);
        }
 
@@ -775,8 +756,8 @@ public class CompressedMatrixBlock extends MatrixBlock {
                        return false;
                else if(_colGroups == null || nonZeros == 0)
                        return true;
-               else{
-                       if(nonZeros == -1){
+               else {
+                       if(nonZeros == -1) {
                                // try to use column groups
                                for(AColGroup g : _colGroups)
                                        if(!g.isEmpty())
@@ -1177,8 +1158,7 @@ public class CompressedMatrixBlock extends MatrixBlock {
        }
 
        @Override
-       public void appendRowToSparse(SparseBlock dest, MatrixBlock src, int i, 
int rowoffset, int coloffset,
-               boolean deep) {
+       public void appendRowToSparse(SparseBlock dest, MatrixBlock src, int i, 
int rowoffset, int coloffset, boolean deep) {
                throw new DMLCompressionException("Can't append row to 
compressed Matrix");
        }
 
@@ -1238,7 +1218,7 @@ public class CompressedMatrixBlock extends MatrixBlock {
        }
 
        @Override
-       public void denseToSparse(boolean allowCSR, int k){
+       public void denseToSparse(boolean allowCSR, int k) {
                // do nothing
        }
 
@@ -1327,13 +1307,13 @@ public class CompressedMatrixBlock extends MatrixBlock {
                throw new DMLCompressionException("Invalid to allocate block on 
a compressed MatrixBlock");
        }
 
-       @Override 
+       @Override
        public MatrixBlock transpose(int k) {
                return getUncompressed().transpose(k);
        }
 
-       @Override 
-       public MatrixBlock reshape(int rows,int cols, boolean byRow){
+       @Override
+       public MatrixBlock reshape(int rows, int cols, boolean byRow) {
                return CLALibReshape.reshape(this, rows, cols, byRow);
        }
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibReplace.java 
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibReplace.java
new file mode 100644
index 0000000000..d86026d663
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibReplace.java
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.compress.lib;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
+import org.apache.sysds.runtime.compress.colgroup.AColGroup;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.util.CommonThreadPool;
+
+public class CLALibReplace {
+       private static final Log LOG = 
LogFactory.getLog(CLALibReplace.class.getName());
+
+       private CLALibReplace(){
+               // private constructor
+       }
+
+       public static MatrixBlock replace(CompressedMatrixBlock in, MatrixBlock 
out, double pattern, double replacement,
+               int k) {
+               try {
+
+                       if(Double.isInfinite(pattern)) {
+                               LOG.info("Ignoring replace infinite in 
compression since it does not contain this value");
+                               return in;
+                       }
+                       else if(in.isOverlapping()) {
+                               final String message = "replaceOperations " + 
pattern + " -> " + replacement;
+                               return 
in.getUncompressed(message).replaceOperations(out, pattern, replacement);
+                       }
+                       else
+                               return replaceNormal(in, out, pattern, 
replacement, k);
+               }
+               catch(Exception e) {
+                       throw new RuntimeException("Failed replace pattern: " + 
pattern + " replacement: " + replacement, e);
+               }
+       }
+
+       private static MatrixBlock replaceNormal(CompressedMatrixBlock in, 
MatrixBlock out, double pattern,
+               double replacement, int k) throws Exception {
+               CompressedMatrixBlock ret = new 
CompressedMatrixBlock(in.getNumRows(), in.getNumColumns());
+               final List<AColGroup> prev = in.getColGroups();
+               final int colGroupsLength = prev.size();
+               final List<AColGroup> retList = new 
ArrayList<>(colGroupsLength);
+
+               if(k <= 1)
+                       replaceSingleThread(pattern, replacement, prev, 
colGroupsLength, retList);
+               else
+                       replaceMultiThread(pattern, replacement, k, prev, 
colGroupsLength, retList);
+
+               ret.allocateColGroupList(retList);
+               if(replacement == 0) // have to recompute!
+                       ret.recomputeNonZeros();
+               else if(pattern == 0) // always fully dense.
+                       ret.setNonZeros(((long) in.getNumRows()) * 
in.getNumColumns());
+               else // same nonzeros as input
+                       ret.setNonZeros(in.getNonZeros());
+               return ret;
+       }
+
+       private static void replaceMultiThread(double pattern, double 
replacement, int k, final List<AColGroup> prev,
+               final int colGroupsLength, final List<AColGroup> retList) 
throws InterruptedException, ExecutionException {
+               ExecutorService pool = CommonThreadPool.get(k);
+
+               try {
+                       List<Future<AColGroup>> tasks = new 
ArrayList<>(colGroupsLength);
+                       for(int i = 0; i < colGroupsLength; i++) {
+                               final int j = i;
+                               tasks.add(pool.submit(() -> 
prev.get(j).replace(pattern, replacement)));
+                       }
+                       for(int i = 0; i < colGroupsLength; i++) {
+                               retList.add(tasks.get(i).get());
+                       }
+               }
+               finally {
+                       pool.shutdown();
+               }
+       }
+
+       private static void replaceSingleThread(double pattern, double 
replacement, final List<AColGroup> prev,
+               final int colGroupsLength, final List<AColGroup> retList) {
+               for(int i = 0; i < colGroupsLength; i++)
+                       retList.add(prev.get(i).replace(pattern, replacement));
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
index 2fb64b170d..119589a303 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
@@ -66,6 +66,7 @@ import org.apache.sysds.runtime.transform.tokenize.Tokenizer;
 import org.apache.sysds.runtime.transform.tokenize.TokenizerFactory;
 import org.apache.sysds.runtime.util.AutoDiff;
 import org.apache.sysds.runtime.util.DataConverter;
+import org.apache.sysds.utils.stats.InfrastructureAnalyzer;
 
 public class ParameterizedBuiltinCPInstruction extends 
ComputationCPInstruction {
        private static final Log LOG = 
LogFactory.getLog(ParameterizedBuiltinCPInstruction.class.getName());
@@ -276,7 +277,8 @@ public class ParameterizedBuiltinCPInstruction extends 
ComputationCPInstruction
                                MatrixBlock target = targetObj.acquireRead();
                                double pattern = 
Double.parseDouble(params.get("pattern"));
                                double replacement = 
Double.parseDouble(params.get("replacement"));
-                               MatrixBlock ret = target.replaceOperations(new 
MatrixBlock(), pattern, replacement);
+                               MatrixBlock ret = target.replaceOperations(new 
MatrixBlock(), pattern, replacement, 
+                                       
InfrastructureAnalyzer.getLocalParallelism());
                                if( ret == target ) //shallow copy (avoid 
bufferpool pollution)
                                        ec.setVariable(output.getName(), 
targetObj);
                                else
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index c9086778f0..057811d2db 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -5157,9 +5157,13 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock<MatrixBlock>,
        
        
        @Override
-       public MatrixBlock replaceOperations(MatrixValue result, double 
pattern, double replacement) {
+       public final MatrixBlock replaceOperations(MatrixValue result, double 
pattern, double replacement) {
+               return replaceOperations(result, pattern, replacement, 1);
+       }
+
+       public MatrixBlock replaceOperations(MatrixValue result, double 
pattern, double replacement, int k) {
                MatrixBlock ret = checkType(result);
-               return LibMatrixReplace.replaceOperations(this, ret, pattern, 
replacement);
+               return LibMatrixReplace.replaceOperations(this, ret, pattern, 
replacement, k);
        }
        
        public MatrixBlock extractTriangular(MatrixBlock ret, boolean lower, 
boolean diag, boolean values) {
diff --git 
a/src/test/java/org/apache/sysds/test/component/compress/CompressedCustomTests.java
 
b/src/test/java/org/apache/sysds/test/component/compress/CompressedCustomTests.java
index 32d62fb16c..886198bb22 100644
--- 
a/src/test/java/org/apache/sysds/test/component/compress/CompressedCustomTests.java
+++ 
b/src/test/java/org/apache/sysds/test/component/compress/CompressedCustomTests.java
@@ -20,6 +20,7 @@
 package org.apache.sysds.test.component.compress;
 
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 
@@ -38,6 +39,7 @@ import 
org.apache.sysds.runtime.compress.cost.CostEstimatorBuilder;
 import org.apache.sysds.runtime.compress.cost.CostEstimatorFactory;
 import org.apache.sysds.runtime.compress.cost.InstructionTypeCounter;
 import org.apache.sysds.runtime.compress.lib.CLALibCBind;
+import org.apache.sysds.runtime.compress.lib.CLALibReplace;
 import org.apache.sysds.runtime.compress.workload.WTreeRoot;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.test.TestUtils;
@@ -397,9 +399,18 @@ public class CompressedCustomTests {
                TestUtils.compareMatricesBitAvgDistance(m1, m2, 0, 0, "no");
        }
 
+       @Test(expected = Exception.class)
+       public void cbindWithError() {
+               CLALibCBind.cbind(null, new MatrixBlock[] {null}, 0);
+       }
 
        @Test(expected = Exception.class)
-       public void cbindWithError(){
-               CLALibCBind.cbind(null, new MatrixBlock[]{null}, 0);
+       public void replaceWithError() {
+               CLALibReplace.replace(null, null, 0, 0, 10);
+       }
+
+       @Test
+       public void replaceInf() {
+               assertNull(CLALibReplace.replace(null, null, 
Double.POSITIVE_INFINITY, 0, 10));
        }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/component/compress/CompressedMatrixTest.java
 
b/src/test/java/org/apache/sysds/test/component/compress/CompressedMatrixTest.java
index 5de4967517..d36c6167cf 100644
--- 
a/src/test/java/org/apache/sysds/test/component/compress/CompressedMatrixTest.java
+++ 
b/src/test/java/org/apache/sysds/test/component/compress/CompressedMatrixTest.java
@@ -329,38 +329,6 @@ public class CompressedMatrixTest extends 
AbstractCompressedUnaryTests {
                }
        }
 
-       @Test
-       public void testReplaceNotContainedValue() {
-               double v = min - 1;
-               if(v != 0)
-                       testReplace(v);
-       }
-
-       @Test
-       public void testReplace() {
-               if(min != 0)
-                       testReplace(min);
-       }
-
-       @Test
-       public void testReplaceZero() {
-               testReplace(0);
-       }
-
-       private void testReplace(double value) {
-               try {
-                       if(!(cmb instanceof CompressedMatrixBlock) || rows * 
cols > 10000)
-                               return;
-                       ucRet = mb.replaceOperations(ucRet, value, 1425);
-                       MatrixBlock ret2 = cmb.replaceOperations(new 
MatrixBlock(), value, 1425);
-                       compareResultMatrices(ucRet, ret2, 1);
-               }
-               catch(Exception e) {
-                       e.printStackTrace();
-                       throw new DMLRuntimeException(e);
-               }
-       }
-
        @Test
        public void testCompressedMatrixConstruction() {
                try {
diff --git 
a/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java
 
b/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java
index 507a2fc663..8692f56b69 100644
--- 
a/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java
+++ 
b/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java
@@ -1173,7 +1173,7 @@ public abstract class CompressedTestBase extends TestBase 
{
                }
                catch(AssertionError e) {
                        e.printStackTrace();
-                       fail("failed Cbind: " + cmb.toString() );
+                       fail("failed Cbind: " + cmb.toString());
                }
        }
 
@@ -1299,4 +1299,42 @@ public abstract class CompressedTestBase extends 
TestBase {
                return new 
CompressionSettingsBuilder().setSeed(compressionSeed).setMinimumSampleSize(100);
        }
 
+       @Test
+       public void testReplaceNotContainedValue() {
+               double v = min - 1;
+               if(v != 0)
+                       testReplace(v, 132);
+       }
+
+       @Test
+       public void testReplace() {
+               if(min != 0)
+                       testReplace(min, 323);
+       }
+
+       @Test
+       public void testReplaceWithZero() {
+               if(min != 0)
+                       testReplace(min, 0);
+       }
+
+       @Test
+       public void testReplaceZero() {
+               testReplace(0, 3232);
+       }
+
+       private void testReplace(double value, double replacements) {
+               try {
+                       if(!(cmb instanceof CompressedMatrixBlock) || rows * 
cols > 10000)
+                               return;
+                       ucRet = mb.replaceOperations(ucRet, value, 
replacements, _k);
+                       MatrixBlock ret2 = cmb.replaceOperations(new 
MatrixBlock(), value, replacements, _k);
+                       compareResultMatrices(ucRet, ret2, 1);
+               }
+               catch(Exception e) {
+                       e.printStackTrace();
+                       throw new DMLRuntimeException(e);
+               }
+       }
+
 }

Reply via email to