Repository: spark
Updated Branches:
  refs/heads/master b52c7f9fc -> 512a2f191


[SPARK-6615][MLLIB] Python API for Word2Vec

This is the sub-task of SPARK-6254.
Wrap missing method for `Word2Vec` and `Word2VecModel`.

Author: lewuathe <lewua...@me.com>

Closes #5296 from Lewuathe/SPARK-6615 and squashes the following commits:

f14c304 [lewuathe] Reorder tests
1d326b9 [lewuathe] Merge master
e2bedfb [lewuathe] Modify test cases
afb866d [lewuathe] [SPARK-6615] Python API for Word2Vec


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/512a2f19
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/512a2f19
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/512a2f19

Branch: refs/heads/master
Commit: 512a2f191a6b53699373b6588f316b4437050425
Parents: b52c7f9
Author: lewuathe <lewua...@me.com>
Authored: Fri Apr 3 09:49:50 2015 -0700
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Fri Apr 3 09:49:50 2015 -0700

----------------------------------------------------------------------
 .../spark/mllib/api/python/PythonMLLibAPI.scala |  8 +++-
 python/pyspark/mllib/feature.py                 | 18 +++++++-
 python/pyspark/mllib/tests.py                   | 45 +++++++++++++++++---
 3 files changed, 64 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/512a2f19/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 5995d6d..6c386ca 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -476,13 +476,15 @@ private[python] class PythonMLLibAPI extends Serializable 
{
       learningRate: Double,
       numPartitions: Int,
       numIterations: Int,
-      seed: Long): Word2VecModelWrapper = {
+      seed: Long,
+      minCount: Int): Word2VecModelWrapper = {
     val word2vec = new Word2Vec()
       .setVectorSize(vectorSize)
       .setLearningRate(learningRate)
       .setNumPartitions(numPartitions)
       .setNumIterations(numIterations)
       .setSeed(seed)
+      .setMinCount(minCount)
     try {
       val model = 
word2vec.fit(dataJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER))
       new Word2VecModelWrapper(model)
@@ -516,6 +518,10 @@ private[python] class PythonMLLibAPI extends Serializable {
       val words = result.map(_._1)
       List(words, similarity).map(_.asInstanceOf[Object]).asJava
     }
+
+    def getVectors: JMap[String, JList[Float]] = {
+      model.getVectors.map({case (k, v) => (k, v.toList.asJava)}).asJava
+    }
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/512a2f19/python/pyspark/mllib/feature.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
index 4bfe301..3cda120 100644
--- a/python/pyspark/mllib/feature.py
+++ b/python/pyspark/mllib/feature.py
@@ -337,6 +337,12 @@ class Word2VecModel(JavaVectorTransformer):
         words, similarity = self.call("findSynonyms", word, num)
         return zip(words, similarity)
 
+    def getVectors(self):
+        """
+        Returns a map of words to their vector representations.
+        """
+        return self.call("getVectors")
+
 
 class Word2Vec(object):
     """
@@ -379,6 +385,7 @@ class Word2Vec(object):
         self.numPartitions = 1
         self.numIterations = 1
         self.seed = random.randint(0, sys.maxint)
+        self.minCount = 5
 
     def setVectorSize(self, vectorSize):
         """
@@ -417,6 +424,14 @@ class Word2Vec(object):
         self.seed = seed
         return self
 
+    def setMinCount(self, minCount):
+        """
+        Sets minCount, the minimum number of times a token must appear
+        to be included in the word2vec model's vocabulary (default: 5).
+        """
+        self.minCount = minCount
+        return self
+
     def fit(self, data):
         """
         Computes the vector representation of each word in vocabulary.
@@ -428,7 +443,8 @@ class Word2Vec(object):
             raise TypeError("data should be an RDD of list of string")
         jmodel = callMLlibFunc("trainWord2Vec", data, int(self.vectorSize),
                                float(self.learningRate), 
int(self.numPartitions),
-                               int(self.numIterations), long(self.seed))
+                               int(self.numIterations), long(self.seed),
+                               int(self.minCount))
         return Word2VecModel(jmodel)
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/512a2f19/python/pyspark/mllib/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 6e9c68e..dd3b66c 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -42,6 +42,7 @@ from pyspark.mllib.linalg import Vector, SparseVector, 
DenseVector, VectorUDT, _
 from pyspark.mllib.regression import LabeledPoint
 from pyspark.mllib.random import RandomRDDs
 from pyspark.mllib.stat import Statistics
+from pyspark.mllib.feature import Word2Vec
 from pyspark.mllib.feature import IDF
 from pyspark.serializers import PickleSerializer
 from pyspark.sql import SQLContext
@@ -630,6 +631,12 @@ class ChiSqTestTests(PySparkTestCase):
         self.assertIsNotNone(chi[1000])
 
 
+class SerDeTest(PySparkTestCase):
+    def test_to_java_object_rdd(self):  # SPARK-6660
+        data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0L)
+        self.assertEqual(_to_java_object_rdd(data).count(), 10)
+
+
 class FeatureTest(PySparkTestCase):
     def test_idf_model(self):
         data = [
@@ -643,11 +650,39 @@ class FeatureTest(PySparkTestCase):
         self.assertEqual(len(idf), 11)
 
 
-class SerDeTest(PySparkTestCase):
-    def test_to_java_object_rdd(self):  # SPARK-6660
-        data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0L)
-        self.assertEqual(_to_java_object_rdd(data).count(), 10)
-
+class Word2VecTests(PySparkTestCase):
+    def test_word2vec_setters(self):
+        data = [
+            ["I", "have", "a", "pen"],
+            ["I", "like", "soccer", "very", "much"],
+            ["I", "live", "in", "Tokyo"]
+        ]
+        model = Word2Vec() \
+            .setVectorSize(2) \
+            .setLearningRate(0.01) \
+            .setNumPartitions(2) \
+            .setNumIterations(10) \
+            .setSeed(1024) \
+            .setMinCount(3)
+        self.assertEquals(model.vectorSize, 2)
+        self.assertTrue(model.learningRate < 0.02)
+        self.assertEquals(model.numPartitions, 2)
+        self.assertEquals(model.numIterations, 10)
+        self.assertEquals(model.seed, 1024)
+        self.assertEquals(model.minCount, 3)
+
+    def test_word2vec_get_vectors(self):
+        data = [
+            ["a", "b", "c", "d", "e", "f", "g"],
+            ["a", "b", "c", "d", "e", "f"],
+            ["a", "b", "c", "d", "e"],
+            ["a", "b", "c", "d"],
+            ["a", "b", "c"],
+            ["a", "b"],
+            ["a"]
+        ]
+        model = Word2Vec().fit(self.sc.parallelize(data))
+        self.assertEquals(len(model.getVectors()), 3)
 
 if __name__ == "__main__":
     if not _have_scipy:


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to