Repository: spark Updated Branches: refs/heads/branch-2.0 9e91a1009 -> ed545763a
[SPARK-10835][ML] Word2Vec should accept non-null string array, in addition to existing null string array ## What changes were proposed in this pull request? To match Tokenizer and for compatibility with Word2Vec, output a nullable string array type in NGram ## How was this patch tested? Jenkins tests. Author: Sean Owen <so...@cloudera.com> Closes #15179 from srowen/SPARK-10835. (cherry picked from commit f3fe55439e4c865c26502487a1bccf255da33f4a) Signed-off-by: Sean Owen <so...@cloudera.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ed545763 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ed545763 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ed545763 Branch: refs/heads/branch-2.0 Commit: ed545763adc3f50569581c9b017b396e8997ac31 Parents: 9e91a10 Author: Sean Owen <so...@cloudera.com> Authored: Sat Sep 24 08:06:41 2016 +0100 Committer: Sean Owen <so...@cloudera.com> Committed: Sat Sep 24 08:06:56 2016 +0100 ---------------------------------------------------------------------- .../org/apache/spark/ml/feature/Word2Vec.scala | 3 ++- .../apache/spark/ml/feature/Word2VecSuite.scala | 21 ++++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/ed545763/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 14c0512..d53f3df 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -108,7 +108,8 @@ private[feature] trait Word2VecBase extends Params * Validate and transform the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true)) + val typeCandidates = List(new ArrayType(StringType, true), new ArrayType(StringType, false)) + SchemaUtils.checkColumnTypes(schema, $(inputCol), typeCandidates) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) } } http://git-wip-us.apache.org/repos/asf/spark/blob/ed545763/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index 16c74f6..c8f1311 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -207,5 +207,26 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val newInstance = testDefaultReadWrite(instance) assert(newInstance.getVectors.collect() === instance.getVectors.collect()) } + + test("Word2Vec works with input that is non-nullable (NGram)") { + val spark = this.spark + import spark.implicits._ + + val sentence = "a q s t q s t b b b s t m s t m q " + val docDF = sc.parallelize(Seq(sentence, sentence)).map(_.split(" ")).toDF("text") + + val ngram = new NGram().setN(2).setInputCol("text").setOutputCol("ngrams") + val ngramDF = ngram.transform(docDF) + + val model = new Word2Vec() + .setVectorSize(2) + .setInputCol("ngrams") + .setOutputCol("result") + .fit(ngramDF) + + // Just test that this transformation succeeds + model.transform(ngramDF).collect() + } + } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org