This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit c42a629d1d9d3896fa22ec9793bb6e21df3a76b3
Author: Sebastian Baunsgaard <[email protected]>
AuthorDate: Sun Dec 29 22:06:42 2024 +0100

    [MINOR] Fused decompression in CLALibScalar
    
    Closes #2169
---
 .../sysds/runtime/compress/lib/CLALibScalar.java   | 84 +++++++++++++++++++++-
 1 file changed, 81 insertions(+), 3 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java 
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java
index a8b1a8ad22..5588a538aa 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java
@@ -31,12 +31,14 @@ import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
+import org.apache.sysds.runtime.compress.DMLCompressionException;
 import org.apache.sysds.runtime.compress.colgroup.AColGroup;
 import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
 import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
 import org.apache.sysds.runtime.compress.colgroup.ColGroupOLE;
 import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
 import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
+import org.apache.sysds.runtime.data.DenseBlock;
 import org.apache.sysds.runtime.functionobjects.Divide;
 import org.apache.sysds.runtime.functionobjects.Minus;
 import org.apache.sysds.runtime.functionobjects.Multiply;
@@ -57,10 +59,13 @@ public final class CLALibScalar {
        }
 
        public static MatrixBlock scalarOperations(ScalarOperator sop, 
CompressedMatrixBlock m1, MatrixValue result) {
+               // Timing time = new Timing(true);
                if(isInvalidForCompressedOutput(m1, sop)) {
                        LOG.warn("scalar overlapping not supported for op: " + 
sop.fn.getClass().getSimpleName());
-                       MatrixBlock m1d = m1.decompress(sop.getNumThreads());
-                       return m1d.scalarOperations(sop, result);
+
+                       return fusedScalarAndDecompress(m1, sop);
+                       // MatrixBlock m1d = m1.decompress(sop.getNumThreads());
+                       // return m1d.scalarOperations(sop, result);
                }
                CompressedMatrixBlock ret = setupRet(m1, result);
 
@@ -89,11 +94,84 @@ public final class CLALibScalar {
                        ret.setOverlapping(m1.isOverlapping());
                }
 
-               ret.recomputeNonZeros();
+               if(sop.fn instanceof Divide) {
+                       ret.setNonZeros(m1.getNonZeros());
+               }
+               else {
+                       ret.recomputeNonZeros();
+               }
 
+               // System.out.println("CLA Scalar: " + sop + " " + 
m1.getNumRows() + ", " + m1.getNumColumns() + ", " +
+               // m1.getColGroups().size()
+               // + " -- " + "\t\t" + time.stop());
                return ret;
        }
 
+       private static MatrixBlock 
fusedScalarAndDecompress(CompressedMatrixBlock in, ScalarOperator sop) {
+               int k = sop.getNumThreads();
+               ExecutorService pool = CommonThreadPool.get(k);
+               try {
+                       final int nRow  = in.getNumRows();
+                       final int nCol = in.getNumColumns();
+                       final MatrixBlock out = new MatrixBlock(nRow, nCol, 
false);
+                       final List<AColGroup> groups = in.getColGroups();
+                       out.allocateDenseBlock();
+                       final DenseBlock db = out.getDenseBlock();
+                       final int blkz = Math.max((int)(Math.ceil((double)nRow 
/ k)), 256);
+                       final List<Future<Long>> tasks = new ArrayList<>();
+                       for(int i = 0; i < nRow; i += blkz) {
+                               final int start = i;
+                               final int end = Math.min(i + blkz, nRow);
+                               tasks.add(pool.submit(() -> 
fusedDecompressAndScalar(groups, nCol, start, end, db, sop)));
+                       }
+                       long nnz = 0;
+                       for(Future<Long> t : tasks) {
+                               nnz += t.get();
+                       }
+                       out.setNonZeros(nnz);
+                       out.examSparsity(true, k);
+                       return out;
+               }
+               catch(Exception e) {
+                       throw new DMLCompressionException("failed fused scalar 
operation", e);
+               }
+               finally {
+                       pool.shutdown();
+               }
+
+               // MatrixBlock m1d = m1.decompress(sop.getNumThreads());
+               // return m1d.scalarOperations(sop, result);
+       }
+
+       private static long fusedDecompressAndScalar(final List<AColGroup> 
groups, int nCol, int start, int end,
+               DenseBlock db, ScalarOperator sop) {
+               long nnz = 0;
+               for(int b = start; b < end; b += 32) {
+                       int bs = b;
+                       int be = Math.min(b + 32, end);
+                       nnz += fusedDecompressAndScalarBlock(groups, nCol, bs, 
be, db, sop);
+               }
+               return nnz;
+       }
+
+       private static long fusedDecompressAndScalarBlock(final List<AColGroup> 
groups, int nCol, int bs, int be,
+               DenseBlock db, ScalarOperator sop) {
+               long nnz = 0;
+               for(AColGroup g : groups) {
+                       // main block to optimize is decompression speed since 
it is most likely an overlapping input
+                       g.decompressToDenseBlock(db, bs, be);
+               }
+               for(int r = bs; r < be; r++) {
+                       double[] vals = db.values(r);
+                       int off = db.pos(r);
+                       for(int c = off; c < nCol + off; c++) {
+                               vals[c] = sop.executeScalar(vals[c]);
+                               nnz += vals[c] == 0 ? 0 : 1;
+                       }
+               }
+               return nnz;
+       }
+
        private static CompressedMatrixBlock setupRet(CompressedMatrixBlock m1, 
MatrixValue result) {
                CompressedMatrixBlock ret;
                if(result == null || !(result instanceof CompressedMatrixBlock))

Reply via email to