Repository: spark Updated Branches: refs/heads/master d8830c503 -> 56e1bd337
[SPARK-17629][ML] methods to return synonyms directly ## What changes were proposed in this pull request? provide methods to return synonyms directly, without wrapping them in a dataframe In performance sensitive applications (such as user facing apis) the roundtrip to and from dataframes is costly and unnecessary The methods are named ``findSynonymsArray`` to make the return type clear, which also implies a local datastructure ## How was this patch tested? updated word2vec tests Author: Asher Krim <ak...@hubspot.com> Closes #16811 from Krimit/w2vFindSynonymsLocal. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/56e1bd33 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/56e1bd33 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/56e1bd33 Branch: refs/heads/master Commit: 56e1bd337ccb03cb01702e4260e4be59d2aa0ead Parents: d8830c5 Author: Asher Krim <ak...@hubspot.com> Authored: Tue Mar 7 20:36:46 2017 -0800 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Tue Mar 7 20:36:46 2017 -0800 ---------------------------------------------------------------------- .../org/apache/spark/ml/feature/Word2Vec.scala | 37 ++++++++++++++++---- .../apache/spark/ml/feature/Word2VecSuite.scala | 20 +++++++---- 2 files changed, 45 insertions(+), 12 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/56e1bd33/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 42e8a66..4ca062c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -227,25 +227,50 @@ class Word2VecModel private[ml] ( /** * Find "num" number of words closest in similarity to the given word, not - * including the word itself. Returns a dataframe with the words and the - * cosine similarities between the synonyms and the given word. + * including the word itself. + * @return a dataframe with columns "word" and "similarity" of the word and the cosine + * similarities between the synonyms and the given word vector. */ @Since("1.5.0") def findSynonyms(word: String, num: Int): DataFrame = { val spark = SparkSession.builder().getOrCreate() - spark.createDataFrame(wordVectors.findSynonyms(word, num)).toDF("word", "similarity") + spark.createDataFrame(findSynonymsArray(word, num)).toDF("word", "similarity") } /** - * Find "num" number of words whose vector representation most similar to the supplied vector. + * Find "num" number of words whose vector representation is most similar to the supplied vector. * If the supplied vector is the vector representation of a word in the model's vocabulary, - * that word will be in the results. Returns a dataframe with the words and the cosine + * that word will be in the results. + * @return a dataframe with columns "word" and "similarity" of the word and the cosine * similarities between the synonyms and the given word vector. */ @Since("2.0.0") def findSynonyms(vec: Vector, num: Int): DataFrame = { val spark = SparkSession.builder().getOrCreate() - spark.createDataFrame(wordVectors.findSynonyms(vec, num)).toDF("word", "similarity") + spark.createDataFrame(findSynonymsArray(vec, num)).toDF("word", "similarity") + } + + /** + * Find "num" number of words whose vector representation is most similar to the supplied vector. + * If the supplied vector is the vector representation of a word in the model's vocabulary, + * that word will be in the results. + * @return an array of the words and the cosine similarities between the synonyms given + * word vector. + */ + @Since("2.2.0") + def findSynonymsArray(vec: Vector, num: Int): Array[(String, Double)] = { + wordVectors.findSynonyms(vec, num) + } + + /** + * Find "num" number of words closest in similarity to the given word, not + * including the word itself. + * @return an array of the words and the cosine similarities between the synonyms given + * word vector. + */ + @Since("2.2.0") + def findSynonymsArray(word: String, num: Int): Array[(String, Double)] = { + wordVectors.findSynonyms(word, num) } /** @group setParam */ http://git-wip-us.apache.org/repos/asf/spark/blob/56e1bd33/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index 613cc3d..2043a16 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -133,14 +133,22 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul .setSeed(42L) .fit(docDF) - val expectedSimilarity = Array(0.2608488929093532, -0.8271274846926078) - val (synonyms, similarity) = model.findSynonyms("a", 2).rdd.map { + val expected = Map(("b", 0.2608488929093532), ("c", -0.8271274846926078)) + val findSynonymsResult = model.findSynonyms("a", 2).rdd.map { case Row(w: String, sim: Double) => (w, sim) - }.collect().unzip + }.collectAsMap() + + expected.foreach { + case (expectedSynonym, expectedSimilarity) => + assert(findSynonymsResult.contains(expectedSynonym)) + assert(expectedSimilarity ~== findSynonymsResult.get(expectedSynonym).get absTol 1E-5) + } - assert(synonyms === Array("b", "c")) - expectedSimilarity.zip(similarity).foreach { - case (expected, actual) => assert(math.abs((expected - actual) / expected) < 1E-5) + val findSynonymsArrayResult = model.findSynonymsArray("a", 2).toMap + findSynonymsResult.foreach { + case (expectedSynonym, expectedSimilarity) => + assert(findSynonymsArrayResult.contains(expectedSynonym)) + assert(expectedSimilarity ~== findSynonymsArrayResult.get(expectedSynonym).get absTol 1E-5) } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org