This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 0263279673 [SYSTEMDS-3895] Add OOC row and column aggregations with
tests
0263279673 is described below
commit 0263279673aed3403f73ebc0c8702107d8511a5c
Author: Jessica Priebe <[email protected]>
AuthorDate: Wed Aug 20 13:10:24 2025 +0200
[SYSTEMDS-3895] Add OOC row and column aggregations with tests
Closes #2309.
---
.../ooc/AggregateUnaryOOCInstruction.java | 130 +++++++++++++++++----
.../test/functions/ooc/ColAggregationTest.java | 113 ++++++++++++++++++
.../test/functions/ooc/RowAggregationTest.java | 113 ++++++++++++++++++
.../scripts/functions/ooc/ColAggregationTest.dml | 24 ++++
.../scripts/functions/ooc/RowAggregationTest.dml | 24 ++++
5 files changed, 384 insertions(+), 20 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java
index c333088239..a656cd337c 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java
@@ -30,10 +30,15 @@ import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
+import org.apache.sysds.runtime.meta.DataCharacteristics;
+import org.apache.sysds.runtime.util.CommonThreadPool;
+import java.util.HashMap;
+import java.util.concurrent.ExecutorService;
public class AggregateUnaryOOCInstruction extends ComputationOOCInstruction {
private AggregateOperator _aop = null;
@@ -61,34 +66,119 @@ public class AggregateUnaryOOCInstruction extends
ComputationOOCInstruction {
@Override
public void processInstruction( ExecutionContext ec ) {
- //TODO support all types of aggregations, currently only full
aggregation
+ //TODO support all types of aggregations, currently only full
aggregation, row aggregation and column aggregation
//setup operators and input queue
AggregateUnaryOperator aggun = (AggregateUnaryOperator)
getOperator();
MatrixObject min = ec.getMatrixObject(input1);
LocalTaskQueue<IndexedMatrixValue> q = min.getStreamHandle();
- IndexedMatrixValue tmp = null;
int blen = ConfigurationManager.getBlocksize();
-
- //read blocks and aggregate immediately into result
- int extra = _aop.correction.getNumRemovedRowsColumns();
- MatrixBlock ret = new MatrixBlock(1,1+extra,false);
- MatrixBlock corr = new MatrixBlock(1,1+extra,false);
- try {
- while((tmp = q.dequeueTask()) !=
LocalTaskQueue.NO_MORE_TASKS) {
- //block aggregation
- MatrixBlock ltmp = (MatrixBlock) ((MatrixBlock)
tmp.getValue())
- .aggregateUnaryOperations(aggun, new
MatrixBlock(), blen, tmp.getIndexes());
- //accumulation into final result
- OperationsOnMatrixValues.incrementalAggregation(
- ret, _aop.existsCorrection() ? corr :
null, ltmp, _aop, true);
+
+ if (aggun.isRowAggregate() || aggun.isColAggregate()) {
+ // intermediate state per aggregation index
+ HashMap<Long, MatrixBlock> aggs = new HashMap<>(); //
partial aggregates
+ HashMap<Long, MatrixBlock> corrs = new HashMap<>(); //
correction blocks
+ HashMap<Long, Integer> cnt = new HashMap<>(); //
processed block count per agg idx
+
+ DataCharacteristics chars =
ec.getDataCharacteristics(input1.getName());
+ // number of blocks to process per aggregation idx (row
or column dim)
+ long nBlocks = aggun.isRowAggregate()?
chars.getNumColBlocks() : chars.getNumRowBlocks();
+
+ LocalTaskQueue<IndexedMatrixValue> qOut = new
LocalTaskQueue<>();
+ ec.getMatrixObject(output).setStreamHandle(qOut);
+ ExecutorService pool = CommonThreadPool.get();
+ try {
+ pool.submit(() -> {
+ IndexedMatrixValue tmp = null;
+ try {
+ while((tmp = q.dequeueTask())
!= LocalTaskQueue.NO_MORE_TASKS) {
+ long idx =
aggun.isRowAggregate() ?
+
tmp.getIndexes().getRowIndex() : tmp.getIndexes().getColumnIndex();
+
if(aggs.containsKey(idx)) {
+ // update
existing partial aggregate for this idx
+ MatrixBlock ret
= aggs.get(idx);
+ MatrixBlock
corr = corrs.get(idx);
+
+ // aggregation
+ MatrixBlock
ltmp = (MatrixBlock) ((MatrixBlock) tmp.getValue())
+
.aggregateUnaryOperations(aggun, new MatrixBlock(), blen, tmp.getIndexes());
+
OperationsOnMatrixValues.incrementalAggregation(ret,
+
_aop.existsCorrection() ? corr : null, ltmp, _aop, true);
+
+
aggs.replace(idx, ret);
+
corrs.replace(idx, corr);
+
cnt.replace(idx, cnt.get(idx) + 1);
+ }
+ else {
+ // first block
for this idx - init aggregate and correction
+ // TODO avoid
corr block for inplace incremental aggregation
+ int rows =
tmp.getValue().getNumRows();
+ int cols =
tmp.getValue().getNumColumns();
+ int extra =
_aop.correction.getNumRemovedRowsColumns();
+ MatrixBlock ret
= aggun.isRowAggregate()? new MatrixBlock(rows, 1 + extra, false) : new
MatrixBlock(1 + extra, cols, false);
+ MatrixBlock
corr = aggun.isRowAggregate()? new MatrixBlock(rows, 1 + extra, false) : new
MatrixBlock(1 + extra, cols, false);
+
+ // aggregation
+ MatrixBlock
ltmp = (MatrixBlock) ((MatrixBlock) tmp.getValue()).aggregateUnaryOperations(
+ aggun,
new MatrixBlock(), blen, tmp.getIndexes());
+
OperationsOnMatrixValues.incrementalAggregation(ret,
+
_aop.existsCorrection() ? corr : null, ltmp, _aop, true);
+
+ aggs.put(idx,
ret);
+ corrs.put(idx,
corr);
+ cnt.put(idx, 1);
+ }
+
+ if(cnt.get(idx) ==
nBlocks) {
+ // all input
blocks for this idx processed - emit aggregated block
+ MatrixBlock ret
= aggs.get(idx);
+ // drop
correction row/col
+
ret.dropLastRowsOrColumns(_aop.correction);
+ MatrixIndexes
midx = aggun.isRowAggregate()? new
MatrixIndexes(tmp.getIndexes().getRowIndex(), 1) : new MatrixIndexes(1,
tmp.getIndexes().getColumnIndex());
+
IndexedMatrixValue tmpOut = new IndexedMatrixValue(midx, ret);
+
+
qOut.enqueueTask(tmpOut);
+ // drop
intermediate states
+
aggs.remove(idx);
+
corrs.remove(idx);
+ cnt.remove(idx);
+ }
+ }
+ qOut.closeInput();
+ }
+ catch(Exception ex) {
+ throw new
DMLRuntimeException(ex);
+ }
+ });
+ } catch (Exception ex) {
+ throw new DMLRuntimeException(ex);
+ } finally {
+ pool.shutdown();
}
}
- catch(Exception ex) {
- throw new DMLRuntimeException(ex);
+ // full aggregation
+ else {
+ IndexedMatrixValue tmp = null;
+ //read blocks and aggregate immediately into result
+ int extra = _aop.correction.getNumRemovedRowsColumns();
+ MatrixBlock ret = new MatrixBlock(1,1+extra,false);
+ MatrixBlock corr = new MatrixBlock(1,1+extra,false);
+ try {
+ while((tmp = q.dequeueTask()) !=
LocalTaskQueue.NO_MORE_TASKS) {
+ //block aggregation
+ MatrixBlock ltmp = (MatrixBlock)
((MatrixBlock) tmp.getValue())
+
.aggregateUnaryOperations(aggun, new MatrixBlock(), blen, tmp.getIndexes());
+ //accumulation into final result
+
OperationsOnMatrixValues.incrementalAggregation(
+ ret, _aop.existsCorrection() ?
corr : null, ltmp, _aop, true);
+ }
+ }
+ catch(Exception ex) {
+ throw new DMLRuntimeException(ex);
+ }
+
+ //create scalar output
+ ec.setScalarOutput(output.getName(), new
DoubleObject(ret.get(0, 0)));
}
-
- //create scalar output
- ec.setScalarOutput(output.getName(), new
DoubleObject(ret.get(0, 0)));
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/ooc/ColAggregationTest.java
b/src/test/java/org/apache/sysds/test/functions/ooc/ColAggregationTest.java
new file mode 100644
index 0000000000..ff430039e2
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/ooc/ColAggregationTest.java
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.ooc;
+
+import org.apache.sysds.common.Opcodes;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.instructions.Instruction;
+import org.apache.sysds.runtime.io.MatrixWriter;
+import org.apache.sysds.runtime.io.MatrixWriterFactory;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.DataConverter;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class ColAggregationTest extends AutomatedTestBase{
+ private static final String TEST_NAME = "ColAggregationTest";
+ private static final String TEST_DIR = "functions/ooc/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
ColAggregationTest.class.getSimpleName() + "/";
+ private static final String INPUT_NAME = "X";
+ private static final String OUTPUT_NAME = "res";
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ TestConfiguration config = new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME);
+ addTestConfiguration(TEST_NAME, config);
+ }
+
+ @Test
+ public void testColAggregationNoRewrite() {
+ testColAggregation(false);
+ }
+
+ /**
+ * Test the col aggregation, "colSums(X)", with OOC backend.
+ */
+ @Test
+ public void testColAggregationRewrite() {
+ testColAggregation(true);
+ }
+
+ public void testColAggregation(boolean rewrite)
+ {
+ Types.ExecMode platformOld = rtplatform;
+ rtplatform = Types.ExecMode.SINGLE_NODE;
+ boolean oldRewrite =
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrite;
+
+ try {
+ getAndLoadTestConfiguration(TEST_NAME);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-explain", "-stats",
"-ooc",
+ "-args", input(INPUT_NAME),
output(OUTPUT_NAME)};
+
+ int rows = 4200, cols = 2700;
+ MatrixBlock mb = MatrixBlock.randOperations(rows, cols,
1.0, -1, 1, "uniform", 7);
+ MatrixWriter writer =
MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY);
+ writer.writeMatrixToHDFS(mb, input(INPUT_NAME), rows,
cols, 1000, rows*cols);
+ HDFSTool.writeMetaDataFile(input(INPUT_NAME+".mtd"),
Types.ValueType.FP64,
+ new
MatrixCharacteristics(rows,cols,1000,rows*cols), Types.FileFormat.BINARY);
+
+ runTest(true, false, null, -1);
+
+ double[][] res =
DataConverter.convertToDoubleMatrix(DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME),
Types.FileFormat.BINARY, 1, cols, 1000, 1000));
+ for(int j = 0; j < cols; j++) {
+ double expected = 0.0;
+ for(int i = 0; i < rows; i++) {
+ expected += mb.get(i, j);
+ }
+ Assert.assertEquals(expected, res[0][j], 1e-10);
+ }
+
+ String prefix = Instruction.OOC_INST_PREFIX;
+ Assert.assertTrue("OOC wasn't used for RBLK",
+ heavyHittersContainsString(prefix +
Opcodes.RBLK));
+ // uack+
+ Assert.assertTrue("OOC wasn't used for COLSUMS",
+ heavyHittersContainsString(prefix +
Opcodes.UACKP));
+ }
+ catch(Exception ex) {
+ Assert.fail(ex.getMessage());
+ }
+ finally {
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION =
oldRewrite;
+ resetExecMode(platformOld);
+ }
+ }
+
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/ooc/RowAggregationTest.java
b/src/test/java/org/apache/sysds/test/functions/ooc/RowAggregationTest.java
new file mode 100644
index 0000000000..920573bb64
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/ooc/RowAggregationTest.java
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.ooc;
+
+import org.apache.sysds.common.Opcodes;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.instructions.Instruction;
+import org.apache.sysds.runtime.io.MatrixWriter;
+import org.apache.sysds.runtime.io.MatrixWriterFactory;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.DataConverter;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class RowAggregationTest extends AutomatedTestBase{
+ private static final String TEST_NAME = "RowAggregationTest";
+ private static final String TEST_DIR = "functions/ooc/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
RowAggregationTest.class.getSimpleName() + "/";
+ private static final String INPUT_NAME = "X";
+ private static final String OUTPUT_NAME = "res";
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ TestConfiguration config = new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME);
+ addTestConfiguration(TEST_NAME, config);
+ }
+
+ @Test
+ public void testRowAggregationNoRewrite() {
+ testRowAggregation(false);
+ }
+
+ /**
+ * Test the row aggregation, "rowSums(X)", with OOC backend.
+ */
+ @Test
+ public void testRowAggregationRewrite() {
+ testRowAggregation(true);
+ }
+
+ public void testRowAggregation(boolean rewrite)
+ {
+ Types.ExecMode platformOld = rtplatform;
+ rtplatform = Types.ExecMode.SINGLE_NODE;
+ boolean oldRewrite =
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrite;
+
+ try {
+ getAndLoadTestConfiguration(TEST_NAME);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-explain", "-stats",
"-ooc",
+ "-args", input(INPUT_NAME),
output(OUTPUT_NAME)};
+
+ int rows = 3900, cols = 1700;
+ MatrixBlock mb = MatrixBlock.randOperations(rows, cols,
1.0, -1, 1, "uniform", 7);
+ MatrixWriter writer =
MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY);
+ writer.writeMatrixToHDFS(mb, input(INPUT_NAME), rows,
cols, 1000, rows*cols);
+ HDFSTool.writeMetaDataFile(input(INPUT_NAME+".mtd"),
Types.ValueType.FP64,
+ new
MatrixCharacteristics(rows,cols,1000,rows*cols), Types.FileFormat.BINARY);
+
+ runTest(true, false, null, -1);
+
+ double[][] res =
DataConverter.convertToDoubleMatrix(DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME),
Types.FileFormat.BINARY, rows, 1, 1000, 1000));
+ for(int i = 0; i < rows; i++) {
+ double expected = 0.0;
+ for(int j = 0; j < cols; j++) {
+ expected += mb.get(i, j);
+ }
+ Assert.assertEquals(expected, res[i][0], 1e-10);
+ }
+
+ String prefix = Instruction.OOC_INST_PREFIX;
+ Assert.assertTrue("OOC wasn't used for RBLK",
+ heavyHittersContainsString(prefix +
Opcodes.RBLK));
+ // uark+
+ Assert.assertTrue("OOC wasn't used for ROWSUMS",
+ heavyHittersContainsString(prefix +
Opcodes.UARKP));
+ }
+ catch(Exception ex) {
+ Assert.fail(ex.getMessage());
+ }
+ finally {
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION =
oldRewrite;
+ resetExecMode(platformOld);
+ }
+ }
+
+}
diff --git a/src/test/scripts/functions/ooc/ColAggregationTest.dml
b/src/test/scripts/functions/ooc/ColAggregationTest.dml
new file mode 100644
index 0000000000..41da91238b
--- /dev/null
+++ b/src/test/scripts/functions/ooc/ColAggregationTest.dml
@@ -0,0 +1,24 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = read($1);
+res = colSums(X)
+write(res, $2, format="binary");
diff --git a/src/test/scripts/functions/ooc/RowAggregationTest.dml
b/src/test/scripts/functions/ooc/RowAggregationTest.dml
new file mode 100644
index 0000000000..1cc1272cac
--- /dev/null
+++ b/src/test/scripts/functions/ooc/RowAggregationTest.dml
@@ -0,0 +1,24 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = read($1);
+res = rowSums(X)
+write(res, $2, format="binary");