This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new cd16f7aa40 [SYSTEMDS-3837] Fix trace error handling (only squared
matrices)
cd16f7aa40 is described below
commit cd16f7aa40f72681d3022ea68e8a45381beca541
Author: Matthias Boehm <[email protected]>
AuthorDate: Sat Feb 15 12:31:59 2025 +0100
[SYSTEMDS-3837] Fix trace error handling (only squared matrices)
The trace of a matrix is only defined for squared matrices and our
kernels also assume that internally. However, there was on systematic
error handling leading to some invalid invocations failing with
index-out-of-bounds while others succeeded.
---
.../sysds/parser/BuiltinFunctionExpression.java | 8 +++++-
.../sysds/test/functions/aggregate/TraceTest.java | 31 +++++++++++++++-------
.../scripts/functions/aggregate/TraceInvalid1.dml | 22 +++++++++++++++
.../scripts/functions/aggregate/TraceInvalid2.dml | 22 +++++++++++++++
4 files changed, 73 insertions(+), 10 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
index 84e0fe079b..6a68f867f9 100644
--- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
@@ -942,9 +942,15 @@ public class BuiltinFunctionExpression extends
DataIdentifier {
output.setBlocksize (id.getBlocksize());
output.setValueType(id.getValueType());
break;
+ case TRACE:
+ if(getFirstExpr().getOutput().dimsKnown()
+ && getFirstExpr().getOutput().getDim1() !=
getFirstExpr().getOutput().getDim2())
+ {
+ raiseValidateError("Trace is only defined on
squared matrices but found ["
+
+getFirstExpr().getOutput().getDim1()+"x"+getFirstExpr().getOutput().getDim2()+"].",
conditional);
+ }
case SUM:
case PROD:
- case TRACE:
case SD:
case VAR:
// sum(X);
diff --git
a/src/test/java/org/apache/sysds/test/functions/aggregate/TraceTest.java
b/src/test/java/org/apache/sysds/test/functions/aggregate/TraceTest.java
index fcd46b260d..7c9043a90d 100644
--- a/src/test/java/org/apache/sysds/test/functions/aggregate/TraceTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/aggregate/TraceTest.java
@@ -44,16 +44,20 @@ public class TraceTest extends AutomatedTestBase {
private final static String TEST_DIR = "functions/aggregate/";
private static final String TEST_CLASS_DIR = TEST_DIR +
TraceTest.class.getSimpleName() + "/";
- private final static String TEST_GENERAL = "General";
- private final static String TEST_SCALAR = "Scalar";
+ private final static String TEST_GENERAL = "TraceTest";
+ private final static String TEST_SCALAR = "TraceScalarTest";
+ private final static String TEST_INVALID1 = "TraceInvalid1";
+ private final static String TEST_INVALID2 = "TraceInvalid2";
@Override
public void setUp() {
// positive tests
- addTestConfiguration(TEST_GENERAL, new
TestConfiguration(TEST_CLASS_DIR, "TraceTest", new String[] {"b"}));
+ addTestConfiguration(TEST_GENERAL, new
TestConfiguration(TEST_CLASS_DIR, TEST_GENERAL, new String[] {"b"}));
// negative tests
- addTestConfiguration(TEST_SCALAR, new
TestConfiguration(TEST_CLASS_DIR, "TraceScalarTest", new String[] {"b"}));
+ addTestConfiguration(TEST_SCALAR, new
TestConfiguration(TEST_CLASS_DIR, TEST_SCALAR, new String[] {"b"}));
+ addTestConfiguration(TEST_INVALID1, new
TestConfiguration(TEST_CLASS_DIR, TEST_INVALID1, new String[] {"b"}));
+ addTestConfiguration(TEST_INVALID2, new
TestConfiguration(TEST_CLASS_DIR, TEST_INVALID2, new String[] {"b"}));
}
@Test
@@ -85,16 +89,25 @@ public class TraceTest extends AutomatedTestBase {
@Test
public void testScalar() {
- int scalar = 12;
-
TestConfiguration config = getTestConfiguration(TEST_SCALAR);
- config.addVariable("scalar", scalar);
+ config.addVariable("scalar", 12);
createHelperMatrix();
-
loadTestConfiguration(config);
-
+ runTest(true, LanguageException.class);
+ }
+
+ @Test
+ public void testInvalid1() {
+ TestConfiguration config = getTestConfiguration(TEST_INVALID1);
+ loadTestConfiguration(config);
runTest(true, LanguageException.class);
}
+ @Test
+ public void testInvalid2() {
+ TestConfiguration config = getTestConfiguration(TEST_INVALID2);
+ loadTestConfiguration(config);
+ runTest(true, LanguageException.class);
+ }
}
diff --git a/src/test/scripts/functions/aggregate/TraceInvalid1.dml
b/src/test/scripts/functions/aggregate/TraceInvalid1.dml
new file mode 100644
index 0000000000..c324fd15c2
--- /dev/null
+++ b/src/test/scripts/functions/aggregate/TraceInvalid1.dml
@@ -0,0 +1,22 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+print(trace(rand(rows=100,cols=10)))
diff --git a/src/test/scripts/functions/aggregate/TraceInvalid2.dml
b/src/test/scripts/functions/aggregate/TraceInvalid2.dml
new file mode 100644
index 0000000000..e25c6a8843
--- /dev/null
+++ b/src/test/scripts/functions/aggregate/TraceInvalid2.dml
@@ -0,0 +1,22 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+print(trace(rand(rows=1000,cols=10)))