This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 4750f63923 [SYSTEMDS-3678] Fix list append size propagation issue
4750f63923 is described below

commit 4750f63923cdff3abe8398c2d7d99192c1b18fbf
Author: Matthias Boehm <[email protected]>
AuthorDate: Sat Mar 23 17:08:20 2024 +0100

    [SYSTEMDS-3678] Fix list append size propagation issue
    
    This patch fixes a list append size propagation issue in loops, where
    the initial propagation finds discrepancies during reconcilation of
    alternative path, needs to reset the loop body, but can't because
    the list size propagation only propagated known sizes.
---
 src/main/java/org/apache/sysds/hops/BinaryOp.java        |  6 ++++--
 .../builtin/part1/BuiltinGaussianClassifierTest.java     | 16 +++-------------
 2 files changed, 7 insertions(+), 15 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java 
b/src/main/java/org/apache/sysds/hops/BinaryOp.java
index 74740f10ce..954f0919ab 100644
--- a/src/main/java/org/apache/sysds/hops/BinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java
@@ -1018,8 +1018,10 @@ public class BinaryOp extends MultiThreadedHop {
                        setDim2(0);
                }
                else if ( getDataType() == DataType.LIST ) {
-                       if( input1.getDataType().isList() && input1.rowsKnown() 
) {
-                               setDim1(input1.getDim1() + 1);
+                       if( (op == OpOp2.CBIND || op == OpOp2.RBIND)
+                               && input1.getDataType().isList() ) {
+                               //always derive from input to allow unsetting
+                               setDim1(input1.rowsKnown() ? input1.getDim1() + 
1 : -1);
                                setDim2(1); //always col-vector
                        }
                }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinGaussianClassifierTest.java
 
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinGaussianClassifierTest.java
index f73ff70a3d..3dab13f332 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinGaussianClassifierTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinGaussianClassifierTest.java
@@ -89,19 +89,9 @@ public class BuiltinGaussianClassifierTest extends 
AutomatedTestBase
                fullDMLScriptName = HOME + TEST_NAME + ".dml";
 
                double varSmoothing = 1e-9;
-
-               List<String> proArgs = new ArrayList<>();
-               proArgs.add("-args");
-               proArgs.add(input("X"));
-               proArgs.add(input("Y"));
-               proArgs.add(String.valueOf(varSmoothing));
-               proArgs.add(output("priors"));
-               proArgs.add(output("means"));
-               proArgs.add(output("determinants"));
-               proArgs.add(output("invcovs"));
-
-               programArgs = proArgs.toArray(new String[proArgs.size()]);
-
+               programArgs = new String[] {"-args",
+                       input("X"), input("Y"), String.valueOf(varSmoothing),
+                       output("priors"), output("means"), 
output("determinants"), output("invcovs")};
                rCmd = getRCmd(inputDir(), Double.toString(varSmoothing), 
expectedDir());
                
                double[][] X = getRandomMatrix(rows, cols, 0, 100, sparsity, 
-1);

Reply via email to