Repository: spark
Updated Branches:
  refs/heads/master 3434572b1 -> f54ff19b1


[SPARK-11349][ML] Support transform string label for RFormula

Currently ```RFormula``` can only handle label with ```NumericType``` or 
```BinaryType``` (cast it to ```DoubleType``` as the label of Linear Regression 
training), we should also support label of ```StringType``` which is needed for 
Logistic Regression (glm with family = "binomial").
For label of ```StringType```, we should use ```StringIndexer``` to transform 
it to 0-based index.

Author: Yanbo Liang <yblia...@gmail.com>

Closes #9302 from yanboliang/spark-11349.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f54ff19b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f54ff19b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f54ff19b

Branch: refs/heads/master
Commit: f54ff19b1edd4903950cb334987a447445fa97ef
Parents: 3434572
Author: Yanbo Liang <yblia...@gmail.com>
Authored: Tue Nov 3 08:32:37 2015 -0800
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Tue Nov 3 08:32:37 2015 -0800

----------------------------------------------------------------------
 .../org/apache/spark/ml/feature/RFormula.scala   | 10 +++++++++-
 .../apache/spark/ml/feature/RFormulaSuite.scala  | 19 +++++++++++++++++++
 2 files changed, 28 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f54ff19b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index f9b8400..5c43a41 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -132,6 +132,14 @@ class RFormula(override val uid: String) extends 
Estimator[RFormulaModel] with R
       .setOutputCol($(featuresCol))
     encoderStages += new VectorAttributeRewriter($(featuresCol), 
prefixesToRewrite.toMap)
     encoderStages += new ColumnPruner(tempColumns.toSet)
+
+    if (dataset.schema.fieldNames.contains(resolvedFormula.label) &&
+      dataset.schema(resolvedFormula.label).dataType == StringType) {
+      encoderStages += new StringIndexer()
+        .setInputCol(resolvedFormula.label)
+        .setOutputCol($(labelCol))
+    }
+
     val pipelineModel = new 
Pipeline(uid).setStages(encoderStages.toArray).fit(dataset)
     copyValues(new RFormulaModel(uid, resolvedFormula, 
pipelineModel).setParent(this))
   }
@@ -172,7 +180,7 @@ class RFormulaModel private[feature](
   override def transformSchema(schema: StructType): StructType = {
     checkCanTransform(schema)
     val withFeatures = pipelineModel.transformSchema(schema)
-    if (hasLabelCol(schema)) {
+    if (hasLabelCol(withFeatures)) {
       withFeatures
     } else if (schema.exists(_.name == resolvedFormula.label)) {
       val nullable = schema(resolvedFormula.label).dataType match {

http://git-wip-us.apache.org/repos/asf/spark/blob/f54ff19b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
index b560130..dc20a5e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -107,6 +107,25 @@ class RFormulaSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     assert(result.collect() === expected.collect())
   }
 
+  test("index string label") {
+    val formula = new RFormula().setFormula("id ~ a + b")
+    val original = sqlContext.createDataFrame(
+      Seq(("male", "foo", 4), ("female", "bar", 4), ("female", "bar", 5), 
("male", "baz", 5))
+    ).toDF("id", "a", "b")
+    val model = formula.fit(original)
+    val result = model.transform(original)
+    val resultSchema = model.transformSchema(original.schema)
+    val expected = sqlContext.createDataFrame(
+      Seq(
+        ("male", "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0),
+        ("female", "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0),
+        ("female", "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 0.0),
+        ("male", "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 1.0))
+    ).toDF("id", "a", "b", "features", "label")
+    // assert(result.schema.toString == resultSchema.toString)
+    assert(result.collect() === expected.collect())
+  }
+
   test("attribute generation") {
     val formula = new RFormula().setFormula("id ~ a + b")
     val original = sqlContext.createDataFrame(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to