This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 0cc2c9f [SYSTEMDS-2949] Fix function call hoisting out of expressions
0cc2c9f is described below
commit 0cc2c9f98ca2bd242d6d7d2e20c3802e52f83f9b
Author: Matthias Boehm <[email protected]>
AuthorDate: Wed May 12 23:10:16 2021 +0200
[SYSTEMDS-2949] Fix function call hoisting out of expressions
This patch fixes parsing issues where partially incorrect hoisting of
function calls out of complex expressions lead to null pointer
exceptions (as a result of hoisting multiple functions). We now added
more tests and use a conservative approach of cuts before and after
hoisted functions, which are later merged with other blocks but ensure
the validity of the simplifying assumption made in the parser.
---
.../org/apache/sysds/parser/DMLTranslator.java | 52 ++++++++++------------
.../org/apache/sysds/parser/StatementBlock.java | 34 +++++++-------
.../functions/misc/FunctionInExpressionTest.java | 48 +++++++++++---------
.../scripts/functions/misc/FunInExpression8.dml | 27 +++++++++++
.../scripts/functions/misc/FunInExpression9.dml | 32 +++++++++++++
5 files changed, 128 insertions(+), 65 deletions(-)
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index b050a3b..3b2146a 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -1736,44 +1736,40 @@ public class DMLTranslator
Hop left = processExpression(source.getLeft(), null, hops);
Hop right = processExpression(source.getRight(), null, hops);
- if (left == null || right == null){
- left = processExpression(source.getLeft(), null,
hops);
- right = processExpression(source.getRight(), null,
hops);
+ if (left == null || right == null) {
+ throw new ParseException("Missing input in binary
expressions (" + source.toString()+"): "
+ +
((left==null)?source.getLeft():source.getRight())+",
line="+source.getBeginLine());
}
-
- Hop currBop = null;
-
+
//prepare target identifier and ensure that output type is of
inferred type
- //(type should not be determined by target (e.g., string for print)
+ //(type should not be determined by target (e.g., string for
print)
if (target == null) {
- target = createTarget(source);
+ target = createTarget(source);
}
target.setValueType(source.getOutput().getValueType());
- if (source.getOpCode() == Expression.BinaryOp.PLUS) {
- currBop = new BinaryOp(target.getName(),
target.getDataType(), target.getValueType(), OpOp2.PLUS, left, right);
- } else if (source.getOpCode() == Expression.BinaryOp.MINUS) {
- currBop = new BinaryOp(target.getName(),
target.getDataType(), target.getValueType(), OpOp2.MINUS, left, right);
- } else if (source.getOpCode() == Expression.BinaryOp.MULT) {
- currBop = new BinaryOp(target.getName(),
target.getDataType(), target.getValueType(), OpOp2.MULT, left, right);
- } else if (source.getOpCode() == Expression.BinaryOp.DIV) {
- currBop = new BinaryOp(target.getName(),
target.getDataType(), target.getValueType(), OpOp2.DIV, left, right);
- } else if (source.getOpCode() == Expression.BinaryOp.MODULUS) {
- currBop = new BinaryOp(target.getName(),
target.getDataType(), target.getValueType(), OpOp2.MODULUS, left, right);
- } else if (source.getOpCode() == Expression.BinaryOp.INTDIV) {
- currBop = new BinaryOp(target.getName(),
target.getDataType(), target.getValueType(), OpOp2.INTDIV, left, right);
- } else if (source.getOpCode() == Expression.BinaryOp.MATMULT) {
- currBop = new AggBinaryOp(target.getName(),
target.getDataType(), target.getValueType(), OpOp2.MULT,
org.apache.sysds.common.Types.AggOp.SUM, left, right);
- } else if (source.getOpCode() == Expression.BinaryOp.POW) {
- currBop = new BinaryOp(target.getName(),
target.getDataType(), target.getValueType(), OpOp2.POW, left, right);
- }
- else {
- throw new ParseException("Unsupported parsing of binary
expression: "+source.getOpCode());
+ Hop currBop = null;
+ switch( source.getOpCode() ) {
+ case PLUS:
+ case MINUS:
+ case MULT:
+ case DIV:
+ case MODULUS:
+ case POW:
+ case INTDIV:
+ currBop = new BinaryOp(target.getName(),
target.getDataType(),
+ target.getValueType(),
OpOp2.valueOf(source.getOpCode().name()), left, right);
+ break;
+ case MATMULT:
+ currBop = new AggBinaryOp(target.getName(),
target.getDataType(), target.getValueType(), OpOp2.MULT,
org.apache.sysds.common.Types.AggOp.SUM, left, right);
+ break;
+ default:
+ throw new ParseException("Unsupported parsing
of binary expression: "+source.getOpCode());
}
+
setIdentifierParams(currBop, source.getOutput());
currBop.setParseInfo(source);
return currBop;
-
}
private Hop processRelationalExpression(RelationalExpression source,
DataIdentifier target, HashMap<String, Hop> hops) {
diff --git a/src/main/java/org/apache/sysds/parser/StatementBlock.java
b/src/main/java/org/apache/sysds/parser/StatementBlock.java
index c4876d4..570af39 100644
--- a/src/main/java/org/apache/sysds/parser/StatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/StatementBlock.java
@@ -183,10 +183,10 @@ public class StatementBlock extends LiveVariableAnalysis
implements ParseInfo
public boolean isMergeableFunctionCallBlock(DMLProgram dmlProg) {
// check whether targetIndex stmt block is for a mergable
function call
Statement stmt = this.getStatement(0);
-
+
// Check whether targetIndex block is: control stmt block or
stmt block for un-mergable function call
if ( stmt instanceof WhileStatement || stmt instanceof
IfStatement || stmt instanceof ForStatement
- || stmt instanceof FunctionStatement ||
isMergeablePrintStatement(stmt) /*|| stmt instanceof ELStatement*/ )
+ || stmt instanceof FunctionStatement ||
isMergeablePrintStatement(stmt) )
{
return false;
}
@@ -232,7 +232,7 @@ public class StatementBlock extends LiveVariableAnalysis
implements ParseInfo
}
}
- // regular function block
+ // regular statement block
return true;
}
@@ -360,18 +360,17 @@ public class StatementBlock extends LiveVariableAnalysis
implements ParseInfo
ArrayList<StatementBlock> result = new ArrayList<>();
StatementBlock currentBlock = null;
- for (int i = 0; i < body.size(); i++){
+ for (int i = 0; i < body.size(); i++) {
StatementBlock current = body.get(i);
if (current.isMergeableFunctionCallBlock(dmlProg)){
- if (currentBlock != null) {
+ if (currentBlock != null)
currentBlock.addStatementBlock(current);
- } else {
+ else
currentBlock = current;
- }
- } else {
- if (currentBlock != null) {
+ }
+ else {
+ if (currentBlock != null)
result.add(currentBlock);
- }
result.add(current);
currentBlock = null;
}
@@ -465,7 +464,6 @@ public class StatementBlock extends LiveVariableAnalysis
implements ParseInfo
}
return result;
-
}
public static List<StatementBlock>
rHoistFunctionCallsFromExpressions(StatementBlock current, DMLProgram prog) {
@@ -634,11 +632,17 @@ public class StatementBlock extends LiveVariableAnalysis
implements ParseInfo
List<StatementBlock> ret = new ArrayList<>();
StatementBlock current = new StatementBlock(sb);
for(Statement stmt : stmts) {
+ //cut the statement block before and after the current
function
+ //(cut before is precondition for subsequent merge
steps which
+ //assume function statements as the first statement in
the block)
+ boolean cut = stmt instanceof AssignmentStatement
+ && ((AssignmentStatement)stmt).getSource()
instanceof FunctionCallIdentifier;
+ if( cut && current.getNumStatements() > 0 ) { //before
+ ret.add(current);
+ current = new StatementBlock(sb);
+ }
current.addStatement(stmt);
- //cut the statement block after the current function
- if( stmt instanceof AssignmentStatement
- && ((AssignmentStatement)stmt).getSource()
- instanceof FunctionCallIdentifier ) {
+ if( cut ) { //after
ret.add(current);
current = new StatementBlock(sb);
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/misc/FunctionInExpressionTest.java
b/src/test/java/org/apache/sysds/test/functions/misc/FunctionInExpressionTest.java
index 78ad721..da54458 100644
---
a/src/test/java/org/apache/sysds/test/functions/misc/FunctionInExpressionTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/misc/FunctionInExpressionTest.java
@@ -29,13 +29,12 @@ import org.apache.sysds.test.TestUtils;
public class FunctionInExpressionTest extends AutomatedTestBase
{
- private final static String TEST_NAME1 = "FunInExpression1";
- private final static String TEST_NAME2 = "FunInExpression2";
- private final static String TEST_NAME3 = "FunInExpression3";
- private final static String TEST_NAME4 = "FunInExpression4";
- private final static String TEST_NAME5 = "FunInExpression5";
- private final static String TEST_NAME6 = "FunInExpression6";
- private final static String TEST_NAME7 = "FunInExpression7";
//dml-bodied builtin
+ private final static String[] TEST_NAMES = new String[] {
+ "FunInExpression1", "FunInExpression2", "FunInExpression3",
+ "FunInExpression4", "FunInExpression5", "FunInExpression6",
+ //dml-bodied functions (w/ and w/o CSEs)
+ "FunInExpression7", "FunInExpression8", "FunInExpression9"
+ };
private final static String TEST_DIR = "functions/misc/";
private final static String TEST_CLASS_DIR = TEST_DIR +
FunctionInExpressionTest.class.getSimpleName() + "/";
@@ -43,48 +42,53 @@ public class FunctionInExpressionTest extends
AutomatedTestBase
@Override
public void setUp() {
TestUtils.clearAssertionInformation();
- addTestConfiguration( TEST_NAME1, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
- addTestConfiguration( TEST_NAME2, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) );
- addTestConfiguration( TEST_NAME3, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] { "R" }) );
- addTestConfiguration( TEST_NAME4, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] { "R" }) );
- addTestConfiguration( TEST_NAME5, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] { "R" }) );
- addTestConfiguration( TEST_NAME6, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME6, new String[] { "R" }) );
- addTestConfiguration( TEST_NAME7, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME7, new String[] { "R" }) );
+ for(int i=0; i<TEST_NAMES.length; i++)
+ addTestConfiguration(TEST_NAMES[i], new
TestConfiguration(TEST_CLASS_DIR, TEST_NAMES[i], new String[] {"R"}));
}
@Test
public void testFunInExpression1() {
- runFunInExpressionTest( TEST_NAME1 );
+ runFunInExpressionTest( TEST_NAMES[0] );
}
@Test
public void testFunInExpression2() {
- runFunInExpressionTest( TEST_NAME2 );
+ runFunInExpressionTest( TEST_NAMES[1] );
}
@Test
public void testFunInExpression3() {
- runFunInExpressionTest( TEST_NAME3 );
+ runFunInExpressionTest( TEST_NAMES[2] );
}
@Test
public void testFunInExpression4() {
- runFunInExpressionTest( TEST_NAME4 );
+ runFunInExpressionTest( TEST_NAMES[3] );
}
@Test
public void testFunInExpression5() {
- runFunInExpressionTest( TEST_NAME5 );
+ runFunInExpressionTest( TEST_NAMES[4] );
}
@Test
public void testFunInExpression6() {
- runFunInExpressionTest( TEST_NAME6 );
+ runFunInExpressionTest( TEST_NAMES[5] );
}
@Test
public void testFunInExpression7() {
- runFunInExpressionTest( TEST_NAME7 );
+ runFunInExpressionTest( TEST_NAMES[6] );
+ }
+
+ @Test
+ public void testFunInExpression8() {
+ runFunInExpressionTest( TEST_NAMES[7] );
+ }
+
+ @Test
+ public void testFunInExpression9() {
+ runFunInExpressionTest( TEST_NAMES[8] );
}
private void runFunInExpressionTest( String testName )
@@ -94,7 +98,7 @@ public class FunctionInExpressionTest extends
AutomatedTestBase
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + testName + ".dml";
- programArgs = new String[]{"-args", output("R") };
+ programArgs = new String[]{"-explain","-args", output("R") };
fullRScriptName = HOME + testName + ".R";
rCmd = getRCmd(expectedDir());
diff --git a/src/test/scripts/functions/misc/FunInExpression8.dml
b/src/test/scripts/functions/misc/FunInExpression8.dml
new file mode 100644
index 0000000..70c8080
--- /dev/null
+++ b/src/test/scripts/functions/misc/FunInExpression8.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = matrix(4, rows=10, cols=10);
+R1 = sigmoid(A) + 7;
+R2 = sigmoid(A) - 7;
+R = as.matrix(sum(abs(R2-R1+14)<1e-10)*7/100)
+print(toString(R))
+write(R, $1);
diff --git a/src/test/scripts/functions/misc/FunInExpression9.dml
b/src/test/scripts/functions/misc/FunInExpression9.dml
new file mode 100644
index 0000000..254b870
--- /dev/null
+++ b/src/test/scripts/functions/misc/FunInExpression9.dml
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+foo = function(Matrix[Double] A, Double val)
+ return(Matrix[Double] R)
+{
+ R1 = sigmoid(A) + val;
+ R2 = sigmoid(A) - val;
+ R = as.matrix(sum(abs(R2-R1+14)<1e-10)*7/100)
+}
+
+A = matrix(4, rows=10, cols=10);
+R = foo(A, 7)
+write(R, $1);