This is an automated email from the ASF dual-hosted git repository. baunsgaard pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
commit a98560cb306012771dce215bda150a89dd9bf482 Author: baunsgaard <[email protected]> AuthorDate: Sun Jan 16 14:02:08 2022 +0100 [SYSTEMDS-3243] Compressed Matrix Multiplication part This commit follow the previous by modifying the compression tests and compression path for Matrix Multiplcation to fit with the design of the normal MatrixBlock. Closes #1480 --- .../runtime/compress/CompressedMatrixBlock.java | 105 +---------------- .../runtime/compress/lib/CLALibMatrixMult.java | 128 +++++++++++++++++++++ .../component/compress/CompressedTestBase.java | 3 +- .../test/component/estim/OpBindChainTest.java | 4 +- .../test/component/estim/OpElemWChainTest.java | 4 +- .../component/estim/SquaredProductChainTest.java | 2 +- 6 files changed, 140 insertions(+), 106 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java index 85cc23b..ec09226 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java +++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java @@ -51,8 +51,8 @@ import org.apache.sysds.runtime.compress.lib.CLALibCompAgg; import org.apache.sysds.runtime.compress.lib.CLALibDecompress; import org.apache.sysds.runtime.compress.lib.CLALibLeftMultBy; import org.apache.sysds.runtime.compress.lib.CLALibMMChain; +import org.apache.sysds.runtime.compress.lib.CLALibMatrixMult; import org.apache.sysds.runtime.compress.lib.CLALibReExpand; -import org.apache.sysds.runtime.compress.lib.CLALibRightMultBy; import org.apache.sysds.runtime.compress.lib.CLALibScalar; import org.apache.sysds.runtime.compress.lib.CLALibSlice; import org.apache.sysds.runtime.compress.lib.CLALibSquash; @@ -61,13 +61,11 @@ import org.apache.sysds.runtime.compress.lib.CLALibUtils; import org.apache.sysds.runtime.controlprogram.caching.CacheBlock; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType; import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; -import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseRow; import org.apache.sysds.runtime.functionobjects.MinusMultiply; import org.apache.sysds.runtime.functionobjects.PlusMultiply; -import org.apache.sysds.runtime.functionobjects.SwapIndex; import org.apache.sysds.runtime.functionobjects.TernaryValueFunction.ValueFunctionWithConstant; import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.instructions.cp.CM_COV_Object; @@ -76,7 +74,6 @@ import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; import org.apache.sysds.runtime.matrix.data.CTableMap; import org.apache.sysds.runtime.matrix.data.IJV; import org.apache.sysds.runtime.matrix.data.LibMatrixDatagen; -import org.apache.sysds.runtime.matrix.data.LibMatrixReorg; import org.apache.sysds.runtime.matrix.data.LibMatrixTercell; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.data.MatrixIndexes; @@ -471,105 +468,15 @@ public class CompressedMatrixBlock extends MatrixBlock { } @Override - public MatrixBlock aggregateBinaryOperations(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, - AggregateBinaryOperator op) { - // create output matrix block - return aggregateBinaryOperations(m1, m2, ret, op, false, false); + public MatrixBlock aggregateBinaryOperations(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, AggregateBinaryOperator op) { + checkAggregateBinaryOperations(m1, m2, op); + return CLALibMatrixMult.matrixMultiply(m1, m2, ret, op.getNumThreads(), false, false); } public MatrixBlock aggregateBinaryOperations(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, AggregateBinaryOperator op, boolean transposeLeft, boolean transposeRight) { - validateMatrixMult(m1, m2); - final int k = op.getNumThreads(); - final Timing time = LOG.isTraceEnabled() ? new Timing(true) : null; - - if(m1 instanceof CompressedMatrixBlock && m2 instanceof CompressedMatrixBlock) { - return doubleCompressedAggregateBinaryOperations((CompressedMatrixBlock) m1, (CompressedMatrixBlock) m2, ret, - op, transposeLeft, transposeRight); - } - boolean transposeOutput = false; - if(transposeLeft || transposeRight) { - - if((m1 instanceof CompressedMatrixBlock && transposeLeft) || - (m2 instanceof CompressedMatrixBlock && transposeRight)) { - // change operation from m1 %*% m2 -> t( t(m2) %*% t(m1) ) - transposeOutput = true; - MatrixBlock tmp = m1; - m1 = m2; - m2 = tmp; - boolean tmpLeft = transposeLeft; - transposeLeft = !transposeRight; - transposeRight = !tmpLeft; - - } - - if(!(m1 instanceof CompressedMatrixBlock) && transposeLeft) { - m1 = LibMatrixReorg.transpose(m1, k); - transposeLeft = false; - } - else if(!(m2 instanceof CompressedMatrixBlock) && transposeRight) { - m2 = LibMatrixReorg.transpose(m2, k); - transposeRight = false; - } - } - - final boolean right = (m1 == this); - final MatrixBlock that = right ? m2 : m1; - - // create output matrix block - if(right) - ret = CLALibRightMultBy.rightMultByMatrix(this, that, ret, op.getNumThreads()); - else - ret = CLALibLeftMultBy.leftMultByMatrix(this, that, ret, op.getNumThreads()); - - if(LOG.isTraceEnabled()) - LOG.trace("MM: Time block w/ sharedDim: " + m1.getNumColumns() + " rowLeft: " + m1.getNumRows() + " colRight:" - + m2.getNumColumns() + " in " + time.stop() + "ms."); - - if(transposeOutput) { - if(ret instanceof CompressedMatrixBlock) { - LOG.warn("Transposing decompression"); - ret = ((CompressedMatrixBlock) ret).decompress(k); - } - ret = LibMatrixReorg.transpose(ret, k); - } - - return ret; - } - - private void validateMatrixMult(MatrixBlock m1, MatrixBlock m2) { - if(!(m1 == this || m2 == this)) - throw new DMLRuntimeException("Invalid aggregateBinaryOperation One of either input should be this"); - } - - private MatrixBlock doubleCompressedAggregateBinaryOperations(CompressedMatrixBlock m1, CompressedMatrixBlock m2, - MatrixBlock ret, AggregateBinaryOperator op, boolean transposeLeft, boolean transposeRight) { - if(!transposeLeft && !transposeRight) { - // If both are not transposed, decompress the right hand side. to enable - // compressed overlapping output. - LOG.warn("Matrix decompression from multiplying two compressed matrices."); - return aggregateBinaryOperations(m1, getUncompressed(m2), ret, op, transposeLeft, transposeRight); - } - else if(transposeLeft && !transposeRight) { - if(m1.getNumColumns() > m2.getNumColumns()) { - ret = CLALibLeftMultBy.leftMultByMatrixTransposed(m1, m2, ret, op.getNumThreads()); - ReorgOperator r_op = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), op.getNumThreads()); - return ret.reorgOperations(r_op, new MatrixBlock(), 0, 0, 0); - } - else - return CLALibLeftMultBy.leftMultByMatrixTransposed(m2, m1, ret, op.getNumThreads()); - - } - else if(!transposeLeft && transposeRight) { - throw new DMLCompressionException("Not Implemented compressed Matrix Mult, to produce larger matrix"); - // worst situation since it blows up the result matrix in number of rows in - // either compressed matrix. - } - else { - ret = aggregateBinaryOperations(m2, m1, ret, op); - ReorgOperator r_op = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), op.getNumThreads()); - return ret.reorgOperations(r_op, new MatrixBlock(), 0, 0, 0); - } + checkAggregateBinaryOperations(m1, m2, op, transposeLeft, transposeRight); + return CLALibMatrixMult.matrixMultiply(m1, m2, ret, op.getNumThreads(), transposeLeft, transposeRight); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMatrixMult.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMatrixMult.java new file mode 100644 index 0000000..941338d --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMatrixMult.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.compress.lib; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.runtime.compress.CompressedMatrixBlock; +import org.apache.sysds.runtime.compress.DMLCompressionException; +import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing; +import org.apache.sysds.runtime.functionobjects.SwapIndex; +import org.apache.sysds.runtime.matrix.data.LibMatrixReorg; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.ReorgOperator; + +public class CLALibMatrixMult { + private static final Log LOG = LogFactory.getLog(CLALibMatrixMult.class.getName()); + + public static MatrixBlock matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k) { + return matrixMultiply(m1, m2, ret, k, false, false); + } + + public static MatrixBlock matrixMultiply(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, + int k, boolean transposeLeft, boolean transposeRight) { + final Timing time = LOG.isTraceEnabled() ? new Timing(true) : null; + + if(m1 instanceof CompressedMatrixBlock && m2 instanceof CompressedMatrixBlock) { + return doubleCompressedMatrixMultiply((CompressedMatrixBlock) m1, (CompressedMatrixBlock) m2, ret, + k, transposeLeft, transposeRight); + } + + boolean transposeOutput = false; + if(transposeLeft || transposeRight) { + + if((m1 instanceof CompressedMatrixBlock && transposeLeft) || + (m2 instanceof CompressedMatrixBlock && transposeRight)) { + // change operation from m1 %*% m2 -> t( t(m2) %*% t(m1) ) + transposeOutput = true; + MatrixBlock tmp = m1; + m1 = m2; + m2 = tmp; + boolean tmpLeft = transposeLeft; + transposeLeft = !transposeRight; + transposeRight = !tmpLeft; + } + + if(!(m1 instanceof CompressedMatrixBlock) && transposeLeft) { + m1 = LibMatrixReorg.transpose(m1, k); + transposeLeft = false; + } + else if(!(m2 instanceof CompressedMatrixBlock) && transposeRight) { + m2 = LibMatrixReorg.transpose(m2, k); + transposeRight = false; + } + } + + final boolean right = (m1 instanceof CompressedMatrixBlock); + final CompressedMatrixBlock c =(CompressedMatrixBlock) (right ? m1 : m2); + final MatrixBlock that = right ? m2 : m1; + + // create output matrix block + if(right) + ret = CLALibRightMultBy.rightMultByMatrix(c, that, ret, k); + else + ret = CLALibLeftMultBy.leftMultByMatrix(c, that, ret, k); + + if(LOG.isTraceEnabled()) + LOG.trace("MM: Time block w/ sharedDim: " + m1.getNumColumns() + " rowLeft: " + m1.getNumRows() + " colRight:" + + m2.getNumColumns() + " in " + time.stop() + "ms."); + + if(transposeOutput) { + if(ret instanceof CompressedMatrixBlock) { + LOG.warn("Transposing decompression"); + ret = ((CompressedMatrixBlock) ret).decompress(k); + } + ret = LibMatrixReorg.transpose(ret, k); + } + + return ret; + } + + private static MatrixBlock doubleCompressedMatrixMultiply(CompressedMatrixBlock m1, CompressedMatrixBlock m2, + MatrixBlock ret, int k, boolean transposeLeft, boolean transposeRight) { + if(!transposeLeft && !transposeRight) { + // If both are not transposed, decompress the right hand side. to enable + // compressed overlapping output. + LOG.warn("Matrix decompression from multiplying two compressed matrices."); + return matrixMultiply(m1, CompressedMatrixBlock.getUncompressed(m2), ret, k, transposeLeft, transposeRight); + } + else if(transposeLeft && !transposeRight) { + if(m1.getNumColumns() > m2.getNumColumns()) { + ret = CLALibLeftMultBy.leftMultByMatrixTransposed(m1, m2, ret, k); + ReorgOperator r_op = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), k); + return ret.reorgOperations(r_op, new MatrixBlock(), 0, 0, 0); + } + else + return CLALibLeftMultBy.leftMultByMatrixTransposed(m2, m1, ret, k); + + } + else if(!transposeLeft && transposeRight) { + throw new DMLCompressionException("Not Implemented compressed Matrix Mult, to produce larger matrix"); + // worst situation since it blows up the result matrix in number of rows in + // either compressed matrix. + } + else { + ret = CLALibMatrixMult.matrixMult(m2, m1, ret, k); + ReorgOperator r_op = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), k); + return ret.reorgOperations(r_op, new MatrixBlock(), 0, 0, 0); + } + } + +} diff --git a/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java b/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java index d310e27..5d201c0 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java +++ b/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java @@ -621,7 +621,6 @@ public abstract class CompressedTestBase extends TestBase { return; // Early termination since the test does not test what we wanted. // Make Operator - AggregateBinaryOperator abop = InstructionUtils.getMatMultOperator(_k); AggregateBinaryOperator abopSingle = InstructionUtils.getMatMultOperator(1); // vector-matrix uncompressed @@ -633,7 +632,7 @@ public abstract class CompressedTestBase extends TestBase { ucRet = right.aggregateBinaryOperations(left, right, ucRet, abopSingle); MatrixBlock ret2 = ((CompressedMatrixBlock) cmb).aggregateBinaryOperations(compMatrix, cmb, new MatrixBlock(), - abop, transposeLeft, transposeRight); + abopSingle, transposeLeft, transposeRight); compareResultMatrices(ucRet, ret2, 100); } diff --git a/src/test/java/org/apache/sysds/test/component/estim/OpBindChainTest.java b/src/test/java/org/apache/sysds/test/component/estim/OpBindChainTest.java index f75ba17..12c66cf 100644 --- a/src/test/java/org/apache/sysds/test/component/estim/OpBindChainTest.java +++ b/src/test/java/org/apache/sysds/test/component/estim/OpBindChainTest.java @@ -136,7 +136,7 @@ public class OpBindChainTest extends AutomatedTestBase m2 = MatrixBlock.randOperations(n, k, sp[1], 1, 1, "uniform", 7); m1.append(m2, m3, false); m4 = MatrixBlock.randOperations(k, m, sp[1], 1, 1, "uniform", 5); - m5 = m1.aggregateBinaryOperations(m3, m4, + m5 = m3.aggregateBinaryOperations(m3, m4, new MatrixBlock(), InstructionUtils.getMatMultOperator(1)); est = estim.estim(new MMNode(new MMNode(new MMNode(m1), new MMNode(m2), op), new MMNode(m4), OpCode.MM)).getSparsity(); //System.out.println(est); @@ -147,7 +147,7 @@ public class OpBindChainTest extends AutomatedTestBase m2 = MatrixBlock.randOperations(m, n, sp[1], 1, 1, "uniform", 7); m1.append(m2, m3, true); m4 = MatrixBlock.randOperations(k+n, m, sp[1], 1, 1, "uniform", 5); - m5 = m1.aggregateBinaryOperations(m3, m4, + m5 = m3.aggregateBinaryOperations(m3, m4, new MatrixBlock(), InstructionUtils.getMatMultOperator(1)); est = estim.estim(new MMNode(new MMNode(new MMNode(m1), new MMNode(m2), op), new MMNode(m4), OpCode.MM)).getSparsity(); //System.out.println(est); diff --git a/src/test/java/org/apache/sysds/test/component/estim/OpElemWChainTest.java b/src/test/java/org/apache/sysds/test/component/estim/OpElemWChainTest.java index f6410e8..7a76d7a 100644 --- a/src/test/java/org/apache/sysds/test/component/estim/OpElemWChainTest.java +++ b/src/test/java/org/apache/sysds/test/component/estim/OpElemWChainTest.java @@ -129,7 +129,7 @@ public class OpElemWChainTest extends AutomatedTestBase case MULT: bOp = new BinaryOperator(Multiply.getMultiplyFnObject()); m1.binaryOperations(bOp, m2, m4); - m5 = m1.aggregateBinaryOperations(m4, m3, + m5 = m4.aggregateBinaryOperations(m4, m3, new MatrixBlock(), InstructionUtils.getMatMultOperator(1)); est = estim.estim(new MMNode(new MMNode(new MMNode(m1), new MMNode(m2), op), new MMNode(m3), OpCode.MM)).getSparsity(); // System.out.println(m5.getSparsity()); @@ -138,7 +138,7 @@ public class OpElemWChainTest extends AutomatedTestBase case PLUS: bOp = new BinaryOperator(Plus.getPlusFnObject()); m1.binaryOperations(bOp, m2, m4); - m5 = m1.aggregateBinaryOperations(m4, m3, + m5 = m4.aggregateBinaryOperations(m4, m3, new MatrixBlock(), InstructionUtils.getMatMultOperator(1)); est = estim.estim(new MMNode(new MMNode(new MMNode(m1), new MMNode(m2), op), new MMNode(m3), OpCode.MM)).getSparsity(); // System.out.println(m5.getSparsity()); diff --git a/src/test/java/org/apache/sysds/test/component/estim/SquaredProductChainTest.java b/src/test/java/org/apache/sysds/test/component/estim/SquaredProductChainTest.java index e35bc57..25cd99e 100644 --- a/src/test/java/org/apache/sysds/test/component/estim/SquaredProductChainTest.java +++ b/src/test/java/org/apache/sysds/test/component/estim/SquaredProductChainTest.java @@ -132,7 +132,7 @@ public class SquaredProductChainTest extends AutomatedTestBase MatrixBlock m3 = MatrixBlock.randOperations(n, n2, sp[2], 1, 1, "uniform", 3); MatrixBlock m4 = m1.aggregateBinaryOperations(m1, m2, new MatrixBlock(), InstructionUtils.getMatMultOperator(1)); - MatrixBlock m5 = m1.aggregateBinaryOperations(m4, m3, + MatrixBlock m5 = m4.aggregateBinaryOperations(m4, m3, new MatrixBlock(), InstructionUtils.getMatMultOperator(1)); //compare estimated and real sparsity
