This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 27ed89b7be5 [SPARK-38775][ML] cleanup validation functions 27ed89b7be5 is described below commit 27ed89b7be5ebb91e4a0b106b1669a7867a6012d Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Sat Jun 18 21:51:50 2022 -0700 [SPARK-38775][ML] cleanup validation functions ### What changes were proposed in this pull request? 1, remove unused `extractInstances` and `extractLabeledPoints` in `Predictor`; 2, remove unused `checkNonNegativeWeight` in `function`; 3, move `getNumClasses` from `Clasifier` to `DatasetUtils`; 4, move `getNumFeatures` from `MetadataUtils` to `DatasetUtils`; ### Why are the changes needed? to unify to methods ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? existing testsuites Closes #36049 from zhengruifeng/validate_cleanup. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Dongjoon Hyun <dongj...@apache.org> --- .../spark/examples/ml/DeveloperApiExample.scala | 7 +- .../main/scala/org/apache/spark/ml/Predictor.scala | 51 +--------- .../spark/ml/classification/Classifier.scala | 106 +-------------------- .../ml/classification/DecisionTreeClassifier.scala | 3 +- .../spark/ml/classification/FMClassifier.scala | 2 +- .../spark/ml/classification/GBTClassifier.scala | 20 +--- .../ml/classification/RandomForestClassifier.scala | 2 +- .../spark/ml/clustering/GaussianMixture.scala | 2 +- .../evaluation/BinaryClassificationEvaluator.scala | 7 +- .../spark/ml/evaluation/ClusteringEvaluator.scala | 21 ++-- .../spark/ml/evaluation/ClusteringMetrics.scala | 6 +- .../MulticlassClassificationEvaluator.scala | 8 +- .../spark/ml/evaluation/RegressionEvaluator.scala | 16 ++-- .../scala/org/apache/spark/ml/feature/LSH.scala | 2 +- .../org/apache/spark/ml/feature/RobustScaler.scala | 2 +- .../org/apache/spark/ml/feature/Selector.scala | 2 +- .../ml/feature/UnivariateFeatureSelector.scala | 2 +- .../apache/spark/ml/feature/VectorIndexer.scala | 2 +- .../main/scala/org/apache/spark/ml/functions.scala | 6 -- .../apache/spark/ml/regression/FMRegressor.scala | 2 +- .../apache/spark/ml/regression/GBTRegressor.scala | 20 +--- .../regression/GeneralizedLinearRegression.scala | 2 +- .../spark/ml/regression/LinearRegression.scala | 2 +- .../org/apache/spark/ml/util/DatasetUtils.scala | 82 +++++++++++++++- .../org/apache/spark/ml/util/MetadataUtils.scala | 14 +-- .../spark/ml/classification/ClassifierSuite.scala | 44 +-------- project/MimaExcludes.scala | 16 +++- 27 files changed, 152 insertions(+), 297 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index 487cb27b93f..bfee3301f8e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -24,6 +24,7 @@ import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors} import org.apache.spark.ml.param.{IntParam, ParamMap} import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.{Dataset, Row, SparkSession} +import org.apache.spark.sql.functions.col /** * A simple example demonstrating how to write your own learning algorithm using Estimator, @@ -120,8 +121,10 @@ private class MyLogisticRegression(override val uid: String) // This method is used by fit() override protected def train(dataset: Dataset[_]): MyLogisticRegressionModel = { - // Extract columns from data using helper method. - val oldDataset = extractLabeledPoints(dataset) + // Extract columns from data. + val oldDataset = dataset.select(col($(labelCol)).cast("double"), col($(featuresCol))) + .rdd + .map { case Row(l: Double, f: Vector) => LabeledPoint(l, f) } // Do learning to estimate the coefficients vector. val numFeatures = oldDataset.take(1)(0).features.size diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index e0b128e3698..9c6eb880c80 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -18,14 +18,11 @@ package org.apache.spark.ml import org.apache.spark.annotation.Since -import org.apache.spark.ml.feature.{Instance, LabeledPoint} -import org.apache.spark.ml.functions.checkNonNegativeWeight -import org.apache.spark.ml.linalg.{Vector, VectorUDT} +import org.apache.spark.ml.linalg.VectorUDT import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} @@ -63,40 +60,6 @@ private[ml] trait PredictorParams extends Params } SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType) } - - /** - * Extract [[labelCol]], weightCol(if any) and [[featuresCol]] from the given dataset, - * and put it in an RDD with strong types. - */ - protected def extractInstances(dataset: Dataset[_]): RDD[Instance] = { - val w = this match { - case p: HasWeightCol => - if (isDefined(p.weightCol) && $(p.weightCol).nonEmpty) { - checkNonNegativeWeight((col($(p.weightCol)).cast(DoubleType))) - } else { - lit(1.0) - } - } - - dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { - case Row(label: Double, weight: Double, features: Vector) => - Instance(label, weight, features) - } - } - - /** - * Extract [[labelCol]], weightCol(if any) and [[featuresCol]] from the given dataset, - * and put it in an RDD with strong types. - * Validate the output instances with the given function. - */ - protected def extractInstances( - dataset: Dataset[_], - validateInstance: Instance => Unit): RDD[Instance] = { - extractInstances(dataset).map { instance => - validateInstance(instance) - instance - } - } } /** @@ -176,16 +139,6 @@ abstract class Predictor[ override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema, fitting = true, featuresDataType) } - - /** - * Extract [[labelCol]] and [[featuresCol]] from the given dataset, - * and put it in an RDD with strong types. - */ - protected def extractLabeledPoints(dataset: Dataset[_]): RDD[LabeledPoint] = { - dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { - case Row(label: Double, features: Vector) => LabeledPoint(label, features) - } - } } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 09324e2087d..2d7719a29ca 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -17,17 +17,13 @@ package org.apache.spark.ml.classification -import org.apache.spark.SparkException import org.apache.spark.annotation.Since import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams} -import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.HasRawPredictionCol import org.apache.spark.ml.util._ -import org.apache.spark.ml.util.DatasetUtils._ -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, StructType} @@ -44,23 +40,6 @@ private[spark] trait ClassifierParams val parentSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType) SchemaUtils.appendColumn(parentSchema, $(rawPredictionCol), new VectorUDT) } - - /** - * Extract [[labelCol]], weightCol(if any) and [[featuresCol]] from the given dataset, - * and put it in an RDD with strong types. - * Validates the label on the classifier is a valid integer in the range [0, numClasses). - */ - protected def extractInstances( - dataset: Dataset[_], - numClasses: Int): RDD[Instance] = { - val validateInstance = (instance: Instance) => { - val label = instance.label - require(label.toLong == label && label >= 0 && label < numClasses, s"Classifier was given" + - s" dataset with invalid label $label. Labels must be integers in range" + - s" [0, $numClasses).") - } - extractInstances(dataset, validateInstance) - } } /** @@ -81,89 +60,6 @@ abstract class Classifier[ def setRawPredictionCol(value: String): E = set(rawPredictionCol, value).asInstanceOf[E] // TODO: defaultEvaluator (follow-up PR) - - /** - * Extract [[labelCol]] and [[featuresCol]] from the given dataset, - * and put it in an RDD with strong types. - * - * @param dataset DataFrame with columns for labels ([[org.apache.spark.sql.types.NumericType]]) - * and features (`Vector`). - * @param numClasses Number of classes label can take. Labels must be integers in the range - * [0, numClasses). - * @note Throws `SparkException` if any label is a non-integer or is negative - */ - protected def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = { - validateNumClasses(numClasses) - dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { - case Row(label: Double, features: Vector) => - validateLabel(label, numClasses) - LabeledPoint(label, features) - } - } - - /** - * Validates that number of classes is greater than zero. - * - * @param numClasses Number of classes label can take. - */ - protected def validateNumClasses(numClasses: Int): Unit = { - require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" + - s" $numClasses, but requires numClasses > 0.") - } - - /** - * Validates the label on the classifier is a valid integer in the range [0, numClasses). - * - * @param label The label to validate. - * @param numClasses Number of classes label can take. Labels must be integers in the range - * [0, numClasses). - */ - protected def validateLabel(label: Double, numClasses: Int): Unit = { - require(label.toLong == label && label >= 0 && label < numClasses, s"Classifier was given" + - s" dataset with invalid label $label. Labels must be integers in range" + - s" [0, $numClasses).") - } - - /** - * Get the number of classes. This looks in column metadata first, and if that is missing, - * then this assumes classes are indexed 0,1,...,numClasses-1 and computes numClasses - * by finding the maximum label value. - * - * Label validation (ensuring all labels are integers >= 0) needs to be handled elsewhere, - * such as in `extractLabeledPoints()`. - * - * @param dataset Dataset which contains a column [[labelCol]] - * @param maxNumClasses Maximum number of classes allowed when inferred from data. If numClasses - * is specified in the metadata, then maxNumClasses is ignored. - * @return number of classes - * @throws IllegalArgumentException if metadata does not specify numClasses, and the - * actual numClasses exceeds maxNumClasses - */ - protected def getNumClasses(dataset: Dataset[_], maxNumClasses: Int = 100): Int = { - MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match { - case Some(n: Int) => n - case None => - // Get number of classes from dataset itself. - val maxLabelRow: Array[Row] = dataset - .select(max(checkClassificationLabels($(labelCol), Some(maxNumClasses)))) - .take(1) - if (maxLabelRow.isEmpty || maxLabelRow(0).get(0) == null) { - throw new SparkException("ML algorithm was given empty dataset.") - } - val maxDoubleLabel: Double = maxLabelRow.head.getDouble(0) - require((maxDoubleLabel + 1).isValidInt, s"Classifier found max label value =" + - s" $maxDoubleLabel but requires integers in range [0, ... ${Int.MaxValue})") - val numClasses = maxDoubleLabel.toInt + 1 - require(numClasses <= maxNumClasses, s"Classifier inferred $numClasses from label values" + - s" in column $labelCol, but this exceeded the max numClasses ($maxNumClasses) allowed" + - s" to be inferred from values. To avoid this error for labels with > $maxNumClasses" + - s" classes, specify numClasses explicitly in the metadata; this can be done by applying" + - s" StringIndexer to the label column.") - logInfo(this.getClass.getCanonicalName + s" inferred $numClasses classes for" + - s" labelCol=$labelCol since numClasses was not specified in the column metadata.") - numClasses - } - } } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index ec9e779709d..688d2d18f48 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -117,14 +117,13 @@ class DecisionTreeClassifier @Since("1.4.0") ( instr.logPipelineStage(this) instr.logDataset(dataset) val categoricalFeatures = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) - val numClasses = getNumClasses(dataset) + val numClasses = getNumClasses(dataset, $(labelCol)) if (isDefined(thresholds)) { require($(thresholds).length == numClasses, this.getClass.getSimpleName + ".train() called with non-matching numClasses and thresholds.length." + s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") } - validateNumClasses(numClasses) val instances = dataset.select( checkClassificationLabels($(labelCol), Some(numClasses)), diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala index a2e6f0c49ee..51f312cf183 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala @@ -190,7 +190,7 @@ class FMClassifier @Since("3.0.0") ( miniBatchFraction, initStd, maxIter, stepSize, tol, solver, thresholds) instr.logNumClasses(numClasses) - val numFeatures = MetadataUtils.getNumFeatures(dataset, $(featuresCol)) + val numFeatures = getNumFeatures(dataset, $(featuresCol)) instr.logNumFeatures(numFeatures) val handlePersistence = dataset.storageLevel == StorageLevel.NONE diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index a767bc01445..3910beda3d0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -22,14 +22,13 @@ import org.json4s.JsonDSL._ import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging -import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.linalg.{BLAS, DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.impl.GradientBoostedTrees import org.apache.spark.ml.util._ -import org.apache.spark.ml.util.DatasetUtils._ +import org.apache.spark.ml.util.DatasetUtils.extractInstances import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} @@ -169,21 +168,12 @@ class GBTClassifier @Since("1.4.0") ( override protected def train( dataset: Dataset[_]): GBTClassificationModel = instrumented { instr => - - def extractInstances(df: Dataset[_]) = { - df.select( - checkClassificationLabels($(labelCol), Some(2)), - checkNonNegativeWeights(get(weightCol)), - checkNonNanVectors($(featuresCol)) - ).rdd.map { case Row(l: Double, w: Double, v: Vector) => Instance(l, w, v) } - } - val withValidation = isDefined(validationIndicatorCol) && $(validationIndicatorCol).nonEmpty val (trainDataset, validationDataset) = if (withValidation) { - (extractInstances(dataset.filter(not(col($(validationIndicatorCol))))), - extractInstances(dataset.filter(col($(validationIndicatorCol))))) + (extractInstances(this, dataset.filter(not(col($(validationIndicatorCol)))), Some(2)), + extractInstances(this, dataset.filter(col($(validationIndicatorCol))), Some(2))) } else { - (extractInstances(dataset), null) + (extractInstances(this, dataset, Some(2)), null) } val numClasses = 2 @@ -390,7 +380,7 @@ class GBTClassificationModel private[ml]( */ @Since("2.4.0") def evaluateEachIteration(dataset: Dataset[_]): Array[Double] = { - val data = extractInstances(dataset) + val data = extractInstances(this, dataset, Some(2)) GradientBoostedTrees.evaluateEachIteration(data, trees, treeWeights, loss, OldAlgo.Classification) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 25f4e103ac7..048e5949e1c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -141,7 +141,7 @@ class RandomForestClassifier @Since("1.4.0") ( instr.logDataset(dataset) val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) - val numClasses: Int = getNumClasses(dataset) + val numClasses = getNumClasses(dataset, $(labelCol)) if (isDefined(thresholds)) { require($(thresholds).length == numClasses, this.getClass.getSimpleName + diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index bc2fcc03768..03315554b81 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -381,7 +381,7 @@ class GaussianMixture @Since("2.0.0") ( val spark = dataset.sparkSession import spark.implicits._ - val numFeatures = MetadataUtils.getNumFeatures(dataset, $(featuresCol)) + val numFeatures = getNumFeatures(dataset, $(featuresCol)) require(numFeatures < GaussianMixture.MAX_NUM_FEATURES, s"GaussianMixture cannot handle more " + s"than ${GaussianMixture.MAX_NUM_FEATURES} features because the size of the covariance" + s" matrix is quadratic in the number of features.") diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index 93b66f3ab70..1a97eb29100 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -18,11 +18,10 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.Since -import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, MetadataUtils, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.functions._ @@ -129,8 +128,8 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va dataset.select( col($(rawPredictionCol)), col($(labelCol)).cast(DoubleType), - if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) - else checkNonNegativeWeight(col($(weightCol)).cast(DoubleType))).rdd.map { + DatasetUtils.checkNonNegativeWeights(get(weightCol)) + ).rdd.map { case Row(rawPrediction: Vector, label: Double, weight: Double) => (rawPrediction(1), label, weight) case Row(rawPrediction: Double, label: Double, weight: Double) => diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala index fa2c25a5912..143e26f2f74 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala @@ -18,13 +18,11 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.Since -import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol, HasWeightCol} import org.apache.spark.ml.util._ import org.apache.spark.sql.Dataset import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.DoubleType /** * Evaluator for clustering results. @@ -130,18 +128,13 @@ class ClusteringEvaluator @Since("2.3.0") (@Since("2.3.0") override val uid: Str SchemaUtils.checkNumericType(schema, $(weightCol)) } - val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol) - - val vectorCol = DatasetUtils.columnToVector(dataset, $(featuresCol)) - val df = if (!isDefined(weightCol) || $(weightCol).isEmpty) { - dataset.select(col($(predictionCol)), - vectorCol.as($(featuresCol), dataset.schema($(featuresCol)).metadata), - lit(1.0).as(weightColName)) - } else { - dataset.select(col($(predictionCol)), - vectorCol.as($(featuresCol), dataset.schema($(featuresCol)).metadata), - checkNonNegativeWeight(col(weightColName).cast(DoubleType))) - } + val df = dataset.select( + col($(predictionCol)), + DatasetUtils.columnToVector(dataset, $(featuresCol)) + .as($(featuresCol), dataset.schema($(featuresCol)).metadata), + DatasetUtils.checkNonNegativeWeights(get(weightCol)) + .as(if (!isDefined(weightCol)) "weightCol" else $(weightCol)) + ) val metrics = new ClusteringMetrics(df) metrics.setDistanceMeasure($(distanceMeasure)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringMetrics.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringMetrics.scala index ffeb9492777..0106c872297 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringMetrics.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkContext import org.apache.spark.annotation.Since import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.linalg.{BLAS, DenseVector, Vector, Vectors} -import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.ml.util.DatasetUtils._ import org.apache.spark.sql.{Column, DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.DoubleType @@ -293,7 +293,7 @@ private[evaluation] object SquaredEuclideanSilhouette extends Silhouette { predictionCol: String, featuresCol: String, weightCol: String): Map[Double, ClusterStats] = { - val numFeatures = MetadataUtils.getNumFeatures(df, featuresCol) + val numFeatures = getNumFeatures(df, featuresCol) val clustersStatsRDD = df.select( col(predictionCol).cast(DoubleType), col(featuresCol), col("squaredNorm"), col(weightCol)) .rdd @@ -509,7 +509,7 @@ private[evaluation] object CosineSilhouette extends Silhouette { featuresCol: String, predictionCol: String, weightCol: String): Map[Double, (Vector, Double)] = { - val numFeatures = MetadataUtils.getNumFeatures(df, featuresCol) + val numFeatures = getNumFeatures(df, featuresCol) val clustersStatsRDD = df.select( col(predictionCol).cast(DoubleType), col(normalizedFeaturesColName), col(weightCol)) .rdd diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala index beeefde8c5f..023987d09ba 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -18,7 +18,6 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.Since -import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -180,18 +179,13 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType) SchemaUtils.checkNumericType(schema, $(labelCol)) - val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) { - checkNonNegativeWeight(col($(weightCol)).cast(DoubleType)) - } else { - lit(1.0) - } - if ($(metricName) == "logLoss") { // probabilityCol is only needed to compute logloss require(schema.fieldNames.contains($(probabilityCol)), "probabilityCol is needed to compute logloss") } + val w = DatasetUtils.checkNonNegativeWeights(get(weightCol)) val rdd = if (schema.fieldNames.contains($(probabilityCol))) { val p = DatasetUtils.columnToVector(dataset, $(probabilityCol)) dataset.select(col($(predictionCol)), col($(labelCol)).cast(DoubleType), w, p) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala index 902869cc681..9503e9ea11b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -18,10 +18,9 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.Since -import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol, HasWeightCol} -import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.functions._ @@ -120,12 +119,13 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui SchemaUtils.checkNumericType(schema, $(labelCol)) val predictionAndLabelsWithWeights = dataset - .select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType), - if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) - else checkNonNegativeWeight(col($(weightCol)).cast(DoubleType))) - .rdd - .map { case Row(prediction: Double, label: Double, weight: Double) => - (prediction, label, weight) } + .select( + col($(predictionCol)).cast(DoubleType), + col($(labelCol)).cast(DoubleType), + DatasetUtils.checkNonNegativeWeights(get(weightCol)) + ).rdd.map { case Row(prediction: Double, label: Double, weight: Double) => + (prediction, label, weight) + } new RegressionMetrics(predictionAndLabelsWithWeights, $(throughOrigin)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala index 7963fc88697..5254762d210 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala @@ -346,7 +346,7 @@ private[ml] abstract class LSH[T <: LSHModel[T]] override def fit(dataset: Dataset[_]): T = { transformSchema(dataset.schema, logging = true) - val inputDim = MetadataUtils.getNumFeatures(dataset, $(inputCol)) + val inputDim = DatasetUtils.getNumFeatures(dataset, $(inputCol)) val model = createRawLSHModel(inputDim).setParent(this) copyValues(model) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala index e8f325ec584..85352d6bcbd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala @@ -145,7 +145,7 @@ class RobustScaler @Since("3.0.0") (@Since("3.0.0") override val uid: String) override def fit(dataset: Dataset[_]): RobustScalerModel = { transformSchema(dataset.schema, logging = true) - val numFeatures = MetadataUtils.getNumFeatures(dataset, $(inputCol)) + val numFeatures = DatasetUtils.getNumFeatures(dataset, $(inputCol)) val vectors = dataset.select($(inputCol)).rdd.map { case Row(vec: Vector) => require(vec.size == numFeatures, diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Selector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Selector.scala index e24593a01b6..1afab326dd7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Selector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Selector.scala @@ -206,7 +206,7 @@ private[ml] abstract class Selector[T <: SelectorModel[T]] val spark = dataset.sparkSession import spark.implicits._ - val numFeatures = MetadataUtils.getNumFeatures(dataset, $(featuresCol)) + val numFeatures = DatasetUtils.getNumFeatures(dataset, $(featuresCol)) val resultDF = getSelectionTestResult(dataset.toDF) def getTopIndices(k: Int): Array[Int] = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala index 7412c42986f..3b43404072d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala @@ -164,7 +164,7 @@ final class UnivariateFeatureSelector @Since("3.1.1")(@Since("3.1.1") override v @Since("3.1.1") override def fit(dataset: Dataset[_]): UnivariateFeatureSelectorModel = { transformSchema(dataset.schema, logging = true) - val numFeatures = MetadataUtils.getNumFeatures(dataset, $(featuresCol)) + val numFeatures = DatasetUtils.getNumFeatures(dataset, $(featuresCol)) var threshold = Double.NaN if (isSet(selectionThreshold)) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index 874b4213872..0e571ad508f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -140,7 +140,7 @@ class VectorIndexer @Since("1.4.0") ( @Since("2.0.0") override def fit(dataset: Dataset[_]): VectorIndexerModel = { transformSchema(dataset.schema, logging = true) - val numFeatures = MetadataUtils.getNumFeatures(dataset, $(inputCol)) + val numFeatures = DatasetUtils.getNumFeatures(dataset, $(inputCol)) val vectorDataset = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v } val maxCats = $(maxCategories) val categoryStats: VectorIndexer.CategoryStats = vectorDataset.mapPartitions { iter => diff --git a/mllib/src/main/scala/org/apache/spark/ml/functions.scala b/mllib/src/main/scala/org/apache/spark/ml/functions.scala index 43622a4f3ed..2bd7233f3ac 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/functions.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/functions.scala @@ -85,10 +85,4 @@ object functions { def array_to_vector(v: Column): Column = { arrayToVectorUdf(v) } - - private[ml] def checkNonNegativeWeight = udf { - value: Double => - require(value >= 0, s"illegal weight value: $value. weight must be >= 0.0.") - value - } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala index c0178ac6c76..e6e8c2f1fa4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala @@ -413,7 +413,7 @@ class FMRegressor @Since("3.0.0") ( instr.logParams(this, factorSize, fitIntercept, fitLinear, regParam, miniBatchFraction, initStd, maxIter, stepSize, tol, solver) - val numFeatures = MetadataUtils.getNumFeatures(dataset, $(featuresCol)) + val numFeatures = getNumFeatures(dataset, $(featuresCol)) instr.logNumFeatures(numFeatures) val handlePersistence = dataset.storageLevel == StorageLevel.NONE diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 10a203e9ee6..0c58cc2449b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -22,13 +22,12 @@ import org.json4s.JsonDSL._ import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging -import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.linalg.{BLAS, Vector} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.impl.GradientBoostedTrees import org.apache.spark.ml.util._ -import org.apache.spark.ml.util.DatasetUtils._ +import org.apache.spark.ml.util.DatasetUtils.extractInstances import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} @@ -166,21 +165,12 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) def setWeightCol(value: String): this.type = set(weightCol, value) override protected def train(dataset: Dataset[_]): GBTRegressionModel = instrumented { instr => - - def extractInstances(df: Dataset[_]) = { - df.select( - checkRegressionLabels($(labelCol)), - checkNonNegativeWeights(get(weightCol)), - checkNonNanVectors($(featuresCol)) - ).rdd.map { case Row(l: Double, w: Double, v: Vector) => Instance(l, w, v) } - } - val withValidation = isDefined(validationIndicatorCol) && $(validationIndicatorCol).nonEmpty val (trainDataset, validationDataset) = if (withValidation) { - (extractInstances(dataset.filter(not(col($(validationIndicatorCol))))), - extractInstances(dataset.filter(col($(validationIndicatorCol))))) + (extractInstances(this, dataset.filter(not(col($(validationIndicatorCol))))), + extractInstances(this, dataset.filter(col($(validationIndicatorCol))))) } else { - (extractInstances(dataset), null) + (extractInstances(this, dataset), null) } instr.logPipelineStage(this) @@ -349,7 +339,7 @@ class GBTRegressionModel private[ml]( */ @Since("2.4.0") def evaluateEachIteration(dataset: Dataset[_], loss: String): Array[Double] = { - val data = extractInstances(dataset) + val data = extractInstances(this, dataset) GradientBoostedTrees.evaluateEachIteration(data, trees, treeWeights, convertToOldLossType(loss), OldAlgo.Regression) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 88581d03084..6d8507239eb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -384,7 +384,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val instr.logParams(this, labelCol, featuresCol, weightCol, offsetCol, predictionCol, linkPredictionCol, family, solver, fitIntercept, link, maxIter, regParam, tol, aggregationDepth) - val numFeatures = MetadataUtils.getNumFeatures(dataset, $(featuresCol)) + val numFeatures = getNumFeatures(dataset, $(featuresCol)) instr.logNumFeatures(numFeatures) if (numFeatures > WeightedLeastSquares.MAX_NUM_FEATURES) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index a53ef8c79b4..46986249e0b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -338,7 +338,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String } // Extract the number of features before deciding optimization solver. - val numFeatures = MetadataUtils.getNumFeatures(dataset, $(featuresCol)) + val numFeatures = getNumFeatures(dataset, $(featuresCol)) instr.logNumFeatures(numFeatures) val instances = dataset.select( diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala index c32e901e5cd..130790ac909 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala @@ -17,7 +17,13 @@ package org.apache.spark.ml.util +import org.apache.spark.SparkException +import org.apache.spark.internal.Logging +import org.apache.spark.ml.PredictorParams +import org.apache.spark.ml.classification.ClassifierParams +import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.linalg._ +import org.apache.spark.ml.param.shared.HasWeightCol import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} import org.apache.spark.rdd.RDD import org.apache.spark.sql._ @@ -25,7 +31,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -private[spark] object DatasetUtils { +private[spark] object DatasetUtils extends Logging { private[ml] def checkNonNanValues(colName: String, displayed: String): Column = { val casted = col(colName).cast(DoubleType) @@ -96,6 +102,26 @@ private[spark] object DatasetUtils { } } + private[ml] def extractInstances( + p: PredictorParams, + df: Dataset[_], + numClasses: Option[Int] = None): RDD[Instance] = { + val labelCol = p match { + case c: ClassifierParams => + checkClassificationLabels(c.getLabelCol, numClasses) + case _ => // TODO: there is no RegressorParams, maybe add it in the future? + checkRegressionLabels(p.getLabelCol) + } + + val weightCol = p match { + case w: HasWeightCol => checkNonNegativeWeights(w.get(w.weightCol)) + case _ => lit(1.0) + } + + df.select(labelCol, weightCol, checkNonNanVectors(p.getFeaturesCol)) + .rdd.map { case Row(l: Double, w: Double, v: Vector) => Instance(l, w, v) } + } + /** * Cast a column in a Dataset to Vector type. * @@ -138,4 +164,58 @@ private[spark] object DatasetUtils { case Row(point: Vector) => OldVectors.fromML(point) } } + + /** + * Get the number of classes. This looks in column metadata first, and if that is missing, + * then this assumes classes are indexed 0,1,...,numClasses-1 and computes numClasses + * by finding the maximum label value. + * + * Label validation (ensuring all labels are integers >= 0) needs to be handled elsewhere, + * such as in `extractLabeledPoints()`. + * + * @param dataset Dataset which contains a column [[labelCol]] + * @param maxNumClasses Maximum number of classes allowed when inferred from data. If numClasses + * is specified in the metadata, then maxNumClasses is ignored. + * @return number of classes + * @throws IllegalArgumentException if metadata does not specify numClasses, and the + * actual numClasses exceeds maxNumClasses + */ + private[ml] def getNumClasses( + dataset: Dataset[_], + labelCol: String, + maxNumClasses: Int = 100): Int = { + MetadataUtils.getNumClasses(dataset.schema(labelCol)) match { + case Some(n: Int) => n + case None => + // Get number of classes from dataset itself. + val maxLabelRow: Array[Row] = dataset + .select(max(checkClassificationLabels(labelCol, Some(maxNumClasses)))) + .take(1) + if (maxLabelRow.isEmpty || maxLabelRow(0).get(0) == null) { + throw new SparkException("ML algorithm was given empty dataset.") + } + val maxDoubleLabel: Double = maxLabelRow.head.getDouble(0) + require((maxDoubleLabel + 1).isValidInt, s"Classifier found max label value =" + + s" $maxDoubleLabel but requires integers in range [0, ... ${Int.MaxValue})") + val numClasses = maxDoubleLabel.toInt + 1 + require(numClasses <= maxNumClasses, s"Classifier inferred $numClasses from label values" + + s" in column $labelCol, but this exceeded the max numClasses ($maxNumClasses) allowed" + + s" to be inferred from values. To avoid this error for labels with > $maxNumClasses" + + s" classes, specify numClasses explicitly in the metadata; this can be done by applying" + + s" StringIndexer to the label column.") + logInfo(this.getClass.getCanonicalName + s" inferred $numClasses classes for" + + s" labelCol=$labelCol since numClasses was not specified in the column metadata.") + numClasses + } + } + + /** + * Obtain the number of features in a vector column. + * If no metadata is available, extract it from the dataset. + */ + private[ml] def getNumFeatures(dataset: Dataset[_], vectorCol: String): Int = { + MetadataUtils.getNumFeatures(dataset.schema(vectorCol)).getOrElse { + dataset.select(columnToVector(dataset, vectorCol)).head.getAs[Vector](0).size + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala index 6db0408e8d2..631261af249 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala @@ -20,8 +20,7 @@ package org.apache.spark.ml.util import scala.collection.immutable.HashMap import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.linalg.{Vector, VectorUDT} -import org.apache.spark.sql.Dataset +import org.apache.spark.ml.linalg.VectorUDT import org.apache.spark.sql.types.StructField @@ -42,17 +41,6 @@ private[spark] object MetadataUtils { } } - /** - * Obtain the number of features in a vector column. - * If no metadata is available, extract it from the dataset. - */ - def getNumFeatures(dataset: Dataset[_], vectorCol: String): Int = { - getNumFeatures(dataset.schema(vectorCol)).getOrElse { - dataset.select(DatasetUtils.columnToVector(dataset, vectorCol)) - .head.getAs[Vector](0).size - } - } - /** * Examine a schema to identify the number of features in a vector column. * Returns None if the number of features is not specified. diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala index 1aea4b47cd8..57cd99ecced 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala @@ -22,9 +22,8 @@ import org.apache.spark.ml.classification.ClassifierSuite.MockClassifier import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset} class ClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -35,41 +34,6 @@ class ClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) }.toDF() } - test("extractLabeledPoints") { - val c = new MockClassifier - // Valid dataset - val df0 = getTestData(Seq(0.0, 2.0, 1.0, 5.0)) - c.extractLabeledPoints(df0, 6).count() - // Invalid datasets - val df1 = getTestData(Seq(0.0, -2.0, 1.0, 5.0)) - withClue("Classifier should fail if label is negative") { - val e: SparkException = intercept[SparkException] { - c.extractLabeledPoints(df1, 6).count() - } - assert(e.getMessage.contains("given dataset with invalid label")) - } - val df2 = getTestData(Seq(0.0, 2.1, 1.0, 5.0)) - withClue("Classifier should fail if label is not an integer") { - val e: SparkException = intercept[SparkException] { - c.extractLabeledPoints(df2, 6).count() - } - assert(e.getMessage.contains("given dataset with invalid label")) - } - // extractLabeledPoints with numClasses specified - withClue("Classifier should fail if label is >= numClasses") { - val e: SparkException = intercept[SparkException] { - c.extractLabeledPoints(df0, numClasses = 5).count() - } - assert(e.getMessage.contains("given dataset with invalid label")) - } - withClue("Classifier.extractLabeledPoints should fail if numClasses <= 0") { - val e: IllegalArgumentException = intercept[IllegalArgumentException] { - c.extractLabeledPoints(df0, numClasses = 0).count() - } - assert(e.getMessage.contains("but requires numClasses > 0")) - } - } - test("getNumClasses") { val c = new MockClassifier // Valid dataset @@ -122,10 +86,8 @@ object ClassifierSuite { override def train(dataset: Dataset[_]): MockClassificationModel = throw new UnsupportedOperationException() - // Make methods public - override def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = - super.extractLabeledPoints(dataset, numClasses) - def getNumClasses(dataset: Dataset[_]): Int = super.getNumClasses(dataset) + def getNumClasses(dataset: Dataset[_]): Int = + DatasetUtils.getNumClasses(dataset, $(labelCol)) } class MockClassificationModel(override val uid: String) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 927384d4f1e..01fc5d65c03 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -40,7 +40,21 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.recommendation.ALSModel.checkedCast"), // [SPARK-39110] Show metrics properties in HistoryServer environment tab - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ApplicationEnvironmentInfo.this") + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ApplicationEnvironmentInfo.this"), + + // [SPARK-38775][ML] Cleanup validation functions + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.PredictionModel.extractInstances"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.Predictor.extractInstances"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.Predictor.extractLabeledPoints"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.ClassificationModel.extractInstances"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.Classifier.extractInstances"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.Classifier.extractLabeledPoints"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.Classifier.validateNumClasses"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.Classifier.validateLabel"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.Classifier.getNumClasses"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.Classifier.getNumClasses$default$2"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.OneVsRest.extractInstances"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.OneVsRestModel.extractInstances") ) // Exclude rules for 3.3.x from 3.2.0 --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org