Copilot commented on code in PR #55728:
URL: https://github.com/apache/spark/pull/55728#discussion_r3200111161
##########
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala:
##########
@@ -712,17 +853,44 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
)
val rdd = sc.parallelize(arr.toImmutableArraySeq)
- val strategy = new OldStrategy(algo = OldAlgo.Regression, impurity =
Variance, maxDepth = 4,
- numClasses = 0, maxBins = 32)
-
- val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1,
featureSubsetStrategy = "auto",
- seed = 42, instr = None).head
-
- val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1,
featureSubsetStrategy = "auto",
- seed = 42, instr = None, prune = false).head
+ val strategy = new OldStrategy(
+ algo = OldAlgo.Regression,
+ impurity = Variance,
+ maxDepth = 4,
+ numClasses = 0,
+ maxBins = 32)
+
+ val prunedTree = RandomForest
+ .run(
+ rdd,
+ strategy,
+ numTrees = 1,
+ featureSubsetStrategy = "auto",
+ seed = 42,
+ instr = None,
+ prune = true)
+ .head
+
+ val unprunedTree = RandomForest
+ .run(
+ rdd,
+ strategy,
+ numTrees = 1,
+ featureSubsetStrategy = "auto",
+ seed = 42,
+ instr = None,
+ prune = false)
+ .head
+
+ val defaultBehaviorTree = RandomForest
+ .run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", seed =
42, instr = None)
Review Comment:
`RandomForest.run` no longer accepts a `prune` named argument, but this call
still passes `prune = false`. This will not compile; set `strategy.pruneTree =
false` on the `OldStrategy` instead.
##########
mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala:
##########
@@ -211,10 +211,32 @@ private[ml] trait TreeClassifierParams extends Params {
(value: String) =>
TreeClassifierParams.supportedImpurities.contains(value.toLowerCase(Locale.ROOT)))
- setDefault(impurity -> "gini")
+ /**
+ * If true, the trained tree will undergo a 'pruning' process after training
in which nodes
+ * that have the same class predictions will be merged. This drawback means
that the class
+ * probabilities will be lost. The benefit being that at prediction time
the tree will be
+ * smaller and have faster predictions
+ * If false, the post-training tree will undergo no pruning. The benefit
being that you
+ * maintain the class prediction probabilities
+ * (default = true)
+ * @group param
+ */
+ final val pruneTree: BooleanParam = new BooleanParam(this, "pruneTree", "" +
+ "If true, the trained tree will undergo a 'pruning' process after training
in which nodes" +
+ " that have the same class predictions will be merged. This drawback
means that the class" +
+ " probabilities will be lost. The benefit being that at prediction time
the tree will be" +
+ " smaller and have faster predictions" +
+ " If false, the post-training tree will undergo no pruning. The benefit
being that you" +
+ " maintain the class prediction probabilities"
Review Comment:
The `pruneTree` Scaladoc/help text reads as a single run-on sentence (e.g.,
"faster predictions If false") and contains grammar issues ("This drawback
means..."). Please add punctuation and rephrase for clarity since this text
becomes user-facing Param documentation.
##########
mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala:
##########
@@ -225,7 +234,7 @@ private[spark] object RandomForest extends Logging with
Serializable {
timer.stop("findBestSplits")
if (earlyStopModelSizeThresholdInBytes > 0) {
- val nodes = topNodes.map(_.toNode(prune))
+ val nodes = topNodes.map(_.toNode(strategy.pruneTree))
Review Comment:
Switching pruning from an explicit method parameter to `strategy.pruneTree`
means pruning now depends on the `Strategy` default for all callers. Since
`Strategy.pruneTree` defaults to `false`, this silently changes behavior for
code paths which don’t explicitly set it (e.g., old mllib API, regressors,
GBT). If the intent is only to make pruning configurable, consider defaulting
`pruneTree` to `true` in `Strategy.defaultStrategy(...)` / `Strategy`
constructor (or explicitly setting it in `DecisionTreeParams.getOldStrategy`)
to preserve prior behavior.
##########
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala:
##########
@@ -682,18 +797,44 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
val rdd = sc.parallelize(arr.toImmutableArraySeq)
val numClasses = 2
- val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity =
Gini, maxDepth = 4,
- numClasses = numClasses, maxBins = 32)
-
- val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1,
featureSubsetStrategy = "auto",
- seed = 42, instr = None).head
-
- val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1,
featureSubsetStrategy = "auto",
- seed = 42, instr = None, prune = false).head
+ val strategy = new OldStrategy(
+ algo = OldAlgo.Classification,
+ impurity = Gini,
+ maxDepth = 4,
+ numClasses = numClasses,
+ maxBins = 32)
+
+ val prunedTree = RandomForest
+ .run(
+ rdd,
+ strategy,
+ numTrees = 1,
+ featureSubsetStrategy = "auto",
+ seed = 42,
+ instr = None,
+ prune = true)
+ .head
+
+ val unprunedTree = RandomForest
+ .run(
+ rdd,
+ strategy,
+ numTrees = 1,
+ featureSubsetStrategy = "auto",
+ seed = 42,
+ instr = None,
+ prune = false)
+ .head
Review Comment:
`RandomForest.run` no longer accepts a `prune` named argument, but this call
still passes `prune = false`. This will not compile; use `strategy.pruneTree =
false` (or set it once before the call) to test the unpruned behavior.
##########
mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala:
##########
@@ -77,6 +79,7 @@ class Strategy @Since("1.3.0") (
@Since("1.0.0") @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] =
Map[Int, Int](),
@Since("1.2.0") @BeanProperty var minInstancesPerNode: Int = 1,
@Since("1.2.0") @BeanProperty var minInfoGain: Double = 0.0,
+ @Since("3.1.2") @BeanProperty var pruneTree: Boolean = false,
Review Comment:
The new `pruneTree` field is annotated `@Since("3.1.2")`, but this
repository is Spark 5.0.0 (see `python/pyspark/version.py`). Please update the
`@Since` version to the correct release where this parameter is actually
introduced.
##########
mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala:
##########
@@ -211,10 +211,32 @@ private[ml] trait TreeClassifierParams extends Params {
(value: String) =>
TreeClassifierParams.supportedImpurities.contains(value.toLowerCase(Locale.ROOT)))
- setDefault(impurity -> "gini")
+ /**
+ * If true, the trained tree will undergo a 'pruning' process after training
in which nodes
+ * that have the same class predictions will be merged. This drawback means
that the class
+ * probabilities will be lost. The benefit being that at prediction time
the tree will be
+ * smaller and have faster predictions
+ * If false, the post-training tree will undergo no pruning. The benefit
being that you
+ * maintain the class prediction probabilities
+ * (default = true)
+ * @group param
+ */
+ final val pruneTree: BooleanParam = new BooleanParam(this, "pruneTree", "" +
+ "If true, the trained tree will undergo a 'pruning' process after training
in which nodes" +
+ " that have the same class predictions will be merged. This drawback
means that the class" +
+ " probabilities will be lost. The benefit being that at prediction time
the tree will be" +
+ " smaller and have faster predictions" +
+ " If false, the post-training tree will undergo no pruning. The benefit
being that you" +
+ " maintain the class prediction probabilities"
+ )
+
+ // HERE
Review Comment:
Leftover debug/comment marker `// HERE` should be removed before merging.
##########
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala:
##########
@@ -431,18 +512,32 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
val input = sc.parallelize(arr.map(_.toInstance).toImmutableArraySeq)
// Must set maxBins s.t. the feature will be treated as an ordered
categorical feature.
- val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity =
Gini, maxDepth = 1,
- numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3)
-
- val model = RandomForest.run(input, strategy, numTrees = 1,
featureSubsetStrategy = "all",
- seed = 42, instr = None, prune = false).head
+ val strategy = new OldStrategy(
+ algo = OldAlgo.Classification,
+ impurity = Gini,
+ maxDepth = 1,
+ numClasses = 2,
+ categoricalFeaturesInfo = Map(0 -> 3),
+ maxBins = 3)
+
+ val model = RandomForest
+ .run(
+ input,
+ strategy,
+ numTrees = 1,
+ featureSubsetStrategy = "all",
+ seed = 42,
+ instr = None,
+ prune = false)
+ .head
Review Comment:
`RandomForest.run` no longer takes a `prune` parameter (it was removed from
`ml/tree/impl/RandomForest.scala`), but this test still passes `prune = false`
as a named argument. This will not compile; update the test to set
`strategy.pruneTree` instead (or reintroduce the `prune` parameter if it’s
still intended for testing).
##########
mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala:
##########
@@ -74,6 +74,10 @@ class DecisionTreeClassifier @Since("1.4.0") (
@Since("1.4.0")
def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
+ /** @group setParam */
+ @Since("3.1.2")
Review Comment:
`pruneTree` is introduced as a new API but is annotated `@Since("3.1.2")`.
This should be updated to the Spark version which will actually ship this
change (the repo is currently 5.0.0).
##########
python/pyspark/ml/classification.py:
##########
@@ -1838,7 +1841,7 @@ def setParams(
"""
setParams(self, \\*, featuresCol="features", labelCol="label",
predictionCol="prediction", \
probabilityCol="probability",
rawPredictionCol="rawPrediction", \
- maxDepth=5, maxBins=32, minInstancesPerNode=1,
minInfoGain=0.0, \
+ maxDepth=5, maxBins=32, minInstancesPerNode=1,
minInfoGain=0.0, pruneTree=True\
Review Comment:
The generated docstring signature is missing a comma/line-continuation
between `pruneTree=True` and `maxMemoryInMB=...`, so the documented call
signature is malformed. Please add the missing separator so the docstring
matches the real Python signature.
##########
python/pyspark/ml/classification.py:
##########
@@ -2097,7 +2108,7 @@ def __init__(
"""
__init__(self, \\*, featuresCol="features", labelCol="label",
predictionCol="prediction", \
probabilityCol="probability",
rawPredictionCol="rawPrediction", \
- maxDepth=5, maxBins=32, minInstancesPerNode=1,
minInfoGain=0.0, \
+ maxDepth=5, maxBins=32, minInstancesPerNode=1,
minInfoGain=0.0, pruneTree=True\
Review Comment:
The generated docstring signature is missing a comma/line-continuation
between `pruneTree=True` and `maxMemoryInMB=...`, so the documented call
signature is malformed. Please add the missing separator so the docstring
matches the real Python signature.
##########
python/pyspark/ml/tree.py:
##########
@@ -424,6 +432,12 @@ def getImpurity(self) -> str:
Gets the value of impurity or its default value.
"""
return self.getOrDefault(self.impurity)
+ @since("3.1.2")
+ def getPruneTree(self):
+ """
+ Gets the value of pruneTree or its default value.
+ """
+ return self.getOrDefault(self.pruneTree)
Review Comment:
`getPruneTree` should follow the typing and versioning conventions used by
neighboring getters: add a return type annotation (`-> bool`) and update the
`@since` version (currently `3.1.2`) to the Spark version where this new param
is introduced (repo is 5.0.0).
##########
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala:
##########
@@ -682,18 +797,44 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
val rdd = sc.parallelize(arr.toImmutableArraySeq)
val numClasses = 2
- val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity =
Gini, maxDepth = 4,
- numClasses = numClasses, maxBins = 32)
-
- val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1,
featureSubsetStrategy = "auto",
- seed = 42, instr = None).head
-
- val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1,
featureSubsetStrategy = "auto",
- seed = 42, instr = None, prune = false).head
+ val strategy = new OldStrategy(
+ algo = OldAlgo.Classification,
+ impurity = Gini,
+ maxDepth = 4,
+ numClasses = numClasses,
+ maxBins = 32)
+
+ val prunedTree = RandomForest
+ .run(
+ rdd,
+ strategy,
+ numTrees = 1,
+ featureSubsetStrategy = "auto",
+ seed = 42,
+ instr = None,
+ prune = true)
+ .head
+
+ val unprunedTree = RandomForest
+ .run(
+ rdd,
+ strategy,
+ numTrees = 1,
+ featureSubsetStrategy = "auto",
+ seed = 42,
+ instr = None,
+ prune = false)
+ .head
+
+ val defaultBehaviorTree = RandomForest
+ .run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", seed =
42, instr = None)
Review Comment:
`RandomForest.run` no longer accepts a `prune` named argument, but this call
still passes `prune = true`. This will not compile; use `strategy.pruneTree =
true` for the call path being tested (and similarly update the corresponding
unpruned call).
##########
python/pyspark/ml/tree.py:
##########
@@ -415,6 +415,14 @@ class _TreeClassifierParams(Params):
typeConverter=TypeConverters.toString,
)
+ pruneTree = Param(Params._dummy(), "pruneTree", "" +
+ "If true, the trained tree will undergo a 'pruning'
process after training in which nodes" +
+ " that have the same class predictions will be merged.
This drawback means that the class" +
+ " probabilities will be lost. The benefit being that at
prediction time the tree will be" +
+ " smaller and have faster predictions" +
+ " If false, the post-training tree will undergo no
pruning. The benefit being that you" +
+ " maintain the class prediction probabilities",
typeConverter=TypeConverters.toBoolean)
Review Comment:
This PR adds a new user-facing Python param (`pruneTree`) and new setter(s),
but there’s no corresponding PySpark test asserting the param is
exposed/round-trips correctly (e.g., `setPruneTree(False)` affects
`getPruneTree()` / JVM param map, and default matches the intended behavior).
Consider adding coverage in `python/pyspark/ml/tests/test_classification.py`
alongside the existing DecisionTreeClassifier/RandomForestClassifier tests.
##########
mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala:
##########
@@ -76,6 +76,10 @@ class RandomForestClassifier @Since("1.4.0") (
@Since("1.4.0")
def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
+ /** @group setParam */
+ @Since("3.1.2")
Review Comment:
`pruneTree` is introduced as a new API but is annotated `@Since("3.1.2")`.
This should be updated to the Spark version which will actually ship this
change (the repo is currently 5.0.0).
##########
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala:
##########
@@ -712,17 +853,44 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
)
val rdd = sc.parallelize(arr.toImmutableArraySeq)
- val strategy = new OldStrategy(algo = OldAlgo.Regression, impurity =
Variance, maxDepth = 4,
- numClasses = 0, maxBins = 32)
-
- val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1,
featureSubsetStrategy = "auto",
- seed = 42, instr = None).head
-
- val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1,
featureSubsetStrategy = "auto",
- seed = 42, instr = None, prune = false).head
+ val strategy = new OldStrategy(
+ algo = OldAlgo.Regression,
+ impurity = Variance,
+ maxDepth = 4,
+ numClasses = 0,
+ maxBins = 32)
+
+ val prunedTree = RandomForest
+ .run(
+ rdd,
+ strategy,
+ numTrees = 1,
+ featureSubsetStrategy = "auto",
+ seed = 42,
+ instr = None,
+ prune = true)
+ .head
+
+ val unprunedTree = RandomForest
+ .run(
+ rdd,
+ strategy,
+ numTrees = 1,
+ featureSubsetStrategy = "auto",
+ seed = 42,
+ instr = None,
+ prune = false)
+ .head
+
+ val defaultBehaviorTree = RandomForest
+ .run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", seed =
42, instr = None)
Review Comment:
`RandomForest.run` no longer accepts a `prune` named argument, but this
regression test still passes `prune = true`. This will not compile; set
`strategy.pruneTree = true` before calling `run` (and mirror for the unpruned
call).
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]