This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 27caa8ba2c [SYSTEMDS-3149] Improved decisionTree 
(robustness/rmEmpty-short-circuit)
27caa8ba2c is described below

commit 27caa8ba2c40111b486d3691bccafb23c1d285bc
Author: Matthias Boehm <[email protected]>
AuthorDate: Fri May 5 20:58:00 2023 +0200

    [SYSTEMDS-3149] Improved decisionTree (robustness/rmEmpty-short-circuit)
    
    This patch improves the robustness and performance of the decisionTree /
    randomForest builtins when performing value and/or feature sampling.
    
    First, when sampling features (max_features<1.0) one might sample zero
    features which led to crashes. The likelihood of this happens directly
    depends on the number_of_features^max_features.
    
    Second, value and feature sampling might sample all values/features,
    especially when very close to max_features. So far, removeEmpty scanned
    despite a form of short-circuiting, the selection vector and created new
    output matrix objects which polluted the buffer pool. We now check
    upfront the number of non-zeros in given selection vectors, short-
    circuit the entire computation, and even return the original input
    meta data object, avoiding duplicates in the buffer pool.
    
    On a scenario of running 1000 removeEmpty with full selection vector on
    a 10M x 10 matrix, this patch improved the runtime stats as follows:
    
    Cache writes (Li/WB/FS/HDFS):   1/1003/0/0.
    Cache times (ACQr/m, RLS, EXP): 0.004/0.012/0.077/0.000 sec.
    Heavy hitter instructions:
      1  rmempty      105.928   1000
    
    -->
    
    Cache writes (Li/WB/FS/HDFS):   1/3/0/0.
    Cache times (ACQr/m, RLS, EXP): 0.004/0.000/0.006/0.000 sec.
    Heavy hitter instructions:
      5  rmempty        0.026   1000
---
 scripts/builtin/decisionTree.dml                                   | 2 ++
 .../runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java | 7 +++++--
 .../java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java  | 6 ++++++
 3 files changed, 13 insertions(+), 2 deletions(-)

diff --git a/scripts/builtin/decisionTree.dml b/scripts/builtin/decisionTree.dml
index 384f591b73..85e414c61d 100644
--- a/scripts/builtin/decisionTree.dml
+++ b/scripts/builtin/decisionTree.dml
@@ -162,6 +162,8 @@ findBestSplit = function(Matrix[Double] X2, Matrix[Double] 
y2, Matrix[Double] fo
   if( max_features < 1.0 ) {
     rI = rand(rows=n, cols=1, seed=seed) <= (n^max_features/n);
     feat = removeEmpty(target=feat, margin="rows", select=rI);
+    if( sum(feat) == 0 ) #sample at least one
+      feat[1,1] = round(rand(rows=1, cols=1, min=1, max=n));
   }
 
   # evaluate features and feature splits
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
index a67f8cd20d..9dfbdbec7f 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
@@ -227,10 +227,13 @@ public class ParameterizedBuiltinCPInstruction extends 
ComputationCPInstruction
 
                                // compute the result
                                boolean emptyReturn = 
Boolean.parseBoolean(params.get("empty.return").toLowerCase());
-                               MatrixBlock soresBlock = 
target.removeEmptyOperations(new MatrixBlock(), margin.equals("rows"), 
emptyReturn, select);
+                               MatrixBlock ret = 
target.removeEmptyOperations(new MatrixBlock(), margin.equals("rows"), 
emptyReturn, select);
 
                                // release locks
-                               ec.setMatrixOutput(output.getName(), 
soresBlock);
+                               if( target == ret ) //short-circuit (avoid 
buffer pool pollution)
+                                       ec.setVariable(output.getName(), 
ec.getVariable(params.get("target")));
+                               else
+                                       ec.setMatrixOutput(output.getName(), 
ret);
                                ec.releaseMatrixInput(params.get("target"));
                                if(params.containsKey("select"))
                                        
ec.releaseMatrixInput(params.get("select"));
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
index 260c8b26a6..587fe74af1 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
@@ -661,6 +661,12 @@ public class LibMatrixReorg {
                        return ret;
                }
                
+               // short-circuit for select-all (shallow-copy input)
+               if( select != null && (select.nonZeros == 
(rows?in.rlen:in.clen)) ) {
+                       return in;
+               }
+               
+               // core removeEmpty
                if( rows )
                        return removeEmptyRows(in, ret, select, emptyReturn);
                else //cols

Reply via email to