Repository: spark Updated Branches: refs/heads/master 1c3e402e6 -> 149b3ee2d
[SPARK-7242][SQL][MLLIB] Frequent items for DataFrames Finding frequent items with possibly false positives, using the algorithm described in `http://www.cs.umd.edu/~samir/498/karp.pdf`. public API under: ``` df.stat.freqItems(cols: Array[String], support: Double = 0.001): DataFrame ``` The output is a local DataFrame having the input column names with `-freqItems` appended to it. This is a single pass algorithm that may return false positives, but no false negatives. cc mengxr rxin Let's get the implementations in, I can add python API in a follow up PR. Author: Burak Yavuz <brk...@gmail.com> Closes #5799 from brkyvz/freq-items and squashes the following commits: a6ec82c [Burak Yavuz] addressed comments v? 39b1bba [Burak Yavuz] removed toSeq 0915e23 [Burak Yavuz] addressed comments v2.1 3a5c177 [Burak Yavuz] addressed comments v2.0 482e741 [Burak Yavuz] removed old import 38e784d [Burak Yavuz] addressed comments v1.0 8279d4d [Burak Yavuz] added default value for support 3d82168 [Burak Yavuz] made base implementation Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/149b3ee2 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/149b3ee2 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/149b3ee2 Branch: refs/heads/master Commit: 149b3ee2dac992355adbe44e989570726c1f35d0 Parents: 1c3e402 Author: Burak Yavuz <brk...@gmail.com> Authored: Thu Apr 30 16:40:32 2015 -0700 Committer: Reynold Xin <r...@databricks.com> Committed: Thu Apr 30 16:40:32 2015 -0700 ---------------------------------------------------------------------- .../scala/org/apache/spark/sql/DataFrame.scala | 11 ++ .../spark/sql/DataFrameStatFunctions.scala | 68 +++++++++++ .../sql/execution/stat/FrequentItems.scala | 121 +++++++++++++++++++ .../apache/spark/sql/JavaDataFrameSuite.java | 14 ++- .../apache/spark/sql/DataFrameStatSuite.scala | 47 +++++++ 5 files changed, 256 insertions(+), 5 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/149b3ee2/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 2669300..7be2a01 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -331,6 +331,17 @@ class DataFrame private[sql]( def na: DataFrameNaFunctions = new DataFrameNaFunctions(this) /** + * Returns a [[DataFrameStatFunctions]] for working statistic functions support. + * {{{ + * // Finding frequent items in column with name 'a'. + * df.stat.freqItems(Seq("a")) + * }}} + * + * @group dfops + */ + def stat: DataFrameStatFunctions = new DataFrameStatFunctions(this) + + /** * Cartesian join with another [[DataFrame]]. * * Note that cartesian joins are very expensive without an extra filter that can be pushed down. http://git-wip-us.apache.org/repos/asf/spark/blob/149b3ee2/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala new file mode 100644 index 0000000..42e5cbc --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -0,0 +1,68 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.execution.stat.FrequentItems + +/** + * :: Experimental :: + * Statistic functions for [[DataFrame]]s. + */ +@Experimental +final class DataFrameStatFunctions private[sql](df: DataFrame) { + + /** + * Finding frequent items for columns, possibly with false positives. Using the + * frequent element count algorithm described in + * [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]]. + * The `support` should be greater than 1e-4. + * + * @param cols the names of the columns to search frequent items in. + * @param support The minimum frequency for an item to be considered `frequent`. Should be greater + * than 1e-4. + * @return A Local DataFrame with the Array of frequent items for each column. + */ + def freqItems(cols: Array[String], support: Double): DataFrame = { + FrequentItems.singlePassFreqItems(df, cols, support) + } + + /** + * Runs `freqItems` with a default `support` of 1%. + * + * @param cols the names of the columns to search frequent items in. + * @return A Local DataFrame with the Array of frequent items for each column. + */ + def freqItems(cols: Array[String]): DataFrame = { + FrequentItems.singlePassFreqItems(df, cols, 0.01) + } + + /** + * Python friendly implementation for `freqItems` + */ + def freqItems(cols: List[String], support: Double): DataFrame = { + FrequentItems.singlePassFreqItems(df, cols, support) + } + + /** + * Python friendly implementation for `freqItems` with a default `support` of 1%. + */ + def freqItems(cols: List[String]): DataFrame = { + FrequentItems.singlePassFreqItems(df, cols, 0.01) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/149b3ee2/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala new file mode 100644 index 0000000..5ae7e10 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -0,0 +1,121 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.stat + +import scala.collection.mutable.{Map => MutableMap} + +import org.apache.spark.Logging +import org.apache.spark.sql.{Column, DataFrame, Row} +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.types.{ArrayType, StructField, StructType} + +private[sql] object FrequentItems extends Logging { + + /** A helper class wrapping `MutableMap[Any, Long]` for simplicity. */ + private class FreqItemCounter(size: Int) extends Serializable { + val baseMap: MutableMap[Any, Long] = MutableMap.empty[Any, Long] + + /** + * Add a new example to the counts if it exists, otherwise deduct the count + * from existing items. + */ + def add(key: Any, count: Long): this.type = { + if (baseMap.contains(key)) { + baseMap(key) += count + } else { + if (baseMap.size < size) { + baseMap += key -> count + } else { + // TODO: Make this more efficient... A flatMap? + baseMap.retain((k, v) => v > count) + baseMap.transform((k, v) => v - count) + } + } + this + } + + /** + * Merge two maps of counts. + * @param other The map containing the counts for that partition + */ + def merge(other: FreqItemCounter): this.type = { + other.baseMap.foreach { case (k, v) => + add(k, v) + } + this + } + } + + /** + * Finding frequent items for columns, possibly with false positives. Using the + * frequent element count algorithm described in + * [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]]. + * The `support` should be greater than 1e-4. + * For Internal use only. + * + * @param df The input DataFrame + * @param cols the names of the columns to search frequent items in + * @param support The minimum frequency for an item to be considered `frequent`. Should be greater + * than 1e-4. + * @return A Local DataFrame with the Array of frequent items for each column. + */ + private[sql] def singlePassFreqItems( + df: DataFrame, + cols: Seq[String], + support: Double): DataFrame = { + require(support >= 1e-4, s"support ($support) must be greater than 1e-4.") + val numCols = cols.length + // number of max items to keep counts for + val sizeOfMap = (1 / support).toInt + val countMaps = Seq.tabulate(numCols)(i => new FreqItemCounter(sizeOfMap)) + val originalSchema = df.schema + val colInfo = cols.map { name => + val index = originalSchema.fieldIndex(name) + (name, originalSchema.fields(index).dataType) + } + + val freqItems = df.select(cols.map(Column(_)):_*).rdd.aggregate(countMaps)( + seqOp = (counts, row) => { + var i = 0 + while (i < numCols) { + val thisMap = counts(i) + val key = row.get(i) + thisMap.add(key, 1L) + i += 1 + } + counts + }, + combOp = (baseCounts, counts) => { + var i = 0 + while (i < numCols) { + baseCounts(i).merge(counts(i)) + i += 1 + } + baseCounts + } + ) + val justItems = freqItems.map(m => m.baseMap.keys.toSeq) + val resultRow = Row(justItems:_*) + // append frequent Items to the column name for easy debugging + val outputCols = colInfo.map { v => + StructField(v._1 + "_freqItems", ArrayType(v._2, false)) + } + val schema = StructType(outputCols).toAttributes + new DataFrame(df.sqlContext, LocalRelation(schema, Seq(resultRow))) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/149b3ee2/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index e5c9504..966d879 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -22,10 +22,7 @@ import com.google.common.primitives.Ints; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.TestData$; +import org.apache.spark.sql.*; import org.apache.spark.sql.test.TestSQLContext; import org.apache.spark.sql.test.TestSQLContext$; import org.apache.spark.sql.types.*; @@ -178,5 +175,12 @@ public class JavaDataFrameSuite { Assert.assertEquals(bean.getD().get(i), d.apply(i)); } } - + + @Test + public void testFrequentItems() { + DataFrame df = context.table("testData2"); + String[] cols = new String[]{"a"}; + DataFrame results = df.stat().freqItems(cols, 0.2); + Assert.assertTrue(results.collect()[0].getSeq(0).contains(1)); + } } http://git-wip-us.apache.org/repos/asf/spark/blob/149b3ee2/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala new file mode 100644 index 0000000..bb1d29c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.scalatest.FunSuite +import org.scalatest.Matchers._ + +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext.implicits._ + +class DataFrameStatSuite extends FunSuite { + + val sqlCtx = TestSQLContext + + test("Frequent Items") { + def toLetter(i: Int): String = (i + 96).toChar.toString + val rows = Array.tabulate(1000) { i => + if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0) + } + val df = sqlCtx.sparkContext.parallelize(rows).toDF("numbers", "letters", "negDoubles") + + val results = df.stat.freqItems(Array("numbers", "letters"), 0.1) + val items = results.collect().head + items.getSeq[Int](0) should contain (1) + items.getSeq[String](1) should contain (toLetter(1)) + + val singleColResults = df.stat.freqItems(Array("negDoubles"), 0.1) + val items2 = singleColResults.collect().head + items2.getSeq[Double](0) should contain (-1.0) + + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org