This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 44e88d843a [SYSTEMDS-3509] Fix IPA pass function forwarding (named
args ordering)
44e88d843a is described below
commit 44e88d843a819b107c13478b329c578ff809d6d6
Author: Matthias Boehm <[email protected]>
AuthorDate: Sat Mar 25 16:32:05 2023 +0100
[SYSTEMDS-3509] Fix IPA pass function forwarding (named args ordering)
This patch fixes the inter-procedural-analysis (IPA) rewrite pass
'function forwarding' where a chain of function calls is collapsed to
a single function call. Previously if function arguments with same
name were passed in different orders a misassignment could happen but
only if the rewrite applies. We now wire the function arguments in the
correct order according to argument names.
---
scripts/builtin/decisionTreePredict.dml | 1 -
.../org/apache/sysds/hops/ipa/IPAPassForwardFunctionCalls.java | 10 ++++++----
.../org/apache/sysds/lops/rewrite/RewriteAddChkpointLop.java | 1 +
.../apache/sysds/runtime/instructions/cp/CPInstruction.java | 1 -
src/test/scripts/functions/builtin/decisionTreePredict.dml | 3 +--
5 files changed, 8 insertions(+), 8 deletions(-)
diff --git a/scripts/builtin/decisionTreePredict.dml
b/scripts/builtin/decisionTreePredict.dml
index b312910a48..c4e75b4fe1 100644
--- a/scripts/builtin/decisionTreePredict.dml
+++ b/scripts/builtin/decisionTreePredict.dml
@@ -46,7 +46,6 @@ m_decisionTreePredict = function(Matrix[Double] X,
Matrix[Double] y = matrix(0,0
Matrix[Double] ctypes, Matrix[Double] M, String strategy="TT", Boolean
verbose = FALSE)
return (Matrix[Double] yhat)
{
- print(toString(M))
if( strategy == "TT" )
yhat = predict_TT(M, X);
else if( strategy == "GEMM" )
diff --git
a/src/main/java/org/apache/sysds/hops/ipa/IPAPassForwardFunctionCalls.java
b/src/main/java/org/apache/sysds/hops/ipa/IPAPassForwardFunctionCalls.java
index 8b57742f26..e35f2f50d6 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/IPAPassForwardFunctionCalls.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/IPAPassForwardFunctionCalls.java
@@ -55,7 +55,7 @@ public class IPAPassForwardFunctionCalls extends IPAPass
//step 1: basic application filter: simple forwarding
call
if( fstmt.getBody().size() != 1 ||
!singleFunctionOp(fstmt.getBody().get(0).getHops())
- ||
!hasOnlySimplyArguments((FunctionOp)fstmt.getBody().get(0).getHops().get(0)))
+ ||
!hasOnlySimpleArguments((FunctionOp)fstmt.getBody().get(0).getHops().get(0)))
continue;
if( LOG.isDebugEnabled() )
LOG.debug("IPA: Forward-function-call candidate
L1: '"+fkey+"'");
@@ -96,7 +96,7 @@ public class IPAPassForwardFunctionCalls extends IPAPass
return hops.get(0) instanceof FunctionOp;
}
- private static boolean hasOnlySimplyArguments(FunctionOp fop) {
+ private static boolean hasOnlySimpleArguments(FunctionOp fop) {
return fop.getInput().stream().allMatch(h -> h instanceof
LiteralOp
|| HopRewriteUtils.isData(h, OpOpData.TRANSIENTREAD));
}
@@ -127,15 +127,17 @@ public class IPAPassForwardFunctionCalls extends IPAPass
for( int i=0; i<call2.getInput().size(); i++ )
probe.put(call2.getInputVariableNames()[i],
call2.getInput().get(i));
- //construct new inputs for call1
+ //construct new named inputs for call1 (in right order)
+ ArrayList<String> varNames = new ArrayList<>();
ArrayList<Hop> inputs = new ArrayList<>();
for( int i=0; i<call1.getInput().size(); i++ )
if( probe.containsKey(call1.getInputVariableNames()[i])
) {
+ varNames.add(call1.getInputVariableNames()[i]);
inputs.add(
(probe.get(call1.getInputVariableNames()[i]) instanceof LiteralOp) ?
probe.get(call1.getInputVariableNames()[i]) : call1.getInput().get(i));
}
HopRewriteUtils.removeAllChildReferences(call1);
call1.addAllInputs(inputs);
- call1.setInputVariableNames(call2.getInputVariableNames());
+ call1.setInputVariableNames(varNames.toArray(new String[0]));
}
}
diff --git
a/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointLop.java
b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointLop.java
index d4f976a795..1cf5761423 100644
--- a/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointLop.java
+++ b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointLop.java
@@ -57,6 +57,7 @@ public class RewriteAddChkpointLop extends LopRewriteRule
Map<Long, Integer> operatorJobCount = new HashMap<>();
markPersistableSparkOps(sparkRoots, operatorJobCount);
// TODO: A rewrite pass to remove less effective chkpoints
+ @SuppressWarnings("unused")
List<Lop> nodesWithChkpt = addChkpointLop(lops,
operatorJobCount);
//New node is added inplace in the Lop DAG
return List.of(sb);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
index aa17fa2cab..174e4f2d27 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
@@ -31,7 +31,6 @@ import
org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.CPInstructionParser;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.fed.FEDInstructionUtils;
-import org.apache.sysds.runtime.lineage.LineageCacheConfig;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.privacy.propagation.PrivacyPropagator;
diff --git a/src/test/scripts/functions/builtin/decisionTreePredict.dml
b/src/test/scripts/functions/builtin/decisionTreePredict.dml
index e87b01c581..733d363b78 100644
--- a/src/test/scripts/functions/builtin/decisionTreePredict.dml
+++ b/src/test/scripts/functions/builtin/decisionTreePredict.dml
@@ -21,6 +21,5 @@
M = read($1);
X = read($2);
-# FIXME reordering of M and X yields wrong passing
-Y = decisionTreePredict(M=M, X=X, ctypes=matrix(2,1,ncol(X)+1), strategy=$3);
+Y = decisionTreePredict(X=X, M=M, ctypes=matrix(2,1,ncol(X)+1), strategy=$3);
write(Y, $4);