Christian Reiniger created SPARK-20081:
------------------------------------------

             Summary: RandomForestClassifier doesn't seem to support more than 
100 labels
                 Key: SPARK-20081
                 URL: https://issues.apache.org/jira/browse/SPARK-20081
             Project: Spark
          Issue Type: Bug
          Components: MLlib
    Affects Versions: 2.1.0
         Environment: Java
            Reporter: Christian Reiniger


When feeding data with more than 100 labels into RanfomForestClassifer#fit() 
(from java code), I get the following error message:

{code}
Classifier inferred 143 from label values in column rfc_df0e968db9df__labelCol, 
but this exceeded the max numClasses (100) allowed to be inferred from values.  
  To avoid this error for labels with > 100 classes, specify numClasses 
explicitly in the metadata; this can be done by applying StringIndexer to the 
label column.
{code}

Setting "numClasses" in the metadata for the label column doesn't make a 
difference. Looking at the code, this is not surprising, since 
MetadataUtils.getNumClasses() ignores this setting:

{code:language=scala}
  def getNumClasses(labelSchema: StructField): Option[Int] = {
    Attribute.fromStructField(labelSchema) match {
      case binAttr: BinaryAttribute => Some(2)
      case nomAttr: NominalAttribute => nomAttr.getNumValues
      case _: NumericAttribute | UnresolvedAttribute => None
    }
  }
{code}

The alternative would be to pass a proper "maxNumClasses" parameter to the 
classifier, so that Classifier#getNumClasses() allows a larger number of 
auto-detected labels. However, RandomForestClassifer#train() calls 
#getNumClasses without the "maxNumClasses" parameter, causing it to use the 
default of 100:

{code:language=scala}
  override protected def train(dataset: Dataset[_]): 
RandomForestClassificationModel = {
    val categoricalFeatures: Map[Int, Int] =
      MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
    val numClasses: Int = getNumClasses(dataset)
// ...
{code}

My scala skills are pretty sketchy, so please correct me if I misinterpreted 
something. But as it seems right now, there is no way to learn from data with 
more than 100 labels via RandomForestClassifier.



--
This message was sent by Atlassian JIRA
(v6.3.15#6346)

---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org
For additional commands, e-mail: issues-h...@spark.apache.org

Reply via email to