This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new af36db5b6b [SYSTEMDS-3702,3703] Fix eval function calls in remote 
parfor
af36db5b6b is described below

commit af36db5b6b5bcc8cb83f0eada02fda9e4d3992aa
Author: Matthias Boehm <mboe...@gmail.com>
AuthorDate: Sun Jun 2 11:40:23 2024 +0200

    [SYSTEMDS-3702,3703] Fix eval function calls in remote parfor
    
    This patch adds a safeguard for compiling the remote parfor body
    program blocks for both optimized and unoptimized functions to forced
    CP (because Spark instructions are not supported in Spark tasks).
    New scripts for SHAP sampling revealed missing robust in combination
    with eval in remote parfor (which uses the unoptimized functions).
    
    Thanks to Louis and Christina for catching this issue.
---
 .../runtime/controlprogram/ParForProgramBlock.java |  4 +-
 .../parfor/opt/OptTreeConverter.java               |  4 ++
 .../sysds/runtime/util/ProgramConverter.java       |  6 +++
 .../test/functions/misc/FunctionPotpourriTest.java | 10 ++++-
 .../functions/misc/FunPotpourriParforEvalSpark.dml | 49 ++++++++++++++++++++++
 5 files changed, 69 insertions(+), 4 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java 
b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
index 195ed23a0e..c4c75c35e7 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
@@ -1410,8 +1410,8 @@ public class ParForProgramBlock extends ForProgramBlock {
                }
                
                //try recompile Spark instructions to CP
-               HashSet<String> fnStack = new HashSet<>();
-               Recompiler.recompileProgramBlockHierarchy2Forced(_childBlocks, 
tid, fnStack, ExecType.CP);
+               Recompiler.recompileProgramBlockHierarchy2Forced(
+                       _childBlocks, tid, new HashSet<>(), ExecType.CP);
                return true;
        }
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptTreeConverter.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptTreeConverter.java
index 5cc354d30e..f1075997b9 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptTreeConverter.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptTreeConverter.java
@@ -530,6 +530,10 @@ public class OptTreeConverter
                return ret;
        }
 
+       public static boolean rContainsSparkInstruction( List<ProgramBlock> 
pbs, boolean inclFunctions ) {
+               return pbs.stream().anyMatch(pb -> 
rContainsSparkInstruction(pb, inclFunctions));
+       }
+       
        public static boolean rContainsSparkInstruction( ProgramBlock pb, 
boolean inclFunctions )
        {
                boolean ret = false;
diff --git a/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java 
b/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java
index 7585012db7..30d412eed7 100644
--- a/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java
+++ b/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java
@@ -25,6 +25,7 @@ import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.mapred.JobConf;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.ExecType;
 import org.apache.sysds.common.Types.FileFormat;
 import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.conf.CompilerConfig;
@@ -71,6 +72,7 @@ import 
org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContextFactory;
 import org.apache.sysds.runtime.controlprogram.paramserv.SparkPSBody;
 import org.apache.sysds.runtime.controlprogram.parfor.ParForBody;
+import org.apache.sysds.runtime.controlprogram.parfor.opt.OptTreeConverter;
 import 
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
 import org.apache.sysds.runtime.instructions.CPInstructionParser;
 import org.apache.sysds.runtime.instructions.Instruction;
@@ -1146,6 +1148,10 @@ public class ProgramConverter
                                sb.append( fkey );
                                sb.append( KEY_VALUE_DELIM );
                                FunctionProgramBlock fpb2 = 
prog.getFunctionProgramBlock(fkey, false);
+                               if( 
OptTreeConverter.rContainsSparkInstruction(fpb2.getChildBlocks(), false) ) {
+                                       
Recompiler.recompileProgramBlockHierarchy2Forced(
+                                               fpb2.getChildBlocks(), -1, new 
HashSet<>(), ExecType.CP);
+                               }
                                sb.append( rSerializeProgramBlock(fpb2, clsMap) 
);
                        }
                        count++;
diff --git 
a/src/test/java/org/apache/sysds/test/functions/misc/FunctionPotpourriTest.java 
b/src/test/java/org/apache/sysds/test/functions/misc/FunctionPotpourriTest.java
index b30522249f..cfd236a383 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/misc/FunctionPotpourriTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/misc/FunctionPotpourriTest.java
@@ -60,6 +60,7 @@ public class FunctionPotpourriTest extends AutomatedTestBase
                "FunPotpourriEvalNamespace2",
                "FunPotpourriBuiltinPrecedence",
                "FunPotpourriParforEvalBuiltin",
+               "FunPotpourriParforEvalSpark",
                "FunPotpourriEvalNamespace3",
        };
        
@@ -265,9 +266,14 @@ public class FunctionPotpourriTest extends 
AutomatedTestBase
                runFunctionTest( TEST_NAMES[27], null, true );
        }
        
+       @Test
+       public void testFunctionParforEvalSpark() {
+               runFunctionTest( TEST_NAMES[28], null, true );
+       }
+       
        @Test
        public void testFunctionEvalNamespace3() {
-               runFunctionTest( TEST_NAMES[28], null, false );
+               runFunctionTest( TEST_NAMES[29], null, false );
        }
        
        private void runFunctionTest(String testName, Class<?> error) {
@@ -291,7 +297,7 @@ public class FunctionPotpourriTest extends AutomatedTestBase
        
                        if( testName.equals(TEST_NAMES[17]) )
                                
Assert.assertTrue(heavyHittersContainsString("print"));
-                       if( evalRewrite )
+                       if( evalRewrite && !testName.equals(TEST_NAMES[28]) )
                                
Assert.assertTrue(!heavyHittersContainsString("eval"));
                }
                finally {
diff --git a/src/test/scripts/functions/misc/FunPotpourriParforEvalSpark.dml 
b/src/test/scripts/functions/misc/FunPotpourriParforEvalSpark.dml
new file mode 100644
index 0000000000..4af2f6a4a5
--- /dev/null
+++ b/src/test/scripts/functions/misc/FunPotpourriParforEvalSpark.dml
@@ -0,0 +1,49 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+R = myFunction(100)
+print(sum(R))
+
+myFunction = function(Integer X)
+  return(Matrix[Double] A)
+{
+  model_function="dummy_rand_model"
+  Y = 200
+  A = matrix(0, rows=X, cols=Y)
+
+  #use below parfor to force in spark
+  parfor (i in 1:X, opt=CONSTRAINED, mode=REMOTE_SPARK ){
+  #for (i in 1:X){
+    #use function directly
+    #P = rand(rows=1, cols=Y )
+
+    #use eval
+    P = eval(model_function, list(Y=Y))
+    A[i] = P
+  }
+  print("Avg:\n"+toString(avg(A)))
+}
+
+dummy_rand_model = function(Integer Y)
+  return( Matrix[Double] P)
+{
+  P = matrix(1, rows=1, cols=Y)
+}

Reply via email to