This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/systemml.git
commit cf74661016928e3413d693f939a67964f3256b19 Author: Matthias Boehm <mboe...@gmail.com> AuthorDate: Tue Apr 21 22:00:46 2020 +0200 [SYSTEMDS-207] Fix dml-builtin-function hoisting from expressions Function calls to dml-bodied functions need to bind their outputs to logical variable names and hence require a cut of the basic block for correctness. To still allow such functions in expressions (which is very common), we perform function call hoisting from expressions during parsing in order to be able to cut after entire statements. This automatically applied to the new dml-bodied builtin functions too, but because theses functions are loaded before ran into null pointer exceptions during validation (thanks Arnab for catching this). This fix extends the function hoisting by probing for dml-bodied builtin functions and lazily loading, parsing, and adding the required functions if needed. By reusing the recently added mechanics from lazy function loading in eval functions, we keep the number of alternative entry points very small. --- docs/Tasks.txt | 1 + .../java/org/apache/sysds/parser/DMLProgram.java | 4 +- .../org/apache/sysds/parser/StatementBlock.java | 59 ++++++++++++++-------- .../functions/misc/FunctionInExpressionTest.java | 7 +++ .../scripts/functions/misc/FunInExpression7.dml | 26 ++++++++++ 5 files changed, 73 insertions(+), 24 deletions(-) diff --git a/docs/Tasks.txt b/docs/Tasks.txt index 5ae71b1..7a61c05 100644 --- a/docs/Tasks.txt +++ b/docs/Tasks.txt @@ -165,6 +165,7 @@ SYSTEMDS-200 Various Fixes * 204 Fix rewrite simplify sequences of binary comparisons OK * 205 Fix scoping of builtin dml-bodied functions (vs user-defined) * 206 Fix codegen outer template compilation (tsmm) OK + * 207 Fix builtin function call hoisting from expressions OK SYSTEMDS-210 Extended lists Operations * 211 Cbind and Rbind over lists of matrices OK diff --git a/src/main/java/org/apache/sysds/parser/DMLProgram.java b/src/main/java/org/apache/sysds/parser/DMLProgram.java index 4e5e229..2487aec 100644 --- a/src/main/java/org/apache/sysds/parser/DMLProgram.java +++ b/src/main/java/org/apache/sysds/parser/DMLProgram.java @@ -166,11 +166,11 @@ public class DMLProgram try { //handle statement blocks of all functions for( FunctionStatementBlock fsb : getFunctionStatementBlocks() ) - StatementBlock.rHoistFunctionCallsFromExpressions(fsb); + StatementBlock.rHoistFunctionCallsFromExpressions(fsb, this); //handle statement blocks of main program ArrayList<StatementBlock> tmp = new ArrayList<>(); for( StatementBlock sb : _blocks ) - tmp.addAll(StatementBlock.rHoistFunctionCallsFromExpressions(sb)); + tmp.addAll(StatementBlock.rHoistFunctionCallsFromExpressions(sb, this)); _blocks = tmp; } catch(LanguageException ex) { diff --git a/src/main/java/org/apache/sysds/parser/StatementBlock.java b/src/main/java/org/apache/sysds/parser/StatementBlock.java index f275a84..f6a8f72 100644 --- a/src/main/java/org/apache/sysds/parser/StatementBlock.java +++ b/src/main/java/org/apache/sysds/parser/StatementBlock.java @@ -23,6 +23,8 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; +import java.util.Map; +import java.util.Map.Entry; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -37,6 +39,7 @@ import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.parser.Expression.FormatType; import org.apache.sysds.parser.LanguageException.LanguageErrorCodes; import org.apache.sysds.parser.PrintStatement.PRINTTYPE; +import org.apache.sysds.parser.dml.DmlSyntacticValidator; import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence; import org.apache.sysds.utils.MLContextProxy; @@ -460,13 +463,13 @@ public class StatementBlock extends LiveVariableAnalysis implements ParseInfo } - public static List<StatementBlock> rHoistFunctionCallsFromExpressions(StatementBlock current) { + public static List<StatementBlock> rHoistFunctionCallsFromExpressions(StatementBlock current, DMLProgram prog) { if (current instanceof FunctionStatementBlock) { FunctionStatementBlock fsb = (FunctionStatementBlock)current; FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0); ArrayList<StatementBlock> tmp = new ArrayList<>(); for (StatementBlock sb : fstmt.getBody()) - tmp.addAll(rHoistFunctionCallsFromExpressions(sb)); + tmp.addAll(rHoistFunctionCallsFromExpressions(sb, prog)); fstmt.setBody(tmp); } else if (current instanceof WhileStatementBlock) { @@ -475,7 +478,7 @@ public class StatementBlock extends LiveVariableAnalysis implements ParseInfo //TODO handle predicates ArrayList<StatementBlock> tmp = new ArrayList<>(); for (StatementBlock sb : wstmt.getBody()) - tmp.addAll(rHoistFunctionCallsFromExpressions(sb)); + tmp.addAll(rHoistFunctionCallsFromExpressions(sb, prog)); wstmt.setBody(tmp); } else if (current instanceof IfStatementBlock) { @@ -484,12 +487,12 @@ public class StatementBlock extends LiveVariableAnalysis implements ParseInfo //TODO handle predicates ArrayList<StatementBlock> tmp = new ArrayList<>(); for (StatementBlock sb : istmt.getIfBody()) - tmp.addAll(rHoistFunctionCallsFromExpressions(sb)); + tmp.addAll(rHoistFunctionCallsFromExpressions(sb, prog)); istmt.setIfBody(tmp); if( istmt.getElseBody() != null && !istmt.getElseBody().isEmpty() ) { ArrayList<StatementBlock> tmp2 = new ArrayList<>(); for (StatementBlock sb : istmt.getElseBody()) - tmp2.addAll(rHoistFunctionCallsFromExpressions(sb)); + tmp2.addAll(rHoistFunctionCallsFromExpressions(sb, prog)); istmt.setElseBody(tmp2); } } @@ -499,25 +502,25 @@ public class StatementBlock extends LiveVariableAnalysis implements ParseInfo //TODO handle predicates ArrayList<StatementBlock> tmp = new ArrayList<>(); for (StatementBlock sb : fstmt.getBody()) - tmp.addAll(rHoistFunctionCallsFromExpressions(sb)); + tmp.addAll(rHoistFunctionCallsFromExpressions(sb, prog)); fstmt.setBody(tmp); } else { //generic (last-level) ArrayList<Statement> tmp = new ArrayList<>(); for(Statement stmt : current.getStatements()) - tmp.addAll(rHoistFunctionCallsFromExpressions(stmt)); + tmp.addAll(rHoistFunctionCallsFromExpressions(stmt, prog)); if( current.getStatements().size() != tmp.size() ) return createStatementBlocks(current, tmp); } return Arrays.asList(current); } - public static List<Statement> rHoistFunctionCallsFromExpressions(Statement stmt) { + public static List<Statement> rHoistFunctionCallsFromExpressions(Statement stmt, DMLProgram prog) { ArrayList<Statement> tmp = new ArrayList<>(); if( stmt instanceof AssignmentStatement ) { AssignmentStatement astmt = (AssignmentStatement)stmt; boolean ix = (astmt.getTargetList().get(0) instanceof IndexedIdentifier); - rHoistFunctionCallsFromExpressions(astmt.getSource(), !ix, tmp); + rHoistFunctionCallsFromExpressions(astmt.getSource(), !ix, tmp, prog); if( ix && astmt.getSource() instanceof FunctionCallIdentifier ) { AssignmentStatement lstmt = (AssignmentStatement) tmp.get(tmp.size()-1); astmt.setSource(copy(lstmt.getTarget())); @@ -525,13 +528,13 @@ public class StatementBlock extends LiveVariableAnalysis implements ParseInfo } else if( stmt instanceof MultiAssignmentStatement ) { MultiAssignmentStatement mstmt = (MultiAssignmentStatement)stmt; - rHoistFunctionCallsFromExpressions(mstmt.getSource(), true, tmp); + rHoistFunctionCallsFromExpressions(mstmt.getSource(), true, tmp, prog); } else if( stmt instanceof PrintStatement ) { PrintStatement pstmt = (PrintStatement)stmt; for(int i=0; i<pstmt.expressions.size(); i++) { Expression lexpr = pstmt.getExpressions().get(i); - rHoistFunctionCallsFromExpressions(lexpr, false, tmp); + rHoistFunctionCallsFromExpressions(lexpr, false, tmp, prog); if( lexpr instanceof FunctionCallIdentifier ) { AssignmentStatement lstmt = (AssignmentStatement) tmp.get(tmp.size()-1); pstmt.getExpressions().set(i, copy(lstmt.getTarget())); @@ -550,52 +553,64 @@ public class StatementBlock extends LiveVariableAnalysis implements ParseInfo return ret; } - public static Expression rHoistFunctionCallsFromExpressions(Expression expr, boolean root, ArrayList<Statement> tmp) { + public static Expression rHoistFunctionCallsFromExpressions(Expression expr, boolean root, ArrayList<Statement> tmp, DMLProgram prog) { if( expr == null || expr instanceof ConstIdentifier ) return expr; //do nothing if( expr instanceof BinaryExpression ) { BinaryExpression lexpr = (BinaryExpression) expr; - lexpr.setLeft(rHoistFunctionCallsFromExpressions(lexpr.getLeft(), false, tmp)); - lexpr.setRight(rHoistFunctionCallsFromExpressions(lexpr.getRight(), false, tmp)); + lexpr.setLeft(rHoistFunctionCallsFromExpressions(lexpr.getLeft(), false, tmp, prog)); + lexpr.setRight(rHoistFunctionCallsFromExpressions(lexpr.getRight(), false, tmp, prog)); } else if( expr instanceof RelationalExpression ) { RelationalExpression lexpr = (RelationalExpression) expr; - lexpr.setLeft(rHoistFunctionCallsFromExpressions(lexpr.getLeft(), false, tmp)); - lexpr.setRight(rHoistFunctionCallsFromExpressions(lexpr.getRight(), false, tmp)); + lexpr.setLeft(rHoistFunctionCallsFromExpressions(lexpr.getLeft(), false, tmp, prog)); + lexpr.setRight(rHoistFunctionCallsFromExpressions(lexpr.getRight(), false, tmp, prog)); } else if( expr instanceof BooleanExpression ) { BooleanExpression lexpr = (BooleanExpression) expr; - lexpr.setLeft(rHoistFunctionCallsFromExpressions(lexpr.getLeft(), false, tmp)); - lexpr.setRight(rHoistFunctionCallsFromExpressions(lexpr.getRight(), false, tmp)); + lexpr.setLeft(rHoistFunctionCallsFromExpressions(lexpr.getLeft(), false, tmp, prog)); + lexpr.setRight(rHoistFunctionCallsFromExpressions(lexpr.getRight(), false, tmp, prog)); } else if( expr instanceof BuiltinFunctionExpression ) { BuiltinFunctionExpression lexpr = (BuiltinFunctionExpression) expr; Expression[] clexpr = lexpr.getAllExpr(); for( int i=0; i<clexpr.length; i++ ) - clexpr[i] = rHoistFunctionCallsFromExpressions(clexpr[i], false, tmp); + clexpr[i] = rHoistFunctionCallsFromExpressions(clexpr[i], false, tmp, prog); } else if( expr instanceof ParameterizedBuiltinFunctionExpression ) { ParameterizedBuiltinFunctionExpression lexpr = (ParameterizedBuiltinFunctionExpression) expr; HashMap<String, Expression> clexpr = lexpr.getVarParams(); for( String key : clexpr.keySet() ) - clexpr.put(key, rHoistFunctionCallsFromExpressions(clexpr.get(key), false, tmp)); + clexpr.put(key, rHoistFunctionCallsFromExpressions(clexpr.get(key), false, tmp, prog)); } else if( expr instanceof DataExpression ) { DataExpression lexpr = (DataExpression) expr; HashMap<String, Expression> clexpr = lexpr.getVarParams(); for( String key : clexpr.keySet() ) - clexpr.put(key, rHoistFunctionCallsFromExpressions(clexpr.get(key), false, tmp)); + clexpr.put(key, rHoistFunctionCallsFromExpressions(clexpr.get(key), false, tmp, prog)); } else if( expr instanceof FunctionCallIdentifier ) { FunctionCallIdentifier fexpr = (FunctionCallIdentifier) expr; for( ParameterExpression pexpr : fexpr.getParamExprs() ) - pexpr.setExpr(rHoistFunctionCallsFromExpressions(pexpr.getExpr(), false, tmp)); + pexpr.setExpr(rHoistFunctionCallsFromExpressions(pexpr.getExpr(), false, tmp, prog)); if( !root ) { //core hoisting String varname = StatementBlockRewriteRule.createCutVarName(true); DataIdentifier di = new DataIdentifier(varname); di.setDataType(fexpr.getDataType()); di.setValueType(fexpr.getValueType()); tmp.add(new AssignmentStatement(di, fexpr, di)); + //add hoisted dml-bodied builtin function to program (if not already loaded) + if( Builtins.contains(fexpr.getName(), true, false) + && !prog.containsFunctionStatementBlock(Builtins.getInternalFName(fexpr.getName(), DataType.SCALAR)) + && !prog.containsFunctionStatementBlock(Builtins.getInternalFName(fexpr.getName(), DataType.MATRIX))) { + Map<String,FunctionStatementBlock> fsbs = DmlSyntacticValidator + .loadAndParseBuiltinFunction(fexpr.getName(), fexpr.getNamespace()); + for( Entry<String,FunctionStatementBlock> fsb : fsbs.entrySet() ) { + if( !prog.containsFunctionStatementBlock(fsb.getKey()) ) + prog.addFunctionStatementBlock(fsb.getKey(), fsb.getValue()); + fsb.getValue().setDMLProg(prog); + } + } return di; } } diff --git a/src/test/java/org/apache/sysds/test/functions/misc/FunctionInExpressionTest.java b/src/test/java/org/apache/sysds/test/functions/misc/FunctionInExpressionTest.java index bac9e63..11441c5 100644 --- a/src/test/java/org/apache/sysds/test/functions/misc/FunctionInExpressionTest.java +++ b/src/test/java/org/apache/sysds/test/functions/misc/FunctionInExpressionTest.java @@ -35,6 +35,7 @@ public class FunctionInExpressionTest extends AutomatedTestBase private final static String TEST_NAME4 = "FunInExpression4"; private final static String TEST_NAME5 = "FunInExpression5"; private final static String TEST_NAME6 = "FunInExpression6"; + private final static String TEST_NAME7 = "FunInExpression7"; //dml-bodied builtin private final static String TEST_DIR = "functions/misc/"; private final static String TEST_CLASS_DIR = TEST_DIR + FunctionInExpressionTest.class.getSimpleName() + "/"; @@ -48,6 +49,7 @@ public class FunctionInExpressionTest extends AutomatedTestBase addTestConfiguration( TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] { "R" }) ); addTestConfiguration( TEST_NAME5, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] { "R" }) ); addTestConfiguration( TEST_NAME6, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME6, new String[] { "R" }) ); + addTestConfiguration( TEST_NAME7, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME7, new String[] { "R" }) ); } @Test @@ -80,6 +82,11 @@ public class FunctionInExpressionTest extends AutomatedTestBase runFunInExpressionTest( TEST_NAME6 ); } + @Test + public void testFunInExpression7() { + runFunInExpressionTest( TEST_NAME7 ); + } + private void runFunInExpressionTest( String testName ) { TestConfiguration config = getTestConfiguration(testName); diff --git a/src/test/scripts/functions/misc/FunInExpression7.dml b/src/test/scripts/functions/misc/FunInExpression7.dml new file mode 100644 index 0000000..1bb4725 --- /dev/null +++ b/src/test/scripts/functions/misc/FunInExpression7.dml @@ -0,0 +1,26 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +A = matrix(4, rows=10, cols=10); +R1 = log(sigmoid(A) + 7); +R2 = log(1/(1+exp(-A)) + 7); +R = as.matrix(sum(abs(R1-R2)<1e-10)*(7/100)) +write( R, $1 );