This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 4750f63923 [SYSTEMDS-3678] Fix list append size propagation issue
4750f63923 is described below
commit 4750f63923cdff3abe8398c2d7d99192c1b18fbf
Author: Matthias Boehm <[email protected]>
AuthorDate: Sat Mar 23 17:08:20 2024 +0100
[SYSTEMDS-3678] Fix list append size propagation issue
This patch fixes a list append size propagation issue in loops, where
the initial propagation finds discrepancies during reconcilation of
alternative path, needs to reset the loop body, but can't because
the list size propagation only propagated known sizes.
---
src/main/java/org/apache/sysds/hops/BinaryOp.java | 6 ++++--
.../builtin/part1/BuiltinGaussianClassifierTest.java | 16 +++-------------
2 files changed, 7 insertions(+), 15 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java
b/src/main/java/org/apache/sysds/hops/BinaryOp.java
index 74740f10ce..954f0919ab 100644
--- a/src/main/java/org/apache/sysds/hops/BinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java
@@ -1018,8 +1018,10 @@ public class BinaryOp extends MultiThreadedHop {
setDim2(0);
}
else if ( getDataType() == DataType.LIST ) {
- if( input1.getDataType().isList() && input1.rowsKnown()
) {
- setDim1(input1.getDim1() + 1);
+ if( (op == OpOp2.CBIND || op == OpOp2.RBIND)
+ && input1.getDataType().isList() ) {
+ //always derive from input to allow unsetting
+ setDim1(input1.rowsKnown() ? input1.getDim1() +
1 : -1);
setDim2(1); //always col-vector
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinGaussianClassifierTest.java
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinGaussianClassifierTest.java
index f73ff70a3d..3dab13f332 100644
---
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinGaussianClassifierTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinGaussianClassifierTest.java
@@ -89,19 +89,9 @@ public class BuiltinGaussianClassifierTest extends
AutomatedTestBase
fullDMLScriptName = HOME + TEST_NAME + ".dml";
double varSmoothing = 1e-9;
-
- List<String> proArgs = new ArrayList<>();
- proArgs.add("-args");
- proArgs.add(input("X"));
- proArgs.add(input("Y"));
- proArgs.add(String.valueOf(varSmoothing));
- proArgs.add(output("priors"));
- proArgs.add(output("means"));
- proArgs.add(output("determinants"));
- proArgs.add(output("invcovs"));
-
- programArgs = proArgs.toArray(new String[proArgs.size()]);
-
+ programArgs = new String[] {"-args",
+ input("X"), input("Y"), String.valueOf(varSmoothing),
+ output("priors"), output("means"),
output("determinants"), output("invcovs")};
rCmd = getRCmd(inputDir(), Double.toString(varSmoothing),
expectedDir());
double[][] X = getRandomMatrix(rows, cols, 0, 100, sparsity,
-1);