Github user jkbradley commented on a diff in the pull request:

    https://github.com/apache/spark/pull/12663#discussion_r60955574
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala ---
    @@ -62,6 +65,76 @@ abstract class Classifier[
       def setRawPredictionCol(value: String): E = set(rawPredictionCol, 
value).asInstanceOf[E]
     
       // TODO: defaultEvaluator (follow-up PR)
    +
    +  /**
    +   * Extract [[labelCol]] and [[featuresCol]] from the given dataset,
    +   * and put it in an RDD with strong types.
    +   * @throws SparkException  if any label is not an integer >= 0
    +   */
    +  override protected def extractLabeledPoints(dataset: Dataset[_]): 
RDD[LabeledPoint] = {
    +    dataset.select(col($(labelCol)).cast(DoubleType), 
col($(featuresCol))).rdd.map {
    +      case Row(label: Double, features: Vector) =>
    +        require(label % 1 == 0 && label >= 0, s"Classifier was given 
dataset with invalid label" +
    +          s" $label.  Labels must be integers in range [0, 1, ..., 
numClasses-1]")
    +        LabeledPoint(label, features)
    +    }
    +  }
    +
    +  /**
    +   * Get the number of classes.  This looks in column metadata first, and 
if that is missing,
    +   * then this assumes classes are indexed 0,1,...,numClasses-1 and 
computes numClasses
    +   * by finding the maximum label value.
    +   *
    +   * Label validation (ensuring all labels are integers >= 0) needs to be 
handled elsewhere,
    +   * such as in [[extractLabeledPoints()]].
    +   *
    +   * @param dataset  Dataset which contains a column [[labelCol]]
    +   * @param maxNumClasses  Maximum number of classes allowed when inferred 
from data.  If numClasses
    +   *                       is specified in the metadata, then 
maxNumClasses is ignored.
    +   * @return  number of classes
    +   * @throws IllegalArgumentException  if metadata does not specify 
numClasses, and the
    +   *                                   actual numClasses exceeds 
maxNumClasses
    +   */
    +  protected def getNumClasses(dataset: Dataset[_], maxNumClasses: Int = 
1000): Int = {
    +    MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
    +      case Some(n: Int) => n
    +      case None =>
    +        // Get number of classes from dataset itself.
    +        val maxLabelRow: Array[Row] = 
dataset.select(max($(labelCol))).take(1)
    +        if (maxLabelRow.isEmpty) {
    +          throw new SparkException("ML algorithm was given empty dataset.")
    +        }
    +        val maxLabel: Int = maxLabelRow.head.getDouble(0).toInt
    +        val numClasses = maxLabel + 1
    +        require(numClasses <= maxNumClasses, s"Classifier inferred 
$numClasses from label values" +
    +          s" in column $labelCol since numClasses were not specified in 
dataset metadata, but" +
    +          s" this exceeded the max numClasses ($maxNumClasses) allowed to 
be inferred from" +
    +          s" values.  To avoid this error, specify numClasses in metadata, 
such as by applying" +
    +          s" StringIndexer to the label column.")
    +        numClasses
    +    }
    +  }
    +}
    +
    +private[ml] object Classifier {
    +  /**
    +   * Extract labelCol and featuresCol from the given dataset,
    +   * and put it in an RDD with strong types.
    +   * @throws SparkException  if any label is not an integer >= 0 and < 
numClasses
    +   */
    +  def extractLabeledPoints(
    --- End diff --
    
    Note this is ```private[ml]``` instead of protected within class Classifier 
since GBTClassifier does not yet implement Classifier.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to