This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 0c7e0468e9 [SYSTEMDS-3904] New OOC matrix-vector multiplication
0c7e0468e9 is described below
commit 0c7e0468e9a9ad43ebc8a80d16d42e71f03be2e0
Author: Janardhan Pulivarthi <[email protected]>
AuthorDate: Sun Aug 10 10:35:00 2025 +0200
[SYSTEMDS-3904] New OOC matrix-vector multiplication
Closes #2305.
---
.../java/org/apache/sysds/hops/AggBinaryOp.java | 8 ++
.../runtime/instructions/OOCInstructionParser.java | 4 +
.../ooc/MatrixVectorBinaryOOCInstruction.java | 142 +++++++++++++++++++++
.../runtime/instructions/ooc/OOCInstruction.java | 2 +-
.../ooc/MatrixVectorBinaryMultiplicationTest.java | 132 +++++++++++++++++++
.../functions/ooc/MatrixVectorMultiplication.dml | 29 +++++
6 files changed, 316 insertions(+), 1 deletion(-)
diff --git a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
index 0be3143206..fb20cc41d0 100644
--- a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
@@ -240,6 +240,14 @@ public class AggBinaryOp extends MultiThreadedHop {
default:
throw new
HopsException(this.printErrorLocation() + "Invalid Matrix Mult Method (" +
_method + ") while constructing SPARK lops.");
}
+ } else if (et == ExecType.OOC) {
+ Lop in1 = getInput().get(0).constructLops();
+ Lop in2 = getInput().get(1).constructLops();
+ MatMultCP matmult = new MatMultCP(in1, in2,
getDataType(), getValueType(),
+ et,
OptimizerUtils.getConstrainedNumThreads(_maxNumThreads));
+ setOutputDimensions(matmult);
+ setLineNumbers(matmult);
+ setLops(matmult);
}
} else
throw new HopsException(this.printErrorLocation() +
"Invalid operation in AggBinary Hop, aggBin(" + innerOp + "," + outerOp + ")
while constructing lops.");
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
index a744b5d813..9b1165b819 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
@@ -28,6 +28,7 @@ import
org.apache.sysds.runtime.instructions.ooc.BinaryOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.OOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.ReblockOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.UnaryOOCInstruction;
+import
org.apache.sysds.runtime.instructions.ooc.MatrixVectorBinaryOOCInstruction;
public class OOCInstructionParser extends InstructionParser {
protected static final Log LOG =
LogFactory.getLog(OOCInstructionParser.class.getName());
@@ -56,6 +57,9 @@ public class OOCInstructionParser extends InstructionParser {
return
UnaryOOCInstruction.parseInstruction(str);
case Binary:
return
BinaryOOCInstruction.parseInstruction(str);
+ case AggregateBinary:
+ case MAPMM:
+ return
MatrixVectorBinaryOOCInstruction.parseInstruction(str);
default:
throw new DMLRuntimeException("Invalid OOC
Instruction Type: " + ooctype);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java
new file mode 100644
index 0000000000..a36dc7c885
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.ooc;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.ExecutorService;
+
+import org.apache.sysds.common.Opcodes;
+import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
+import org.apache.sysds.runtime.functionobjects.Multiply;
+import org.apache.sysds.runtime.functionobjects.Plus;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
+import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
+import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
+import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.util.CommonThreadPool;
+
+public class MatrixVectorBinaryOOCInstruction extends
ComputationOOCInstruction {
+
+
+ protected MatrixVectorBinaryOOCInstruction(OOCType type, Operator op,
CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) {
+ super(type, op, in1, in2, out, opcode, istr);
+ }
+
+ public static MatrixVectorBinaryOOCInstruction parseInstruction(String
str) {
+ String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
+ InstructionUtils.checkNumFields(parts, 4);
+ String opcode = parts[0];
+ CPOperand in1 = new CPOperand(parts[1]); // the larget matrix
(streamed)
+ CPOperand in2 = new CPOperand(parts[2]); // the small vector
(in-memory)
+ CPOperand out = new CPOperand(parts[3]);
+
+ AggregateOperator agg = new AggregateOperator(0,
Plus.getPlusFnObject());
+ AggregateBinaryOperator ba = new
AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);
+
+ return new MatrixVectorBinaryOOCInstruction(OOCType.MAPMM, ba,
in1, in2, out, opcode, str);
+ }
+
+ @Override
+ public void processInstruction( ExecutionContext ec ) {
+ // 1. Identify the inputs
+ MatrixObject min = ec.getMatrixObject(input1); // big matrix
+ MatrixBlock vin = ec.getMatrixObject(input2)
+ .acquireReadAndRelease(); // in-memory vector
+
+ // 2. Pre-partition the in-memory vector into a hashmap
+ HashMap<Long, MatrixBlock> partitionedVector = new HashMap<>();
+ int blksize = vin.getDataCharacteristics().getBlocksize();
+ if (blksize < 0)
+ blksize = ConfigurationManager.getBlocksize();
+ for (int i=0; i<vin.getNumRows(); i+=blksize) {
+ long key = (long) (i/blksize) + 1; // the key starts at
1
+ int end_row = Math.min(i + blksize, vin.getNumRows());
+ MatrixBlock vectorSlice = vin.slice(i, end_row - 1);
+ partitionedVector.put(key, vectorSlice);
+ }
+
+ LocalTaskQueue<IndexedMatrixValue> qIn = min.getStreamHandle();
+ LocalTaskQueue<IndexedMatrixValue> qOut = new
LocalTaskQueue<>();
+ BinaryOperator plus =
InstructionUtils.parseBinaryOperator(Opcodes.PLUS.toString());
+ ec.getMatrixObject(output).setStreamHandle(qOut);
+
+ ExecutorService pool = CommonThreadPool.get();
+ try {
+ // Core logic: background thread
+ pool.submit(() -> {
+ IndexedMatrixValue tmp = null;
+ try {
+ HashMap<Long, MatrixBlock>
partialResults = new HashMap<>();
+ while((tmp = qIn.dequeueTask()) !=
LocalTaskQueue.NO_MORE_TASKS) {
+ MatrixBlock matrixBlock =
(MatrixBlock) tmp.getValue();
+ long rowIndex =
tmp.getIndexes().getRowIndex();
+ long colIndex =
tmp.getIndexes().getColumnIndex();
+ MatrixBlock vectorSlice =
partitionedVector.get(colIndex);
+
+ // Now, call the operation with
the correct, specific operator.
+ MatrixBlock partialResult =
matrixBlock.aggregateBinaryOperations(
+ matrixBlock,
vectorSlice, new MatrixBlock(), (AggregateBinaryOperator) _optr);
+
+ // for single column block, no
aggregation neeeded
+ if( min.getNumColumns() <=
min.getBlocksize() ) {
+ qOut.enqueueTask(new
IndexedMatrixValue(tmp.getIndexes(), partialResult));
+ }
+ else {
+ MatrixBlock currAgg =
partialResults.get(rowIndex);
+ if (currAgg == null)
+
partialResults.put(rowIndex, partialResult);
+ else
+
currAgg.binaryOperationsInPlace(plus, partialResult);
+ }
+ }
+
+ // emit aggregated blocks
+ if( min.getNumColumns() >
min.getBlocksize() ) {
+ for (Map.Entry<Long,
MatrixBlock> entry : partialResults.entrySet()) {
+ MatrixIndexes
outIndexes = new MatrixIndexes(entry.getKey(), 1L);
+ qOut.enqueueTask(new
IndexedMatrixValue(outIndexes, entry.getValue()));
+ }
+ }
+ }
+ catch(Exception ex) {
+ throw new DMLRuntimeException(ex);
+ }
+ finally {
+ qOut.closeInput();
+ }
+ });
+ } catch (Exception e) {
+ throw new DMLRuntimeException(e);
+ }
+ finally {
+ pool.shutdown();
+ }
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
index db3d2da8b1..d3c2dfcbd7 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
@@ -30,7 +30,7 @@ public abstract class OOCInstruction extends Instruction {
protected static final Log LOG =
LogFactory.getLog(OOCInstruction.class.getName());
public enum OOCType {
- Reblock, AggregateUnary, Binary, Unary
+ Reblock, AggregateUnary, Binary, Unary, MAPMM, AggregateBinary
}
protected final OOCInstruction.OOCType _ooctype;
diff --git
a/src/test/java/org/apache/sysds/test/functions/ooc/MatrixVectorBinaryMultiplicationTest.java
b/src/test/java/org/apache/sysds/test/functions/ooc/MatrixVectorBinaryMultiplicationTest.java
new file mode 100644
index 0000000000..de4e7e9912
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/ooc/MatrixVectorBinaryMultiplicationTest.java
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.ooc;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.io.MatrixWriter;
+import org.apache.sysds.runtime.io.MatrixWriterFactory;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.DataConverter;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.io.IOException;
+
+public class MatrixVectorBinaryMultiplicationTest extends AutomatedTestBase {
+ private final static String TEST_NAME1 = "MatrixVectorMultiplication";
+ private final static String TEST_DIR = "functions/ooc/";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
MatrixVectorBinaryMultiplicationTest.class.getSimpleName() + "/";
+ private final static double eps = 1e-10;
+ private static final String INPUT_NAME = "X";
+ private static final String INPUT_NAME2 = "v";
+ private static final String OUTPUT_NAME = "res";
+
+ private final static int rows = 5000;
+ private final static int cols_wide = 2000;
+ private final static int cols_skinny = 500;
+
+ private final static double sparsity1 = 0.7;
+ private final static double sparsity2 = 0.1;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ TestConfiguration config = new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1);
+ addTestConfiguration(TEST_NAME1, config);
+ }
+
+ @Test
+ public void testMVBinaryMultiplication1() {
+ runMatrixVectorMultiplicationTest(cols_wide, false);
+ }
+
+ @Test
+ public void testMVBinaryMultiplication2() {
+ runMatrixVectorMultiplicationTest(cols_skinny, false);
+ }
+
+ private void runMatrixVectorMultiplicationTest(int cols, boolean sparse
)
+ {
+ Types.ExecMode platformOld =
setExecMode(Types.ExecMode.SINGLE_NODE);
+
+ try
+ {
+ getAndLoadTestConfiguration(TEST_NAME1);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
+ programArgs = new String[]{"-explain", "-stats", "-ooc",
+ "-args", input(INPUT_NAME),
input(INPUT_NAME2), output(OUTPUT_NAME)};
+
+ // 1. Generate the data in-memory as MatrixBlock objects
+ double[][] A_data = getRandomMatrix(rows, cols, 0, 1,
sparse?sparsity2:sparsity1, 10);
+ double[][] x_data = getRandomMatrix(cols, 1, 0, 1, 1.0,
10);
+
+ // 2. Convert the double arrays to MatrixBlock objects
+ MatrixBlock A_mb =
DataConverter.convertToMatrixBlock(A_data);
+ MatrixBlock x_mb =
DataConverter.convertToMatrixBlock(x_data);
+
+ // 3. Create a binary matrix writer
+ MatrixWriter writer =
MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY);
+
+ // 4. Write matrix A to a binary SequenceFile
+ writer.writeMatrixToHDFS(A_mb, input(INPUT_NAME), rows,
cols, 1000, A_mb.getNonZeros());
+ HDFSTool.writeMetaDataFile(input(INPUT_NAME + ".mtd"),
Types.ValueType.FP64,
+ new MatrixCharacteristics(rows, cols,
1000, A_mb.getNonZeros()), Types.FileFormat.BINARY);
+
+ // 5. Write vector x to a binary SequenceFile
+ writer.writeMatrixToHDFS(x_mb, input(INPUT_NAME2),
cols, 1, 1000, x_mb.getNonZeros());
+ HDFSTool.writeMetaDataFile(input(INPUT_NAME2 + ".mtd"),
Types.ValueType.FP64,
+ new MatrixCharacteristics(cols, 1,
1000, x_mb.getNonZeros()), Types.FileFormat.BINARY);
+
+ boolean exceptionExpected = false;
+ runTest(true, exceptionExpected, null, -1);
+
+ double[][] C1 = readMatrix(output(OUTPUT_NAME),
Types.FileFormat.BINARY, rows, cols, 1000, 1000);
+ double result = 0.0;
+ for(int i = 0; i < rows; i++) { // verify the results
with Java
+ double expected = 0.0;
+ for(int j = 0; j < cols; j++) {
+ expected += A_mb.get(i, j) *
x_mb.get(j,0);
+ }
+ result = C1[i][0];
+ Assert.assertEquals(expected, result, eps);
+ }
+ }
+ catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ finally {
+ resetExecMode(platformOld);
+ }
+ }
+
+ private static double[][] readMatrix(String fname, Types.FileFormat
fmt, long rows, long cols, int brows, int bcols )
+ throws IOException
+ {
+ MatrixBlock mb = DataConverter.readMatrixFromHDFS(fname, fmt,
rows, cols, brows, bcols);
+ double[][] C = DataConverter.convertToDoubleMatrix(mb);
+ return C;
+ }
+}
diff --git a/src/test/scripts/functions/ooc/MatrixVectorMultiplication.dml
b/src/test/scripts/functions/ooc/MatrixVectorMultiplication.dml
new file mode 100644
index 0000000000..c72db07780
--- /dev/null
+++ b/src/test/scripts/functions/ooc/MatrixVectorMultiplication.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# Read input matrix and operator from command line args
+X = read($1);
+v = read($2);
+
+# Operation under test
+res = X %*% v;
+
+write(res, $3, format="binary")