[ 
https://issues.apache.org/jira/browse/SPARK-19449?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=15851576#comment-15851576
 ] 

Aseem Bansal commented on SPARK-19449:
--------------------------------------

Isn't the decision tree debug string print it as a series of IF-ELSE? I printed 
the debug string for the 2 random forest models and it was exactly the same. In 
other words the 2 implementations should be mathematically equivalent. 

The random processes for selecting data should not cause any issues as I 
ensured that the exact same data is going to both versions. It works for 
decision trees and random forest classifier is just majority vote of bunch of 
decision trees classifiers so I cannot see how that could be different.

> Inconsistent results between ml package RandomForestClassificationModel and 
> mllib package RandomForestModel
> -----------------------------------------------------------------------------------------------------------
>
>                 Key: SPARK-19449
>                 URL: https://issues.apache.org/jira/browse/SPARK-19449
>             Project: Spark
>          Issue Type: Bug
>          Components: ML, MLlib
>    Affects Versions: 2.1.0
>            Reporter: Aseem Bansal
>
> I worked on some code to convert ml package RandomForestClassificationModel 
> to mllib package RandomForestModel. It was needed because we need to make 
> predictions on the order of ms. I found that the results are inconsistent 
> although the underlying DecisionTreeModel are exactly the same. So the 
> behavior between the 2 implementations is inconsistent which should not be 
> the case.
> The below code can be used to reproduce the issue. Can run this as a simple 
> Java app as long as you have spark dependencies set up properly.
> {noformat}
> import org.apache.spark.ml.Transformer;
> import org.apache.spark.ml.classification.*;
> import org.apache.spark.ml.linalg.*;
> import org.apache.spark.ml.regression.RandomForestRegressionModel;
> import org.apache.spark.mllib.linalg.DenseVector;
> import org.apache.spark.mllib.linalg.Vector;
> import org.apache.spark.mllib.tree.configuration.Algo;
> import org.apache.spark.mllib.tree.model.DecisionTreeModel;
> import org.apache.spark.mllib.tree.model.RandomForestModel;
> import org.apache.spark.sql.Dataset;
> import org.apache.spark.sql.Row;
> import org.apache.spark.sql.RowFactory;
> import org.apache.spark.sql.SparkSession;
> import org.apache.spark.sql.types.DataTypes;
> import org.apache.spark.sql.types.Metadata;
> import org.apache.spark.sql.types.StructField;
> import org.apache.spark.sql.types.StructType;
> import scala.Enumeration;
> import java.util.ArrayList;
> import java.util.List;
> import java.util.Random;
> abstract class Predictor {
>     abstract double predict(Vector vector);
> }
> public class MainConvertModels {
>     public static final int seed = 42;
>     public static void main(String[] args) {
>         int numRows = 1000;
>         int numFeatures = 3;
>         int numClasses = 2;
>         double trainFraction = 0.8;
>         double testFraction = 0.2;
>         SparkSession spark = SparkSession.builder()
>                 .appName("conversion app")
>                 .master("local")
>                 .getOrCreate();
>         Dataset<Row> data = getDummyData(spark, numRows, numFeatures, 
> numClasses);
>         Dataset<Row>[] splits = data.randomSplit(new double[]{trainFraction, 
> testFraction}, seed);
>         Dataset<Row> trainingData = splits[0];
>         Dataset<Row> testData = splits[1];
>         testData.cache();
>         List<Double> labels = getLabels(testData);
>         List<DenseVector> features = getFeatures(testData);
>         DecisionTreeClassifier classifier1 = new DecisionTreeClassifier();
>         DecisionTreeClassificationModel model1 = 
> classifier1.fit(trainingData);
>         final DecisionTreeModel convertedModel1 = 
> convertDecisionTreeModel(model1, Algo.Classification());
>         RandomForestClassifier classifier = new RandomForestClassifier();
>         RandomForestClassificationModel model2 = classifier.fit(trainingData);
>         final RandomForestModel convertedModel2 = 
> convertRandomForestModel(model2);
>         System.out.println(
>                 "****** DecisionTreeClassifier\n" +
>                         "** Original **" + getInfo(model1, testData) + "\n" +
>                         "** New      **" + getInfo(new Predictor() {
>                     double predict(Vector vector) {return 
> convertedModel1.predict(vector);}
>                 }, labels, features) + "\n" +
>                         "\n" +
>                 "****** RandomForestClassifier\n" +
>                         "** Original **" + getInfo(model2, testData) + "\n" +
>                         "** New      **" + getInfo(new Predictor() {double 
> predict(Vector vector) {return convertedModel2.predict(vector);}}, labels, 
> features) + "\n" +
>                         "\n" +
>                         "");
>     }
>     static Dataset<Row> getDummyData(SparkSession spark, int numberRows, int 
> numberFeatures, int labelUpperBound) {
>         StructType schema = new StructType(new StructField[]{
>                 new StructField("label", DataTypes.DoubleType, false, 
> Metadata.empty()),
>                 new StructField("features", new VectorUDT(), false, 
> Metadata.empty())
>         });
>         double[][] vectors = prepareData(numberRows, numberFeatures);
>         Random random = new Random(seed);
>         List<Row> dataTest = new ArrayList<>();
>         for (double[] vector : vectors) {
>             double label = (double) random.nextInt(2);
>             dataTest.add(RowFactory.create(label, Vectors.dense(vector)));
>         }
>         return spark.createDataFrame(dataTest, schema);
>     }
>     static double[][] prepareData(int numRows, int numFeatures) {
>         Random random = new Random(seed);
>         double[][] result = new double[numRows][numFeatures];
>         for (int row = 0; row < numRows; row++) {
>             for (int feature = 0; feature < numFeatures; feature++) {
>                 result[row][feature] = random.nextDouble();
>             }
>         }
>         return result;
>     }
>     static String getInfo(Predictor predictor,
>                           List<Double> labels,
>                           List<DenseVector> features) {
>         Long startTime = System.currentTimeMillis();
>         List<Double> predictions = new ArrayList<>();
>         for (DenseVector feature : features) {
>             predictions.add(predictor.predict(feature));
>         }
>         return getInfo(startTime, labels, predictions);
>     }
>     static List<Double> getLabels(Dataset<Row> testData) {
>         List<Double> labels = new ArrayList<>();
>         List<DenseVector> vectors = new ArrayList<>();
>         for (Row row : testData.collectAsList()) {
>             vectors.add(new DenseVector(((org.apache.spark.ml.linalg.Vector) 
> row.get(1)).toArray()));
>             labels.add(row.getDouble(0));
>         }
>         return labels;
>     }
>     static List<DenseVector> getFeatures(Dataset<Row> testData) {
>         List<DenseVector> features = new ArrayList<>();
>         for (Row row : testData.collectAsList()) {
>             features.add(new DenseVector(((org.apache.spark.ml.linalg.Vector) 
> row.get(1)).toArray()));
>         }
>         return features;
>     }
>     static String getInfo(Transformer model, Dataset<Row> testData) {
>         Dataset<Row> predictions = model.transform(testData);
>         predictions.cache();
>         Dataset<Row> correctPredictions = predictions.filter("label == 
> prediction");
>         correctPredictions.cache();
>         Dataset<Row> incorrectPredictions = predictions.filter("label != 
> prediction");
>         incorrectPredictions.cache();
>         Long truePositives = correctPredictions.filter("prediction == 
> 1.0").count();
>         Long trueNegatives = correctPredictions.filter("prediction == 
> 0.0").count();
>         Long falsePositives = incorrectPredictions.filter("prediction == 
> 1.0").count();
>         Long falseNegatives = incorrectPredictions.filter("prediction == 
> 0.0").count();
>         return getInfo(null, truePositives, trueNegatives, falsePositives, 
> falseNegatives);
>     }
>     static String getInfo(Long startTime, List<Double> labels, List<Double> 
> predictions) {
>         Long endTime = System.currentTimeMillis();
>         if (labels.size() != predictions.size()) {
>             throw new RuntimeException("labels size is " + labels.size() +
>                     " but predictions size is " + predictions.size());
>         }
>         Long truePositives = 0L;
>         Long trueNegatives = 0L;
>         Long falsePositives = 0L;
>         Long falseNegatives = 0L;
>         for (int i = 0; i < labels.size(); i++) {
>             double label = labels.get(i);
>             double prediction = predictions.get(i);
>             if (label == prediction) {
>                 if (prediction == 1.0) {
>                     truePositives += 1;
>                 } else {
>                     trueNegatives += 1;
>                 }
>             } else {
>                 if (prediction == 1.0) {
>                     falsePositives += 1;
>                 } else {
>                     falseNegatives += 1;
>                 }
>             }
>         }
>         return getInfo(endTime - startTime, truePositives, trueNegatives, 
> falsePositives, falseNegatives);
>     }
>     static double ratio(Long numerator, Long denominator) {
>         if (numerator == 0 || denominator == 0) {
>             return 0;
>         }
>         return ((double) numerator) / denominator;
>     }
>     static String getInfo(Long timeTakenMilliseconds, Long truePositives, 
> Long trueNegatives, Long falsePositives,
>                           Long falseNegatives) {
>         Long testDataCount = truePositives + trueNegatives + falsePositives + 
> falseNegatives;
>         double accuracy = ratio(truePositives + trueNegatives, testDataCount);
>         double precision = ratio(truePositives, truePositives + 
> falsePositives);
>         double recall = ratio(truePositives, truePositives + falseNegatives);
>         String last = "";
>         if (timeTakenMilliseconds != null) {
>             last = ", Average time taken (ms) " + 
> ratio(timeTakenMilliseconds, testDataCount);
>         }
>         return (
>                 "true positives " + truePositives
>                         + ", true negatives " + trueNegatives
>                         + ", false positives " + falsePositives
>                         + ", false negatives " + falseNegatives
>                         + ", total " + testDataCount
>                         + "\n\t accuracy " + accuracy
>                         + ", precision " + precision
>                         + ", recall " + recall
>                         + last
>         );
>     }
>     static DecisionTreeModel 
> convertDecisionTreeModel(org.apache.spark.ml.tree.DecisionTreeModel model,
>                                                       Enumeration.Value algo) 
> {
>         return new DecisionTreeModel(model.rootNode().toOld(1), algo);
>     }
>     static RandomForestModel 
> convertRandomForestModel(org.apache.spark.ml.tree.TreeEnsembleModel model) {
>         Enumeration.Value algo;
>         if (model instanceof RandomForestRegressionModel) {
>             algo = Algo.Regression();
>         } else {
>             algo = Algo.Classification();
>         }
>         Object[] decisionTreeModels = model.trees();
>         DecisionTreeModel[] convertedDecisionTreeModels = new 
> DecisionTreeModel[decisionTreeModels.length];
>         for (int i = 0; i < decisionTreeModels.length; i++) {
>             org.apache.spark.ml.tree.DecisionTreeModel originalModel = 
> (org.apache.spark.ml.tree.DecisionTreeModel) decisionTreeModels[i];
>             DecisionTreeModel convertedModel = 
> convertDecisionTreeModel(originalModel, algo);
>             convertedDecisionTreeModels[i] = convertedModel;
>         }
>         RandomForestModel result = new RandomForestModel(algo, 
> convertedDecisionTreeModels);
>         return result;
>     }
> }
> {noformat}
> The output looks like the below. In the below the Original refers to ml 
> package version and New refers to mllib package version. 
> - I converted the mllib version Decision tree to ml version Decision tree. 
> Gave both versions same input and I received the exact same output. 
> - Then converted the mllib version Random Forest to ml version Random Forest 
> giving both the same underlying Decision trees (using the previoeus 
> conversion method). Gave both versions same input but I received different 
> output. 
> {noformat}
> ****** DecisionTreeClassifier
> ** Original **true positives 8128, true negatives 1923, false positives 7942, 
> false negatives 1897, total 19890
>      accuracy 0.5053293112116641, precision 0.5057871810827629, recall 
> 0.8107730673316709
> ** New      **true positives 8128, true negatives 1923, false positives 7942, 
> false negatives 1897, total 19890
>      accuracy 0.5053293112116641, precision 0.5057871810827629, recall 
> 0.8107730673316709, Average time taken (ms) 0.001558572146807441
> ****** RandomForestClassifier
> ** Original **true positives 3940, true negatives 5915, false positives 3950, 
> false negatives 6085, total 19890
>      accuracy 0.49547511312217196, precision 0.49936628643852976, recall 
> 0.39301745635910224
> ** New      **true positives 2461, true negatives 7350, false positives 2515, 
> false negatives 7564, total 19890
>      accuracy 0.4932629462041227, precision 0.4945739549839228, recall 
> 0.2454862842892768, Average time taken (ms) 0.01085972850678733
> {noformat}



--
This message was sent by Atlassian JIRA
(v6.3.15#6346)

---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org
For additional commands, e-mail: issues-h...@spark.apache.org

Reply via email to