This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 7c504d0d18 [SYSTEMDS-3914] New out-of-core transpose-self matmult
instruction
7c504d0d18 is described below
commit 7c504d0d184b4a1f3da9d433db062ff994e8fa6a
Author: Janardhan Pulivarthi <[email protected]>
AuthorDate: Wed Sep 3 22:11:03 2025 +0530
[SYSTEMDS-3914] New out-of-core transpose-self matmult instruction
Closes #2323.
---
.../java/org/apache/sysds/hops/AggBinaryOp.java | 10 +-
.../runtime/instructions/OOCInstructionParser.java | 3 +
.../runtime/instructions/ooc/OOCInstruction.java | 2 +-
.../instructions/ooc/TSMMOOCInstruction.java | 96 +++++++++++++++++
.../test/functions/ooc/TransposeSelfMMTest.java | 117 +++++++++++++++++++++
src/test/scripts/functions/ooc/TSMM.dml | 28 +++++
6 files changed, 246 insertions(+), 10 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
index fb20cc41d0..8685524a3f 100644
--- a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
@@ -179,7 +179,7 @@ public class AggBinaryOp extends MultiThreadedHop {
et = ExecType.CP;
}
- if (et == ExecType.CP || et == ExecType.GPU || et ==
ExecType.FED) {
+ if (et == ExecType.CP || et == ExecType.GPU || et ==
ExecType.FED || et == ExecType.OOC) {
//matrix mult operation selection part 3 (CP
type)
_method =
optFindMMultMethodCP(input1.getDim1(), input1.getDim2(),
input2.getDim1(),
input2.getDim2(), mmtsj, chain, _hasLeftPMInput);
@@ -240,14 +240,6 @@ public class AggBinaryOp extends MultiThreadedHop {
default:
throw new
HopsException(this.printErrorLocation() + "Invalid Matrix Mult Method (" +
_method + ") while constructing SPARK lops.");
}
- } else if (et == ExecType.OOC) {
- Lop in1 = getInput().get(0).constructLops();
- Lop in2 = getInput().get(1).constructLops();
- MatMultCP matmult = new MatMultCP(in1, in2,
getDataType(), getValueType(),
- et,
OptimizerUtils.getConstrainedNumThreads(_maxNumThreads));
- setOutputDimensions(matmult);
- setLineNumbers(matmult);
- setLops(matmult);
}
} else
throw new HopsException(this.printErrorLocation() +
"Invalid operation in AggBinary Hop, aggBin(" + innerOp + "," + outerOp + ")
while constructing lops.");
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
index 73b5ca0261..e898800aba 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
@@ -27,6 +27,7 @@ import
org.apache.sysds.runtime.instructions.ooc.AggregateUnaryOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.BinaryOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.OOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.ReblockOOCInstruction;
+import org.apache.sysds.runtime.instructions.ooc.TSMMOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.UnaryOOCInstruction;
import
org.apache.sysds.runtime.instructions.ooc.MatrixVectorBinaryOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.TransposeOOCInstruction;
@@ -61,6 +62,8 @@ public class OOCInstructionParser extends InstructionParser {
case AggregateBinary:
case MAPMM:
return
MatrixVectorBinaryOOCInstruction.parseInstruction(str);
+ case MMTSJ:
+ return TSMMOOCInstruction.parseInstruction(str);
case Reorg:
return
TransposeOOCInstruction.parseInstruction(str);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
index 2e5e6f41eb..95bc188dc1 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
@@ -33,7 +33,7 @@ public abstract class OOCInstruction extends Instruction {
protected static final Log LOG =
LogFactory.getLog(OOCInstruction.class.getName());
public enum OOCType {
- Reblock, AggregateUnary, Binary, Unary, MAPMM, Reorg,
AggregateBinary
+ Reblock, AggregateUnary, Binary, Unary, MAPMM, Reorg,
AggregateBinary, MMTSJ
}
protected final OOCInstruction.OOCType _ooctype;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TSMMOOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TSMMOOCInstruction.java
new file mode 100644
index 0000000000..b3f302c204
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TSMMOOCInstruction.java
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.ooc;
+
+import org.apache.sysds.common.Opcodes;
+import org.apache.sysds.lops.MMTSJ;
+import org.apache.sysds.lops.MMTSJ.MMTSJType;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
+import org.apache.sysds.runtime.functionobjects.Multiply;
+import org.apache.sysds.runtime.functionobjects.Plus;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
+import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
+import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+
+public class TSMMOOCInstruction extends ComputationOOCInstruction {
+ private final MMTSJType _type;
+
+ protected TSMMOOCInstruction(OOCType type, Operator op, CPOperand in1,
CPOperand out, MMTSJ.MMTSJType mmtsjType, String opcode, String istr) {
+ super(type, op, in1, out, opcode, istr);
+ _type = mmtsjType;
+ }
+
+ public static TSMMOOCInstruction parseInstruction(String str) {
+ String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
+ InstructionUtils.checkNumFields(parts, 3);
+ String opcode = parts[0];
+ CPOperand in1 = new CPOperand(parts[1]); // the large matrix
(streamed), columns <= blocksize
+ CPOperand out = new CPOperand(parts[2]);
+ MMTSJ.MMTSJType mmtsjType = MMTSJ.MMTSJType.valueOf(parts[3]);
+
+ AggregateOperator agg = new AggregateOperator(0,
Plus.getPlusFnObject());
+ AggregateBinaryOperator ba = new
AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);
+
+ return new TSMMOOCInstruction(OOCType.MMTSJ, ba, in1, out,
mmtsjType, opcode, str);
+ }
+
+ @Override
+ public void processInstruction( ExecutionContext ec ) {
+ MatrixObject min = ec.getMatrixObject(input1);
+ int nRows = (int) min.getDataCharacteristics().getRows();
+ int nCols = (int) min.getDataCharacteristics().getCols();
+ int bLen = min.getDataCharacteristics().getBlocksize();
+
+ LocalTaskQueue<IndexedMatrixValue> qIn = min.getStreamHandle();
+ BinaryOperator plus =
InstructionUtils.parseBinaryOperator(Opcodes.PLUS.toString());
+
+ //validation check TODO extend compiler to not create OOC
otherwise
+ if( (_type.isLeft() && nCols > bLen)
+ || (_type.isRight() && nRows > bLen) )
+ {
+ throw new UnsupportedOperationException();
+ }
+
+ int dim = _type.isLeft() ? nCols : nRows;
+ MatrixBlock resultBlock = new MatrixBlock(dim, dim, false);
+ try {
+ IndexedMatrixValue tmp = null;
+ // aggregate partial tsmm outputs into result as inputs
stream in
+ while((tmp = qIn.dequeueTask()) !=
LocalTaskQueue.NO_MORE_TASKS) {
+ MatrixBlock partialResult = ((MatrixBlock)
tmp.getValue())
+ .transposeSelfMatrixMultOperations(new
MatrixBlock(), _type);
+ resultBlock.binaryOperationsInPlace(plus,
partialResult);
+ }
+ }
+ catch(Exception ex) {
+ throw new DMLRuntimeException(ex);
+ }
+
+ ec.setMatrixOutput(output.getName(), resultBlock);
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/ooc/TransposeSelfMMTest.java
b/src/test/java/org/apache/sysds/test/functions/ooc/TransposeSelfMMTest.java
new file mode 100644
index 0000000000..ed61038a71
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/ooc/TransposeSelfMMTest.java
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.ooc;
+
+import org.apache.sysds.common.Opcodes;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.lops.MMTSJ;
+import org.apache.sysds.runtime.instructions.Instruction;
+import org.apache.sysds.runtime.io.MatrixWriter;
+import org.apache.sysds.runtime.io.MatrixWriterFactory;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.DataConverter;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.io.IOException;
+
+public class TransposeSelfMMTest extends AutomatedTestBase {
+ private final static String TEST_NAME1 = "TSMM";
+ private final static String TEST_DIR = "functions/ooc/";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
TransposeSelfMMTest.class.getSimpleName() + "/";
+ private final static double eps = 1e-8;
+ private static final String INPUT_NAME = "X";
+ private static final String OUTPUT_NAME = "res";
+
+ private final static int rows = 2143;
+ private final static int cols = 123;
+ private final static double sparsity1 = 0.7;
+ private final static double sparsity2 = 0.1;
+ private final int k = 1;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ TestConfiguration config = new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1);
+ addTestConfiguration(TEST_NAME1, config);
+ }
+
+ @Test
+ public void testTsmmDense() {
+ runTSMMTest(cols, false);
+ }
+
+ @Test
+ public void testTsmmSparse() {
+ runTSMMTest(cols, false);
+ }
+
+ private void runTSMMTest(int cols, boolean sparse )
+ {
+ Types.ExecMode platformOld =
setExecMode(Types.ExecMode.SINGLE_NODE);
+
+ try
+ {
+ getAndLoadTestConfiguration(TEST_NAME1);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
+ programArgs = new String[]{"-explain", "-stats", "-ooc",
+ "-args",
input(INPUT_NAME), output(OUTPUT_NAME)};
+
+ // 1. Generate the data in-memory as MatrixBlock objects
+ double[][] A_data = getRandomMatrix(rows, cols, 0, 1,
sparse?sparsity2:sparsity1, 10);
+
+ // 2. Convert the double arrays to MatrixBlock objects
+ MatrixBlock A_mb =
DataConverter.convertToMatrixBlock(A_data);
+
+ // 3. Create a binary matrix writer
+ MatrixWriter writer =
MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY);
+
+ // 4. Write matrix A to a binary SequenceFile
+ writer.writeMatrixToHDFS(A_mb, input(INPUT_NAME), rows,
cols, 1000, A_mb.getNonZeros());
+ HDFSTool.writeMetaDataFile(input(INPUT_NAME + ".mtd"),
Types.ValueType.FP64,
+ new MatrixCharacteristics(rows, cols, 1000,
A_mb.getNonZeros()), Types.FileFormat.BINARY);
+
+ runTest(true, false, null, -1);
+
+ //check tsmm OOC
+ Assert.assertTrue("OOC wasn't used for TSMM",
+
heavyHittersContainsString(Instruction.OOC_INST_PREFIX + Opcodes.TSMM));
+
+ //compare results
+ MatrixBlock ret1 =
DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME),
+ Types.FileFormat.BINARY, cols, cols, 1000,
cols*cols);
+ MatrixBlock ret2 = new MatrixBlock(rows, rows, false);
+ A_mb.transposeSelfMatrixMultOperations(ret2,
MMTSJ.MMTSJType.LEFT, k);
+ TestUtils.compareMatrices(ret1, ret2, eps);
+ }
+ catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ finally {
+ resetExecMode(platformOld);
+ }
+ }
+}
diff --git a/src/test/scripts/functions/ooc/TSMM.dml
b/src/test/scripts/functions/ooc/TSMM.dml
new file mode 100644
index 0000000000..432d2d9daa
--- /dev/null
+++ b/src/test/scripts/functions/ooc/TSMM.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# Read input matrix and operator from command line args
+X = read($1);
+
+# Operation under test
+res = t(X) %*% X;
+
+write(res, $2, format="binary")