Copilot commented on code in PR #55758:
URL: https://github.com/apache/spark/pull/55758#discussion_r3207902691
##########
mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala:
##########
@@ -211,10 +211,27 @@ private[ml] trait TreeClassifierParams extends Params {
(value: String) =>
TreeClassifierParams.supportedImpurities.contains(value.toLowerCase(Locale.ROOT)))
- setDefault(impurity -> "gini")
+ /**
+ * If true, the trained tree will undergo a pruning process after training,
in which nodes
+ * with the same class predictions are merged. The resulting tree will be
smaller and have
+ * faster predictions, but class probabilities will be lost.
+ * If false, no pruning is applied after training, and class probabilities
are preserved.
+ * (default = true)
+ * @group param
+ */
+ final val pruneTree: BooleanParam = new BooleanParam(this, "pruneTree", "" +
+ "If true, the trained tree will undergo a pruning process after training,
in which nodes" +
+ " with the same class predictions are merged. The resulting tree will be
smaller and have" +
+ " faster predictions, but class probabilities will be lost." +
+ " If false, no pruning is applied after training, and class probabilities
are preserved."
Review Comment:
The docstring claims pruning means “class probabilities will be lost” and
“preserved” when disabled, but pruning here only merges sibling leaves with
identical predicted class; the resulting merged leaf still carries an
ImpurityCalculator (and thus probabilities), though probability estimates
become less fine-grained. Please reword this param description to avoid
implying probabilities are unavailable/undefined after pruning.
##########
mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala:
##########
@@ -211,10 +211,27 @@ private[ml] trait TreeClassifierParams extends Params {
(value: String) =>
TreeClassifierParams.supportedImpurities.contains(value.toLowerCase(Locale.ROOT)))
- setDefault(impurity -> "gini")
+ /**
+ * If true, the trained tree will undergo a pruning process after training,
in which nodes
+ * with the same class predictions are merged. The resulting tree will be
smaller and have
+ * faster predictions, but class probabilities will be lost.
+ * If false, no pruning is applied after training, and class probabilities
are preserved.
+ * (default = true)
+ * @group param
+ */
+ final val pruneTree: BooleanParam = new BooleanParam(this, "pruneTree", "" +
+ "If true, the trained tree will undergo a pruning process after training,
in which nodes" +
+ " with the same class predictions are merged. The resulting tree will be
smaller and have" +
+ " faster predictions, but class probabilities will be lost." +
+ " If false, no pruning is applied after training, and class probabilities
are preserved."
Review Comment:
Adding pruneTree introduces new user-facing behavior, but the current tests
updated here only exercise the old API (OldStrategy + impl.RandomForest).
Please add ML-level unit tests (e.g., DecisionTreeClassifierSuite /
RandomForestClassifierSuite) verifying setPruneTree(false) affects the
resulting ML model shape/probabilities and that the default matches
pruneTree=true.
##########
python/pyspark/ml/classification.py:
##########
@@ -1861,6 +1864,12 @@ def setMaxBins(self, value: int) ->
"DecisionTreeClassifier":
"""
return self._set(maxBins=value)
+ def setPruneTree(self, value: bool) -> "DecisionTreeClassifier":
+ """
+ Sets the value of :py:attr:`pruneTree`.
+ """
+ return self._set(pruneTree=value)
+
Review Comment:
This new public setter should be annotated with `@since("4.3.0")` (matching
the new param/getter) so the Python API is correctly versioned in the generated
docs.
##########
mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala:
##########
@@ -113,12 +116,13 @@ class Strategy @Since("1.3.0") (
categoricalFeaturesInfo: Map[Int, Int],
minInstancesPerNode: Int,
minInfoGain: Double,
Review Comment:
This change modifies the existing “backwards compatible” constructor
signature by inserting `pruneTree`, which can break binary compatibility for
downstream code compiled against the previous constructor. Please keep the old
constructor overload (without pruneTree) delegating to the new one with
`pruneTree = true`, and/or avoid adding new fields to the primary constructor
for public APIs.
##########
python/pyspark/ml/tree.py:
##########
@@ -424,6 +431,12 @@ def getImpurity(self) -> str:
Gets the value of impurity or its default value.
"""
return self.getOrDefault(self.impurity)
+ @since("4.3.0")
+ def getPruneTree(self) -> bool:
+ """
+ Gets the value of pruneTree or its default value.
+ """
+ return self.getOrDefault(self.pruneTree)
Review Comment:
For consistency with the rest of this module’s param declarations,
`pruneTree` should be type-annotated (e.g., `Param[bool]`) and separated from
the previous method with a blank line (PEP8) before the `@since` decorator.
This improves static typing and keeps formatting consistent with nearby
params/getters.
##########
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala:
##########
@@ -682,18 +797,45 @@ class RandomForestSuite extends SparkFunSuite with
MLlibTestSparkContext {
val rdd = sc.parallelize(arr.toImmutableArraySeq)
val numClasses = 2
- val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity =
Gini, maxDepth = 4,
- numClasses = numClasses, maxBins = 32)
-
- val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1,
featureSubsetStrategy = "auto",
- seed = 42, instr = None).head
-
- val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1,
featureSubsetStrategy = "auto",
- seed = 42, instr = None, prune = false).head
+ val strategy = new OldStrategy(
+ algo = OldAlgo.Classification,
+ impurity = Gini,
+ maxDepth = 4,
+ numClasses = numClasses,
+ maxBins = 32)
+
+ strategy.pruneTree = true
+ val prunedTree = RandomForest
+ .run(
+ rdd,
+ strategy,
+ numTrees = 1,
+ featureSubsetStrategy = "auto",
+ seed = 42,
+ instr = None)
+ .head
+
+ strategy.pruneTree = false
+ val unprunedTree = RandomForest
+ .run(
+ rdd,
+ strategy,
+ numTrees = 1,
+ featureSubsetStrategy = "auto",
+ seed = 42,
+ instr = None)
+ .head
+
+ strategy.pruneTree = true
+ val defaultBehaviorTree = RandomForest
+ .run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", seed =
42, instr = None)
+ .head
assert(prunedTree.numNodes === 5)
assert(unprunedTree.numNodes === 7)
+ assert(defaultBehaviorTree.numNodes == prunedTree.numNodes)
Review Comment:
These new assertions use `==` while surrounding assertions in this suite use
ScalaTest’s `===` for better diagnostics and to avoid accidental reference
equality pitfalls. Please use `===` here for consistency with the rest of the
file.
##########
python/pyspark/ml/classification.py:
##########
@@ -2163,6 +2175,12 @@ def setMaxBins(self, value: int) ->
"RandomForestClassifier":
"""
return self._set(maxBins=value)
+ def setPruneTree(self, value: bool) -> "RandomForestClassifier":
+ """
+ Sets the value of :py:attr:`pruneTree`.
+ """
+ return self._set(pruneTree=value)
+
Review Comment:
This new public setter should be annotated with `@since("4.3.0")` (matching
the new param/getter) so the Python API is correctly versioned in the generated
docs.
##########
python/pyspark/ml/tree.py:
##########
@@ -415,6 +415,13 @@ class _TreeClassifierParams(Params):
typeConverter=TypeConverters.toString,
)
+ pruneTree = Param(Params._dummy(), "pruneTree", "" +
+ "If true, the trained tree will undergo a pruning
process after training, in which nodes" +
+ " with the same class predictions are merged. The
resulting tree will be smaller and have" +
+ " faster predictions, but class probabilities will be
lost." +
+ " If false, no pruning is applied after training, and
class probabilities are preserved.",
Review Comment:
The pruneTree param description says “class probabilities will be lost,” but
pruning in Spark merges sibling leaves with the same predicted class and still
retains an impurity calculator (so probabilities remain available, though less
fine-grained). Please reword to reflect that pruning may change/coarsen
probability estimates rather than removing probabilities entirely.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]