Repository: spark
Updated Branches:
  refs/heads/branch-1.1 a65c9ac11 -> e654cfdd0


[SPARK-2852][MLLIB] API consistency for `mllib.feature`

This is part of SPARK-2828:

1. added a Java-friendly fit method to Word2Vec with tests
2. change DeveloperApi to Experimental for Normalizer & StandardScaler
3. change default feature dimension to 2^20 in HashingTF

Author: Xiangrui Meng <m...@databricks.com>

Closes #1807 from mengxr/feature-api-check and squashes the following commits:

773c1a9 [Xiangrui Meng] change default numFeatures to 2^20 in HashingTF change 
annotation from DeveloperApi to Experimental in Normalizer and StandardScaler
883e122 [Xiangrui Meng] add @Experimental to Word2VecModel add a Java-friendly 
method to Word2Vec.fit with tests

(cherry picked from commit 25cff1019da9d6cfc486a31d035b372ea5fbdfd2)
Signed-off-by: Xiangrui Meng <m...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e654cfdd
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e654cfdd
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e654cfdd

Branch: refs/heads/branch-1.1
Commit: e654cfdd02e56fd3aaf6b784dcd25cb9ec35aece
Parents: a65c9ac
Author: Xiangrui Meng <m...@databricks.com>
Authored: Wed Aug 6 14:07:51 2014 -0700
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Wed Aug 6 14:08:03 2014 -0700

----------------------------------------------------------------------
 .../apache/spark/mllib/feature/HashingTF.scala  |  4 +-
 .../apache/spark/mllib/feature/Normalizer.scala |  6 +-
 .../spark/mllib/feature/StandardScaler.scala    |  6 +-
 .../apache/spark/mllib/feature/Word2Vec.scala   | 19 +++++-
 .../spark/mllib/feature/JavaWord2VecSuite.java  | 66 ++++++++++++++++++++
 5 files changed, 91 insertions(+), 10 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e654cfdd/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala
index 0f6d580..c534758 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala
@@ -32,12 +32,12 @@ import org.apache.spark.util.Utils
  * :: Experimental ::
  * Maps a sequence of terms to their term frequencies using the hashing trick.
  *
- * @param numFeatures number of features (default: 1000000)
+ * @param numFeatures number of features (default: 2^20^)
  */
 @Experimental
 class HashingTF(val numFeatures: Int) extends Serializable {
 
-  def this() = this(1000000)
+  def this() = this(1 << 20)
 
   /**
    * Returns the index of the input term.

http://git-wip-us.apache.org/repos/asf/spark/blob/e654cfdd/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala
index ea9fd0a..3afb477 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala
@@ -19,11 +19,11 @@ package org.apache.spark.mllib.feature
 
 import breeze.linalg.{DenseVector => BDV, SparseVector => BSV}
 
-import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.annotation.Experimental
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
 
 /**
- * :: DeveloperApi ::
+ * :: Experimental ::
  * Normalizes samples individually to unit L^p^ norm
  *
  * For any 1 <= p < Double.PositiveInfinity, normalizes samples using
@@ -33,7 +33,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
  *
  * @param p Normalization in L^p^ space, p = 2 by default.
  */
-@DeveloperApi
+@Experimental
 class Normalizer(p: Double) extends VectorTransformer {
 
   def this() = this(2)

http://git-wip-us.apache.org/repos/asf/spark/blob/e654cfdd/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
index cc2d757..e6c9f8f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
@@ -19,14 +19,14 @@ package org.apache.spark.mllib.feature
 
 import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV}
 
-import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.annotation.Experimental
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.rdd.RDDFunctions._
 import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
 import org.apache.spark.rdd.RDD
 
 /**
- * :: DeveloperApi ::
+ * :: Experimental ::
  * Standardizes features by removing the mean and scaling to unit variance 
using column summary
  * statistics on the samples in the training set.
  *
@@ -34,7 +34,7 @@ import org.apache.spark.rdd.RDD
  *                 dense output, so this does not work on sparse input and 
will raise an exception.
  * @param withStd True by default. Scales the data to unit standard deviation.
  */
-@DeveloperApi
+@Experimental
 class StandardScaler(withMean: Boolean, withStd: Boolean) extends 
VectorTransformer {
 
   def this() = this(false, true)

http://git-wip-us.apache.org/repos/asf/spark/blob/e654cfdd/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index 3bf44ad..395037e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -17,6 +17,9 @@
 
 package org.apache.spark.mllib.feature
 
+import java.lang.{Iterable => JavaIterable}
+
+import scala.collection.JavaConverters._
 import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 
@@ -25,6 +28,7 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas}
 import org.apache.spark.Logging
 import org.apache.spark.SparkContext._
 import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.rdd.RDDFunctions._
 import org.apache.spark.rdd._
@@ -239,7 +243,7 @@ class Word2Vec extends Serializable with Logging {
       a += 1
     }
   }
-  
+
   /**
    * Computes the vector representation of each word in vocabulary.
    * @param dataset an RDD of words
@@ -369,11 +373,22 @@ class Word2Vec extends Serializable with Logging {
 
     new Word2VecModel(word2VecMap.toMap)
   }
+
+  /**
+   * Computes the vector representation of each word in vocabulary (Java 
version).
+   * @param dataset a JavaRDD of words
+   * @return a Word2VecModel
+   */
+  def fit[S <: JavaIterable[String]](dataset: JavaRDD[S]): Word2VecModel = {
+    fit(dataset.rdd.map(_.asScala))
+  }
 }
 
 /**
-* Word2Vec model
+ * :: Experimental ::
+ * Word2Vec model
  */
+@Experimental
 class Word2VecModel private[mllib] (
     private val model: Map[String, Array[Float]]) extends Serializable {
 

http://git-wip-us.apache.org/repos/asf/spark/blob/e654cfdd/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java 
b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java
new file mode 100644
index 0000000..fb7afe8
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.feature;
+
+import java.io.Serializable;
+import java.util.List;
+
+import scala.Tuple2;
+
+import com.google.common.collect.Lists;
+import com.google.common.base.Strings;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+
+public class JavaWord2VecSuite implements Serializable {
+  private transient JavaSparkContext sc;
+
+  @Before
+  public void setUp() {
+    sc = new JavaSparkContext("local", "JavaWord2VecSuite");
+  }
+
+  @After
+  public void tearDown() {
+    sc.stop();
+    sc = null;
+  }
+
+  @Test
+  @SuppressWarnings("unchecked")
+  public void word2Vec() {
+    // The tests are to check Java compatibility.
+    String sentence = Strings.repeat("a b ", 100) + Strings.repeat("a c ", 10);
+    List<String> words = Lists.newArrayList(sentence.split(" "));
+    List<List<String>> localDoc = Lists.newArrayList(words, words);
+    JavaRDD<List<String>> doc = sc.parallelize(localDoc);
+    Word2Vec word2vec = new Word2Vec()
+      .setVectorSize(10)
+      .setSeed(42L);
+    Word2VecModel model = word2vec.fit(doc);
+    Tuple2<String, Object>[] syms = model.findSynonyms("a", 2);
+    Assert.assertEquals(2, syms.length);
+    Assert.assertEquals("b", syms[0]._1());
+    Assert.assertEquals("c", syms[1]._1());
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to