Repository: spark
Updated Branches:
  refs/heads/master 376d90d55 -> 0c8444cf6


[SPARK-14657][SPARKR][ML] RFormula w/o intercept should output reference 
category when encoding string terms

## What changes were proposed in this pull request?

Please see [SPARK-14657](https://issues.apache.org/jira/browse/SPARK-14657) for 
detail of this bug.
I searched online and test some other cases, found when we fit R glm model(or 
other models powered by R formula) w/o intercept on a dataset including 
string/category features, one of the categories in the first category feature 
is being used as reference category, we will not drop any category for that 
feature.
I think we should keep consistent semantics between Spark RFormula and R 
formula.
## How was this patch tested?

Add standard unit tests.

cc mengxr

Author: Yanbo Liang <yblia...@gmail.com>

Closes #12414 from yanboliang/spark-14657.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0c8444cf
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0c8444cf
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0c8444cf

Branch: refs/heads/master
Commit: 0c8444cf6d0620cd219ddcf5f50b12ff648639e9
Parents: 376d90d
Author: Yanbo Liang <yblia...@gmail.com>
Authored: Thu Jun 29 10:32:32 2017 +0800
Committer: Yanbo Liang <yblia...@gmail.com>
Committed: Thu Jun 29 10:32:32 2017 +0800

----------------------------------------------------------------------
 .../org/apache/spark/ml/feature/RFormula.scala  | 10 ++-
 .../apache/spark/ml/feature/RFormulaSuite.scala | 83 ++++++++++++++++++++
 2 files changed, 92 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0c8444cf/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index 1fad0a6..4b44878 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -205,12 +205,20 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override 
val uid: String)
     }.toMap
 
     // Then we handle one-hot encoding and interactions between terms.
+    var keepReferenceCategory = false
     val encodedTerms = resolvedFormula.terms.map {
       case Seq(term) if dataset.schema(term).dataType == StringType =>
         val encodedCol = tmpColumn("onehot")
-        encoderStages += new OneHotEncoder()
+        var encoder = new OneHotEncoder()
           .setInputCol(indexed(term))
           .setOutputCol(encodedCol)
+        // Formula w/o intercept, one of the categories in the first category 
feature is
+        // being used as reference category, we will not drop any category for 
that feature.
+        if (!hasIntercept && !keepReferenceCategory) {
+          encoder = encoder.setDropLast(false)
+          keepReferenceCategory = true
+        }
+        encoderStages += encoder
         prefixesToRewrite(encodedCol + "_") = term + "_"
         encodedCol
       case Seq(term) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/0c8444cf/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
index 41d0062..23570d6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -213,6 +213,89 @@ class RFormulaSuite extends SparkFunSuite with 
MLlibTestSparkContext with Defaul
     assert(result.collect() === expected.collect())
   }
 
+  test("formula w/o intercept, we should output reference category when 
encoding string terms") {
+    /*
+     R code:
+
+     df <- data.frame(id = c(1, 2, 3, 4),
+                  a = c("foo", "bar", "bar", "baz"),
+                  b = c("zq", "zz", "zz", "zz"),
+                  c = c(4, 4, 5, 5))
+     model.matrix(id ~ a + b + c - 1, df)
+
+       abar abaz afoo bzz c
+     1    0    0    1   0 4
+     2    1    0    0   1 4
+     3    1    0    0   1 5
+     4    0    1    0   1 5
+
+     model.matrix(id ~ a:b + c - 1, df)
+
+       c abar:bzq abaz:bzq afoo:bzq abar:bzz abaz:bzz afoo:bzz
+     1 4        0        0        1        0        0        0
+     2 4        0        0        0        1        0        0
+     3 5        0        0        0        1        0        0
+     4 5        0        0        0        0        1        0
+    */
+    val original = Seq((1, "foo", "zq", 4), (2, "bar", "zz", 4), (3, "bar", 
"zz", 5),
+      (4, "baz", "zz", 5)).toDF("id", "a", "b", "c")
+
+    val formula1 = new RFormula().setFormula("id ~ a + b + c - 1")
+      .setStringIndexerOrderType(StringIndexer.alphabetDesc)
+    val model1 = formula1.fit(original)
+    val result1 = model1.transform(original)
+    val resultSchema1 = model1.transformSchema(original.schema)
+    // Note the column order is different between R and Spark.
+    val expected1 = Seq(
+      (1, "foo", "zq", 4, Vectors.sparse(5, Array(0, 4), Array(1.0, 4.0)), 
1.0),
+      (2, "bar", "zz", 4, Vectors.dense(0.0, 0.0, 1.0, 1.0, 4.0), 2.0),
+      (3, "bar", "zz", 5, Vectors.dense(0.0, 0.0, 1.0, 1.0, 5.0), 3.0),
+      (4, "baz", "zz", 5, Vectors.dense(0.0, 1.0, 0.0, 1.0, 5.0), 4.0)
+    ).toDF("id", "a", "b", "c", "features", "label")
+    assert(result1.schema.toString == resultSchema1.toString)
+    assert(result1.collect() === expected1.collect())
+
+    val attrs1 = AttributeGroup.fromStructField(result1.schema("features"))
+    val expectedAttrs1 = new AttributeGroup(
+      "features",
+      Array[Attribute](
+        new BinaryAttribute(Some("a_foo"), Some(1)),
+        new BinaryAttribute(Some("a_baz"), Some(2)),
+        new BinaryAttribute(Some("a_bar"), Some(3)),
+        new BinaryAttribute(Some("b_zz"), Some(4)),
+        new NumericAttribute(Some("c"), Some(5))))
+    assert(attrs1 === expectedAttrs1)
+
+    // There is no impact for string terms interaction.
+    val formula2 = new RFormula().setFormula("id ~ a:b + c - 1")
+      .setStringIndexerOrderType(StringIndexer.alphabetDesc)
+    val model2 = formula2.fit(original)
+    val result2 = model2.transform(original)
+    val resultSchema2 = model2.transformSchema(original.schema)
+    // Note the column order is different between R and Spark.
+    val expected2 = Seq(
+      (1, "foo", "zq", 4, Vectors.sparse(7, Array(1, 6), Array(1.0, 4.0)), 
1.0),
+      (2, "bar", "zz", 4, Vectors.sparse(7, Array(4, 6), Array(1.0, 4.0)), 
2.0),
+      (3, "bar", "zz", 5, Vectors.sparse(7, Array(4, 6), Array(1.0, 5.0)), 
3.0),
+      (4, "baz", "zz", 5, Vectors.sparse(7, Array(2, 6), Array(1.0, 5.0)), 4.0)
+    ).toDF("id", "a", "b", "c", "features", "label")
+    assert(result2.schema.toString == resultSchema2.toString)
+    assert(result2.collect() === expected2.collect())
+
+    val attrs2 = AttributeGroup.fromStructField(result2.schema("features"))
+    val expectedAttrs2 = new AttributeGroup(
+      "features",
+      Array[Attribute](
+        new NumericAttribute(Some("a_foo:b_zz"), Some(1)),
+        new NumericAttribute(Some("a_foo:b_zq"), Some(2)),
+        new NumericAttribute(Some("a_baz:b_zz"), Some(3)),
+        new NumericAttribute(Some("a_baz:b_zq"), Some(4)),
+        new NumericAttribute(Some("a_bar:b_zz"), Some(5)),
+        new NumericAttribute(Some("a_bar:b_zq"), Some(6)),
+        new NumericAttribute(Some("c"), Some(7))))
+    assert(attrs2 === expectedAttrs2)
+  }
+
   test("index string label") {
     val formula = new RFormula().setFormula("id ~ a + b")
     val original =


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to