Repository: spark Updated Branches: refs/heads/master 7c382524a -> f3fe55439
[SPARK-10835][ML] Word2Vec should accept non-null string array, in addition to existing null string array ## What changes were proposed in this pull request? To match Tokenizer and for compatibility with Word2Vec, output a nullable string array type in NGram ## How was this patch tested? Jenkins tests. Author: Sean Owen <so...@cloudera.com> Closes #15179 from srowen/SPARK-10835. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f3fe5543 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f3fe5543 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f3fe5543 Branch: refs/heads/master Commit: f3fe55439e4c865c26502487a1bccf255da33f4a Parents: 7c38252 Author: Sean Owen <so...@cloudera.com> Authored: Sat Sep 24 08:06:41 2016 +0100 Committer: Sean Owen <so...@cloudera.com> Committed: Sat Sep 24 08:06:41 2016 +0100 ---------------------------------------------------------------------- .../org/apache/spark/ml/feature/Word2Vec.scala | 3 ++- .../apache/spark/ml/feature/Word2VecSuite.scala | 21 ++++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/f3fe5543/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 14c0512..d53f3df 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -108,7 +108,8 @@ private[feature] trait Word2VecBase extends Params * Validate and transform the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true)) + val typeCandidates = List(new ArrayType(StringType, true), new ArrayType(StringType, false)) + SchemaUtils.checkColumnTypes(schema, $(inputCol), typeCandidates) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) } } http://git-wip-us.apache.org/repos/asf/spark/blob/f3fe5543/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index 0b441f8..613cc3d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -207,5 +207,26 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val newInstance = testDefaultReadWrite(instance) assert(newInstance.getVectors.collect() === instance.getVectors.collect()) } + + test("Word2Vec works with input that is non-nullable (NGram)") { + val spark = this.spark + import spark.implicits._ + + val sentence = "a q s t q s t b b b s t m s t m q " + val docDF = sc.parallelize(Seq(sentence, sentence)).map(_.split(" ")).toDF("text") + + val ngram = new NGram().setN(2).setInputCol("text").setOutputCol("ngrams") + val ngramDF = ngram.transform(docDF) + + val model = new Word2Vec() + .setVectorSize(2) + .setInputCol("ngrams") + .setOutputCol("result") + .fit(ngramDF) + + // Just test that this transformation succeeds + model.transform(ngramDF).collect() + } + } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org