This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 0263279673 [SYSTEMDS-3895] Add OOC row and column aggregations with 
tests
0263279673 is described below

commit 0263279673aed3403f73ebc0c8702107d8511a5c
Author: Jessica Priebe <[email protected]>
AuthorDate: Wed Aug 20 13:10:24 2025 +0200

    [SYSTEMDS-3895] Add OOC row and column aggregations with tests
    
    Closes #2309.
---
 .../ooc/AggregateUnaryOOCInstruction.java          | 130 +++++++++++++++++----
 .../test/functions/ooc/ColAggregationTest.java     | 113 ++++++++++++++++++
 .../test/functions/ooc/RowAggregationTest.java     | 113 ++++++++++++++++++
 .../scripts/functions/ooc/ColAggregationTest.dml   |  24 ++++
 .../scripts/functions/ooc/RowAggregationTest.dml   |  24 ++++
 5 files changed, 384 insertions(+), 20 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java
index c333088239..a656cd337c 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java
@@ -30,10 +30,15 @@ import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.instructions.cp.DoubleObject;
 import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
 import org.apache.sysds.runtime.matrix.data.OperationsOnMatrixValues;
 import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
 import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
+import org.apache.sysds.runtime.meta.DataCharacteristics;
+import org.apache.sysds.runtime.util.CommonThreadPool;
 
+import java.util.HashMap;
+import java.util.concurrent.ExecutorService;
 
 public class AggregateUnaryOOCInstruction extends ComputationOOCInstruction {
        private AggregateOperator _aop = null;
@@ -61,34 +66,119 @@ public class AggregateUnaryOOCInstruction extends 
ComputationOOCInstruction {
        
        @Override
        public void processInstruction( ExecutionContext ec ) {
-               //TODO support all types of aggregations, currently only full 
aggregation
+               //TODO support all types of aggregations, currently only full 
aggregation, row aggregation and column aggregation
                
                //setup operators and input queue
                AggregateUnaryOperator aggun = (AggregateUnaryOperator) 
getOperator(); 
                MatrixObject min = ec.getMatrixObject(input1);
                LocalTaskQueue<IndexedMatrixValue> q = min.getStreamHandle();
-               IndexedMatrixValue tmp = null;
                int blen = ConfigurationManager.getBlocksize();
-               
-               //read blocks and aggregate immediately into result
-               int extra = _aop.correction.getNumRemovedRowsColumns();
-               MatrixBlock ret = new MatrixBlock(1,1+extra,false);
-               MatrixBlock corr = new MatrixBlock(1,1+extra,false);
-               try {
-                       while((tmp = q.dequeueTask()) != 
LocalTaskQueue.NO_MORE_TASKS) {
-                               //block aggregation
-                               MatrixBlock ltmp = (MatrixBlock) ((MatrixBlock) 
tmp.getValue())
-                                       .aggregateUnaryOperations(aggun, new 
MatrixBlock(), blen, tmp.getIndexes());
-                               //accumulation into final result
-                               OperationsOnMatrixValues.incrementalAggregation(
-                                       ret, _aop.existsCorrection() ? corr : 
null, ltmp, _aop, true);
+
+               if (aggun.isRowAggregate() || aggun.isColAggregate()) {
+                       // intermediate state per aggregation index
+                       HashMap<Long, MatrixBlock> aggs = new HashMap<>(); // 
partial aggregates
+                       HashMap<Long, MatrixBlock> corrs = new HashMap<>(); // 
correction blocks
+                       HashMap<Long, Integer> cnt = new HashMap<>(); // 
processed block count per agg idx
+
+                       DataCharacteristics chars = 
ec.getDataCharacteristics(input1.getName());
+                       // number of blocks to process per aggregation idx (row 
or column dim)
+                       long nBlocks = aggun.isRowAggregate()? 
chars.getNumColBlocks() : chars.getNumRowBlocks();
+
+                       LocalTaskQueue<IndexedMatrixValue> qOut = new 
LocalTaskQueue<>();
+                       ec.getMatrixObject(output).setStreamHandle(qOut);
+                       ExecutorService pool = CommonThreadPool.get();
+                       try {
+                               pool.submit(() -> {
+                                       IndexedMatrixValue tmp = null;
+                                       try {
+                                               while((tmp = q.dequeueTask()) 
!= LocalTaskQueue.NO_MORE_TASKS) {
+                                                       long idx  = 
aggun.isRowAggregate() ?
+                                                               
tmp.getIndexes().getRowIndex() : tmp.getIndexes().getColumnIndex();
+                                                       
if(aggs.containsKey(idx)) {
+                                                               // update 
existing partial aggregate for this idx
+                                                               MatrixBlock ret 
= aggs.get(idx);
+                                                               MatrixBlock 
corr = corrs.get(idx);
+
+                                                               // aggregation
+                                                               MatrixBlock 
ltmp = (MatrixBlock) ((MatrixBlock) tmp.getValue())
+                                                                       
.aggregateUnaryOperations(aggun, new MatrixBlock(), blen, tmp.getIndexes());
+                                                               
OperationsOnMatrixValues.incrementalAggregation(ret,
+                                                                       
_aop.existsCorrection() ? corr : null, ltmp, _aop, true);
+
+                                                               
aggs.replace(idx, ret);
+                                                               
corrs.replace(idx, corr);
+                                                               
cnt.replace(idx, cnt.get(idx) + 1);
+                                                       }
+                                                       else {
+                                                               // first block 
for this idx - init aggregate and correction
+                                                               // TODO avoid 
corr block for inplace incremental aggregation
+                                                               int rows = 
tmp.getValue().getNumRows();
+                                                               int cols = 
tmp.getValue().getNumColumns();
+                                                               int extra = 
_aop.correction.getNumRemovedRowsColumns();
+                                                               MatrixBlock ret 
= aggun.isRowAggregate()? new MatrixBlock(rows, 1 + extra, false) : new 
MatrixBlock(1 + extra, cols, false);
+                                                               MatrixBlock 
corr = aggun.isRowAggregate()? new MatrixBlock(rows, 1 + extra, false) : new 
MatrixBlock(1 + extra, cols, false);
+
+                                                               // aggregation
+                                                               MatrixBlock 
ltmp = (MatrixBlock) ((MatrixBlock) tmp.getValue()).aggregateUnaryOperations(
+                                                                       aggun, 
new MatrixBlock(), blen, tmp.getIndexes());
+                                                               
OperationsOnMatrixValues.incrementalAggregation(ret,
+                                                                       
_aop.existsCorrection() ? corr : null, ltmp, _aop, true);
+
+                                                               aggs.put(idx, 
ret);
+                                                               corrs.put(idx, 
corr);
+                                                               cnt.put(idx, 1);
+                                                       }
+
+                                                       if(cnt.get(idx) == 
nBlocks) {
+                                                               // all input 
blocks for this idx processed - emit aggregated block
+                                                               MatrixBlock ret 
= aggs.get(idx);
+                                                               // drop 
correction row/col
+                                                               
ret.dropLastRowsOrColumns(_aop.correction);
+                                                               MatrixIndexes 
midx = aggun.isRowAggregate()? new 
MatrixIndexes(tmp.getIndexes().getRowIndex(), 1) : new MatrixIndexes(1, 
tmp.getIndexes().getColumnIndex());
+                                                               
IndexedMatrixValue tmpOut = new IndexedMatrixValue(midx, ret);
+
+                                                               
qOut.enqueueTask(tmpOut);
+                                                               // drop 
intermediate states
+                                                               
aggs.remove(idx);
+                                                               
corrs.remove(idx);
+                                                               cnt.remove(idx);
+                                                       }
+                                               }
+                                               qOut.closeInput();
+                                       }
+                                       catch(Exception ex) {
+                                               throw new 
DMLRuntimeException(ex);
+                                       }
+                               });
+                       } catch (Exception ex) {
+                               throw new DMLRuntimeException(ex);
+                       } finally {
+                               pool.shutdown();
                        }
                }
-               catch(Exception ex) {
-                       throw new DMLRuntimeException(ex);
+               // full aggregation
+               else {
+                       IndexedMatrixValue tmp = null;
+                       //read blocks and aggregate immediately into result
+                       int extra = _aop.correction.getNumRemovedRowsColumns();
+                       MatrixBlock ret = new MatrixBlock(1,1+extra,false);
+                       MatrixBlock corr = new MatrixBlock(1,1+extra,false);
+                       try {
+                               while((tmp = q.dequeueTask()) != 
LocalTaskQueue.NO_MORE_TASKS) {
+                                       //block aggregation
+                                       MatrixBlock ltmp = (MatrixBlock) 
((MatrixBlock) tmp.getValue())
+                                               
.aggregateUnaryOperations(aggun, new MatrixBlock(), blen, tmp.getIndexes());
+                                       //accumulation into final result
+                                       
OperationsOnMatrixValues.incrementalAggregation(
+                                               ret, _aop.existsCorrection() ? 
corr : null, ltmp, _aop, true);
+                               }
+                       }
+                       catch(Exception ex) {
+                               throw new DMLRuntimeException(ex);
+                       }
+
+                       //create scalar output
+                       ec.setScalarOutput(output.getName(), new 
DoubleObject(ret.get(0, 0)));
                }
-               
-               //create scalar output
-               ec.setScalarOutput(output.getName(), new 
DoubleObject(ret.get(0, 0)));
        }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/ooc/ColAggregationTest.java 
b/src/test/java/org/apache/sysds/test/functions/ooc/ColAggregationTest.java
new file mode 100644
index 0000000000..ff430039e2
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/ooc/ColAggregationTest.java
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.ooc;
+
+import org.apache.sysds.common.Opcodes;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.instructions.Instruction;
+import org.apache.sysds.runtime.io.MatrixWriter;
+import org.apache.sysds.runtime.io.MatrixWriterFactory;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.DataConverter;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class ColAggregationTest extends AutomatedTestBase{
+       private static final String TEST_NAME = "ColAggregationTest";
+       private static final String TEST_DIR = "functions/ooc/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
ColAggregationTest.class.getSimpleName() + "/";
+       private static final String INPUT_NAME = "X";
+       private static final String OUTPUT_NAME = "res";
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               TestConfiguration config = new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME);
+               addTestConfiguration(TEST_NAME, config);
+       }
+
+       @Test
+       public void testColAggregationNoRewrite() {
+               testColAggregation(false);
+       }
+
+       /**
+        * Test the col aggregation, "colSums(X)", with OOC backend.
+        */
+       @Test
+       public void testColAggregationRewrite() {
+               testColAggregation(true);
+       }
+
+       public void testColAggregation(boolean rewrite)
+       {
+               Types.ExecMode platformOld = rtplatform;
+               rtplatform = Types.ExecMode.SINGLE_NODE;
+               boolean oldRewrite = 
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+               OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrite;
+
+               try {
+                       getAndLoadTestConfiguration(TEST_NAME);
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       programArgs = new String[] {"-explain", "-stats", 
"-ooc",
+                               "-args", input(INPUT_NAME), 
output(OUTPUT_NAME)};
+
+                       int rows = 4200, cols = 2700;
+                       MatrixBlock mb = MatrixBlock.randOperations(rows, cols, 
1.0, -1, 1, "uniform", 7);
+                       MatrixWriter writer = 
MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY);
+                       writer.writeMatrixToHDFS(mb, input(INPUT_NAME), rows, 
cols, 1000, rows*cols);
+                       HDFSTool.writeMetaDataFile(input(INPUT_NAME+".mtd"), 
Types.ValueType.FP64,
+                               new 
MatrixCharacteristics(rows,cols,1000,rows*cols), Types.FileFormat.BINARY);
+
+                       runTest(true, false, null, -1);
+
+                       double[][] res = 
DataConverter.convertToDoubleMatrix(DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME),
 Types.FileFormat.BINARY, 1, cols, 1000, 1000));
+                       for(int j = 0; j < cols; j++) {
+                               double expected = 0.0;
+                               for(int i = 0; i < rows; i++) {
+                                       expected += mb.get(i, j);
+                               }
+                               Assert.assertEquals(expected, res[0][j], 1e-10);
+                       }
+
+                       String prefix = Instruction.OOC_INST_PREFIX;
+                       Assert.assertTrue("OOC wasn't used for RBLK",
+                               heavyHittersContainsString(prefix + 
Opcodes.RBLK));
+                       // uack+
+                       Assert.assertTrue("OOC wasn't used for COLSUMS",
+                               heavyHittersContainsString(prefix + 
Opcodes.UACKP));
+               }
+               catch(Exception ex) {
+                       Assert.fail(ex.getMessage());
+               }
+               finally {
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
oldRewrite;
+                       resetExecMode(platformOld);
+               }
+       }
+
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/ooc/RowAggregationTest.java 
b/src/test/java/org/apache/sysds/test/functions/ooc/RowAggregationTest.java
new file mode 100644
index 0000000000..920573bb64
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/ooc/RowAggregationTest.java
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.ooc;
+
+import org.apache.sysds.common.Opcodes;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.instructions.Instruction;
+import org.apache.sysds.runtime.io.MatrixWriter;
+import org.apache.sysds.runtime.io.MatrixWriterFactory;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.DataConverter;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class RowAggregationTest extends AutomatedTestBase{
+       private static final String TEST_NAME = "RowAggregationTest";
+       private static final String TEST_DIR = "functions/ooc/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
RowAggregationTest.class.getSimpleName() + "/";
+       private static final String INPUT_NAME = "X";
+       private static final String OUTPUT_NAME = "res";
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               TestConfiguration config = new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME);
+               addTestConfiguration(TEST_NAME, config);
+       }
+
+       @Test
+       public void testRowAggregationNoRewrite() {
+               testRowAggregation(false);
+       }
+
+       /**
+        * Test the row aggregation, "rowSums(X)", with OOC backend.
+        */
+       @Test
+       public void testRowAggregationRewrite() {
+               testRowAggregation(true);
+       }
+
+       public void testRowAggregation(boolean rewrite)
+       {
+               Types.ExecMode platformOld = rtplatform;
+               rtplatform = Types.ExecMode.SINGLE_NODE;
+               boolean oldRewrite = 
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+               OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrite;
+
+               try {
+                       getAndLoadTestConfiguration(TEST_NAME);
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       programArgs = new String[] {"-explain", "-stats", 
"-ooc",
+                               "-args", input(INPUT_NAME), 
output(OUTPUT_NAME)};
+
+                       int rows = 3900, cols = 1700;
+                       MatrixBlock mb = MatrixBlock.randOperations(rows, cols, 
1.0, -1, 1, "uniform", 7);
+                       MatrixWriter writer = 
MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY);
+                       writer.writeMatrixToHDFS(mb, input(INPUT_NAME), rows, 
cols, 1000, rows*cols);
+                       HDFSTool.writeMetaDataFile(input(INPUT_NAME+".mtd"), 
Types.ValueType.FP64,
+                               new 
MatrixCharacteristics(rows,cols,1000,rows*cols), Types.FileFormat.BINARY);
+
+                       runTest(true, false, null, -1);
+
+                       double[][] res = 
DataConverter.convertToDoubleMatrix(DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME),
 Types.FileFormat.BINARY, rows, 1, 1000, 1000));
+                       for(int i = 0; i < rows; i++) {
+                               double expected = 0.0;
+                               for(int j = 0; j < cols; j++) {
+                                       expected += mb.get(i, j);
+                               }
+                               Assert.assertEquals(expected, res[i][0], 1e-10);
+                       }
+
+                       String prefix = Instruction.OOC_INST_PREFIX;
+                       Assert.assertTrue("OOC wasn't used for RBLK",
+                               heavyHittersContainsString(prefix + 
Opcodes.RBLK));
+                       // uark+
+                       Assert.assertTrue("OOC wasn't used for ROWSUMS",
+                               heavyHittersContainsString(prefix + 
Opcodes.UARKP));
+               }
+               catch(Exception ex) {
+                       Assert.fail(ex.getMessage());
+               }
+               finally {
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
oldRewrite;
+                       resetExecMode(platformOld);
+               }
+       }
+
+}
diff --git a/src/test/scripts/functions/ooc/ColAggregationTest.dml 
b/src/test/scripts/functions/ooc/ColAggregationTest.dml
new file mode 100644
index 0000000000..41da91238b
--- /dev/null
+++ b/src/test/scripts/functions/ooc/ColAggregationTest.dml
@@ -0,0 +1,24 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = read($1);
+res = colSums(X)
+write(res, $2, format="binary");
diff --git a/src/test/scripts/functions/ooc/RowAggregationTest.dml 
b/src/test/scripts/functions/ooc/RowAggregationTest.dml
new file mode 100644
index 0000000000..1cc1272cac
--- /dev/null
+++ b/src/test/scripts/functions/ooc/RowAggregationTest.dml
@@ -0,0 +1,24 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = read($1);
+res = rowSums(X)
+write(res, $2, format="binary");

Reply via email to