This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 5b77f42129 [SYSTEMDS-3506] Improved randomForestPredict (parfor, 
lmPredictStats)
5b77f42129 is described below

commit 5b77f42129b728297010b1a3b41fb899706ec481
Author: Matthias Boehm <[email protected]>
AuthorDate: Tue Mar 14 18:48:48 2023 +0100

    [SYSTEMDS-3506] Improved randomForestPredict (parfor, lmPredictStats)
    
    The recently added randomForestPredict revealed a few issues with
    calling decisionTreePredict in a parfor context. In detail, the complex
    function call graphs did not allow for inferring the number of parfor
    iterations during compilation time, leading to nested parfor with
    different degree of parallelism but only a workload of 1 task in the
    outer and inner loops. This special case lead to rewrites that changed
    left indexing operations into result variables to assignments which in
    turn caused issues during result merge.
    
    This patch resolves these issues by falling back from parfor loops with
    1 iteration to basic for-loops during runtime. Besides preventing many
    special cases, this also avoid unnecessary optimization overhead.
    
    Furthermore, this patch also introduced the new lmPredictStats (called
    from lmPredict and randomForestPredict) to avoid code duplication.
---
 scripts/builtin/lmPredict.dml                      | 15 +------
 scripts/builtin/lmPredictStats.dml                 | 51 ++++++++++++++++++++++
 scripts/builtin/randomForestPredict.dml            | 25 +++--------
 .../java/org/apache/sysds/common/Builtins.java     |  1 +
 .../runtime/controlprogram/ParForProgramBlock.java |  9 +++-
 .../recompile/PredicateRecompileTest.java          |  4 ++
 .../recompile/SparsityFunctionRecompileTest.java   |  8 ++--
 .../functions/recompile/SparsityRecompileTest.java |  6 +--
 8 files changed, 78 insertions(+), 41 deletions(-)

diff --git a/scripts/builtin/lmPredict.dml b/scripts/builtin/lmPredict.dml
index d3f9a3e130..ce332a300c 100644
--- a/scripts/builtin/lmPredict.dml
+++ b/scripts/builtin/lmPredict.dml
@@ -43,17 +43,6 @@ m_lmPredict = function(Matrix[Double] X, Matrix[Double] B,
   intercept = ifelse(icpt>0 | ncol(X)+1==nrow(B), as.scalar(B[nrow(B),]), 0);
   yhat = X %*% B[1:ncol(X),] + intercept;
 
-  if( verbose ) {
-    y_residual = ytest - yhat;
-    avg_res = sum(y_residual) / nrow(ytest);
-    ss_res = sum(y_residual^2);
-    ss_avg_res = ss_res - nrow(ytest) * avg_res^2;
-    R2 = 1 - ss_res / (sum(ytest^2) - nrow(ytest) * 
(sum(ytest)/nrow(ytest))^2);
-    print("\nAccuracy:" +
-          "\n--sum(ytest) = " + sum(ytest) +
-          "\n--sum(yhat) = " + sum(yhat) +
-          "\n--AVG_RES_Y: " + avg_res +
-          "\n--SS_AVG_RES_Y: " + ss_avg_res +
-          "\n--R2: " + R2 );
-  }
+  if( verbose )
+    lmPredictStats(yhat, ytest);
 }
diff --git a/scripts/builtin/lmPredictStats.dml 
b/scripts/builtin/lmPredictStats.dml
new file mode 100644
index 0000000000..986b98233c
--- /dev/null
+++ b/scripts/builtin/lmPredictStats.dml
@@ -0,0 +1,51 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# This builtin function computes and prints a summary of accuracy
+# measures for regression problems.
+#
+# INPUT:
+# 
------------------------------------------------------------------------------
+# yhat     column vector of predicted response values y
+# ytest    column vector of actual response values y
+# 
------------------------------------------------------------------------------
+#
+# OUTPUT:
+# 
------------------------------------------------------------------------------
+# R        column vector holding avg_res, ss_avg_res, and R2
+# 
------------------------------------------------------------------------------
+
+m_lmPredictStats = function(Matrix[Double] yhat, Matrix[Double] ytest)
+  return (Matrix[Double] R)
+{
+  y_residual = ytest - yhat;
+  avg_res = sum(y_residual) / nrow(ytest);
+  ss_res = sum(y_residual^2);
+  ss_avg_res = ss_res - nrow(ytest) * avg_res^2;
+  R2 = 1 - ss_res / (sum(ytest^2) - nrow(ytest) * (sum(ytest)/nrow(ytest))^2);
+  print("\nAccuracy:" +
+        "\n--sum(ytest) = " + sum(ytest) +
+        "\n--sum(yhat) = " + sum(yhat) +
+        "\n--AVG_RES_Y: " + avg_res +
+        "\n--SS_AVG_RES_Y: " + ss_avg_res +
+        "\n--R2: " + R2 );
+  R = as.matrix(list(avg_res, ss_avg_res, R2));
+}
diff --git a/scripts/builtin/randomForestPredict.dml 
b/scripts/builtin/randomForestPredict.dml
index af2ec157ef..1e08acb6ac 100644
--- a/scripts/builtin/randomForestPredict.dml
+++ b/scripts/builtin/randomForestPredict.dml
@@ -53,8 +53,7 @@ m_randomForestPredict = function(Matrix[Double] X, 
Matrix[Double] y = matrix(0,0
 
   # scoring of num_tree decision trees
   Ytmp = matrix(0, rows=nrow(M), cols=nrow(X));
-  # TODO parfor issue with decisionTreePredict
-  for(i in 1:nrow(M)) {
+  parfor(i in 1:nrow(M)) {
     if( verbose )
       print("randomForest: start scoring tree "+i+"/"+nrow(M)+".");
 
@@ -82,24 +81,10 @@ m_randomForestPredict = function(Matrix[Double] X, 
Matrix[Double] y = matrix(0,0
 
   # summary statistics
   if( yExists & verbose ) {
-    if( classify ) {
-      accuracy = sum(yhat == y) / nrow(y) * 100;
-      print("Accuracy (%): " + accuracy);
-    }
-    else {
-      # TODO eliminate redundancy with lmPredict
-      y_residual = y - yhat;
-      avg_res = sum(y_residual) / nrow(y);
-      ss_res = sum(y_residual^2);
-      ss_avg_res = ss_res - nrow(y) * avg_res^2;
-      R2 = 1 - ss_res / (sum(y^2) - nrow(y) * (sum(y)/nrow(y))^2);
-      print("\nAccuracy:" +
-            "\n--sum(y) = " + sum(y) +
-            "\n--sum(yhat) = " + sum(yhat) +
-            "\n--AVG_RES_Y: " + avg_res +
-            "\n--SS_AVG_RES_Y: " + ss_avg_res +
-            "\n--R2: " + R2 );
-    }
+    if( classify )
+      print("Accuracy (%): " + (sum(yhat == y) / nrow(y) * 100));
+    else
+      lmPredictStats(yhat, y);
   }
 
   if(verbose) {
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java 
b/src/main/java/org/apache/sysds/common/Builtins.java
index 068968eb87..424a01438e 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -197,6 +197,7 @@ public enum Builtins {
        LMCG("lmCG", true),
        LMDS("lmDS", true),
        LMPREDICT("lmPredict", true),
+       LMPREDICT_STATS("lmPredictStats", true),
        LOCAL("local", false),
        LOG("log", false),
        LOGSUMEXP("logSumExp", true),
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java 
b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
index 6dbfc34123..b287b4957c 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
@@ -589,8 +589,15 @@ public class ParForProgramBlock extends ForProgramBlock
                //early exit on num iterations (e.g., for invalid loop bounds)
                _numIterations = UtilFunctions.getSeqLength( 
                        from0.getDoubleValue(), to0.getDoubleValue(), 
incr0.getDoubleValue(), false);
+               // avoid unnecessary optimization/initialization, and issue 
with simplification rewrites
+               // (e.g., rewriting leftindexing into a result variable, to 
assignment)
                if( _numIterations <= 0 )
-                       return; //avoid unnecessary optimization/initialization
+                       return; 
+               if( _numIterations == 1 ) {
+                       //fallback to basic for loop.
+                       super.execute(ec);
+                       return;
+               }
                
                IntObject from = new IntObject(from0.getLongValue());
                IntObject to = new IntObject(to0.getLongValue());
diff --git 
a/src/test/java/org/apache/sysds/test/functions/recompile/PredicateRecompileTest.java
 
b/src/test/java/org/apache/sysds/test/functions/recompile/PredicateRecompileTest.java
index b98c108b1a..e4dd440989 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/recompile/PredicateRecompileTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/recompile/PredicateRecompileTest.java
@@ -282,6 +282,8 @@ public class PredicateRecompileTest extends 
AutomatedTestBase
                                                4 - ((evalExpr||constFold)?4:0) 
:
                                                3 - 
((evalExpr||constFold)?3:0));
                                                //+ 
((!testname.equals(TEST_NAME2)&&!(evalExpr||constFold))?1:0); //loop checkpoint
+                                       if(testname.equals(TEST_NAME4))
+                                               expected += 
(evalExpr||constFold)?0:1;
                                        Assert.assertEquals("Unexpected number 
of executed Spark instructions.",
                                                expected, 
Statistics.getNoOfExecutedSPInst());
                                }
@@ -291,6 +293,8 @@ public class PredicateRecompileTest extends 
AutomatedTestBase
                                                4 - ((evalExpr||constFold)?1:0) 
:
                                                3 - 
((evalExpr||constFold)?1:0));
                                                //+ 
(!testname.equals(TEST_NAME2)?1:0); //loop checkpoint
+                                       if(testname.equals(TEST_NAME4))
+                                               expected += 1;
                                        Assert.assertEquals("Unexpected number 
of executed Spark instructions.", 
                                                expected, 
Statistics.getNoOfExecutedSPInst());
                                }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/recompile/SparsityFunctionRecompileTest.java
 
b/src/test/java/org/apache/sysds/test/functions/recompile/SparsityFunctionRecompileTest.java
index 5534914958..2b2af0a102 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/recompile/SparsityFunctionRecompileTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/recompile/SparsityFunctionRecompileTest.java
@@ -174,20 +174,20 @@ public class SparsityFunctionRecompileTest extends 
AutomatedTestBase
                        HDFSTool.writeMetaDataFile(input("V.mtd"), 
ValueType.FP64, mc, FileFormat.TEXT);
                        
                        boolean exceptionExpected = false;
-                       runTest(true, exceptionExpected, null, -1); 
+                       runTest(true, exceptionExpected, null, -1);
                        
                        //CHECK compiled Spark jobs
                        int expectNumCompiled = 1 //rblk
                                + (testname.equals(TEST_NAME2) ? (IPA?2:5) : 
(IPA?3:4)) //if no write on IPA
-                               + (testname.equals(TEST_NAME4)? 2 : 0); //(+2 
parfor resultmerge);
+                               + (testname.equals(TEST_NAME4)? 0 : 0); //(+2 
parfor resultmerge);
                        Assert.assertEquals("Unexpected number of compiled 
Spark jobs.", 
                                expectNumCompiled, 
Statistics.getNoOfCompiledSPInst());
                
                        //CHECK executed Spark jobs
                        int expectNumExecuted = recompile ? 
-                               (testname.equals(TEST_NAME4)?2:0) : //(2x 
resultmerge) 
+                               (testname.equals(TEST_NAME4)?0:0) : //(2x 
resultmerge) 
                                (testname.equals(TEST_NAME2) ? (IPA?3:5) :
-                                       (testname.equals(TEST_NAME4) ? 
(IPA?6:7) : (IPA?4:5)));
+                                       (testname.equals(TEST_NAME4) ? 
(IPA?4:5) : (IPA?4:5)));
                        Assert.assertEquals("Unexpected number of executed 
Spark jobs.", 
                                expectNumExecuted, 
Statistics.getNoOfExecutedSPInst());
 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/recompile/SparsityRecompileTest.java
 
b/src/test/java/org/apache/sysds/test/functions/recompile/SparsityRecompileTest.java
index d3665c57b8..6f07068ac9 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/recompile/SparsityRecompileTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/recompile/SparsityRecompileTest.java
@@ -130,15 +130,15 @@ public class SparsityRecompileTest extends 
AutomatedTestBase
                        
                        //CHECK compiled Spark jobs
                        int expectNumCompiled = 
(testname.equals(TEST_NAME2)?3:4) //-1 for if
-                               + (testname.equals(TEST_NAME4)?3:0);//(+2 
resultmerge, 1 sum)
+                               + (testname.equals(TEST_NAME4)?1:0);//(+2 
resultmerge, 1 sum)
                        Assert.assertEquals("Unexpected number of compiled 
Spark jobs.", 
                                expectNumCompiled, 
Statistics.getNoOfCompiledSPInst());
                
                        //CHECK executed Spark jobs
                        int expectNumExecuted = recompile ?
-                               ((testname.equals(TEST_NAME4))?2:0) : //(+2 
resultmerge)
+                               ((testname.equals(TEST_NAME4))?0:0) : //(+2 
resultmerge)
                                (testname.equals(TEST_NAME2)?3:4) //reblock + 3 
(-1 for if)
-                                       + ((testname.equals(TEST_NAME4))?3:0); 
//(+2 resultmerge, 1 sum) 
+                                       + ((testname.equals(TEST_NAME4))?1:0); 
//(+2 resultmerge, 1 sum) 
                        Assert.assertEquals("Unexpected number of executed 
Spark jobs.", 
                                expectNumExecuted, 
Statistics.getNoOfExecutedSPInst());
                        

Reply via email to