This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new ca8d20916c [SYSTEMDS-3894] New out-of-core binary scalar-matrix
operations
ca8d20916c is described below
commit ca8d20916c2f6a5073f0e2026511908f53bb0904
Author: Matthias Boehm <[email protected]>
AuthorDate: Tue Jul 15 17:58:05 2025 +0200
[SYSTEMDS-3894] New out-of-core binary scalar-matrix operations
This patch completes the selected example operations for the new
out-of-core backend and related test.
---
src/main/java/org/apache/sysds/hops/BinaryOp.java | 3 -
.../runtime/instructions/OOCInstructionParser.java | 6 +-
.../instructions/ooc/BinaryOOCInstruction.java | 95 ++++++++++++++++++++++
.../functions/ooc/SumScalarMultiplicationTest.java | 29 +++++--
4 files changed, 121 insertions(+), 12 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java
b/src/main/java/org/apache/sysds/hops/BinaryOp.java
index f433931a52..a3ddb45ea6 100644
--- a/src/main/java/org/apache/sysds/hops/BinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java
@@ -854,9 +854,6 @@ public class BinaryOp extends MultiThreadedHop {
_etype = ExecType.CP;
}
- if( _etype == ExecType.OOC ) //TODO
- setExecType(ExecType.CP);
-
//mark for recompile (forever)
setRequiresRecompileIfNecessary();
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
index c437684d3b..0e5b3f1f51 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
@@ -24,6 +24,7 @@ import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.InstructionType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.instructions.ooc.AggregateUnaryOOCInstruction;
+import org.apache.sysds.runtime.instructions.ooc.BinaryOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.OOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.ReblockOOCInstruction;
@@ -50,10 +51,9 @@ public class OOCInstructionParser extends InstructionParser {
return
ReblockOOCInstruction.parseInstruction(str);
case AggregateUnary:
return
AggregateUnaryOOCInstruction.parseInstruction(str);
-
- // TODO:
case Binary:
-
+ return
BinaryOOCInstruction.parseInstruction(str);
+
default:
throw new DMLRuntimeException("Invalid OOC
Instruction Type: " + ooctype);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java
new file mode 100644
index 0000000000..fe76e60b9e
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.ooc;
+
+import java.util.concurrent.ExecutorService;
+
+import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.ScalarObject;
+import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
+import org.apache.sysds.runtime.util.CommonThreadPool;
+
+public class BinaryOOCInstruction extends ComputationOOCInstruction {
+
+ protected BinaryOOCInstruction(OOCType type, Operator bop,
+ CPOperand in1, CPOperand in2, CPOperand out, String
opcode, String istr) {
+ super(type, bop, in1, in2, out, opcode, istr);
+ }
+
+ public static BinaryOOCInstruction parseInstruction(String str) {
+ String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
+ InstructionUtils.checkNumFields(parts, 3);
+ String opcode = parts[0];
+ CPOperand in1 = new CPOperand(parts[1]);
+ CPOperand in2 = new CPOperand(parts[2]);
+ CPOperand out = new CPOperand(parts[3]);
+ Operator bop =
InstructionUtils.parseExtendedBinaryOrBuiltinOperator(opcode, in1, in2);
+
+ return new BinaryOOCInstruction(
+ OOCType.Binary, bop, in1, in2, out, opcode, str);
+ }
+
+ @Override
+ public void processInstruction( ExecutionContext ec ) {
+ //TODO support all types, currently only binary matrix-scalar
+
+ //get operator and scalar
+ CPOperand scalar = ( input1.getDataType() == DataType.MATRIX )
? input2 : input1;
+ ScalarObject constant = ec.getScalarInput(scalar);
+ ScalarOperator sc_op =
((ScalarOperator)_optr).setConstant(constant.getDoubleValue());
+
+ //create thread and process binary operation
+ MatrixObject min = ec.getMatrixObject(input1);
+ LocalTaskQueue<IndexedMatrixValue> qIn = min.getStreamHandle();
+ LocalTaskQueue<IndexedMatrixValue> qOut = new
LocalTaskQueue<>();
+ ec.getMatrixObject(output).setStreamHandle(qOut);
+
+ ExecutorService pool = CommonThreadPool.get();
+ try {
+ pool.submit(() -> {
+ IndexedMatrixValue tmp = null;
+ try {
+ while((tmp = qIn.dequeueTask()) !=
LocalTaskQueue.NO_MORE_TASKS) {
+ IndexedMatrixValue tmpOut = new
IndexedMatrixValue();
+ tmpOut.set(tmp.getIndexes(),
+
tmp.getValue().scalarOperations(sc_op, new MatrixBlock()));
+ qOut.enqueueTask(tmpOut);
+ }
+ qOut.closeInput();
+ }
+ catch(Exception ex) {
+ throw new DMLRuntimeException(ex);
+ }
+ });
+ }
+ finally {
+ pool.shutdown();
+ }
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/ooc/SumScalarMultiplicationTest.java
b/src/test/java/org/apache/sysds/test/functions/ooc/SumScalarMultiplicationTest.java
index 2272588bab..f0d9228a53 100644
---
a/src/test/java/org/apache/sysds/test/functions/ooc/SumScalarMultiplicationTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/ooc/SumScalarMultiplicationTest.java
@@ -23,6 +23,7 @@ import org.apache.sysds.common.Opcodes;
import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.FileFormat;
import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.io.MatrixWriter;
import org.apache.sysds.runtime.io.MatrixWriterFactory;
@@ -57,11 +58,26 @@ public class SumScalarMultiplicationTest extends
AutomatedTestBase {
* Test the sum of scalar multiplication, "sum(X*7)", with OOC backend.
*/
@Test
- public void testSumScalarMult() {
-
+ public void testSumScalarMultNoRewrite() {
+ testSumScalarMult(false);
+ }
+
+ /**
+ * Test the sum of scalar multiplication, "sum(X)*7", with OOC backend.
+ */
+ @Test
+ public void testSumScalarMultRewrite() {
+ testSumScalarMult(true);
+ }
+
+
+ public void testSumScalarMult(boolean rewrite)
+ {
Types.ExecMode platformOld = rtplatform;
rtplatform = Types.ExecMode.SINGLE_NODE;
-
+ boolean oldRewrite =
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrite;
+
try {
getAndLoadTestConfiguration(TEST_NAME);
String HOME = SCRIPT_DIR + TEST_DIR;
@@ -92,16 +108,17 @@ public class SumScalarMultiplicationTest extends
AutomatedTestBase {
String prefix = Instruction.OOC_INST_PREFIX;
Assert.assertTrue("OOC wasn't used for RBLK",
heavyHittersContainsString(prefix +
Opcodes.RBLK));
+ if(!rewrite)
+ Assert.assertTrue("OOC wasn't used for SUM",
+ heavyHittersContainsString(prefix +
Opcodes.MULT));
Assert.assertTrue("OOC wasn't used for SUM",
heavyHittersContainsString(prefix +
Opcodes.UAKP));
-
-// boolean usedOOCMult =
Statistics.getCPHeavyHitterOpCodes().contains(prefix + Opcodes.MULT);
-// Assert.assertTrue("OOC wasn't used for MULT",
usedOOCMult);
}
catch(Exception ex) {
Assert.fail(ex.getMessage());
}
finally {
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION =
oldRewrite;
resetExecMode(platformOld);
}
}