This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 7ebf913c17 [SYSTEMDS-3868] Fix missing function hoisting from if
predicates
7ebf913c17 is described below
commit 7ebf913c17518190a82b216696b0a08c93ba2892
Author: Matthias Boehm <[email protected]>
AuthorDate: Fri Apr 25 17:42:25 2025 +0200
[SYSTEMDS-3868] Fix missing function hoisting from if predicates
This patch adds the missing hoisting of DML function calls
(which always need to bind to variables) from basic if
predicates for convenience and in order to prevent
unexpected errors. Furthermore, this patch simplifies the
existing DML-bodied ampute() builtin by using this features
as well as call the existing sigmoid() instead of a custom one.
---
scripts/builtin/ampute.dml | 13 +++----------
src/main/java/org/apache/sysds/parser/StatementBlock.java | 9 ++++++++-
2 files changed, 11 insertions(+), 11 deletions(-)
diff --git a/scripts/builtin/ampute.dml b/scripts/builtin/ampute.dml
index 691e5b48e2..90557789dd 100644
--- a/scripts/builtin/ampute.dml
+++ b/scripts/builtin/ampute.dml
@@ -184,8 +184,7 @@ return (Matrix[Double] freq, Matrix[Double] patterns,
Matrix[Double] weights) {
u_handleDefaults = function(Matrix[Double] freq, Matrix[Double] patterns,
Matrix[Double] weights, String mech, Integer numFeatures)
return (Matrix[Double] freq, Matrix[Double] patterns, Matrix[Double] weights) {
# Patterns: Default is a quadratic matrix wherein pattern i amputes feature
i.
- empty = u_isEmpty(patterns)
- if (empty) { # FIX ME
+ if (u_isEmpty(patterns)) {
patterns = matrix(1, rows=numFeatures, cols=numFeatures) - diag(matrix(1,
rows=numFeatures, cols=1))
}
@@ -205,8 +204,7 @@ return (Matrix[Double] freq, Matrix[Double] patterns,
Matrix[Double] weights) {
}
# Frequencies: Uniform by default.
- empty = u_isEmpty(freq) # FIX ME
- if (empty) {
+ if (u_isEmpty(freq)) {
freq = matrix(1 / numPatterns, rows=numPatterns, cols=1)
}
}
@@ -282,7 +280,7 @@ return (Matrix[Double] probsArray) {
while (counter < maxIter & (is.na(currentProb) | abs(currentProb - prop) >=
epsilon)) {
counter += 1
shift = lowerRange + (upperRange - lowerRange) / 2
- probsArray = u_sigmoid(zScores + shift) # Calculates Right-Sigmoid
probability (R implementation's default).
+ probsArray = sigmoid(zScores + shift) # Calculates Right-Sigmoid
probability (R implementation's default).
currentProb = mean(probsArray)
if (currentProb - prop > 0) {
upperRange = shift
@@ -293,11 +291,6 @@ return (Matrix[Double] probsArray) {
}
}
-u_sigmoid = function(Matrix[Double] X)
-return (Matrix[Double] sigmoided) {
- sigmoided = 1 / (1 + exp(-X))
-}
-
u_getBounds = function(Matrix[Double] numPerGroup, Integer groupSize, Integer
patternNum)
return(Integer start, Integer end) {
if (patternNum == 1) {
diff --git a/src/main/java/org/apache/sysds/parser/StatementBlock.java
b/src/main/java/org/apache/sysds/parser/StatementBlock.java
index b81a603e7c..2e62cc7f2e 100644
--- a/src/main/java/org/apache/sysds/parser/StatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/StatementBlock.java
@@ -503,7 +503,12 @@ public class StatementBlock extends LiveVariableAnalysis
implements ParseInfo
else if (current instanceof IfStatementBlock) {
IfStatementBlock isb = (IfStatementBlock) current;
IfStatement istmt = (IfStatement)isb.getStatement(0);
- //TODO handle predicates
+ //handle predicate
+ ArrayList<Statement> tmpPred = new ArrayList<>();
+ istmt.getConditionalPredicate().setPredicate(
+ rHoistFunctionCallsFromExpressions(
+
istmt.getConditionalPredicate().getPredicate(), false, tmpPred, prog));
+ //handle if and else body
ArrayList<StatementBlock> tmp = new ArrayList<>();
for (StatementBlock sb : istmt.getIfBody())
tmp.addAll(rHoistFunctionCallsFromExpressions(sb, prog));
@@ -514,6 +519,8 @@ public class StatementBlock extends LiveVariableAnalysis
implements ParseInfo
tmp2.addAll(rHoistFunctionCallsFromExpressions(sb, prog));
istmt.setElseBody(tmp2);
}
+ if( !tmpPred.isEmpty() )
+ return createStatementBlocks(current, tmpPred);
}
else if (current instanceof ForStatementBlock) { //incl parfor
ForStatementBlock fsb = (ForStatementBlock) current;