This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new c0822512655 [SPARK-42526][ML] Add Classifier.getNumClasses back
c0822512655 is described below

commit c082251265584c442140b152701a58f571048be7
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Wed Feb 22 19:02:01 2023 +0800

    [SPARK-42526][ML] Add Classifier.getNumClasses back
    
    ### What changes were proposed in this pull request?
    Add Classifier.getNumClasses back
    
    ### Why are the changes needed?
    some famous libraries like `xgboost` happen to depend on this method, even 
though it is not a public API
    so it should be nice to make xgboost integration better.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    update mima
    
    Closes #40119 from zhengruifeng/ml_add_classifier_get_num_class.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
    (cherry picked from commit a6098beade01eac5cf92727e69b3537fcac31b2d)
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 .../apache/spark/ml/classification/Classifier.scala   | 19 +++++++++++++++++++
 .../ml/classification/DecisionTreeClassifier.scala    |  2 +-
 .../ml/classification/RandomForestClassifier.scala    |  2 +-
 project/MimaExcludes.scala                            |  2 --
 4 files changed, 21 insertions(+), 4 deletions(-)

diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala 
b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
index 2d7719a29ca..c46be175cb2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -56,6 +56,25 @@ abstract class Classifier[
     M <: ClassificationModel[FeaturesType, M]]
   extends Predictor[FeaturesType, E, M] with ClassifierParams {
 
+  /**
+   * Get the number of classes.  This looks in column metadata first, and if 
that is missing,
+   * then this assumes classes are indexed 0,1,...,numClasses-1 and computes 
numClasses
+   * by finding the maximum label value.
+   *
+   * Label validation (ensuring all labels are integers >= 0) needs to be 
handled elsewhere,
+   * such as in `extractLabeledPoints()`.
+   *
+   * @param dataset       Dataset which contains a column [[labelCol]]
+   * @param maxNumClasses Maximum number of classes allowed when inferred from 
data.  If numClasses
+   *                      is specified in the metadata, then maxNumClasses is 
ignored.
+   * @return number of classes
+   * @throws IllegalArgumentException if metadata does not specify numClasses, 
and the
+   *                                  actual numClasses exceeds maxNumClasses
+   */
+  protected def getNumClasses(dataset: Dataset[_], maxNumClasses: Int = 100): 
Int = {
+    DatasetUtils.getNumClasses(dataset, $(labelCol), maxNumClasses)
+  }
+
   /** @group setParam */
   def setRawPredictionCol(value: String): E = set(rawPredictionCol, 
value).asInstanceOf[E]
 
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 688d2d18f48..7deefda2eea 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -117,7 +117,7 @@ class DecisionTreeClassifier @Since("1.4.0") (
     instr.logPipelineStage(this)
     instr.logDataset(dataset)
     val categoricalFeatures = 
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
-    val numClasses = getNumClasses(dataset, $(labelCol))
+    val numClasses = getNumClasses(dataset)
 
     if (isDefined(thresholds)) {
       require($(thresholds).length == numClasses, this.getClass.getSimpleName +
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 048e5949e1c..9295425f9d6 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -141,7 +141,7 @@ class RandomForestClassifier @Since("1.4.0") (
     instr.logDataset(dataset)
     val categoricalFeatures: Map[Int, Int] =
       MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
-    val numClasses = getNumClasses(dataset, $(labelCol))
+    val numClasses = getNumClasses(dataset)
 
     if (isDefined(thresholds)) {
       require($(thresholds).length == numClasses, this.getClass.getSimpleName +
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 3bb8deb2561..0b0fdefd6b6 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -51,8 +51,6 @@ object MimaExcludes {
     
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.Classifier.extractLabeledPoints"),
     
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.Classifier.validateNumClasses"),
     
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.Classifier.validateLabel"),
-    
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.Classifier.getNumClasses"),
-    
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.Classifier.getNumClasses$default$2"),
     
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.OneVsRest.extractInstances"),
     
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.OneVsRestModel.extractInstances"),
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to