Repository: spark
Updated Branches:
  refs/heads/master d109a1bee -> 72353311d


[SPARK-15092][SPARK-15139][PYSPARK][ML] Pyspark TreeEnsemble missing methods

## What changes were proposed in this pull request?

Add `toDebugString` and `totalNumNodes` to `TreeEnsembleModels` and add 
`toDebugString` to `DecisionTreeModel`

## How was this patch tested?

Extended doc tests.

Author: Holden Karau <hol...@us.ibm.com>

Closes #12919 from holdenk/SPARK-15139-pyspark-treeEnsemble-missing-methods.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/72353311
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/72353311
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/72353311

Branch: refs/heads/master
Commit: 72353311d3a37cb523c5bdd8072ffdff99af9749
Parents: d109a1b
Author: Holden Karau <hol...@us.ibm.com>
Authored: Thu Jun 2 15:55:14 2016 -0700
Committer: Nick Pentreath <ni...@za.ibm.com>
Committed: Thu Jun 2 15:55:14 2016 -0700

----------------------------------------------------------------------
 python/pyspark/ml/classification.py | 20 +++++++++++++
 python/pyspark/ml/regression.py     | 48 +++++++++++++++++++++++++++++++-
 2 files changed, 67 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/72353311/python/pyspark/ml/classification.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/classification.py 
b/python/pyspark/ml/classification.py
index ea660d7..177cf9d 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -512,6 +512,8 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, 
HasLabelCol, HasPred
     1
     >>> model.featureImportances
     SparseVector(1, {0: 1.0})
+    >>> print(model.toDebugString)
+    DecisionTreeClassificationModel (uid=...) of depth 1 with 3 nodes...
     >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
     >>> result = model.transform(test0).head()
     >>> result.prediction
@@ -650,6 +652,8 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, 
HasLabelCol, HasPred
     >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], 
["features"])
     >>> model.transform(test1).head().prediction
     1.0
+    >>> model.trees
+    [DecisionTreeClassificationModel (uid=...) of depth..., 
DecisionTreeClassificationModel...]
     >>> rfc_path = temp_path + "/rfc"
     >>> rf.save(rfc_path)
     >>> rf2 = RandomForestClassifier.load(rfc_path)
@@ -730,6 +734,12 @@ class RandomForestClassificationModel(TreeEnsembleModels, 
JavaMLWritable, JavaML
         """
         return self._call_java("featureImportances")
 
+    @property
+    @since("2.0.0")
+    def trees(self):
+        """Trees in this ensemble. Warning: These have null parent 
Estimators."""
+        return [DecisionTreeClassificationModel(m) for m in 
list(self._call_java("trees"))]
+
 
 @inherit_doc
 class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, 
HasPredictionCol, HasMaxIter,
@@ -772,6 +782,10 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, 
HasLabelCol, HasPredictionCol
     >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], 
["features"])
     >>> model.transform(test1).head().prediction
     1.0
+    >>> model.totalNumNodes
+    15
+    >>> print(model.toDebugString)
+    GBTClassificationModel (uid=...)...with 5 trees...
     >>> gbtc_path = temp_path + "gbtc"
     >>> gbt.save(gbtc_path)
     >>> gbt2 = GBTClassifier.load(gbtc_path)
@@ -869,6 +883,12 @@ class GBTClassificationModel(TreeEnsembleModels, 
JavaMLWritable, JavaMLReadable)
         """
         return self._call_java("featureImportances")
 
+    @property
+    @since("2.0.0")
+    def trees(self):
+        """Trees in this ensemble. Warning: These have null parent 
Estimators."""
+        return [DecisionTreeRegressionModel(m) for m in 
list(self._call_java("trees"))]
+
 
 @inherit_doc
 class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, 
HasProbabilityCol,

http://git-wip-us.apache.org/repos/asf/spark/blob/72353311/python/pyspark/ml/regression.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 1b7af7e..7c79ab7 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -593,7 +593,7 @@ class RandomForestParams(TreeEnsembleParams):
     featureSubsetStrategy = \
         Param(Params._dummy(), "featureSubsetStrategy",
               "The number of features to consider for splits at each tree 
node. Supported " +
-              "options: " + ", ".join(supportedFeatureSubsetStrategies),
+              "options: " + ", ".join(supportedFeatureSubsetStrategies) + " 
(0.0-1.0], [1-n].",
               typeConverter=TypeConverters.toString)
 
     def __init__(self):
@@ -744,6 +744,12 @@ class DecisionTreeModel(JavaModel):
         """Return depth of the decision tree."""
         return self._call_java("depth")
 
+    @property
+    @since("2.0.0")
+    def toDebugString(self):
+        """Full description of model."""
+        return self._call_java("toDebugString")
+
     def __repr__(self):
         return self._call_java("toString")
 
@@ -759,11 +765,35 @@ class TreeEnsembleModels(JavaModel):
     """
 
     @property
+    @since("2.0.0")
+    def trees(self):
+        """Trees in this ensemble. Warning: These have null parent 
Estimators."""
+        return [DecisionTreeModel(m) for m in list(self._call_java("trees"))]
+
+    @property
+    @since("2.0.0")
+    def getNumTrees(self):
+        """Number of trees in ensemble."""
+        return self._call_java("getNumTrees")
+
+    @property
     @since("1.5.0")
     def treeWeights(self):
         """Return the weights for each tree"""
         return list(self._call_java("javaTreeWeights"))
 
+    @property
+    @since("2.0.0")
+    def totalNumNodes(self):
+        """Total number of nodes, summed over all trees in the ensemble."""
+        return self._call_java("totalNumNodes")
+
+    @property
+    @since("2.0.0")
+    def toDebugString(self):
+        """Full description of model."""
+        return self._call_java("toDebugString")
+
     def __repr__(self):
         return self._call_java("toString")
 
@@ -825,6 +855,10 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, 
HasLabelCol, HasPredi
     >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
     >>> model.transform(test0).head().prediction
     0.0
+    >>> model.trees
+    [DecisionTreeRegressionModel (uid=...) of depth..., 
DecisionTreeRegressionModel...]
+    >>> model.getNumTrees
+    2
     >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], 
["features"])
     >>> model.transform(test1).head().prediction
     0.5
@@ -898,6 +932,12 @@ class RandomForestRegressionModel(TreeEnsembleModels, 
JavaMLWritable, JavaMLRead
 
     @property
     @since("2.0.0")
+    def trees(self):
+        """Trees in this ensemble. Warning: These have null parent 
Estimators."""
+        return [DecisionTreeRegressionModel(m) for m in 
list(self._call_java("trees"))]
+
+    @property
+    @since("2.0.0")
     def featureImportances(self):
         """
         Estimate of the importance of each feature.
@@ -1045,6 +1085,12 @@ class GBTRegressionModel(TreeEnsembleModels, 
JavaMLWritable, JavaMLReadable):
         """
         return self._call_java("featureImportances")
 
+    @property
+    @since("2.0.0")
+    def trees(self):
+        """Trees in this ensemble. Warning: These have null parent 
Estimators."""
+        return [DecisionTreeRegressionModel(m) for m in 
list(self._call_java("trees"))]
+
 
 @inherit_doc
 class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, 
HasPredictionCol,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to