This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new cd16f7aa40 [SYSTEMDS-3837] Fix trace error handling (only squared 
matrices)
cd16f7aa40 is described below

commit cd16f7aa40f72681d3022ea68e8a45381beca541
Author: Matthias Boehm <[email protected]>
AuthorDate: Sat Feb 15 12:31:59 2025 +0100

    [SYSTEMDS-3837] Fix trace error handling (only squared matrices)
    
    The trace of a matrix is only defined for squared matrices and our
    kernels also assume that internally. However, there was on systematic
    error handling leading to some invalid invocations failing with
    index-out-of-bounds while others succeeded.
---
 .../sysds/parser/BuiltinFunctionExpression.java    |  8 +++++-
 .../sysds/test/functions/aggregate/TraceTest.java  | 31 +++++++++++++++-------
 .../scripts/functions/aggregate/TraceInvalid1.dml  | 22 +++++++++++++++
 .../scripts/functions/aggregate/TraceInvalid2.dml  | 22 +++++++++++++++
 4 files changed, 73 insertions(+), 10 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java 
b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
index 84e0fe079b..6a68f867f9 100644
--- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
@@ -942,9 +942,15 @@ public class BuiltinFunctionExpression extends 
DataIdentifier {
                        output.setBlocksize (id.getBlocksize());
                        output.setValueType(id.getValueType());
                        break;
+               case TRACE:
+                       if(getFirstExpr().getOutput().dimsKnown() 
+                               && getFirstExpr().getOutput().getDim1() != 
getFirstExpr().getOutput().getDim2()) 
+                       {
+                               raiseValidateError("Trace is only defined on 
squared matrices but found ["
+                                       
+getFirstExpr().getOutput().getDim1()+"x"+getFirstExpr().getOutput().getDim2()+"].",
 conditional);
+                       }
                case SUM:
                case PROD:
-               case TRACE:
                case SD:
                case VAR:
                        // sum(X);
diff --git 
a/src/test/java/org/apache/sysds/test/functions/aggregate/TraceTest.java 
b/src/test/java/org/apache/sysds/test/functions/aggregate/TraceTest.java
index fcd46b260d..7c9043a90d 100644
--- a/src/test/java/org/apache/sysds/test/functions/aggregate/TraceTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/aggregate/TraceTest.java
@@ -44,16 +44,20 @@ public class TraceTest extends AutomatedTestBase {
 
        private final static String TEST_DIR = "functions/aggregate/";
        private static final String TEST_CLASS_DIR = TEST_DIR + 
TraceTest.class.getSimpleName() + "/";
-       private final static String TEST_GENERAL = "General";
-       private final static String TEST_SCALAR = "Scalar";
+       private final static String TEST_GENERAL = "TraceTest";
+       private final static String TEST_SCALAR = "TraceScalarTest";
+       private final static String TEST_INVALID1 = "TraceInvalid1";
+       private final static String TEST_INVALID2 = "TraceInvalid2";
 
        @Override
        public void setUp() {
                // positive tests
-               addTestConfiguration(TEST_GENERAL, new 
TestConfiguration(TEST_CLASS_DIR, "TraceTest", new String[] {"b"}));
+               addTestConfiguration(TEST_GENERAL, new 
TestConfiguration(TEST_CLASS_DIR, TEST_GENERAL, new String[] {"b"}));
 
                // negative tests
-               addTestConfiguration(TEST_SCALAR, new 
TestConfiguration(TEST_CLASS_DIR, "TraceScalarTest", new String[] {"b"}));
+               addTestConfiguration(TEST_SCALAR, new 
TestConfiguration(TEST_CLASS_DIR, TEST_SCALAR, new String[] {"b"}));
+               addTestConfiguration(TEST_INVALID1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_INVALID1, new String[] {"b"}));
+               addTestConfiguration(TEST_INVALID2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_INVALID2, new String[] {"b"}));
        }
 
        @Test
@@ -85,16 +89,25 @@ public class TraceTest extends AutomatedTestBase {
 
        @Test
        public void testScalar() {
-               int scalar = 12;
-
                TestConfiguration config = getTestConfiguration(TEST_SCALAR);
-               config.addVariable("scalar", scalar);
+               config.addVariable("scalar", 12);
 
                createHelperMatrix();
-
                loadTestConfiguration(config);
-
+               runTest(true, LanguageException.class);
+       }
+       
+       @Test
+       public void testInvalid1() {
+               TestConfiguration config = getTestConfiguration(TEST_INVALID1);
+               loadTestConfiguration(config);
                runTest(true, LanguageException.class);
        }
 
+       @Test
+       public void testInvalid2() {
+               TestConfiguration config = getTestConfiguration(TEST_INVALID2);
+               loadTestConfiguration(config);
+               runTest(true, LanguageException.class);
+       }
 }
diff --git a/src/test/scripts/functions/aggregate/TraceInvalid1.dml 
b/src/test/scripts/functions/aggregate/TraceInvalid1.dml
new file mode 100644
index 0000000000..c324fd15c2
--- /dev/null
+++ b/src/test/scripts/functions/aggregate/TraceInvalid1.dml
@@ -0,0 +1,22 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+print(trace(rand(rows=100,cols=10)))
diff --git a/src/test/scripts/functions/aggregate/TraceInvalid2.dml 
b/src/test/scripts/functions/aggregate/TraceInvalid2.dml
new file mode 100644
index 0000000000..e25c6a8843
--- /dev/null
+++ b/src/test/scripts/functions/aggregate/TraceInvalid2.dml
@@ -0,0 +1,22 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+print(trace(rand(rows=1000,cols=10)))

Reply via email to