Repository: spark
Updated Branches:
  refs/heads/master d8830c503 -> 56e1bd337


[SPARK-17629][ML] methods to return synonyms directly

## What changes were proposed in this pull request?
provide methods to return synonyms directly, without wrapping them in a 
dataframe

In performance sensitive applications (such as user facing apis) the roundtrip 
to and from dataframes is costly and unnecessary

The methods are named ``findSynonymsArray`` to make the return type clear, 
which also implies a local datastructure
## How was this patch tested?
updated word2vec tests

Author: Asher Krim <ak...@hubspot.com>

Closes #16811 from Krimit/w2vFindSynonymsLocal.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/56e1bd33
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/56e1bd33
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/56e1bd33

Branch: refs/heads/master
Commit: 56e1bd337ccb03cb01702e4260e4be59d2aa0ead
Parents: d8830c5
Author: Asher Krim <ak...@hubspot.com>
Authored: Tue Mar 7 20:36:46 2017 -0800
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Tue Mar 7 20:36:46 2017 -0800

----------------------------------------------------------------------
 .../org/apache/spark/ml/feature/Word2Vec.scala  | 37 ++++++++++++++++----
 .../apache/spark/ml/feature/Word2VecSuite.scala | 20 +++++++----
 2 files changed, 45 insertions(+), 12 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/56e1bd33/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index 42e8a66..4ca062c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -227,25 +227,50 @@ class Word2VecModel private[ml] (
 
   /**
    * Find "num" number of words closest in similarity to the given word, not
-   * including the word itself. Returns a dataframe with the words and the
-   * cosine similarities between the synonyms and the given word.
+   * including the word itself.
+   * @return a dataframe with columns "word" and "similarity" of the word and 
the cosine
+   * similarities between the synonyms and the given word vector.
    */
   @Since("1.5.0")
   def findSynonyms(word: String, num: Int): DataFrame = {
     val spark = SparkSession.builder().getOrCreate()
-    spark.createDataFrame(wordVectors.findSynonyms(word, num)).toDF("word", 
"similarity")
+    spark.createDataFrame(findSynonymsArray(word, num)).toDF("word", 
"similarity")
   }
 
   /**
-   * Find "num" number of words whose vector representation most similar to 
the supplied vector.
+   * Find "num" number of words whose vector representation is most similar to 
the supplied vector.
    * If the supplied vector is the vector representation of a word in the 
model's vocabulary,
-   * that word will be in the results.  Returns a dataframe with the words and 
the cosine
+   * that word will be in the results.
+   * @return a dataframe with columns "word" and "similarity" of the word and 
the cosine
    * similarities between the synonyms and the given word vector.
    */
   @Since("2.0.0")
   def findSynonyms(vec: Vector, num: Int): DataFrame = {
     val spark = SparkSession.builder().getOrCreate()
-    spark.createDataFrame(wordVectors.findSynonyms(vec, num)).toDF("word", 
"similarity")
+    spark.createDataFrame(findSynonymsArray(vec, num)).toDF("word", 
"similarity")
+  }
+
+  /**
+   * Find "num" number of words whose vector representation is most similar to 
the supplied vector.
+   * If the supplied vector is the vector representation of a word in the 
model's vocabulary,
+   * that word will be in the results.
+   * @return an array of the words and the cosine similarities between the 
synonyms given
+   * word vector.
+   */
+  @Since("2.2.0")
+  def findSynonymsArray(vec: Vector, num: Int): Array[(String, Double)] = {
+    wordVectors.findSynonyms(vec, num)
+  }
+
+  /**
+   * Find "num" number of words closest in similarity to the given word, not
+   * including the word itself.
+   * @return an array of the words and the cosine similarities between the 
synonyms given
+   * word vector.
+   */
+  @Since("2.2.0")
+  def findSynonymsArray(word: String, num: Int): Array[(String, Double)] = {
+    wordVectors.findSynonyms(word, num)
   }
 
   /** @group setParam */

http://git-wip-us.apache.org/repos/asf/spark/blob/56e1bd33/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
index 613cc3d..2043a16 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
@@ -133,14 +133,22 @@ class Word2VecSuite extends SparkFunSuite with 
MLlibTestSparkContext with Defaul
       .setSeed(42L)
       .fit(docDF)
 
-    val expectedSimilarity = Array(0.2608488929093532, -0.8271274846926078)
-    val (synonyms, similarity) = model.findSynonyms("a", 2).rdd.map {
+    val expected = Map(("b", 0.2608488929093532), ("c", -0.8271274846926078))
+    val findSynonymsResult = model.findSynonyms("a", 2).rdd.map {
       case Row(w: String, sim: Double) => (w, sim)
-    }.collect().unzip
+    }.collectAsMap()
+
+    expected.foreach {
+      case (expectedSynonym, expectedSimilarity) =>
+        assert(findSynonymsResult.contains(expectedSynonym))
+        assert(expectedSimilarity ~== 
findSynonymsResult.get(expectedSynonym).get absTol 1E-5)
+    }
 
-    assert(synonyms === Array("b", "c"))
-    expectedSimilarity.zip(similarity).foreach {
-      case (expected, actual) => assert(math.abs((expected - actual) / 
expected) < 1E-5)
+    val findSynonymsArrayResult = model.findSynonymsArray("a", 2).toMap
+    findSynonymsResult.foreach {
+      case (expectedSynonym, expectedSimilarity) =>
+        assert(findSynonymsArrayResult.contains(expectedSynonym))
+        assert(expectedSimilarity ~== 
findSynonymsArrayResult.get(expectedSynonym).get absTol 1E-5)
     }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to