This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new a04783eed2 [SYSTEMDS-1965] Extended constant folding (support for
ternary/nary ops)
a04783eed2 is described below
commit a04783eed26414bf56425f55d083f3e38afd2472
Author: Matthias Boehm <[email protected]>
AuthorDate: Wed Jul 19 12:10:30 2023 +0200
[SYSTEMDS-1965] Extended constant folding (support for ternary/nary ops)
This patch extends the existing constant folding by support for ternary
(e.g., ifelse and +*) and nary (e.g., nmax, n+) operations. Furthermore
this also includes new and now-activated old test of constant folding
in functions during IPA.
---
.../sysds/hops/rewrite/RewriteConstantFolding.java | 15 ++++++-
...nstantFoldingScalarVariablePropagationTest.java | 34 +++++++++++-----
...PAConstantFoldingScalarVariablePropagation2.dml | 2 +-
...PAConstantFoldingScalarVariablePropagation3.dml | 46 ++++++++++++++++++++++
4 files changed, 85 insertions(+), 12 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteConstantFolding.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteConstantFolding.java
index de5b4feacc..6980e5b661 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteConstantFolding.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteConstantFolding.java
@@ -34,7 +34,9 @@ import org.apache.sysds.lops.compile.Dag;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.OpOp1;
import org.apache.sysds.common.Types.OpOp2;
+import org.apache.sysds.common.Types.OpOp3;
import org.apache.sysds.common.Types.OpOpData;
+import org.apache.sysds.common.Types.OpOpN;
import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
import org.apache.sysds.runtime.controlprogram.Program;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -96,7 +98,8 @@ public class RewriteConstantFolding extends HopRewriteRule
//fold binary op if both are literals / unary op if literal
if( root.getDataType() == DataType.SCALAR //scalar output
- && ( isApplicableBinaryOp(root) ||
isApplicableUnaryOp(root) ) )
+ && ( isApplicableUnaryOp(root) ||
isApplicableBinaryOp(root)
+ || isApplicableTernaryOp(root) ||
isApplicableNaryOp(root) ) )
{
literal = evalScalarOperation(root);
}
@@ -212,6 +215,16 @@ public class RewriteConstantFolding extends HopRewriteRule
&& hop.getDataType() == DataType.SCALAR);
}
+ private static boolean isApplicableTernaryOp( Hop hop ) {
+ return HopRewriteUtils.isTernary(hop, OpOp3.IFELSE,
OpOp3.MINUS_MULT, OpOp3.PLUS_MULT)
+ && hop.getInput().stream().allMatch(h -> h
instanceof LiteralOp);
+ }
+
+ private static boolean isApplicableNaryOp( Hop hop ) {
+ return HopRewriteUtils.isNary(hop, OpOpN.MIN, OpOpN.MAX,
OpOpN.PLUS)
+ && hop.getInput().stream().allMatch(h -> h instanceof
LiteralOp);
+ }
+
private static boolean isApplicableFalseConjunctivePredicate( Hop hop )
{
ArrayList<Hop> in = hop.getInput();
return ( HopRewriteUtils.isBinary(hop, OpOp2.AND) &&
hop.getDataType().isScalar()
diff --git
a/src/test/java/org/apache/sysds/test/functions/misc/IPAConstantFoldingScalarVariablePropagationTest.java
b/src/test/java/org/apache/sysds/test/functions/misc/IPAConstantFoldingScalarVariablePropagationTest.java
index 74b60d5ff8..843db6c874 100644
---
a/src/test/java/org/apache/sysds/test/functions/misc/IPAConstantFoldingScalarVariablePropagationTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/misc/IPAConstantFoldingScalarVariablePropagationTest.java
@@ -25,7 +25,8 @@ import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
-import org.junit.Ignore;
+import org.apache.sysds.utils.Statistics;
+import org.junit.Assert;
import org.junit.Test;
/**
@@ -47,16 +48,17 @@ public class
IPAConstantFoldingScalarVariablePropagationTest extends AutomatedTe
{
private final static String TEST_NAME1 =
"IPAConstantFoldingScalarVariablePropagation1";
private final static String TEST_NAME2 =
"IPAConstantFoldingScalarVariablePropagation2";
+ private final static String TEST_NAME3 =
"IPAConstantFoldingScalarVariablePropagation3";
+
private final static String TEST_DIR = "functions/misc/";
private final static String TEST_CLASS_DIR = TEST_DIR +
IPAConstantFoldingScalarVariablePropagationTest.class.getSimpleName() + "/";
@Override
public void setUp() {
TestUtils.clearAssertionInformation();
- TestConfiguration conf1 = new TestConfiguration(TEST_CLASS_DIR,
TEST_NAME1, new String[]{});
- TestConfiguration conf2 = new TestConfiguration(TEST_CLASS_DIR,
TEST_NAME2, new String[]{});
- addTestConfiguration(TEST_NAME1, conf1);
- addTestConfiguration(TEST_NAME2, conf2);
+ addTestConfiguration(TEST_NAME1, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{}));
+ addTestConfiguration(TEST_NAME2, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[]{}));
+ addTestConfiguration(TEST_NAME3, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[]{}));
}
@Test
@@ -69,20 +71,26 @@ public class
IPAConstantFoldingScalarVariablePropagationTest extends AutomatedTe
runIPAScalarVariablePropagationTest(TEST_NAME1, false);
}
- // TODO: this test is ignored because sourcing functions from another
script does not allow named variables, with default values.
@Test
- @Ignore
public void testConstantFoldingScalarPropagation2IPASecondChance() {
runIPAScalarVariablePropagationTest(TEST_NAME2, true);
}
- // TODO: this test is ignored because sourcing functions from another
script does not allow named variables, with default values.
@Test
- @Ignore
public void testConstantFoldingScalarPropagation2NoIPASecondChance() {
runIPAScalarVariablePropagationTest(TEST_NAME2, false);
}
+ @Test
+ public void testConstantFoldingScalarPropagation3IPASecondChance() {
+ runIPAScalarVariablePropagationTest(TEST_NAME3, true);
+ }
+
+ @Test
+ public void testConstantFoldingScalarPropagation3NoIPASecondChance() {
+ runIPAScalarVariablePropagationTest(TEST_NAME3, false);
+ }
+
/**
* Test for static rewrites + IPA second chance compilation to allow
* for scalar propagation (IPA) of constant-folded DAG of literals
@@ -106,7 +114,7 @@ public class
IPAConstantFoldingScalarVariablePropagationTest extends AutomatedTe
loadTestConfiguration(config);
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + testname + ".dml";
- programArgs = new String[]{"-stats"};
+ programArgs = new String[]{"-explain","-stats"};
OptimizerUtils.IPA_NUM_REPETITIONS = IPA_SECOND_CHANCE
? 2 : 1;
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
rtplatform = ExecMode.HYBRID;
@@ -118,6 +126,12 @@ public class
IPAConstantFoldingScalarVariablePropagationTest extends AutomatedTe
// (MB: originally, this required a second chance, but
not anymore)
checkNumCompiledSparkInst(0);
checkNumExecutedSparkInst(0);
+
+ //check successful constant folding of entire
expressions
+ if( testname.equals(TEST_NAME3) && IPA_SECOND_CHANCE ) {
+
Assert.assertTrue(Statistics.getCPHeavyHitterCount("floor")==2);
+
Assert.assertTrue(Statistics.getCPHeavyHitterCount("castvti")==2);
+ }
}
finally {
// Reset
diff --git
a/src/test/scripts/functions/misc/IPAConstantFoldingScalarVariablePropagation2.dml
b/src/test/scripts/functions/misc/IPAConstantFoldingScalarVariablePropagation2.dml
index ec1a7fcbf2..6c3632f8eb 100644
---
a/src/test/scripts/functions/misc/IPAConstantFoldingScalarVariablePropagation2.dml
+++
b/src/test/scripts/functions/misc/IPAConstantFoldingScalarVariablePropagation2.dml
@@ -42,7 +42,7 @@ Wf = 3 # filter width
stride = 1
pad = 1 # For same dimensions, (Hf - stride) / 2
F1 = 32 # num conv filters in conv1
-[Wc1, bc1] = conv2d::init(F1, C, Hf, Wf) # inputs: (N, C*Hin*Win)
+[Wc1, bc1] = conv2d::init(F1, C, Hf, Wf, -1) # inputs: (N, C*Hin*Win)
# Create data structure to store gradients computed in parallel
doutc1_agg = matrix(0, rows=num_batches, cols=batch_size*F1*Hin*Win)
diff --git
a/src/test/scripts/functions/misc/IPAConstantFoldingScalarVariablePropagation3.dml
b/src/test/scripts/functions/misc/IPAConstantFoldingScalarVariablePropagation3.dml
new file mode 100644
index 0000000000..c6bb8445ed
--- /dev/null
+++
b/src/test/scripts/functions/misc/IPAConstantFoldingScalarVariablePropagation3.dml
@@ -0,0 +1,46 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+foo = function(Int Hin, Int Win, Int Hf, Int Wf,
+ Int strideh, Int stridew, Int padh, Int padw)
+ return(Integer Hout, Integer Wout)
+{
+ Hout = as.integer(floor((Hin + 2*padh - Hf)/strideh + 1))
+ while(FALSE){} #prevent inlining
+ Wout = as.integer(floor((Win + 2*padw - Wf)/stridew + 1 +
sqrt(6/(padh+padw))))
+
+}
+
+Hin = 224 # input height
+Win = 224 # input width
+Hf = 3 # filter height
+Wf = 3 # filter width
+stride = 1
+pad = 1 # For same dimensions, (Hf - stride) / 2
+
+[Hout1, Wout1] = foo(Hin, Win, Hf, Wf, stride, stride, pad, pad);
+
+while(FALSE){} #DAG cut
+
+[Hout2, Wout2] = foo(Hin, Win, Hf, Wf, stride, stride, pad, pad);
+
+print(Hout1+" "+Wout1+" vs "+Hout2+" "+Wout2)
+#check no ops of foo -> constant folding