zhengruifeng commented on code in PR #36049:
URL: https://github.com/apache/spark/pull/36049#discussion_r841303261


##########
mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala:
##########
@@ -138,4 +140,61 @@ private[spark] object DatasetUtils {
       case Row(point: Vector) => OldVectors.fromML(point)
     }
   }
+
+  /**
+   * Get the number of classes.  This looks in column metadata first, and if 
that is missing,
+   * then this assumes classes are indexed 0,1,...,numClasses-1 and computes 
numClasses
+   * by finding the maximum label value.
+   *
+   * Label validation (ensuring all labels are integers >= 0) needs to be 
handled elsewhere,
+   * such as in `extractLabeledPoints()`.
+   *
+   * @param dataset  Dataset which contains a column [[labelCol]]
+   * @param maxNumClasses  Maximum number of classes allowed when inferred 
from data.  If numClasses
+   *                       is specified in the metadata, then maxNumClasses is 
ignored.
+   * @return  number of classes
+   * @throws IllegalArgumentException  if metadata does not specify 
numClasses, and the
+   *                                   actual numClasses exceeds maxNumClasses
+   */
+  private[ml] def getNumClasses(
+      dataset: Dataset[_],
+      labelCol: String,
+      maxNumClasses: Int = 100): Int = {
+    MetadataUtils.getNumClasses(dataset.schema(labelCol)) match {
+      case Some(n: Int) => n
+      case None =>
+        // Get number of classes from dataset itself.
+        val maxLabelRow: Array[Row] = dataset
+          .select(max(checkClassificationLabels(labelCol, 
Some(maxNumClasses))))
+          .take(1)
+        if (maxLabelRow.isEmpty || maxLabelRow(0).get(0) == null) {
+          throw new SparkException("ML algorithm was given empty dataset.")
+        }
+        val maxDoubleLabel: Double = maxLabelRow.head.getDouble(0)
+        require((maxDoubleLabel + 1).isValidInt, s"Classifier found max label 
value =" +
+          s" $maxDoubleLabel but requires integers in range [0, ... 
${Int.MaxValue})")
+        val numClasses = maxDoubleLabel.toInt + 1
+        require(numClasses <= maxNumClasses, s"Classifier inferred $numClasses 
from label values" +
+          s" in column $labelCol, but this exceeded the max numClasses 
($maxNumClasses) allowed" +
+          s" to be inferred from values.  To avoid this error for labels with 
> $maxNumClasses" +
+          s" classes, specify numClasses explicitly in the metadata; this can 
be done by applying" +
+          s" StringIndexer to the label column.")
+        logInfo(this.getClass.getCanonicalName + s" inferred $numClasses 
classes for" +
+          s" labelCol=$labelCol since numClasses was not specified in the 
column metadata.")
+        numClasses
+    }
+  }
+
+  /**
+   * Obtain the number of features in a vector column.
+   * If no metadata is available, extract it from the dataset.
+   */
+  private[ml] def getNumFeatures(dataset: Dataset[_], vectorCol: String): Int 
= {
+    MetadataUtils.getNumFeatures(dataset.schema(vectorCol)) match {

Review Comment:
   ok, will swith back to getOrElse



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: dev-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe e-mail: dev-unsubscr...@spark.apache.org

Reply via email to