Repository: spark
Updated Branches:
  refs/heads/master 5ab9fcfb0 -> afe35f051


[SPARK-8455] [ML] Implement n-gram feature transformer

Implementation of n-gram feature transformer for ML.

Author: Feynman Liang <fli...@databricks.com>

Closes #6887 from feynmanliang/ngram-featurizer and squashes the following 
commits:

d2c839f [Feynman Liang] Make n > input length yield empty output
9fadd36 [Feynman Liang] Add empty and corner test cases, fix names and spaces
fe93873 [Feynman Liang] Implement n-gram feature transformer


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/afe35f05
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/afe35f05
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/afe35f05

Branch: refs/heads/master
Commit: afe35f0519bc7dcb85010a7eedcff854d4fc313a
Parents: 5ab9fcf
Author: Feynman Liang <fli...@databricks.com>
Authored: Mon Jun 22 14:15:35 2015 -0700
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Mon Jun 22 14:15:35 2015 -0700

----------------------------------------------------------------------
 .../org/apache/spark/ml/feature/NGram.scala     | 69 ++++++++++++++
 .../apache/spark/ml/feature/NGramSuite.scala    | 94 ++++++++++++++++++++
 2 files changed, 163 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/afe35f05/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala
new file mode 100644
index 0000000..8de10eb
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.UnaryTransformer
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
+
+/**
+ * :: Experimental ::
+ * A feature transformer that converts the input array of strings into an 
array of n-grams. Null
+ * values in the input array are ignored.
+ * It returns an array of n-grams where each n-gram is represented by a 
space-separated string of
+ * words.
+ *
+ * When the input is empty, an empty array is returned.
+ * When the input array length is less than n (number of elements per n-gram), 
no n-grams are
+ * returned.
+ */
+@Experimental
+class NGram(override val uid: String)
+  extends UnaryTransformer[Seq[String], Seq[String], NGram] {
+
+  def this() = this(Identifiable.randomUID("ngram"))
+
+  /**
+   * Minimum n-gram length, >= 1.
+   * Default: 2, bigram features
+   * @group param
+   */
+  val n: IntParam = new IntParam(this, "n", "number elements per n-gram (>=1)",
+    ParamValidators.gtEq(1))
+
+  /** @group setParam */
+  def setN(value: Int): this.type = set(n, value)
+
+  /** @group getParam */
+  def getN: Int = $(n)
+
+  setDefault(n -> 2)
+
+  override protected def createTransformFunc: Seq[String] => Seq[String] = {
+    _.iterator.sliding($(n)).withPartial(false).map(_.mkString(" ")).toSeq
+  }
+
+  override protected def validateInputType(inputType: DataType): Unit = {
+    require(inputType.sameType(ArrayType(StringType)),
+      s"Input type must be ArrayType(StringType) but got $inputType.")
+  }
+
+  override protected def outputDataType: DataType = new ArrayType(StringType, 
false)
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/afe35f05/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
new file mode 100644
index 0000000..ab97e3d
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import scala.beans.BeanInfo
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.{DataFrame, Row}
+
+@BeanInfo
+case class NGramTestData(inputTokens: Array[String], wantedNGrams: 
Array[String])
+
+class NGramSuite extends SparkFunSuite with MLlibTestSparkContext {
+  import org.apache.spark.ml.feature.NGramSuite._
+
+  test("default behavior yields bigram features") {
+    val nGram = new NGram()
+      .setInputCol("inputTokens")
+      .setOutputCol("nGrams")
+    val dataset = sqlContext.createDataFrame(Seq(
+      NGramTestData(
+        Array("Test", "for", "ngram", "."),
+        Array("Test for", "for ngram", "ngram .")
+    )))
+    testNGram(nGram, dataset)
+  }
+
+  test("NGramLength=4 yields length 4 n-grams") {
+    val nGram = new NGram()
+      .setInputCol("inputTokens")
+      .setOutputCol("nGrams")
+      .setN(4)
+    val dataset = sqlContext.createDataFrame(Seq(
+      NGramTestData(
+        Array("a", "b", "c", "d", "e"),
+        Array("a b c d", "b c d e")
+      )))
+    testNGram(nGram, dataset)
+  }
+
+  test("empty input yields empty output") {
+    val nGram = new NGram()
+      .setInputCol("inputTokens")
+      .setOutputCol("nGrams")
+      .setN(4)
+    val dataset = sqlContext.createDataFrame(Seq(
+      NGramTestData(
+        Array(),
+        Array()
+      )))
+    testNGram(nGram, dataset)
+  }
+
+  test("input array < n yields empty output") {
+    val nGram = new NGram()
+      .setInputCol("inputTokens")
+      .setOutputCol("nGrams")
+      .setN(6)
+    val dataset = sqlContext.createDataFrame(Seq(
+      NGramTestData(
+        Array("a", "b", "c", "d", "e"),
+        Array()
+      )))
+    testNGram(nGram, dataset)
+  }
+}
+
+object NGramSuite extends SparkFunSuite {
+
+  def testNGram(t: NGram, dataset: DataFrame): Unit = {
+    t.transform(dataset)
+      .select("nGrams", "wantedNGrams")
+      .collect()
+      .foreach { case Row(actualNGrams, wantedNGrams) =>
+        assert(actualNGrams === wantedNGrams)
+      }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to