This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new fb60577586 [MINOR] Parallel Compressed LMM
fb60577586 is described below
commit fb605775865d2ec0fbcc3aff81975576f8baa5e1
Author: Sebastian Baunsgaard <[email protected]>
AuthorDate: Mon Oct 30 15:05:17 2023 +0100
[MINOR] Parallel Compressed LMM
---
.../runtime/compress/lib/CLALibLeftMultBy.java | 96 ++++++++++++++++++++--
.../sysds/runtime/compress/lib/CLALibMMChain.java | 42 ++++++++++
.../runtime/compress/lib/CLALibRightMultBy.java | 4 +-
3 files changed, 133 insertions(+), 9 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
index 6029a87d46..30c1109d3a 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
@@ -32,11 +32,14 @@ import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
+import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.APreAgg;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.functionobjects.Plus;
+import org.apache.sysds.runtime.matrix.data.LibMatrixBincell;
+import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
@@ -45,7 +48,7 @@ import org.apache.sysds.runtime.util.CommonThreadPool;
public final class CLALibLeftMultBy {
private static final Log LOG =
LogFactory.getLog(CLALibLeftMultBy.class.getName());
- private CLALibLeftMultBy(){
+ private CLALibLeftMultBy() {
// private constructor
}
@@ -139,7 +142,15 @@ public final class CLALibLeftMultBy {
}
private static MatrixBlock
leftMultByCompressedTransposedMatrix(CompressedMatrixBlock right,
- CompressedMatrixBlock left, MatrixBlock ret, int k) {
+ CompressedMatrixBlock left, final MatrixBlock ret, int k) {
+ if(k > 1 && ret.getInMemorySize() < 1000000)
+ return
leftMultByCompressedTransposedMatrixParallel(right, left, ret, k);
+ else
+ return
leftMultByCompressedTransposedMatrixSingleThread(right, left, ret);
+ }
+
+ private static MatrixBlock
leftMultByCompressedTransposedMatrixParallel(CompressedMatrixBlock right,
+ CompressedMatrixBlock left, final MatrixBlock ret, int k) {
final int sd = right.getNumRows(); // shared dim
final int cr = right.getNumColumns();
@@ -149,18 +160,88 @@ public final class CLALibLeftMultBy {
final List<AColGroup> leftCG = left.getColGroups();
final boolean containsRight =
CLALibUtils.shouldPreFilter(rightCG);
- double[] cR = containsRight ? new double[cr] : null;
+ final double[] cR = containsRight ? new double[cr] : null;
final List<AColGroup> fRight =
CLALibUtils.filterGroups(rightCG, cR);
final boolean containsLeft =
CLALibUtils.shouldPreFilter(leftCG);
- double[] cL = containsLeft ? new double[rl] : null;
+ final double[] cL = containsLeft ? new double[rl] : null;
final List<AColGroup> fLeft = CLALibUtils.filterGroups(leftCG,
cL);
+ // Force dense output
+ ret.setNonZeros((long) ret.getNumRows() * ret.getNumColumns());
+ ret.allocateDenseBlock();
+
+ final ExecutorService ex = CommonThreadPool.get(k);
+ final List<Future<MatrixBlock>> t = new ArrayList<>();
+
+ for(int j = 0; j < fLeft.size(); j++) {
+ final int jj = j;
+ t.add(ex.submit(() -> {
+ MatrixBlock retT = new
MatrixBlock(ret.getNumRows(), ret.getNumColumns(), false);
+ retT.allocateDenseBlock();
+ for(int i = 0; i < fRight.size(); i++) {
+
fRight.get(i).leftMultByAColGroup(fLeft.get(jj), retT, sd);
+ }
+ retT.examSparsity(true);
+ return retT;
+ }));
+ }
+
+ try {
+ final double[] retV = ret.getDenseBlockValues();
+ if(containsLeft && containsRight)
+ // if both -- multiply the left and right
vectors scaling by number of shared dim
+ outerProductWithScaling(cL, cR, sd, retV);
+ if(containsLeft) // if left -- multiply left with right
sum
+ outerProduct(cL, CLALibUtils.getColSum(fRight,
cr, sd), retV);
+ if(containsRight)// if right -- multiply right with
left sum
+ outerProduct(CLALibUtils.getColSum(fLeft, rl,
sd), cR, retV);
+ for(Future<MatrixBlock> f : t) {
+ MatrixBlock mb = f.get();
+ if(!mb.isEmpty()) {
+ if(mb.isInSparseFormat())
+
LibMatrixBincell.bincellOpInPlaceRight(ret, mb, new
BinaryOperator(Plus.getPlusFnObject()));
+ else
if(mb.getDenseBlock().isContiguous())
+
LibMatrixMult.vectAdd(mb.getDenseBlockValues(), retV, 0, 0, retV.length);
+ else
+
LibMatrixBincell.bincellOpInPlaceRight(ret, mb, new
BinaryOperator(Plus.getPlusFnObject()));
+ }
+ }
+ ret.recomputeNonZeros(k);
+ }
+ catch(Exception e) {
+ throw new DMLCompressionException("Failed parallel Left
Compressed Mult", e);
+ }
+ finally {
+ ex.shutdown();
+ }
+ return ret;
+ }
+
+ private static MatrixBlock
leftMultByCompressedTransposedMatrixSingleThread(CompressedMatrixBlock right,
+ CompressedMatrixBlock left, final MatrixBlock ret) {
+ final int sd = right.getNumRows(); // shared dim
+ final int cr = right.getNumColumns();
+ final int rl = left.getNumColumns();
+
+ final List<AColGroup> rightCG = right.getColGroups();
+ final List<AColGroup> leftCG = left.getColGroups();
+
+ final boolean containsRight =
CLALibUtils.shouldPreFilter(rightCG);
+ final double[] cR = containsRight ? new double[cr] : null;
+ final List<AColGroup> fRight =
CLALibUtils.filterGroups(rightCG, cR);
+
+ final boolean containsLeft =
CLALibUtils.shouldPreFilter(leftCG);
+ final double[] cL = containsLeft ? new double[rl] : null;
+ final List<AColGroup> fLeft = CLALibUtils.filterGroups(leftCG,
cL);
+
+ // Force dense output
+ ret.setNonZeros((long) ret.getNumRows() * ret.getNumColumns());
+ ret.allocateDenseBlock();
for(int j = 0; j < fLeft.size(); j++)
for(int i = 0; i < fRight.size(); i++)
fRight.get(i).leftMultByAColGroup(fLeft.get(j),
ret, sd);
-
- double[] retV = ret.getDenseBlockValues();
+ final double[] retV = ret.getDenseBlockValues();
if(containsLeft && containsRight)
// if both -- multiply the left and right vectors
scaling by number of shared dim
outerProductWithScaling(cL, cR, sd, retV);
@@ -169,7 +250,6 @@ public final class CLALibLeftMultBy {
if(containsRight)// if right -- multiply right with left sum
outerProduct(CLALibUtils.getColSum(fLeft, rl, sd), cR,
retV);
ret.recomputeNonZeros();
-
return ret;
}
@@ -218,7 +298,7 @@ public final class CLALibLeftMultBy {
LMMParallel(noPreAggGroups, preAggGroups, that,
ret, null, overlapping, k);
}
- ret.recomputeNonZeros();
+ ret.recomputeNonZeros(k);
ret.examSparsity();
return ret;
}
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java
index bc164a5e91..060c736871 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java
@@ -35,6 +35,21 @@ import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
+/**
+ * Support compressed MM chain operation to fuse the following cases :
+ *
+ * <p>
+ * XtXv == (t(X) %*% (X %*% v))
+ * </p>
+ *
+ * <p>
+ * XtwXv == (t(X) %*% (w * (X %*% v)))
+ * </p>
+ *
+ * <p>
+ * XtXvy == (t(X) %*% ((X %*% v) - y))
+ * </p>
+ */
public final class CLALibMMChain {
static final Log LOG = LogFactory.getLog(CLALibMMChain.class.getName());
@@ -42,6 +57,33 @@ public final class CLALibMMChain {
// private constructor
}
+ /**
+ * Support compressed MM chain operation to fuse the following cases :
+ *
+ * <p>
+ * XtXv == (t(X) %*% (X %*% v))
+ * </p>
+ *
+ * <p>
+ * XtwXv == (t(X) %*% (w * (X %*% v)))
+ * </p>
+ *
+ * <p>
+ * XtXvy == (t(X) %*% ((X %*% v) - y))
+ * </p>
+ *
+ * Note the point of this optimization is that v and w always are
vectors. This means in practice the all the compute
+ * is faster if the intermediates are exploited.
+ *
+ *
+ * @param x Is the X part of the chain optimized kernel
+ * @param v Is the mandatory v part of the chain
+ * @param w Is the optional w port of t the chain
+ * @param out The output to put the result into. Can also be returned
and in some cases will not be used.
+ * @param ctype either XtwXv, XtXv or XtXvy
+ * @param k the parallelization degree
+ * @return The result either in the given output or a new allocation
+ */
public static MatrixBlock mmChain(CompressedMatrixBlock x, MatrixBlock
v, MatrixBlock w, MatrixBlock out,
ChainType ctype, int k) {
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java
index 39468b0cab..2eef5f9f3f 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java
@@ -243,7 +243,9 @@ public final class CLALibRightMultBy {
catch(InterruptedException | ExecutionException e) {
throw new DMLRuntimeException(e);
}
- pool.shutdown();
+ finally{
+ pool.shutdown();
+ }
return containsNull;
}