Repository: spark
Updated Branches:
  refs/heads/master ab197308a -> 3eb52092b


[SPARK-22974][ML] Attach attributes to output column of CountVectorModel

## What changes were proposed in this pull request?

The output column from `CountVectorModel` lacks attribute. So a later 
transformer like `Interaction` can raise error because no attribute available.

## How was this patch tested?

Added test.

Please review http://spark.apache.org/contributing.html before opening a pull 
request.

Closes #20313 from viirya/SPARK-22974.

Authored-by: Liang-Chi Hsieh <vii...@gmail.com>
Signed-off-by: DB Tsai <d_t...@apple.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3eb52092
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3eb52092
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3eb52092

Branch: refs/heads/master
Commit: 3eb52092b3aa9d7d2fc1e50ac237d47bfb3b9e92
Parents: ab19730
Author: Liang-Chi Hsieh <vii...@gmail.com>
Authored: Tue Aug 14 05:05:16 2018 +0000
Committer: DB Tsai <d_t...@apple.com>
Committed: Tue Aug 14 05:05:16 2018 +0000

----------------------------------------------------------------------
 .../apache/spark/ml/feature/CountVectorizer.scala   |  5 ++++-
 .../spark/ml/feature/CountVectorizerSuite.scala     | 16 ++++++++++++++++
 2 files changed, 20 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3eb52092/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
index 10c48c3..dc8eb82 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
@@ -21,6 +21,7 @@ import org.apache.hadoop.fs.Path
 import org.apache.spark.annotation.Since
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, 
NumericAttribute}
 import org.apache.spark.ml.linalg.{Vectors, VectorUDT}
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
@@ -317,7 +318,9 @@ class CountVectorizerModel(
 
       Vectors.sparse(dictBr.value.size, effectiveCounts)
     }
-    dataset.withColumn($(outputCol), vectorizer(col($(inputCol))))
+    val attrs = vocabulary.map(_ => new 
NumericAttribute).asInstanceOf[Array[Attribute]]
+    val metadata = new AttributeGroup($(outputCol), attrs).toMetadata()
+    dataset.withColumn($(outputCol), vectorizer(col($(inputCol))), metadata)
   }
 
   @Since("1.5.0")

http://git-wip-us.apache.org/repos/asf/spark/blob/3eb52092/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
index 6121766..bca580d 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
@@ -289,4 +289,20 @@ class CountVectorizerSuite extends MLTest with 
DefaultReadWriteTest {
     val newInstance = testDefaultReadWrite(instance)
     assert(newInstance.vocabulary === instance.vocabulary)
   }
+
+  test("SPARK-22974: CountVectorModel should attach proper attribute to output 
column") {
+    val df = spark.createDataFrame(Seq(
+      (0, 1.0, Array("a", "b", "c")),
+      (1, 2.0, Array("a", "b", "b", "c", "a", "d"))
+    )).toDF("id", "features1", "words")
+
+    val cvm = new CountVectorizerModel(Array("a", "b", "c"))
+      .setInputCol("words")
+      .setOutputCol("features2")
+
+    val df1 = cvm.transform(df)
+    val interaction = new Interaction().setInputCols(Array("features1", 
"features2"))
+      .setOutputCol("features")
+    interaction.transform(df1)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to