Github user jkbradley commented on a diff in the pull request: https://github.com/apache/spark/pull/12663#discussion_r60971502 --- Diff: mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala --- @@ -62,6 +65,76 @@ abstract class Classifier[ def setRawPredictionCol(value: String): E = set(rawPredictionCol, value).asInstanceOf[E] // TODO: defaultEvaluator (follow-up PR) + + /** + * Extract [[labelCol]] and [[featuresCol]] from the given dataset, + * and put it in an RDD with strong types. + * @throws SparkException if any label is not an integer >= 0 + */ + override protected def extractLabeledPoints(dataset: Dataset[_]): RDD[LabeledPoint] = { + dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map { + case Row(label: Double, features: Vector) => + require(label % 1 == 0 && label >= 0, s"Classifier was given dataset with invalid label" + + s" $label. Labels must be integers in range [0, 1, ..., numClasses-1]") + LabeledPoint(label, features) + } + } + + /** + * Get the number of classes. This looks in column metadata first, and if that is missing, + * then this assumes classes are indexed 0,1,...,numClasses-1 and computes numClasses + * by finding the maximum label value. + * + * Label validation (ensuring all labels are integers >= 0) needs to be handled elsewhere, + * such as in [[extractLabeledPoints()]]. + * + * @param dataset Dataset which contains a column [[labelCol]] + * @param maxNumClasses Maximum number of classes allowed when inferred from data. If numClasses + * is specified in the metadata, then maxNumClasses is ignored. + * @return number of classes + * @throws IllegalArgumentException if metadata does not specify numClasses, and the + * actual numClasses exceeds maxNumClasses + */ + protected def getNumClasses(dataset: Dataset[_], maxNumClasses: Int = 1000): Int = { + MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match { + case Some(n: Int) => n + case None => --- End diff -- Logging a warning seems reasonable to me. We could also decrease maxNumClasses to force users to do indexing in iffier situations.
--- If your project is set up for it, you can reply to this email and have your reply appear on GitHub as well. If your project does not have this feature enabled and wishes so, or if the feature is enabled but not working, please contact infrastructure at infrastruct...@apache.org or file a JIRA ticket with INFRA. --- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org