This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new a6098beade0 [SPARK-42526][ML] Add Classifier.getNumClasses back a6098beade0 is described below commit a6098beade01eac5cf92727e69b3537fcac31b2d Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Wed Feb 22 19:02:01 2023 +0800 [SPARK-42526][ML] Add Classifier.getNumClasses back ### What changes were proposed in this pull request? Add Classifier.getNumClasses back ### Why are the changes needed? some famous libraries like `xgboost` happen to depend on this method, even though it is not a public API so it should be nice to make xgboost integration better. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? update mima Closes #40119 from zhengruifeng/ml_add_classifier_get_num_class. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../apache/spark/ml/classification/Classifier.scala | 19 +++++++++++++++++++ .../ml/classification/DecisionTreeClassifier.scala | 2 +- .../ml/classification/RandomForestClassifier.scala | 2 +- project/MimaExcludes.scala | 2 -- 4 files changed, 21 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 2d7719a29ca..c46be175cb2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -56,6 +56,25 @@ abstract class Classifier[ M <: ClassificationModel[FeaturesType, M]] extends Predictor[FeaturesType, E, M] with ClassifierParams { + /** + * Get the number of classes. This looks in column metadata first, and if that is missing, + * then this assumes classes are indexed 0,1,...,numClasses-1 and computes numClasses + * by finding the maximum label value. + * + * Label validation (ensuring all labels are integers >= 0) needs to be handled elsewhere, + * such as in `extractLabeledPoints()`. + * + * @param dataset Dataset which contains a column [[labelCol]] + * @param maxNumClasses Maximum number of classes allowed when inferred from data. If numClasses + * is specified in the metadata, then maxNumClasses is ignored. + * @return number of classes + * @throws IllegalArgumentException if metadata does not specify numClasses, and the + * actual numClasses exceeds maxNumClasses + */ + protected def getNumClasses(dataset: Dataset[_], maxNumClasses: Int = 100): Int = { + DatasetUtils.getNumClasses(dataset, $(labelCol), maxNumClasses) + } + /** @group setParam */ def setRawPredictionCol(value: String): E = set(rawPredictionCol, value).asInstanceOf[E] diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 688d2d18f48..7deefda2eea 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -117,7 +117,7 @@ class DecisionTreeClassifier @Since("1.4.0") ( instr.logPipelineStage(this) instr.logDataset(dataset) val categoricalFeatures = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) - val numClasses = getNumClasses(dataset, $(labelCol)) + val numClasses = getNumClasses(dataset) if (isDefined(thresholds)) { require($(thresholds).length == numClasses, this.getClass.getSimpleName + diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 048e5949e1c..9295425f9d6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -141,7 +141,7 @@ class RandomForestClassifier @Since("1.4.0") ( instr.logDataset(dataset) val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) - val numClasses = getNumClasses(dataset, $(labelCol)) + val numClasses = getNumClasses(dataset) if (isDefined(thresholds)) { require($(thresholds).length == numClasses, this.getClass.getSimpleName + diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 70a7c29b8dc..9741e53452a 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -55,8 +55,6 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.Classifier.extractLabeledPoints"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.Classifier.validateNumClasses"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.Classifier.validateLabel"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.Classifier.getNumClasses"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.Classifier.getNumClasses$default$2"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.OneVsRest.extractInstances"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.OneVsRestModel.extractInstances"), --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org