Repository: spark
Updated Branches:
  refs/heads/branch-1.5 73fab8849 -> acda9d954


[SPARK-8874] [ML] Add missing methods in Word2Vec

Add missing methods

1. getVectors
2. findSynonyms

to W2Vec scala and python API

mengxr

Author: MechCoder <manojkumarsivaraj...@gmail.com>

Closes #7263 from MechCoder/missing_methods_w2vec and squashes the following 
commits:

149d5ca [MechCoder] minor doc
69d91b7 [MechCoder] [SPARK-8874] [ML] Add missing methods in Word2Vec

(cherry picked from commit 13675c742a71cbdc8324701c3694775ce1dd5c62)
Signed-off-by: Joseph K. Bradley <jos...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/acda9d95
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/acda9d95
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/acda9d95

Branch: refs/heads/branch-1.5
Commit: acda9d9546fa3f54676e48d76a2b66016d204074
Parents: 73fab88
Author: MechCoder <manojkumarsivaraj...@gmail.com>
Authored: Mon Aug 3 16:44:25 2015 -0700
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Mon Aug 3 16:46:00 2015 -0700

----------------------------------------------------------------------
 .../org/apache/spark/ml/feature/Word2Vec.scala  | 38 +++++++++++-
 .../apache/spark/ml/feature/Word2VecSuite.scala | 62 ++++++++++++++++++++
 2 files changed, 99 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/acda9d95/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index 6ea6590..b4f46ce 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -18,15 +18,17 @@
 package org.apache.spark.ml.feature
 
 import org.apache.spark.annotation.Experimental
+import org.apache.spark.SparkContext
 import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
 import org.apache.spark.mllib.feature
-import org.apache.spark.mllib.linalg.{VectorUDT, Vectors}
+import org.apache.spark.mllib.linalg.{VectorUDT, Vector, Vectors}
 import org.apache.spark.mllib.linalg.BLAS._
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.functions._
+import org.apache.spark.sql.SQLContext
 import org.apache.spark.sql.types._
 
 /**
@@ -146,6 +148,40 @@ class Word2VecModel private[ml] (
     wordVectors: feature.Word2VecModel)
   extends Model[Word2VecModel] with Word2VecBase {
 
+
+  /**
+   * Returns a dataframe with two fields, "word" and "vector", with "word" 
being a String and
+   * and the vector the DenseVector that it is mapped to.
+   */
+  val getVectors: DataFrame = {
+    val sc = SparkContext.getOrCreate()
+    val sqlContext = SQLContext.getOrCreate(sc)
+    import sqlContext.implicits._
+    val wordVec = wordVectors.getVectors.mapValues(vec => 
Vectors.dense(vec.map(_.toDouble)))
+    sc.parallelize(wordVec.toSeq).toDF("word", "vector")
+  }
+
+  /**
+   * Find "num" number of words closest in similarity to the given word.
+   * Returns a dataframe with the words and the cosine similarities between the
+   * synonyms and the given word.
+   */
+  def findSynonyms(word: String, num: Int): DataFrame = {
+    findSynonyms(wordVectors.transform(word), num)
+  }
+
+  /**
+   * Find "num" number of words closest to similarity to the given vector 
representation
+   * of the word. Returns a dataframe with the words and the cosine 
similarities between the
+   * synonyms and the given word vector.
+   */
+  def findSynonyms(word: Vector, num: Int): DataFrame = {
+    val sc = SparkContext.getOrCreate()
+    val sqlContext = SQLContext.getOrCreate(sc)
+    import sqlContext.implicits._
+    sc.parallelize(wordVectors.findSynonyms(word, num)).toDF("word", 
"similarity")
+  }
+
   /** @group setParam */
   def setInputCol(value: String): this.type = set(inputCol, value)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/acda9d95/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
index aa6ce53..adcda0e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
@@ -67,5 +67,67 @@ class Word2VecSuite extends SparkFunSuite with 
MLlibTestSparkContext {
         assert(vector1 ~== vector2 absTol 1E-5, "Transformed vector is 
different with expected.")
     }
   }
+
+  test("getVectors") {
+
+    val sqlContext = new SQLContext(sc)
+    import sqlContext.implicits._
+
+    val sentence = "a b " * 100 + "a c " * 10
+    val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" 
"))
+
+    val codes = Map(
+      "a" -> Array(-0.2811822295188904, -0.6356269121170044, 
-0.3020961284637451),
+      "b" -> Array(1.0309048891067505, -1.29472815990448, 0.22276712954044342),
+      "c" -> Array(-0.08456747233867645, 0.5137411952018738, 
0.11731560528278351)
+    )
+    val expectedVectors = codes.toSeq.sortBy(_._1).map { case (w, v) => 
Vectors.dense(v) }
+
+    val docDF = doc.zip(doc).toDF("text", "alsotext")
+
+    val model = new Word2Vec()
+      .setVectorSize(3)
+      .setInputCol("text")
+      .setOutputCol("result")
+      .setSeed(42L)
+      .fit(docDF)
+
+    val realVectors = model.getVectors.sort("word").select("vector").map {
+      case Row(v: Vector) => v
+    }.collect()
+
+    realVectors.zip(expectedVectors).foreach {
+      case (real, expected) =>
+        assert(real ~== expected absTol 1E-5, "Actual vector is different from 
expected.")
+    }
+  }
+
+  test("findSynonyms") {
+
+    val sqlContext = new SQLContext(sc)
+    import sqlContext.implicits._
+
+    val sentence = "a b " * 100 + "a c " * 10
+    val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" 
"))
+    val docDF = doc.zip(doc).toDF("text", "alsotext")
+
+    val model = new Word2Vec()
+      .setVectorSize(3)
+      .setInputCol("text")
+      .setOutputCol("result")
+      .setSeed(42L)
+      .fit(docDF)
+
+    val expectedSimilarity = Array(0.2789285076917586, -0.6336972059851644)
+    val (synonyms, similarity) = model.findSynonyms("a", 2).map {
+      case Row(w: String, sim: Double) => (w, sim)
+    }.collect().unzip
+
+    assert(synonyms.toArray === Array("b", "c"))
+    expectedSimilarity.zip(similarity).map {
+      case (expected, actual) => assert(math.abs((expected - actual) / 
expected) < 1E-5)
+    }
+
+  }
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to