Hi Kristina,

Currently StringIndexer is a requirement step before training DecisionTree,
RandomForest and GBT related models.
Though it does not necessary by other models such as LogisticRegression and
NaiveBayes, it also strongly recommend to make this preprocessing step
otherwise it may lead incorrect model.
SPARK-7126 <https://issues.apache.org/jira/browse/SPARK-7126> focus on indexing
labels automatically, so it will not necessary to run StringIndexer explicitly
after this JIRA is resolved.

BR
Yanbo

2015-09-29 22:14 GMT+08:00 Kristina Rogale Plazonic <kpl...@gmail.com>:

> Hi,
>
> I'm trying out the ml.classification.RandomForestClassifer() on a simple
> dataframe and it returns an exception that number of classes has not been
> set in my dataframe. However, I cannot find a function that would set
> number of classes, or pass it as an argument anywhere. In mllib, numClasses
> is a parameter passed when training the model. In ml, there is an ugly hack
> using StringIndexer, but should you really be using the hack?
> LogisticRegression and NaiveBayes in ml work without setting the number of
> classes.
>
> Thanks for any pointers!
> Kristina
>
> My code:
>
> import org.apache.spark.mllib.linalg.{Vector, Vectors}
>
> case class Record(label:Double,
> features:org.apache.spark.mllib.linalg.Vector)
>
> val df = sc.parallelize(Seq( Record(0.0, Vectors.dense(1.0, 0.0) ),
>                         Record(0.0, Vectors.dense(1.1, 0.0) ),
>                         Record(0.0, Vectors.dense(1.2, 0.0) ),
>                         Record(1.0, Vectors.dense(0.0, 1.2) ),
>                         Record(1.0, Vectors.dense(0.0, 1.3) ),
>                         Record(1.0, Vectors.dense(0.0, 1.7) ))
>                        ).toDF()
>
> val rf = new RandomForestClassifier()
> val rfmodel = rf.fit(df)
>
> And the error is:
>
> scala> val rfmodel = rf.fit(df)
> java.lang.IllegalArgumentException: RandomForestClassifier was given input
> with invalid label column label, without the number of classes specified.
> See StringIndexer.
> at
> org.apache.spark.ml.classification.RandomForestClassifier.train(RandomForestClassifier.scala:87)
> at
> org.apache.spark.ml.classification.RandomForestClassifier.train(RandomForestClassifier.scala:42)
> at org.apache.spark.ml.Predictor.fit(Predictor.scala:90)
> at $iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC.<init>(<console>:31)
> at $iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC.<init>(<console>:36)
>
>

Reply via email to