This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 15bef0394d [SYSTEMDS-3636] Fix new ultra-sparse tsmm right, and new
tests
15bef0394d is described below
commit 15bef0394d71eefee98769a1f50eb3dadb336aaa
Author: Matthias Boehm <[email protected]>
AuthorDate: Fri Oct 27 14:28:01 2023 +0200
[SYSTEMDS-3636] Fix new ultra-sparse tsmm right, and new tests
---
.../sysds/runtime/matrix/data/LibMatrixMult.java | 21 ++++-----
.../FullMatrixMultiplicationTransposeSelfTest.java | 55 ++++++++++------------
2 files changed, 35 insertions(+), 41 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
index 0f08176fe1..0b8bd216f4 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
@@ -435,9 +435,7 @@ public class LibMatrixMult
//Timing time = new Timing(true);
//pre-processing
- double sp = m1.getSparsity();
- double osp = OptimizerUtils.getMatMultSparsity(sp, sp, m1.rlen,
m1.clen, m1.rlen, false);
- ret.sparse = !leftTranspose && m1.sparse && osp <
MatrixBlock.SPARSITY_TURN_POINT;
+ ret.sparse = isSparseOutputTSMM(m1, leftTranspose);
ret.allocateBlock();
//core tsmm operation
@@ -477,10 +475,7 @@ public class LibMatrixMult
//Timing time = new Timing(true);
//pre-processing (no need to check isThreadSafe)
- double sp = m1.getSparsity();
- ret.sparse = !leftTranspose && m1.sparse &&
MatrixBlock.SPARSITY_TURN_POINT >
- OptimizerUtils.getMatMultSparsity(sp, sp, m1.rlen,
m1.clen, m1.rlen, false);
-
+ ret.sparse = isSparseOutputTSMM(m1, leftTranspose);
ret.allocateBlock();
//core multi-threaded matrix mult computation
@@ -3549,7 +3544,7 @@ public class LibMatrixMult
double v = 0;
while( k<asize & k2<bsize ) {
int aixk = aix[k];
- int bixk = aix[k2];
+ int bixk = bix[k2];
if( aixk < bixk )
k++;
else if( aixk > bixk )
@@ -4203,9 +4198,7 @@ public class LibMatrixMult
MatrixBlock ret = m1;
final int rlen = m1.rlen;
final int clen = m1.clen;
- double sp = m1.getSparsity();
- double osp = OptimizerUtils.getMatMultSparsity(sp, sp, m1.rlen,
m1.clen, m1.rlen, false);
- boolean retSparse = !leftTranspose && m1.sparse && osp <
MatrixBlock.SPARSITY_TURN_POINT;
+ boolean retSparse = isSparseOutputTSMM(m1, leftTranspose);
if( !leftTranspose && !retSparse && m1.sparse && rlen > 1) {
//X%*%t(X) SPARSE MATRIX
//directly via LibMatrixReorg in order to prevent
sparsity change
@@ -4323,6 +4316,12 @@ public class LibMatrixMult
boolean sparseOut =
MatrixBlock.evalSparseFormatInMemory(m1.rlen, m2.clen, estNnz);
return m2.clen < 4*1024 && sparseOut;
}
+
+ public static boolean isSparseOutputTSMM(MatrixBlock m1, boolean
leftTranspose) {
+ double sp = m1.getSparsity();
+ double osp = OptimizerUtils.getMatMultSparsity(sp, sp, m1.rlen,
m1.clen, m1.rlen, false);
+ return !leftTranspose && m1.sparse && osp <
MatrixBlock.ULTRA_SPARSITY_TURN_POINT2;
+ }
public static boolean isOuterProductTSMM(int rlen, int clen, boolean
left) {
return left ? rlen == 1 & clen > 1 : rlen > 1 & clen == 1;
diff --git
a/src/test/java/org/apache/sysds/test/functions/binary/matrix_full_other/FullMatrixMultiplicationTransposeSelfTest.java
b/src/test/java/org/apache/sysds/test/functions/binary/matrix_full_other/FullMatrixMultiplicationTransposeSelfTest.java
index 1599bdac86..3d1b4c239f 100644
---
a/src/test/java/org/apache/sysds/test/functions/binary/matrix_full_other/FullMatrixMultiplicationTransposeSelfTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/binary/matrix_full_other/FullMatrixMultiplicationTransposeSelfTest.java
@@ -22,11 +22,16 @@ package
org.apache.sysds.test.functions.binary.matrix_full_other;
import java.util.HashMap;
import org.junit.AfterClass;
+import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.lops.MMTSJ.MMTSJType;
+import
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
+import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
+import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
@@ -268,35 +273,25 @@ public class FullMatrixMultiplicationTransposeSelfTest
extends AutomatedTestBase
private void runTransposeSelfUltraSparseTest( MMTSJType type )
{
- //rtplatform for MR
- ExecMode platformOld = rtplatform;
- rtplatform = ExecMode.SINGLE_NODE;
-
- try {
- loadTestConfiguration(getTestConfiguration(TEST_NAME2));
- int dim = 10000;
-
- String HOME = SCRIPT_DIR + TEST_DIR;
- fullDMLScriptName = HOME + TEST_NAME2 + ".dml";
- programArgs = new String[]{"-stats","-args", input("A"),
- String.valueOf(dim), String.valueOf(dim),
output("B") };
- fullRScriptName = HOME + TEST_NAME2 + ".R";
- rCmd = "Rscript" + " " + fullRScriptName + " " +
inputDir() + " " + expectedDir();
-
- //generate actual dataset
- double[][] A = getRandomMatrix(dim, dim, 0, 1, 0.0002,
7);
- writeInputMatrix("A", A, true);
-
- runTest(true, false, null, -1);
- //runRScript(true);
-
- //compare matrices
- //HashMap<CellIndex, Double> dmlfile =
readDMLMatrixFromOutputDir("B");
- //HashMap<CellIndex, Double> rfile =
readRMatrixFromExpectedDir("B");
- //TestUtils.compareMatrices(dmlfile, rfile, eps,
"Stat-DML", "Stat-R");
- }
- finally {
- rtplatform = platformOld;
- }
+ //compare sparse tsmm and gemm directly to avoid unnecessary
overhead (e.g., R)
+ int dim = 10000;
+
+ MatrixBlock G = MatrixBlock.randOperations(dim, dim, 0.0002, 0,
1, "uniform", 7);
+ MatrixBlock Gt = LibMatrixReorg.transpose(G);
+ MatrixBlock Gtt = LibMatrixReorg.transpose(Gt);
+ TestUtils.compareMatrices(G, Gtt, 1e-16);
+
+ //single-threaded core operations
+ MatrixBlock R11 = G.transposeSelfMatrixMultOperations(new
MatrixBlock(), MMTSJType.RIGHT);
+ MatrixBlock R12 = LibMatrixMult.matrixMult(G, Gt);
+ Assert.assertEquals(R11.getNonZeros(), R12.getNonZeros());
+ TestUtils.compareMatrices(R11, R12, 1e-8);
+
+ //multi-threaded core operations
+ int k = InfrastructureAnalyzer.getLocalParallelism();
+ MatrixBlock R21 = G.transposeSelfMatrixMultOperations(new
MatrixBlock(), MMTSJType.RIGHT, k);
+ MatrixBlock R22 = LibMatrixMult.matrixMult(G, Gt, k);
+ Assert.assertEquals(R21.getNonZeros(), R22.getNonZeros());
+ TestUtils.compareMatrices(R21, R22, 1e-8);
}
}