Repository: spark
Updated Branches:
  refs/heads/master b515768f2 -> 19401a203


[SPARK-15957][ML] RFormula supports forcing to index label

## What changes were proposed in this pull request?
```RFormula``` will index label only when it is string type currently. If the 
label is numeric type and we use ```RFormula``` to present a classification 
model, there is no label attributes in label column metadata. The label 
attributes are useful when making prediction for classification, so we can 
force to index label by ```StringIndexer``` whether it is numeric or string 
type for classification. Then SparkR wrappers can extract label attributes from 
label column metadata successfully. This feature can help us to fix bug similar 
with [SPARK-15153](https://issues.apache.org/jira/browse/SPARK-15153).
For regression, we will still to keep label as numeric type.
In this PR, we add a param ```indexLabel``` to control whether to force to 
index label for ```RFormula```.

## How was this patch tested?
Unit tests.

Author: Yanbo Liang <yblia...@gmail.com>

Closes #13675 from yanboliang/spark-15957.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/19401a20
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/19401a20
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/19401a20

Branch: refs/heads/master
Commit: 19401a203b441e3355f0d3fc3fd062b6d5bdee1f
Parents: b515768
Author: Yanbo Liang <yblia...@gmail.com>
Authored: Mon Oct 10 22:50:59 2016 -0700
Committer: Yanbo Liang <yblia...@gmail.com>
Committed: Mon Oct 10 22:50:59 2016 -0700

----------------------------------------------------------------------
 .../org/apache/spark/ml/feature/RFormula.scala  | 29 ++++++++++++++++++--
 .../apache/spark/ml/feature/RFormulaSuite.scala | 27 +++++++++++++++++-
 2 files changed, 52 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/19401a20/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index 2ee899b..3898986 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -26,7 +26,7 @@ import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, 
PipelineStage, Transformer}
 import org.apache.spark.ml.attribute.AttributeGroup
 import org.apache.spark.ml.linalg.VectorUDT
-import org.apache.spark.ml.param.{Param, ParamMap}
+import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap}
 import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol}
 import org.apache.spark.ml.util._
 import org.apache.spark.sql.{DataFrame, Dataset}
@@ -104,6 +104,27 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override 
val uid: String)
   @Since("1.5.0")
   def setLabelCol(value: String): this.type = set(labelCol, value)
 
+  /**
+   * Force to index label whether it is numeric or string type.
+   * Usually we index label only when it is string type.
+   * If the formula was used by classification algorithms,
+   * we can force to index label even it is numeric type by setting this param 
with true.
+   * Default: false.
+   * @group param
+   */
+  @Since("2.1.0")
+  val forceIndexLabel: BooleanParam = new BooleanParam(this, "forceIndexLabel",
+    "Force to index label whether it is numeric or string")
+  setDefault(forceIndexLabel -> false)
+
+  /** @group getParam */
+  @Since("2.1.0")
+  def getForceIndexLabel: Boolean = $(forceIndexLabel)
+
+  /** @group setParam */
+  @Since("2.1.0")
+  def setForceIndexLabel(value: Boolean): this.type = set(forceIndexLabel, 
value)
+
   /** Whether the formula specifies fitting an intercept. */
   private[ml] def hasIntercept: Boolean = {
     require(isDefined(formula), "Formula must be defined first.")
@@ -167,8 +188,8 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override 
val uid: String)
     encoderStages += new VectorAttributeRewriter($(featuresCol), 
prefixesToRewrite.toMap)
     encoderStages += new ColumnPruner(tempColumns.toSet)
 
-    if (dataset.schema.fieldNames.contains(resolvedFormula.label) &&
-      dataset.schema(resolvedFormula.label).dataType == StringType) {
+    if ((dataset.schema.fieldNames.contains(resolvedFormula.label) &&
+      dataset.schema(resolvedFormula.label).dataType == StringType) || 
$(forceIndexLabel)) {
       encoderStages += new StringIndexer()
         .setInputCol(resolvedFormula.label)
         .setOutputCol($(labelCol))
@@ -181,6 +202,8 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override 
val uid: String)
   @Since("1.5.0")
   // optimistic schema; does not contain any ML attributes
   override def transformSchema(schema: StructType): StructType = {
+    require(!hasLabelCol(schema) || !$(forceIndexLabel),
+      "If label column already exists, forceIndexLabel can not be set with 
true.")
     if (hasLabelCol(schema)) {
       StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, 
true))
     } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/19401a20/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
index 97c268f..c664460 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -57,7 +57,7 @@ class RFormulaSuite extends SparkFunSuite with 
MLlibTestSparkContext with Defaul
     }
   }
 
-  test("label column already exists") {
+  test("label column already exists and forceIndexLabel was set with false") {
     val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
     val original = Seq((0, 1.0), (2, 2.0)).toDF("x", "y")
     val model = formula.fit(original)
@@ -66,6 +66,14 @@ class RFormulaSuite extends SparkFunSuite with 
MLlibTestSparkContext with Defaul
     assert(resultSchema.toString == model.transform(original).schema.toString)
   }
 
+  test("label column already exists but forceIndexLabel was set with true") {
+    val formula = new RFormula().setFormula("y ~ 
x").setLabelCol("y").setForceIndexLabel(true)
+    val original = spark.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", 
"y")
+    intercept[IllegalArgumentException] {
+      formula.fit(original)
+    }
+  }
+
   test("label column already exists but is not numeric type") {
     val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
     val original = Seq((0, true), (2, false)).toDF("x", "y")
@@ -137,6 +145,23 @@ class RFormulaSuite extends SparkFunSuite with 
MLlibTestSparkContext with Defaul
     assert(result.collect() === expected.collect())
   }
 
+  test("force to index label even it is numeric type") {
+    val formula = new RFormula().setFormula("id ~ a + 
b").setForceIndexLabel(true)
+    val original = spark.createDataFrame(
+      Seq((1.0, "foo", 4), (1.0, "bar", 4), (0.0, "bar", 5), (1.0, "baz", 5))
+    ).toDF("id", "a", "b")
+    val model = formula.fit(original)
+    val result = model.transform(original)
+    val expected = spark.createDataFrame(
+      Seq(
+        (1.0, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 0.0),
+        (1.0, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0),
+        (0.0, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 1.0),
+        (1.0, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 0.0))
+    ).toDF("id", "a", "b", "features", "label")
+    assert(result.collect() === expected.collect())
+  }
+
   test("attribute generation") {
     val formula = new RFormula().setFormula("id ~ a + b")
     val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 
5))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to