This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit a98560cb306012771dce215bda150a89dd9bf482
Author: baunsgaard <[email protected]>
AuthorDate: Sun Jan 16 14:02:08 2022 +0100

    [SYSTEMDS-3243] Compressed Matrix Multiplication part
    
    This commit follow the previous by modifying the compression tests
    and compression path for Matrix Multiplcation to fit with the design of
    the normal MatrixBlock.
    
    Closes #1480
---
 .../runtime/compress/CompressedMatrixBlock.java    | 105 +----------------
 .../runtime/compress/lib/CLALibMatrixMult.java     | 128 +++++++++++++++++++++
 .../component/compress/CompressedTestBase.java     |   3 +-
 .../test/component/estim/OpBindChainTest.java      |   4 +-
 .../test/component/estim/OpElemWChainTest.java     |   4 +-
 .../component/estim/SquaredProductChainTest.java   |   2 +-
 6 files changed, 140 insertions(+), 106 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java 
b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
index 85cc23b..ec09226 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
@@ -51,8 +51,8 @@ import org.apache.sysds.runtime.compress.lib.CLALibCompAgg;
 import org.apache.sysds.runtime.compress.lib.CLALibDecompress;
 import org.apache.sysds.runtime.compress.lib.CLALibLeftMultBy;
 import org.apache.sysds.runtime.compress.lib.CLALibMMChain;
+import org.apache.sysds.runtime.compress.lib.CLALibMatrixMult;
 import org.apache.sysds.runtime.compress.lib.CLALibReExpand;
-import org.apache.sysds.runtime.compress.lib.CLALibRightMultBy;
 import org.apache.sysds.runtime.compress.lib.CLALibScalar;
 import org.apache.sysds.runtime.compress.lib.CLALibSlice;
 import org.apache.sysds.runtime.compress.lib.CLALibSquash;
@@ -61,13 +61,11 @@ import org.apache.sysds.runtime.compress.lib.CLALibUtils;
 import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType;
 import 
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
-import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
 import org.apache.sysds.runtime.data.DenseBlock;
 import org.apache.sysds.runtime.data.SparseBlock;
 import org.apache.sysds.runtime.data.SparseRow;
 import org.apache.sysds.runtime.functionobjects.MinusMultiply;
 import org.apache.sysds.runtime.functionobjects.PlusMultiply;
-import org.apache.sysds.runtime.functionobjects.SwapIndex;
 import 
org.apache.sysds.runtime.functionobjects.TernaryValueFunction.ValueFunctionWithConstant;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
@@ -76,7 +74,6 @@ import 
org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
 import org.apache.sysds.runtime.matrix.data.CTableMap;
 import org.apache.sysds.runtime.matrix.data.IJV;
 import org.apache.sysds.runtime.matrix.data.LibMatrixDatagen;
-import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
 import org.apache.sysds.runtime.matrix.data.LibMatrixTercell;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
@@ -471,105 +468,15 @@ public class CompressedMatrixBlock extends MatrixBlock {
        }
 
        @Override
-       public MatrixBlock aggregateBinaryOperations(MatrixBlock m1, 
MatrixBlock m2, MatrixBlock ret,
-               AggregateBinaryOperator op) {
-               // create output matrix block
-               return aggregateBinaryOperations(m1, m2, ret, op, false, false);
+       public MatrixBlock aggregateBinaryOperations(MatrixBlock m1, 
MatrixBlock m2, MatrixBlock ret, AggregateBinaryOperator op) {
+               checkAggregateBinaryOperations(m1, m2, op);
+               return CLALibMatrixMult.matrixMultiply(m1, m2, ret, 
op.getNumThreads(), false, false);
        }
 
        public MatrixBlock aggregateBinaryOperations(MatrixBlock m1, 
MatrixBlock m2, MatrixBlock ret,
                AggregateBinaryOperator op, boolean transposeLeft, boolean 
transposeRight) {
-               validateMatrixMult(m1, m2);
-               final int k = op.getNumThreads();
-               final Timing time = LOG.isTraceEnabled() ? new Timing(true) : 
null;
-
-               if(m1 instanceof CompressedMatrixBlock && m2 instanceof 
CompressedMatrixBlock) {
-                       return 
doubleCompressedAggregateBinaryOperations((CompressedMatrixBlock) m1, 
(CompressedMatrixBlock) m2, ret,
-                               op, transposeLeft, transposeRight);
-               }
-               boolean transposeOutput = false;
-               if(transposeLeft || transposeRight) {
-
-                       if((m1 instanceof CompressedMatrixBlock && 
transposeLeft) ||
-                               (m2 instanceof CompressedMatrixBlock && 
transposeRight)) {
-                               // change operation from m1 %*% m2 -> t( t(m2) 
%*% t(m1) )
-                               transposeOutput = true;
-                               MatrixBlock tmp = m1;
-                               m1 = m2;
-                               m2 = tmp;
-                               boolean tmpLeft = transposeLeft;
-                               transposeLeft = !transposeRight;
-                               transposeRight = !tmpLeft;
-
-                       }
-
-                       if(!(m1 instanceof CompressedMatrixBlock) && 
transposeLeft) {
-                               m1 = LibMatrixReorg.transpose(m1, k);
-                               transposeLeft = false;
-                       }
-                       else if(!(m2 instanceof CompressedMatrixBlock) && 
transposeRight) {
-                               m2 = LibMatrixReorg.transpose(m2, k);
-                               transposeRight = false;
-                       }
-               }
-
-               final boolean right = (m1 == this);
-               final MatrixBlock that = right ? m2 : m1;
-
-               // create output matrix block
-               if(right)
-                       ret = CLALibRightMultBy.rightMultByMatrix(this, that, 
ret, op.getNumThreads());
-               else
-                       ret = CLALibLeftMultBy.leftMultByMatrix(this, that, 
ret, op.getNumThreads());
-
-               if(LOG.isTraceEnabled())
-                       LOG.trace("MM: Time block w/ sharedDim: " + 
m1.getNumColumns() + " rowLeft: " + m1.getNumRows() + " colRight:"
-                               + m2.getNumColumns() + " in " + time.stop() + 
"ms.");
-
-               if(transposeOutput) {
-                       if(ret instanceof CompressedMatrixBlock) {
-                               LOG.warn("Transposing decompression");
-                               ret = ((CompressedMatrixBlock) 
ret).decompress(k);
-                       }
-                       ret = LibMatrixReorg.transpose(ret, k);
-               }
-
-               return ret;
-       }
-
-       private void validateMatrixMult(MatrixBlock m1, MatrixBlock m2) {
-               if(!(m1 == this || m2 == this))
-                       throw new DMLRuntimeException("Invalid 
aggregateBinaryOperation One of either input should be this");
-       }
-
-       private MatrixBlock 
doubleCompressedAggregateBinaryOperations(CompressedMatrixBlock m1, 
CompressedMatrixBlock m2,
-               MatrixBlock ret, AggregateBinaryOperator op, boolean 
transposeLeft, boolean transposeRight) {
-               if(!transposeLeft && !transposeRight) {
-                       // If both are not transposed, decompress the right 
hand side. to enable
-                       // compressed overlapping output.
-                       LOG.warn("Matrix decompression from multiplying two 
compressed matrices.");
-                       return aggregateBinaryOperations(m1, 
getUncompressed(m2), ret, op, transposeLeft, transposeRight);
-               }
-               else if(transposeLeft && !transposeRight) {
-                       if(m1.getNumColumns() > m2.getNumColumns()) {
-                               ret = 
CLALibLeftMultBy.leftMultByMatrixTransposed(m1, m2, ret, op.getNumThreads());
-                               ReorgOperator r_op = new 
ReorgOperator(SwapIndex.getSwapIndexFnObject(), op.getNumThreads());
-                               return ret.reorgOperations(r_op, new 
MatrixBlock(), 0, 0, 0);
-                       }
-                       else
-                               return 
CLALibLeftMultBy.leftMultByMatrixTransposed(m2, m1, ret, op.getNumThreads());
-
-               }
-               else if(!transposeLeft && transposeRight) {
-                       throw new DMLCompressionException("Not Implemented 
compressed Matrix Mult, to produce larger matrix");
-                       // worst situation since it blows up the result matrix 
in number of rows in
-                       // either compressed matrix.
-               }
-               else {
-                       ret = aggregateBinaryOperations(m2, m1, ret, op);
-                       ReorgOperator r_op = new 
ReorgOperator(SwapIndex.getSwapIndexFnObject(), op.getNumThreads());
-                       return ret.reorgOperations(r_op, new MatrixBlock(), 0, 
0, 0);
-               }
+               checkAggregateBinaryOperations(m1, m2, op, transposeLeft, 
transposeRight);
+               return CLALibMatrixMult.matrixMultiply(m1, m2, ret, 
op.getNumThreads(), transposeLeft, transposeRight);
        }
 
        @Override
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMatrixMult.java 
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMatrixMult.java
new file mode 100644
index 0000000..941338d
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMatrixMult.java
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.compress.lib;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
+import org.apache.sysds.runtime.compress.DMLCompressionException;
+import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
+import org.apache.sysds.runtime.functionobjects.SwapIndex;
+import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
+
+public class CLALibMatrixMult {
+       private static final Log LOG = 
LogFactory.getLog(CLALibMatrixMult.class.getName());
+
+       public static MatrixBlock matrixMult(MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret, int k) {
+               return matrixMultiply(m1, m2, ret, k, false, false);
+       }
+
+       public static MatrixBlock matrixMultiply(MatrixBlock m1, MatrixBlock 
m2, MatrixBlock ret,
+               int k, boolean transposeLeft, boolean transposeRight) {
+               final Timing time = LOG.isTraceEnabled() ? new Timing(true) : 
null;
+
+               if(m1 instanceof CompressedMatrixBlock && m2 instanceof 
CompressedMatrixBlock) {
+                       return 
doubleCompressedMatrixMultiply((CompressedMatrixBlock) m1, 
(CompressedMatrixBlock) m2, ret,
+                               k, transposeLeft, transposeRight);
+               }
+
+               boolean transposeOutput = false;
+               if(transposeLeft || transposeRight) {
+
+                       if((m1 instanceof CompressedMatrixBlock && 
transposeLeft) ||
+                               (m2 instanceof CompressedMatrixBlock && 
transposeRight)) {
+                               // change operation from m1 %*% m2 -> t( t(m2) 
%*% t(m1) )
+                               transposeOutput = true;
+                               MatrixBlock tmp = m1;
+                               m1 = m2;
+                               m2 = tmp;
+                               boolean tmpLeft = transposeLeft;
+                               transposeLeft = !transposeRight;
+                               transposeRight = !tmpLeft;
+                       }
+
+                       if(!(m1 instanceof CompressedMatrixBlock) && 
transposeLeft) {
+                               m1 = LibMatrixReorg.transpose(m1, k);
+                               transposeLeft = false;
+                       }
+                       else if(!(m2 instanceof CompressedMatrixBlock) && 
transposeRight) {
+                               m2 = LibMatrixReorg.transpose(m2, k);
+                               transposeRight = false;
+                       }
+               }
+
+               final boolean right = (m1 instanceof CompressedMatrixBlock);
+               final CompressedMatrixBlock c =(CompressedMatrixBlock) (right ? 
m1 : m2);
+               final MatrixBlock that = right ? m2 : m1;
+
+               // create output matrix block
+               if(right)
+                       ret = CLALibRightMultBy.rightMultByMatrix(c, that, ret, 
k);
+               else
+                       ret = CLALibLeftMultBy.leftMultByMatrix(c, that, ret, 
k);
+
+               if(LOG.isTraceEnabled())
+                       LOG.trace("MM: Time block w/ sharedDim: " + 
m1.getNumColumns() + " rowLeft: " + m1.getNumRows() + " colRight:"
+                               + m2.getNumColumns() + " in " + time.stop() + 
"ms.");
+
+               if(transposeOutput) {
+                       if(ret instanceof CompressedMatrixBlock) {
+                               LOG.warn("Transposing decompression");
+                               ret = ((CompressedMatrixBlock) 
ret).decompress(k);
+                       }
+                       ret = LibMatrixReorg.transpose(ret, k);
+               }
+
+               return ret;
+       }
+
+       private static MatrixBlock 
doubleCompressedMatrixMultiply(CompressedMatrixBlock m1, CompressedMatrixBlock 
m2,
+               MatrixBlock ret, int k, boolean transposeLeft, boolean 
transposeRight) {
+               if(!transposeLeft && !transposeRight) {
+                       // If both are not transposed, decompress the right 
hand side. to enable
+                       // compressed overlapping output.
+                       LOG.warn("Matrix decompression from multiplying two 
compressed matrices.");
+                       return matrixMultiply(m1, 
CompressedMatrixBlock.getUncompressed(m2), ret, k, transposeLeft, 
transposeRight);
+               }
+               else if(transposeLeft && !transposeRight) {
+                       if(m1.getNumColumns() > m2.getNumColumns()) {
+                               ret = 
CLALibLeftMultBy.leftMultByMatrixTransposed(m1, m2, ret, k);
+                               ReorgOperator r_op = new 
ReorgOperator(SwapIndex.getSwapIndexFnObject(), k);
+                               return ret.reorgOperations(r_op, new 
MatrixBlock(), 0, 0, 0);
+                       }
+                       else
+                               return 
CLALibLeftMultBy.leftMultByMatrixTransposed(m2, m1, ret, k);
+
+               }
+               else if(!transposeLeft && transposeRight) {
+                       throw new DMLCompressionException("Not Implemented 
compressed Matrix Mult, to produce larger matrix");
+                       // worst situation since it blows up the result matrix 
in number of rows in
+                       // either compressed matrix.
+               }
+               else {
+                       ret = CLALibMatrixMult.matrixMult(m2, m1, ret, k);
+                       ReorgOperator r_op = new 
ReorgOperator(SwapIndex.getSwapIndexFnObject(), k);
+                       return ret.reorgOperations(r_op, new MatrixBlock(), 0, 
0, 0);
+               }
+       }
+
+}
diff --git 
a/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java
 
b/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java
index d310e27..5d201c0 100644
--- 
a/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java
+++ 
b/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java
@@ -621,7 +621,6 @@ public abstract class CompressedTestBase extends TestBase {
                                return; // Early termination since the test 
does not test what we wanted.
 
                        // Make Operator
-                       AggregateBinaryOperator abop = 
InstructionUtils.getMatMultOperator(_k);
                        AggregateBinaryOperator abopSingle = 
InstructionUtils.getMatMultOperator(1);
 
                        // vector-matrix uncompressed
@@ -633,7 +632,7 @@ public abstract class CompressedTestBase extends TestBase {
                        ucRet = right.aggregateBinaryOperations(left, right, 
ucRet, abopSingle);
 
                        MatrixBlock ret2 = ((CompressedMatrixBlock) 
cmb).aggregateBinaryOperations(compMatrix, cmb, new MatrixBlock(),
-                               abop, transposeLeft, transposeRight);
+                       abopSingle, transposeLeft, transposeRight);
 
                        compareResultMatrices(ucRet, ret2, 100);
                }
diff --git 
a/src/test/java/org/apache/sysds/test/component/estim/OpBindChainTest.java 
b/src/test/java/org/apache/sysds/test/component/estim/OpBindChainTest.java
index f75ba17..12c66cf 100644
--- a/src/test/java/org/apache/sysds/test/component/estim/OpBindChainTest.java
+++ b/src/test/java/org/apache/sysds/test/component/estim/OpBindChainTest.java
@@ -136,7 +136,7 @@ public class OpBindChainTest extends AutomatedTestBase
                                m2 = MatrixBlock.randOperations(n, k, sp[1], 1, 
1, "uniform", 7);
                                m1.append(m2, m3, false);
                                m4 = MatrixBlock.randOperations(k, m, sp[1], 1, 
1, "uniform", 5);
-                               m5 = m1.aggregateBinaryOperations(m3, m4, 
+                               m5 = m3.aggregateBinaryOperations(m3, m4, 
                                                new MatrixBlock(), 
InstructionUtils.getMatMultOperator(1));
                                est = estim.estim(new MMNode(new MMNode(new 
MMNode(m1), new MMNode(m2), op), new MMNode(m4), OpCode.MM)).getSparsity();
                                //System.out.println(est);
@@ -147,7 +147,7 @@ public class OpBindChainTest extends AutomatedTestBase
                                m2 = MatrixBlock.randOperations(m, n, sp[1], 1, 
1, "uniform", 7);
                                m1.append(m2, m3, true);
                                m4 = MatrixBlock.randOperations(k+n, m, sp[1], 
1, 1, "uniform", 5);
-                               m5 = m1.aggregateBinaryOperations(m3, m4, 
+                               m5 = m3.aggregateBinaryOperations(m3, m4, 
                                                new MatrixBlock(), 
InstructionUtils.getMatMultOperator(1));
                                est = estim.estim(new MMNode(new MMNode(new 
MMNode(m1), new MMNode(m2), op), new MMNode(m4), OpCode.MM)).getSparsity();
                                //System.out.println(est);
diff --git 
a/src/test/java/org/apache/sysds/test/component/estim/OpElemWChainTest.java 
b/src/test/java/org/apache/sysds/test/component/estim/OpElemWChainTest.java
index f6410e8..7a76d7a 100644
--- a/src/test/java/org/apache/sysds/test/component/estim/OpElemWChainTest.java
+++ b/src/test/java/org/apache/sysds/test/component/estim/OpElemWChainTest.java
@@ -129,7 +129,7 @@ public class OpElemWChainTest extends AutomatedTestBase
                        case MULT:
                                bOp = new 
BinaryOperator(Multiply.getMultiplyFnObject());
                                m1.binaryOperations(bOp, m2, m4);
-                               m5 = m1.aggregateBinaryOperations(m4, m3, 
+                               m5 = m4.aggregateBinaryOperations(m4, m3, 
                                                new MatrixBlock(), 
InstructionUtils.getMatMultOperator(1));
                                est = estim.estim(new MMNode(new MMNode(new 
MMNode(m1), new MMNode(m2), op), new MMNode(m3), OpCode.MM)).getSparsity();
                                // System.out.println(m5.getSparsity());
@@ -138,7 +138,7 @@ public class OpElemWChainTest extends AutomatedTestBase
                        case PLUS:
                                bOp = new 
BinaryOperator(Plus.getPlusFnObject());
                                m1.binaryOperations(bOp, m2, m4);
-                               m5 = m1.aggregateBinaryOperations(m4, m3, 
+                               m5 = m4.aggregateBinaryOperations(m4, m3, 
                                                new MatrixBlock(), 
InstructionUtils.getMatMultOperator(1));
                                est = estim.estim(new MMNode(new MMNode(new 
MMNode(m1), new MMNode(m2), op), new MMNode(m3), OpCode.MM)).getSparsity();
                                // System.out.println(m5.getSparsity());
diff --git 
a/src/test/java/org/apache/sysds/test/component/estim/SquaredProductChainTest.java
 
b/src/test/java/org/apache/sysds/test/component/estim/SquaredProductChainTest.java
index e35bc57..25cd99e 100644
--- 
a/src/test/java/org/apache/sysds/test/component/estim/SquaredProductChainTest.java
+++ 
b/src/test/java/org/apache/sysds/test/component/estim/SquaredProductChainTest.java
@@ -132,7 +132,7 @@ public class SquaredProductChainTest extends 
AutomatedTestBase
                MatrixBlock m3 = MatrixBlock.randOperations(n, n2, sp[2], 1, 1, 
"uniform", 3);
                MatrixBlock m4 = m1.aggregateBinaryOperations(m1, m2, 
                        new MatrixBlock(), 
InstructionUtils.getMatMultOperator(1));
-               MatrixBlock m5 = m1.aggregateBinaryOperations(m4, m3, 
+               MatrixBlock m5 = m4.aggregateBinaryOperations(m4, m3, 
                        new MatrixBlock(), 
InstructionUtils.getMatMultOperator(1));
                
                //compare estimated and real sparsity

Reply via email to