This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 2e68ad3b1a [SYSTEMDS-3860] Extended codegen row template by var
aggregates
2e68ad3b1a is described below
commit 2e68ad3b1acd4892ed782d01dbfccffc61d2f680
Author: Frxms <[email protected]>
AuthorDate: Fri Apr 18 11:13:02 2025 +0200
[SYSTEMDS-3860] Extended codegen row template by var aggregates
Closes #2244.
---
.../sysds/hops/codegen/cplan/CNodeUnary.java | 4 ++-
.../sysds/hops/codegen/cplan/java/Unary.java | 2 +-
.../sysds/hops/codegen/template/TemplateRow.java | 2 +-
.../sysds/runtime/codegen/LibSpoofPrimitives.java | 14 +++++++-
.../sysds/test/component/misc/DMLScriptTest.java | 1 -
.../functions/builtin/part2/BuiltinMDTest.java | 2 --
.../test/functions/codegen/RowAggTmplTest.java | 38 ++++++++++++++++++++--
.../scripts/functions/codegen/rowAggPattern47.R | 36 ++++++++++++++++++++
.../scripts/functions/codegen/rowAggPattern47.dml | 29 +++++++++++++++++
.../scripts/functions/codegen/rowAggPattern48.R | 36 ++++++++++++++++++++
.../scripts/functions/codegen/rowAggPattern48.dml | 30 +++++++++++++++++
11 files changed, 184 insertions(+), 10 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeUnary.java
b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeUnary.java
index 93cdb2f661..fe67995b6b 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeUnary.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeUnary.java
@@ -33,7 +33,7 @@ public class CNodeUnary extends CNode
public enum UnaryType {
LOOKUP_R, LOOKUP_C, LOOKUP_RC, LOOKUP0, //codegen specific
ROW_SUMS, ROW_SUMSQS, ROW_COUNTNNZS, //codegen specific
- ROW_MEANS, ROW_MINS, ROW_MAXS,
+ ROW_MEANS, ROW_MINS, ROW_MAXS, ROW_VARS,
VECT_EXP, VECT_POW2, VECT_MULT2, VECT_SQRT, VECT_LOG,
VECT_ABS, VECT_ROUND, VECT_CEIL, VECT_FLOOR, VECT_SIGN,
VECT_SIN, VECT_COS, VECT_TAN, VECT_ASIN, VECT_ACOS, VECT_ATAN,
@@ -139,6 +139,7 @@ public class CNodeUnary extends CNode
case ROW_MINS: return "u(Rmin)";
case ROW_MAXS: return "u(Rmax)";
case ROW_MEANS: return "u(Rmean)";
+ case ROW_VARS: return "u(Rvar)";
case ROW_COUNTNNZS: return "u(Rnnz)";
case VECT_EXP:
case VECT_POW2:
@@ -210,6 +211,7 @@ public class CNodeUnary extends CNode
case ROW_MINS:
case ROW_MAXS:
case ROW_MEANS:
+ case ROW_VARS:
case ROW_COUNTNNZS:
case EXP:
case LOOKUP_R:
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Unary.java
b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Unary.java
index 50ea2bace8..d8a1085df5 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Unary.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Unary.java
@@ -32,12 +32,12 @@ public class Unary extends CodeTemplate {
case ROW_MINS:
case ROW_MAXS:
case ROW_MEANS:
+ case ROW_VARS:
case ROW_COUNTNNZS: {
String vectName =
StringUtils.capitalize(type.name().substring(4,
type.name().length()-1).toLowerCase());
return sparse ? " double %TMP% =
LibSpoofPrimitives.vect"+vectName+"(%IN1v%, %IN1i%, %POS1%, alen, len);\n":
" double %TMP% =
LibSpoofPrimitives.vect"+vectName+"(%IN1%, %POS1%, %LEN%);\n";
}
-
case VECT_EXP:
case VECT_POW2:
case VECT_MULT2:
diff --git
a/src/main/java/org/apache/sysds/hops/codegen/template/TemplateRow.java
b/src/main/java/org/apache/sysds/hops/codegen/template/TemplateRow.java
index c42ea6c858..955bf778b8 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/template/TemplateRow.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/template/TemplateRow.java
@@ -67,7 +67,7 @@ import org.apache.sysds.runtime.matrix.data.Pair;
public class TemplateRow extends TemplateBase
{
- private static final AggOp[] SUPPORTED_ROW_AGG = new AggOp[]{AggOp.SUM,
AggOp.MIN, AggOp.MAX, AggOp.MEAN, AggOp.PROD};
+ private static final AggOp[] SUPPORTED_ROW_AGG = new AggOp[]{AggOp.SUM,
AggOp.MIN, AggOp.MAX, AggOp.MEAN, AggOp.PROD, AggOp.VAR};
private static final OpOp1[] SUPPORTED_VECT_UNARY = new OpOp1[]{
OpOp1.EXP, OpOp1.SQRT, OpOp1.LOG, OpOp1.ABS, OpOp1.ROUND,
OpOp1.CEIL, OpOp1.FLOOR, OpOp1.SIGN,
OpOp1.SIN, OpOp1.COS, OpOp1.TAN, OpOp1.ASIN, OpOp1.ACOS,
OpOp1.ATAN, OpOp1.SINH, OpOp1.COSH, OpOp1.TANH,
diff --git
a/src/main/java/org/apache/sysds/runtime/codegen/LibSpoofPrimitives.java
b/src/main/java/org/apache/sysds/runtime/codegen/LibSpoofPrimitives.java
index 6497b6f321..6c0dc395c3 100644
--- a/src/main/java/org/apache/sysds/runtime/codegen/LibSpoofPrimitives.java
+++ b/src/main/java/org/apache/sysds/runtime/codegen/LibSpoofPrimitives.java
@@ -2151,7 +2151,19 @@ public class LibSpoofPrimitives
new DenseBlockFP64(new int[]{K, PQ}, c), PQ, CRS, 0, K,
0, PQ);
return c;
}
-
+
+ public static double vectVar(double[] a, int ai, int len) {
+ double meanVal = Math.pow(vectMean(a, ai, len), 2);
+ double[] aSqr = vectPow2Write(a, ai, len);
+ return (vectSum(aSqr, 0, len)-len*meanVal)/(len-1);
+ }
+
+ public static double vectVar(double[] avals, int[] aix, int ai, int
alen, int len) {
+ double meanVal = Math.pow(vectMean(avals, aix, ai, alen, len),
2);
+ double[] avalsSqr = vectPow2Write(avals, aix, ai, alen, len);
+ return (vectSum(avalsSqr, 0, len)-len*meanVal)/(len-1);
+ }
+
//complex builtin functions that are not directly generated
//(included here in order to reduce the number of imports)
diff --git
a/src/test/java/org/apache/sysds/test/component/misc/DMLScriptTest.java
b/src/test/java/org/apache/sysds/test/component/misc/DMLScriptTest.java
index 5b5483823a..4244ce7421 100644
--- a/src/test/java/org/apache/sysds/test/component/misc/DMLScriptTest.java
+++ b/src/test/java/org/apache/sysds/test/component/misc/DMLScriptTest.java
@@ -24,7 +24,6 @@ package org.apache.sysds.test.component.misc;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.log4j.spi.LoggingEvent;
-import org.apache.sysds.api.DMLOptions;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.parser.LanguageException;
import org.apache.sysds.test.LoggingUtils;
diff --git
a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinMDTest.java
b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinMDTest.java
index b04d476d06..4c51602058 100644
---
a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinMDTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinMDTest.java
@@ -90,8 +90,6 @@ public class BuiltinMDTest extends AutomatedTestBase {
}
@Test
- //@Ignore
- // https://issues.apache.org/jira/browse/SYSTEMDS-3716
public void testMDSP() {
double[][] D = {
{7567, 231, 1231, 1232, 122, 321},
diff --git
a/src/test/java/org/apache/sysds/test/functions/codegen/RowAggTmplTest.java
b/src/test/java/org/apache/sysds/test/functions/codegen/RowAggTmplTest.java
index d3c9edf8e8..b89f3007b4 100644
--- a/src/test/java/org/apache/sysds/test/functions/codegen/RowAggTmplTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/codegen/RowAggTmplTest.java
@@ -87,7 +87,9 @@ public class RowAggTmplTest extends AutomatedTestBase
private static final String TEST_NAME44 = TEST_NAME+"44"; //maxpool(X -
mean(X)) + 7;
private static final String TEST_NAME45 = TEST_NAME+"45"; //vector
allocation;
private static final String TEST_NAME46 = TEST_NAME+"46"; //conv2d(X -
mean(X), F1) + conv2d(X - mean(X), F2);
-
+ private static final String TEST_NAME47 = TEST_NAME+"47"; //sum(X +
rowVars(X))
+ private static final String TEST_NAME48 = TEST_NAME+"48";
//sum(rowVars(X))
+
private static final String TEST_DIR = "functions/codegen/";
private static final String TEST_CLASS_DIR = TEST_DIR +
RowAggTmplTest.class.getSimpleName() + "/";
private final static String TEST_CONF = "SystemDS-config-codegen.xml";
@@ -98,7 +100,7 @@ public class RowAggTmplTest extends AutomatedTestBase
@Override
public void setUp() {
TestUtils.clearAssertionInformation();
- for(int i=1; i<=46; i++)
+ for(int i=1; i<=48; i++)
addTestConfiguration( TEST_NAME+i, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME+i, new String[] { String.valueOf(i)
}) );
}
@@ -795,6 +797,36 @@ public class RowAggTmplTest extends AutomatedTestBase
testCodegenIntegration( TEST_NAME46, false, ExecType.SPARK );
}
+ @Test
+ public void testCodegenRowAggRewrite47CP() {
+ testCodegenIntegration( TEST_NAME47, true, ExecType.CP );
+ }
+
+ @Test
+ public void testCodegenRowAgg47CP() {
+ testCodegenIntegration( TEST_NAME47, false, ExecType.CP );
+ }
+
+ @Test
+ public void testCodegenRowAgg47SP() {
+ testCodegenIntegration( TEST_NAME47, false, ExecType.SPARK );
+ }
+
+ @Test
+ public void testCodegenRowAggRewrite48CP() {
+ testCodegenIntegration( TEST_NAME48, true, ExecType.CP );
+ }
+
+ @Test
+ public void testCodegenRowAgg48CP() {
+ testCodegenIntegration( TEST_NAME48, false, ExecType.CP );
+ }
+
+ @Test
+ public void testCodegenRowAgg48SP() {
+ testCodegenIntegration( TEST_NAME48, false, ExecType.SPARK );
+ }
+
private void testCodegenIntegration( String testname, boolean rewrites,
ExecType instType )
{
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
@@ -807,7 +839,7 @@ public class RowAggTmplTest extends AutomatedTestBase
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + testname + ".dml";
- programArgs = new String[]{"-stats", "-args",
output("S") };
+ programArgs = new String[]{"-explain", "codegen",
"-stats", "-args", output("S") };
fullRScriptName = HOME + testname + ".R";
rCmd = getRCmd(inputDir(), expectedDir());
diff --git a/src/test/scripts/functions/codegen/rowAggPattern47.R
b/src/test/scripts/functions/codegen/rowAggPattern47.R
new file mode 100644
index 0000000000..9d6d4bc9f6
--- /dev/null
+++ b/src/test/scripts/functions/codegen/rowAggPattern47.R
@@ -0,0 +1,36 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args<-commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+library("matrixStats")
+
+# rowVars <- function(X) {
+# apply(X, 1, function(x) sum((x - mean(x))^2) / length(x))
+# }
+
+X = matrix(seq(7, 50*10+6), 50, 10, byrow=TRUE);
+z = seq(1,50)
+
+R = as.matrix(sum(X + rowVars(X)));
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep=""));
diff --git a/src/test/scripts/functions/codegen/rowAggPattern47.dml
b/src/test/scripts/functions/codegen/rowAggPattern47.dml
new file mode 100644
index 0000000000..e3ee077fc1
--- /dev/null
+++ b/src/test/scripts/functions/codegen/rowAggPattern47.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = matrix(seq(7, 50*10+6), 50, 10);
+z = seq(1,50)
+
+while(FALSE){}
+
+R = as.matrix(sum(X + rowVars(X)));
+
+write(R, $1)
diff --git a/src/test/scripts/functions/codegen/rowAggPattern48.R
b/src/test/scripts/functions/codegen/rowAggPattern48.R
new file mode 100644
index 0000000000..bec1427d61
--- /dev/null
+++ b/src/test/scripts/functions/codegen/rowAggPattern48.R
@@ -0,0 +1,36 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+args<-commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+library("matrixStats")
+
+# rowVars <- function(X) {
+# apply(X, 1, function(x) sum((x - mean(x))^2) / length(x))
+# }
+
+Z = matrix(seq(1,10), 1, 10)
+Y = matrix(0, 10, 10)
+X = rbind(Y, Z, Y)
+
+R = as.matrix(sum(rowVars(X)));
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep=""));
diff --git a/src/test/scripts/functions/codegen/rowAggPattern48.dml
b/src/test/scripts/functions/codegen/rowAggPattern48.dml
new file mode 100644
index 0000000000..c367a359cd
--- /dev/null
+++ b/src/test/scripts/functions/codegen/rowAggPattern48.dml
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+Z = matrix(seq(1,10), 1, 10)
+Y = matrix(0, 10, 10)
+X = rbind(Y, Z, Y)
+
+while(FALSE){}
+
+R = as.matrix(sum(rowVars(X)));
+
+write(R, $1)
\ No newline at end of file