This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 75e7e64f22 [SYSTEMDS-3149] Fix misc issues decisionTree/randomForest 
training
75e7e64f22 is described below

commit 75e7e64f228cccfe71017199799298e227e4bd23
Author: Matthias Boehm <[email protected]>
AuthorDate: Tue Apr 11 21:31:27 2023 +0200

    [SYSTEMDS-3149] Fix misc issues decisionTree/randomForest training
    
    This patch fixes various issues in the new decisionTree and randomForest
    built-in functions as well as adds new and stricter tests:
    
    * randomForest validation checks and parameters (consistent to DT)
    * randomForest correct feature map with feature_frac=1.0
    * decisionTree simplification of leaf label computation
    * synchronized deep copy of hop-DAGs to avoid race conditions in parfor
    * added missing size propagation on spark rev operations
    * new tests with randomForest that check equivalent results to DT
      with num_tree=1 and reasonable results with larger ensembles
---
 scripts/builtin/decisionTree.dml                     |  4 ++--
 scripts/builtin/randomForest.dml                     | 20 +++++++++++++-------
 .../instructions/spark/ReorgSPInstruction.java       |  5 ++++-
 .../apache/sysds/runtime/util/ProgramConverter.java  |  5 ++++-
 .../part1/BuiltinDecisionTreeRealDataTest.java       | 20 +++++++++++++++++---
 .../functions/builtin/decisionTreeRealData.dml       | 17 +++++++++++++----
 6 files changed, 53 insertions(+), 18 deletions(-)

diff --git a/scripts/builtin/decisionTree.dml b/scripts/builtin/decisionTree.dml
index 5e72127dd1..4d4e273c65 100644
--- a/scripts/builtin/decisionTree.dml
+++ b/scripts/builtin/decisionTree.dml
@@ -226,8 +226,8 @@ computeLeafLabel = function(Matrix[Double] y2, 
Matrix[Double] I, Boolean classif
   return(Double label)
 {
   f = (I %*% y2) / sum(I);
-  label = ifelse(classify,
-    as.scalar(rowIndexMax(f)), sum(t(f)*seq(1,ncol(f))));
+  label = as.scalar(ifelse(classify,
+    rowIndexMax(f), f %*% seq(1,ncol(f))));
   if(verbose)
     print("-- leaf node label: " + label +" ("+sum(I)*max(f)+"/"+sum(I)+")");
 }
diff --git a/scripts/builtin/randomForest.dml b/scripts/builtin/randomForest.dml
index 176628e3a6..37f7f64fb4 100644
--- a/scripts/builtin/randomForest.dml
+++ b/scripts/builtin/randomForest.dml
@@ -37,6 +37,7 @@
 # feature_frac    Sample fraction of features for each tree in the forest
 # max_depth       Maximum depth of the learned tree (stopping criterion)
 # min_leaf        Minimum number of samples in leaf nodes (stopping criterion)
+# min_split       Minimum number of samples in leaf for attempting a split
 # max_features    Parameter controlling the number of features used as split
 #                 candidates at tree nodes: m = ceil(num_features^max_features)
 # impurity        Impurity measure: entropy, gini (default)
@@ -68,7 +69,7 @@
 
 m_randomForest = function(Matrix[Double] X, Matrix[Double] y, Matrix[Double] 
ctypes,
     Int num_trees = 16, Double sample_frac = 0.1, Double feature_frac = 1.0,
-    Int max_depth = 10, Int min_leaf = 20, Double max_features = 0.5,
+    Int max_depth = 10, Int min_leaf = 20, Int min_split = 50, Double 
max_features = 0.5,
     String impurity = "gini", Int seed = -1, Boolean verbose = FALSE)
   return(Matrix[Double] M)
 {
@@ -81,6 +82,8 @@ m_randomForest = function(Matrix[Double] X, Matrix[Double] y, 
Matrix[Double] cty
   }
   if(ncol(ctypes) != ncol(X)+1)
     stop("randomForest: inconsistent num features (incl. label) and col types: 
"+ncol(X)+" vs "+ncol(ctypes)+".");
+  if( sum(X<=0) != 0 )
+    stop("randomForest: feature matrix X is not properly recoded/binned: 
"+sum(X<=0));
   if(sum(y <= 0) != 0)
     stop("randomForest: y is not properly recoded and binned (contiguous 
positive integers).");
   if(max(y) == 1)
@@ -91,16 +94,19 @@ m_randomForest = function(Matrix[Double] X, Matrix[Double] 
y, Matrix[Double] cty
 
   # training of num_tree decision trees
   M = matrix(0, rows=num_trees, cols=2*(2^max_depth-1));
-  F = matrix(0, rows=num_trees, cols=ncol(X));
+  F = matrix(1, rows=num_trees, cols=ncol(X));
   parfor(i in 1:num_trees) {
     if( verbose )
       print("randomForest: start training tree "+i+"/"+num_trees+".");
 
     # step 1: sample data
-    si1 = as.integer(as.scalar(randSeeds[3*(i-1)+1,1]));
-    I1 = rand(rows=nrow(X), cols=1, seed=si1) <= sample_frac;
-    Xi = removeEmpty(target=X, margin="rows", select=I1);
-    yi = removeEmpty(target=y, margin="rows", select=I1);
+    Xi = X; yi = y;
+    if( sample_frac < 1.0 ) {
+      si1 = as.integer(as.scalar(randSeeds[3*(i-1)+1,1]));
+      I1 = rand(rows=nrow(X), cols=1, seed=si1) <= sample_frac;
+      Xi = removeEmpty(target=X, margin="rows", select=I1);
+      yi = removeEmpty(target=y, margin="rows", select=I1);
+    }
 
     # step 2: sample features
     if( feature_frac < 1.0 ) {
@@ -116,7 +122,7 @@ m_randomForest = function(Matrix[Double] X, Matrix[Double] 
y, Matrix[Double] cty
     # step 3: train decision tree
     t2 = time();
     si3 = as.integer(as.scalar(randSeeds[3*(i-1)+3,1]));
-    Mtemp = decisionTree(X=Xi, y=yi, ctypes=ctypes, max_depth=max_depth,
+    Mtemp = decisionTree(X=Xi, y=yi, ctypes=ctypes, max_depth=max_depth, 
min_split=min_split,
       min_leaf=min_leaf, max_features=max_features, impurity=impurity, 
seed=si3, verbose=verbose);
     M[i,1:length(Mtemp)] = matrix(Mtemp, rows=1, cols=length(Mtemp));
     if( verbose )
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java
index 5b6f2e4e3d..f14afa9009 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java
@@ -234,6 +234,9 @@ public class ReorgSPInstruction extends UnarySPInstruction {
                                boolean ixret = 
sec.getScalarInput(_ixret).getBooleanValue();
                                mcOut.set(mc1.getRows(), ixret?1:mc1.getCols(), 
mc1.getBlocksize(), mc1.getBlocksize());
                        }
+                       else { //e.g., rev
+                               mcOut.set(mc1);
+                       }
                }
                
                //infer initially unknown nnz from input
@@ -241,7 +244,7 @@ public class ReorgSPInstruction extends UnarySPInstruction {
                        boolean sortIx = getOpcode().equalsIgnoreCase("rsort") 
&& sec.getScalarInput(_ixret.getName(), _ixret.getValueType(), 
_ixret.isLiteral()).getBooleanValue();                    
                        if( sortIx )
                                mcOut.setNonZeros(mc1.getRows());
-                       else //default (r', rdiag, rsort data)
+                       else //default (r', rdiag, rev, rsort data)
                                mcOut.setNonZeros(mc1.getNonZeros());
                }
        }
diff --git a/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java 
b/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java
index 34a5287b70..8fbfe31125 100644
--- a/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java
+++ b/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java
@@ -593,7 +593,10 @@ public class ProgramConverter
                                ret.setReadVariables( sb.variablesRead() );
                                
                                //deep copy hops dag for concurrent recompile
-                               ArrayList<Hop> hops = 
Recompiler.deepCopyHopsDag( sb.getHops() );
+                               ArrayList<Hop> hops = sb.getHops();
+                               synchronized(hops) { // guard concurrent 
recompile
+                                       hops = Recompiler.deepCopyHopsDag( hops 
);
+                               }
                                if( !plain )
                                        Recompiler.updateFunctionNames( hops, 
pid );
                                ret.setHops( hops );
diff --git 
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreeRealDataTest.java
 
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreeRealDataTest.java
index 2af6784c36..f797cfff09 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreeRealDataTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreeRealDataTest.java
@@ -24,6 +24,7 @@ import org.apache.sysds.common.Types.ExecType;
 import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.utils.Statistics;
 import org.junit.Assert;
 import org.junit.Test;
 
@@ -42,10 +43,22 @@ public class BuiltinDecisionTreeRealDataTest extends 
AutomatedTestBase {
 
        @Test
        public void testDecisionTreeTitanic() {
-               runDecisionTree(TITANIC_DATA, TITANIC_TFSPEC, 0.875, 
ExecType.CP);
+               runDecisionTree(TITANIC_DATA, TITANIC_TFSPEC, 0.875, 1, 
ExecType.CP);
+       }
+       
+       @Test
+       public void testRandomForestTitanic1() {
+               //one tree with sample_frac=1 should be equivalent to decision 
tree
+               runDecisionTree(TITANIC_DATA, TITANIC_TFSPEC, 0.875, 2, 
ExecType.CP);
+       }
+       
+       @Test
+       public void testRandomForestTitanic8() {
+               //8 trees with sample fraction 0.125 each, accuracy 0.785 due 
to randomness
+               runDecisionTree(TITANIC_DATA, TITANIC_TFSPEC, 0.793, 9, 
ExecType.CP);
        }
 
-       private void runDecisionTree(String data, String tfspec, double minAcc, 
ExecType instType) {
+       private void runDecisionTree(String data, String tfspec, double minAcc, 
int dt, ExecType instType) {
                Types.ExecMode platformOld = setExecMode(instType);
                try {
                        loadTestConfiguration(getTestConfiguration(TEST_NAME));
@@ -53,12 +66,13 @@ public class BuiltinDecisionTreeRealDataTest extends 
AutomatedTestBase {
                        String HOME = SCRIPT_DIR + TEST_DIR;
                        fullDMLScriptName = HOME + TEST_NAME + ".dml";
                        programArgs = new String[] {"-stats",
-                               "-args", data, tfspec, output("R")};
+                               "-args", data, tfspec, String.valueOf(dt), 
output("R")};
 
                        runTest(true, false, null, -1);
 
                        double acc = readDMLMatrixFromOutputDir("R").get(new 
CellIndex(1,1));
                        Assert.assertTrue(acc >= minAcc);
+                       Assert.assertEquals(0, 
Statistics.getNoOfExecutedSPInst());
                }
                finally {
                        rtplatform = platformOld;
diff --git a/src/test/scripts/functions/builtin/decisionTreeRealData.dml 
b/src/test/scripts/functions/builtin/decisionTreeRealData.dml
index 775a73a263..f61b2de77e 100644
--- a/src/test/scripts/functions/builtin/decisionTreeRealData.dml
+++ b/src/test/scripts/functions/builtin/decisionTreeRealData.dml
@@ -30,11 +30,20 @@ Y = X[, ncol(X)]
 X = X[, 1:ncol(X)-1]
 X = imputeByMode(X);
 
-M = decisionTree(X=X, y=Y, ctypes=R, max_features=1, min_split=8, min_leaf=5, 
verbose=TRUE);
-yhat = decisionTreePredict(X=X, y=Y, ctypes=R,  M=M)
+if( $3==1 ) {
+  M = decisionTree(X=X, y=Y, ctypes=R, max_features=1,
+                   min_split=10, min_leaf=4, seed=7, verbose=TRUE);
+  yhat = decisionTreePredict(X=X, y=Y, ctypes=R, M=M)
+}
+else {
+  sf = 1.0/($3-1);
+  M = randomForest(X=X, y=Y, ctypes=R, sample_frac=sf, num_trees=$3-1, 
max_features=1,
+                   min_split=10, min_leaf=4, seed=7, verbose=TRUE);
+  yhat = randomForestPredict(X=X, y=Y, ctypes=R,  M=M)
+}
 
 acc = as.matrix(mean(yhat == Y))
 err = 1-(acc);
-print("accuracy of DT: "+as.scalar(acc))
+print("accuracy: "+as.scalar(acc))
 
-write(acc, $3);
+write(acc, $4);

Reply via email to