This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 61afba5d0f [SYSTEMDS-3888] Fix size propagation over unique operations
61afba5d0f is described below
commit 61afba5d0ffe8e1cbf9f3e4b956d57f0bc3997b4
Author: Matthias Boehm <[email protected]>
AuthorDate: Fri Jun 6 12:25:38 2025 +0200
[SYSTEMDS-3888] Fix size propagation over unique operations
This patch fixes the incorrect size propagation of unique which led
to incorrect results if the dimensions are used in subsequent ops.
Thanks to Chi-Hsin Huang for catching this bug.
Furthermore, this patch also includes minor updates for code quality
(removed unused imports, annotated unused functions)
---
.../java/org/apache/sysds/hops/AggUnaryOp.java | 36 +++++++++++++++++-----
.../sysds/hops/estim/EstimatorLayeredGraph.java | 8 +++--
.../RewriteQuantizationFusedCompression.java | 2 --
.../ParameterizedBuiltinFunctionExpression.java | 20 ++++++------
.../test/functions/misc/SizePropagationTest.java | 15 +++++++++
.../functions/misc/SizePropagationUnique.dml | 28 +++++++++++++++++
6 files changed, 88 insertions(+), 21 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
index 954114a0a4..0b2d62bbe3 100644
--- a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
@@ -323,10 +323,18 @@ public class AggUnaryOp extends MultiThreadedHop
DataCharacteristics ret = null;
Hop input = getInput().get(0);
DataCharacteristics dc = memo.getAllInputStats(input);
- if( _direction == Direction.Col && dc.colsKnown() )
- ret = new MatrixCharacteristics(1, dc.getCols(), -1,
-1);
- else if( _direction == Direction.Row && dc.rowsKnown() )
- ret = new MatrixCharacteristics(dc.getRows(), 1, -1,
-1);
+ if( _op == AggOp.UNIQUE ) {
+ if( _direction == Direction.RowCol && dc.rowsKnown() )
+ ret = new MatrixCharacteristics(dc.getRows(),
1, -1, -1);
+ else
+ ret = new MatrixCharacteristics(dc.getRows(),
dc.getCols(), -1, -1);
+ }
+ else {
+ if( _direction == Direction.Col && dc.colsKnown() )
+ ret = new MatrixCharacteristics(1,
dc.getCols(), -1, -1);
+ else if( _direction == Direction.Row && dc.rowsKnown() )
+ ret = new MatrixCharacteristics(dc.getRows(),
1, -1, -1);
+ }
return ret;
}
@@ -648,9 +656,23 @@ public class AggUnaryOp extends MultiThreadedHop
@Override
public void refreshSizeInformation()
{
- if (getDataType() != DataType.SCALAR)
- {
- Hop input = getInput().get(0);
+ Hop input = getInput().get(0);
+ if( _op == AggOp.UNIQUE ) {
+ if ( _direction == Direction.Col ) {
+ setDim1(-1); //unknown num unique
+ setDim2(input.getDim2());
+ }
+ else if ( _direction == Direction.Row ) {
+ setDim1(input.getDim1());
+ setDim2(-1); //unknown num unique
+ }
+ else {
+ setDim1(-1);
+ setDim2(1);
+ }
+ }
+ //general case: all other unary aggregations
+ else if (getDataType() != DataType.SCALAR) {
if ( _direction == Direction.Col ) //colwise
computations
{
setDim1(1);
diff --git
a/src/main/java/org/apache/sysds/hops/estim/EstimatorLayeredGraph.java
b/src/main/java/org/apache/sysds/hops/estim/EstimatorLayeredGraph.java
index f997db6503..1fbdb1fd46 100644
--- a/src/main/java/org/apache/sysds/hops/estim/EstimatorLayeredGraph.java
+++ b/src/main/java/org/apache/sysds/hops/estim/EstimatorLayeredGraph.java
@@ -57,9 +57,9 @@ public class EstimatorLayeredGraph extends SparsityEstimator {
@Override
public DataCharacteristics estim(MMNode root) {
- List<MatrixBlock> leafs = getMatrices(root, new ArrayList<>());
- List<OpCode> ops = getOps(root, new ArrayList<>());
- List<LayeredGraph> LGs = new ArrayList<>();
+ //List<MatrixBlock> leafs = getMatrices(root, new
ArrayList<>());
+ //List<OpCode> ops = getOps(root, new ArrayList<>());
+ //List<LayeredGraph> LGs = new ArrayList<>();
LayeredGraph ret = traverse(root);
long nnz = ret.estimateNnz();
return root.setDataCharacteristics(new MatrixCharacteristics(
@@ -125,6 +125,7 @@ public class EstimatorLayeredGraph extends
SparsityEstimator {
}
}
+ @SuppressWarnings("unused")
private List<MatrixBlock> getMatrices(MMNode node, List<MatrixBlock>
leafs) {
//NOTE: this extraction is only correct and efficient for
chains, no DAGs
if( node.isLeaf() )
@@ -136,6 +137,7 @@ public class EstimatorLayeredGraph extends
SparsityEstimator {
return leafs;
}
+ @SuppressWarnings("unused")
private List<OpCode> getOps(MMNode node, List<OpCode> ops) {
//NOTE: this extraction is only correct and efficient for
chains, no DAGs
if(node.isLeaf()) {
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteQuantizationFusedCompression.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteQuantizationFusedCompression.java
index f29d1dce81..1ff5e086ce 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteQuantizationFusedCompression.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteQuantizationFusedCompression.java
@@ -27,8 +27,6 @@ import java.util.Map.Entry;
import org.apache.sysds.common.Types.OpOp1;
import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.hops.UnaryOp;
-import org.apache.sysds.runtime.instructions.cp.DoubleObject;
-import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.hops.BinaryOp;
import org.apache.sysds.common.Types.DataType;
diff --git
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
index 4ee92e783b..314440628e 100644
---
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
+++
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
@@ -562,24 +562,26 @@ public class ParameterizedBuiltinFunctionExpression
extends DataIdentifier
private void validateUniqueAggregationDirection(Identifier dataId,
DataIdentifier output) {
HashMap<String, Expression> varParams = getVarParams();
+ String inputDirection = Types.Direction.RowCol.toString();
if (varParams.containsKey("dir")) {
- String inputDirectionString =
varParams.get("dir").toString().toUpperCase();
-
+ inputDirection =
varParams.get("dir").toString().toUpperCase();
// unrecognized value for "dir" parameter
- if
(!inputDirectionString.equals(Types.Direction.Row.toString())
- &&
!inputDirectionString.equals(Types.Direction.Col.toString())
- &&
!inputDirectionString.equals(Types.Direction.RowCol.toString())) {
- raiseValidateError("Invalid argument: " +
inputDirectionString + " is not recognized");
+ if
(!inputDirection.equals(Types.Direction.Row.toString())
+ &&
!inputDirection.equals(Types.Direction.Col.toString())
+ &&
!inputDirection.equals(Types.Direction.RowCol.toString())) {
+ raiseValidateError("Invalid argument: " +
inputDirection + " is not recognized");
}
}
- // rc/r/c -> unique return value is the same as the input in
the worst case
// default to dir="rc"
output.setDataType(DataType.MATRIX);
- output.setDimensions(dataId.getDim1(), dataId.getDim2());
+ output.setDimensions(
+ inputDirection.equals(Types.Direction.Row.toString()) ?
dataId.getDim1() : -1,
+ inputDirection.equals(Types.Direction.Col.toString()) ?
dataId.getDim2() :
+
inputDirection.equals(Types.Direction.RowCol.toString()) ? 1 : -1);
output.setBlocksize(dataId.getBlocksize());
output.setValueType(ValueType.FP64);
- output.setNnz(dataId.getNnz());
+ output.setNnz(-1);
}
private void checkStringParam(boolean optional, String fname, String
pname, boolean conditional) {
diff --git
a/src/test/java/org/apache/sysds/test/functions/misc/SizePropagationTest.java
b/src/test/java/org/apache/sysds/test/functions/misc/SizePropagationTest.java
index 4b4a76aa19..9d9fa59bc9 100644
---
a/src/test/java/org/apache/sysds/test/functions/misc/SizePropagationTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/misc/SizePropagationTest.java
@@ -27,6 +27,7 @@ import
org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
import org.junit.Assert;
import java.util.HashMap;
@@ -38,6 +39,7 @@ public class SizePropagationTest extends AutomatedTestBase
private static final String TEST_NAME3 = "SizePropagationLoopIx2";
private static final String TEST_NAME4 = "SizePropagationLoopIx3";
private static final String TEST_NAME5 = "SizePropagationLoopIx4";
+ private static final String TEST_NAME6 = "SizePropagationUnique";
private static final String TEST_DIR = "functions/misc/";
private static final String TEST_CLASS_DIR = TEST_DIR +
SizePropagationTest.class.getSimpleName() + "/";
@@ -52,6 +54,7 @@ public class SizePropagationTest extends AutomatedTestBase
addTestConfiguration( TEST_NAME3, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] { "R" }) );
addTestConfiguration( TEST_NAME4, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] { "R" }) );
addTestConfiguration( TEST_NAME5, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] { "R" }) );
+ addTestConfiguration( TEST_NAME6, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME6, new String[] { "R" }) );
}
@Test
@@ -104,6 +107,16 @@ public class SizePropagationTest extends AutomatedTestBase
testSizePropagation( TEST_NAME5, true, N );
}
+ @Test
+ public void testSizePropagationUnique1() {
+ testSizePropagation( TEST_NAME6, false, 10 );
+ }
+
+ @Test
+ public void testSizePropagationUnique2() {
+ testSizePropagation( TEST_NAME6, false, 10 );
+ }
+
private void testSizePropagation( String testname, boolean rewrites,
int expect ) {
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
ExecMode oldPlatform = rtplatform;
@@ -122,6 +135,8 @@ public class SizePropagationTest extends AutomatedTestBase
runTest(true, false, null, -1);
HashMap<CellIndex, Double> dmlfile =
readDMLMatrixFromOutputDir("R");
Assert.assertEquals(Double.valueOf(expect),
dmlfile.get(new CellIndex(1,1)));
+ if( testname.equals(TEST_NAME6) )
+ Assert.assertEquals(0,
Statistics.getNoOfCompiledSPInst());
}
finally {
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
diff --git a/src/test/scripts/functions/misc/SizePropagationUnique.dml
b/src/test/scripts/functions/misc/SizePropagationUnique.dml
new file mode 100644
index 0000000000..803cb949ed
--- /dev/null
+++ b/src/test/scripts/functions/misc/SizePropagationUnique.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = matrix("1 2 3 4 5 6 7", rows=7,cols=1)
+B = matrix("4 5 6 7 8 9 10", rows=7,cols=1)
+C = rbind(A,B)
+D = unique(C)
+n = nrow(D);
+R = as.matrix(n);
+write(R, $2);