Repository: spark
Updated Branches:
  refs/heads/master f48420fde -> 248067adb


[SPARK-2961][SQL] Use statistics to prune batches within cached partitions

This PR is based on #1883 authored by marmbrus. Key differences:

1. Batch pruning instead of partition pruning

   When #1883 was authored, batched column buffer building (#1880) hadn't been 
introduced. This PR combines these two and provide partition batch level 
pruning, which leads to smaller memory footprints and can generally skip more 
elements. The cost is that the pruning predicates are evaluated more frequently 
(partition number multiplies batch number per partition).

1. More filters are supported

   Filter predicates consist of `=`, `<`, `<=`, `>`, `>=` and their 
conjunctions and disjunctions are supported.

Author: Cheng Lian <lian.cs....@gmail.com>

Closes #2188 from liancheng/in-mem-batch-pruning and squashes the following 
commits:

68cf019 [Cheng Lian] Marked sqlContext as @transient
4254f6c [Cheng Lian] Enables in-memory partition pruning in 
PartitionBatchPruningSuite
3784105 [Cheng Lian] Overrides InMemoryColumnarTableScan.sqlContext
d2a1d66 [Cheng Lian] Disables in-memory partition pruning by default
062c315 [Cheng Lian] HiveCompatibilitySuite code cleanup
16b77bf [Cheng Lian] Fixed pruning predication conjunctions and disjunctions
16195c5 [Cheng Lian] Enabled both disjunction and conjunction
89950d0 [Cheng Lian] Worked around Scala style check
9c167f6 [Cheng Lian] Minor code cleanup
3c4d5c7 [Cheng Lian] Minor code cleanup
ea59ee5 [Cheng Lian] Renamed PartitionSkippingSuite to 
PartitionBatchPruningSuite
fc517d0 [Cheng Lian] More test cases
1868c18 [Cheng Lian] Code cleanup, bugfix, and adding tests
cb76da4 [Cheng Lian] Added more predicate filters, fixed table scan stats for 
testing purposes
385474a [Cheng Lian] Merge branch 'inMemStats' into in-mem-batch-pruning


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/248067ad
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/248067ad
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/248067ad

Branch: refs/heads/master
Commit: 248067adbe90f93c7d5e23aa61b3072dfdf48a8a
Parents: f48420f
Author: Cheng Lian <lian.cs....@gmail.com>
Authored: Wed Sep 3 18:59:26 2014 -0700
Committer: Michael Armbrust <mich...@databricks.com>
Committed: Wed Sep 3 18:59:26 2014 -0700

----------------------------------------------------------------------
 .../sql/catalyst/expressions/AttributeMap.scala |  41 ++
 .../catalyst/expressions/BoundAttribute.scala   |  12 +-
 .../scala/org/apache/spark/sql/SQLConf.scala    |   7 +
 .../spark/sql/columnar/ColumnBuilder.scala      |  10 +-
 .../apache/spark/sql/columnar/ColumnStats.scala | 434 ++++++-------------
 .../columnar/InMemoryColumnarTableScan.scala    | 131 +++++-
 .../sql/columnar/NullableColumnBuilder.scala    |   1 +
 .../spark/sql/execution/SparkStrategies.scala   |   4 +-
 .../spark/sql/columnar/ColumnStatsSuite.scala   |  39 +-
 .../columnar/NullableColumnBuilderSuite.scala   |   2 +-
 .../columnar/PartitionBatchPruningSuite.scala   |  95 ++++
 .../compression/BooleanBitSetSuite.scala        |   4 +-
 .../compression/DictionaryEncodingSuite.scala   |   2 +-
 .../compression/IntegralDeltaSuite.scala        |   2 +-
 .../compression/RunLengthEncodingSuite.scala    |   4 +-
 .../TestCompressibleColumnBuilder.scala         |   4 +-
 .../hive/execution/HiveCompatibilitySuite.scala |  13 +-
 17 files changed, 446 insertions(+), 359 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/248067ad/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
new file mode 100644
index 0000000..8364379
--- /dev/null
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+/**
+ * Builds a map that is keyed by an Attribute's expression id. Using the 
expression id allows values
+ * to be looked up even when the attributes used differ cosmetically (i.e., 
the capitalization
+ * of the name, or the expected nullability).
+ */
+object AttributeMap {
+  def apply[A](kvs: Seq[(Attribute, A)]) =
+    new AttributeMap(kvs.map(kv => (kv._1.exprId, (kv._1, kv._2))).toMap)
+}
+
+class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)])
+  extends Map[Attribute, A] with Serializable {
+
+  override def get(k: Attribute): Option[A] = baseMap.get(k.exprId).map(_._2)
+
+  override def + [B1 >: A](kv: (Attribute, B1)): Map[Attribute, B1] =
+    (baseMap.map(_._2) + kv).toMap
+
+  override def iterator: Iterator[(Attribute, A)] = baseMap.map(_._2).iterator
+
+  override def -(key: Attribute): Map[Attribute, A] = (baseMap.map(_._2) - 
key).toMap
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/248067ad/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index 54c6baf..fa80b07 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -38,12 +38,20 @@ case class BoundReference(ordinal: Int, dataType: DataType, 
nullable: Boolean)
 }
 
 object BindReferences extends Logging {
-  def bindReference[A <: Expression](expression: A, input: Seq[Attribute]): A 
= {
+
+  def bindReference[A <: Expression](
+      expression: A,
+      input: Seq[Attribute],
+      allowFailures: Boolean = false): A = {
     expression.transform { case a: AttributeReference =>
       attachTree(a, "Binding attribute") {
         val ordinal = input.indexWhere(_.exprId == a.exprId)
         if (ordinal == -1) {
-          sys.error(s"Couldn't find $a in ${input.mkString("[", ",", "]")}")
+          if (allowFailures) {
+            a
+          } else {
+            sys.error(s"Couldn't find $a in ${input.mkString("[", ",", "]")}")
+          }
         } else {
           BoundReference(ordinal, a.dataType, a.nullable)
         }

http://git-wip-us.apache.org/repos/asf/spark/blob/248067ad/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index 64d4935..4137ac7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -26,6 +26,7 @@ import java.util.Properties
 private[spark] object SQLConf {
   val COMPRESS_CACHED = "spark.sql.inMemoryColumnarStorage.compressed"
   val COLUMN_BATCH_SIZE = "spark.sql.inMemoryColumnarStorage.batchSize"
+  val IN_MEMORY_PARTITION_PRUNING = 
"spark.sql.inMemoryColumnarStorage.partitionPruning"
   val AUTO_BROADCASTJOIN_THRESHOLD = "spark.sql.autoBroadcastJoinThreshold"
   val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes"
   val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions"
@@ -124,6 +125,12 @@ trait SQLConf {
   private[spark] def isParquetBinaryAsString: Boolean =
     getConf(PARQUET_BINARY_AS_STRING, "false").toBoolean
 
+  /**
+   * When set to true, partition pruning for in-memory columnar tables is 
enabled.
+   */
+  private[spark] def inMemoryPartitionPruning: Boolean =
+    getConf(IN_MEMORY_PARTITION_PRUNING, "false").toBoolean
+
   /** ********************** SQLConf functionality methods ************ */
 
   /** Set Spark SQL configuration properties. */

http://git-wip-us.apache.org/repos/asf/spark/blob/248067ad/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
index 247337a..b3ec5de 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
@@ -38,7 +38,7 @@ private[sql] trait ColumnBuilder {
   /**
    * Column statistics information
    */
-  def columnStats: ColumnStats[_, _]
+  def columnStats: ColumnStats
 
   /**
    * Returns the final columnar byte buffer.
@@ -47,7 +47,7 @@ private[sql] trait ColumnBuilder {
 }
 
 private[sql] class BasicColumnBuilder[T <: DataType, JvmType](
-    val columnStats: ColumnStats[T, JvmType],
+    val columnStats: ColumnStats,
     val columnType: ColumnType[T, JvmType])
   extends ColumnBuilder {
 
@@ -81,18 +81,18 @@ private[sql] class BasicColumnBuilder[T <: DataType, 
JvmType](
 
 private[sql] abstract class ComplexColumnBuilder[T <: DataType, JvmType](
     columnType: ColumnType[T, JvmType])
-  extends BasicColumnBuilder[T, JvmType](new NoopColumnStats[T, JvmType], 
columnType)
+  extends BasicColumnBuilder[T, JvmType](new NoopColumnStats, columnType)
   with NullableColumnBuilder
 
 private[sql] abstract class NativeColumnBuilder[T <: NativeType](
-    override val columnStats: NativeColumnStats[T],
+    override val columnStats: ColumnStats,
     override val columnType: NativeColumnType[T])
   extends BasicColumnBuilder[T, T#JvmType](columnStats, columnType)
   with NullableColumnBuilder
   with AllCompressionSchemes
   with CompressibleColumnBuilder[T]
 
-private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(new 
BooleanColumnStats, BOOLEAN)
+private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(new 
NoopColumnStats, BOOLEAN)
 
 private[sql] class IntColumnBuilder extends NativeColumnBuilder(new 
IntColumnStats, INT)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/248067ad/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
index 6502110..fc343cc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
@@ -17,381 +17,193 @@
 
 package org.apache.spark.sql.columnar
 
+import java.sql.Timestamp
+
 import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions.{AttributeMap, Attribute, 
AttributeReference}
 import org.apache.spark.sql.catalyst.types._
 
+private[sql] class ColumnStatisticsSchema(a: Attribute) extends Serializable {
+  val upperBound = AttributeReference(a.name + ".upperBound", a.dataType, 
nullable = false)()
+  val lowerBound = AttributeReference(a.name + ".lowerBound", a.dataType, 
nullable = false)()
+  val nullCount =  AttributeReference(a.name + ".nullCount", IntegerType, 
nullable = false)()
+
+  val schema = Seq(lowerBound, upperBound, nullCount)
+}
+
+private[sql] class PartitionStatistics(tableSchema: Seq[Attribute]) extends 
Serializable {
+  val (forAttribute, schema) = {
+    val allStats = tableSchema.map(a => a -> new ColumnStatisticsSchema(a))
+    (AttributeMap(allStats), 
allStats.map(_._2.schema).foldLeft(Seq.empty[Attribute])(_ ++ _))
+  }
+}
+
 /**
  * Used to collect statistical information when building in-memory columns.
  *
  * NOTE: we intentionally avoid using `Ordering[T]` to compare values here 
because `Ordering[T]`
  * brings significant performance penalty.
  */
-private[sql] sealed abstract class ColumnStats[T <: DataType, JvmType] extends 
Serializable {
-  /**
-   * Closed lower bound of this column.
-   */
-  def lowerBound: JvmType
-
-  /**
-   * Closed upper bound of this column.
-   */
-  def upperBound: JvmType
-
+private[sql] sealed trait ColumnStats extends Serializable {
   /**
    * Gathers statistics information from `row(ordinal)`.
    */
-  def gatherStats(row: Row, ordinal: Int)
-
-  /**
-   * Returns `true` if `lower <= row(ordinal) <= upper`.
-   */
-  def contains(row: Row, ordinal: Int): Boolean
+  def gatherStats(row: Row, ordinal: Int): Unit
 
   /**
-   * Returns `true` if `row(ordinal) < upper` holds.
+   * Column statistics represented as a single row, currently including closed 
lower bound, closed
+   * upper bound and null count.
    */
-  def isAbove(row: Row, ordinal: Int): Boolean
-
-  /**
-   * Returns `true` if `lower < row(ordinal)` holds.
-   */
-  def isBelow(row: Row, ordinal: Int): Boolean
-
-  /**
-   * Returns `true` if `row(ordinal) <= upper` holds.
-   */
-  def isAtOrAbove(row: Row, ordinal: Int): Boolean
-
-  /**
-   * Returns `true` if `lower <= row(ordinal)` holds.
-   */
-  def isAtOrBelow(row: Row, ordinal: Int): Boolean
-}
-
-private[sql] sealed abstract class NativeColumnStats[T <: NativeType]
-  extends ColumnStats[T, T#JvmType] {
-
-  type JvmType = T#JvmType
-
-  protected var (_lower, _upper) = initialBounds
-
-  def initialBounds: (JvmType, JvmType)
-
-  protected def columnType: NativeColumnType[T]
-
-  override def lowerBound: T#JvmType = _lower
-
-  override def upperBound: T#JvmType = _upper
-
-  override def isAtOrAbove(row: Row, ordinal: Int) = {
-    contains(row, ordinal) || isAbove(row, ordinal)
-  }
-
-  override def isAtOrBelow(row: Row, ordinal: Int) = {
-    contains(row, ordinal) || isBelow(row, ordinal)
-  }
+  def collectedStatistics: Row
 }
 
-private[sql] class NoopColumnStats[T <: DataType, JvmType] extends 
ColumnStats[T, JvmType] {
-  override def isAtOrBelow(row: Row, ordinal: Int) = true
-
-  override def isAtOrAbove(row: Row, ordinal: Int) = true
-
-  override def isBelow(row: Row, ordinal: Int) = true
-
-  override def isAbove(row: Row, ordinal: Int) = true
+private[sql] class NoopColumnStats extends ColumnStats {
 
-  override def contains(row: Row, ordinal: Int) = true
+  override def gatherStats(row: Row, ordinal: Int): Unit = {}
 
-  override def gatherStats(row: Row, ordinal: Int) {}
-
-  override def upperBound = null.asInstanceOf[JvmType]
-
-  override def lowerBound = null.asInstanceOf[JvmType]
+  override def collectedStatistics = Row()
 }
 
-private[sql] abstract class BasicColumnStats[T <: NativeType](
-    protected val columnType: NativeColumnType[T])
-  extends NativeColumnStats[T]
-
-private[sql] class BooleanColumnStats extends BasicColumnStats(BOOLEAN) {
-  override def initialBounds = (true, false)
-
-  override def isBelow(row: Row, ordinal: Int) = {
-    lowerBound < columnType.getField(row, ordinal)
-  }
-
-  override def isAbove(row: Row, ordinal: Int) = {
-    columnType.getField(row, ordinal) < upperBound
-  }
-
-  override def contains(row: Row, ordinal: Int) = {
-    val field = columnType.getField(row, ordinal)
-    lowerBound <= field && field <= upperBound
-  }
+private[sql] class ByteColumnStats extends ColumnStats {
+  var upper = Byte.MinValue
+  var lower = Byte.MaxValue
+  var nullCount = 0
 
   override def gatherStats(row: Row, ordinal: Int) {
-    val field = columnType.getField(row, ordinal)
-    if (field > upperBound) _upper = field
-    if (field < lowerBound) _lower = field
-  }
-}
-
-private[sql] class ByteColumnStats extends BasicColumnStats(BYTE) {
-  override def initialBounds = (Byte.MaxValue, Byte.MinValue)
-
-  override def isBelow(row: Row, ordinal: Int) = {
-    lowerBound < columnType.getField(row, ordinal)
-  }
-
-  override def isAbove(row: Row, ordinal: Int) = {
-    columnType.getField(row, ordinal) < upperBound
-  }
-
-  override def contains(row: Row, ordinal: Int) = {
-    val field = columnType.getField(row, ordinal)
-    lowerBound <= field && field <= upperBound
+    if (!row.isNullAt(ordinal)) {
+      val value = row.getByte(ordinal)
+      if (value > upper) upper = value
+      if (value < lower) lower = value
+    } else {
+      nullCount += 1
+    }
   }
 
-  override def gatherStats(row: Row, ordinal: Int) {
-    val field = columnType.getField(row, ordinal)
-    if (field > upperBound) _upper = field
-    if (field < lowerBound) _lower = field
-  }
+  def collectedStatistics = Row(lower, upper, nullCount)
 }
 
-private[sql] class ShortColumnStats extends BasicColumnStats(SHORT) {
-  override def initialBounds = (Short.MaxValue, Short.MinValue)
-
-  override def isBelow(row: Row, ordinal: Int) = {
-    lowerBound < columnType.getField(row, ordinal)
-  }
-
-  override def isAbove(row: Row, ordinal: Int) = {
-    columnType.getField(row, ordinal) < upperBound
-  }
-
-  override def contains(row: Row, ordinal: Int) = {
-    val field = columnType.getField(row, ordinal)
-    lowerBound <= field && field <= upperBound
-  }
+private[sql] class ShortColumnStats extends ColumnStats {
+  var upper = Short.MinValue
+  var lower = Short.MaxValue
+  var nullCount = 0
 
   override def gatherStats(row: Row, ordinal: Int) {
-    val field = columnType.getField(row, ordinal)
-    if (field > upperBound) _upper = field
-    if (field < lowerBound) _lower = field
-  }
-}
-
-private[sql] class LongColumnStats extends BasicColumnStats(LONG) {
-  override def initialBounds = (Long.MaxValue, Long.MinValue)
-
-  override def isBelow(row: Row, ordinal: Int) = {
-    lowerBound < columnType.getField(row, ordinal)
-  }
-
-  override def isAbove(row: Row, ordinal: Int) = {
-    columnType.getField(row, ordinal) < upperBound
-  }
-
-  override def contains(row: Row, ordinal: Int) = {
-    val field = columnType.getField(row, ordinal)
-    lowerBound <= field && field <= upperBound
+    if (!row.isNullAt(ordinal)) {
+      val value = row.getShort(ordinal)
+      if (value > upper) upper = value
+      if (value < lower) lower = value
+    } else {
+      nullCount += 1
+    }
   }
 
-  override def gatherStats(row: Row, ordinal: Int) {
-    val field = columnType.getField(row, ordinal)
-    if (field > upperBound) _upper = field
-    if (field < lowerBound) _lower = field
-  }
+  def collectedStatistics = Row(lower, upper, nullCount)
 }
 
-private[sql] class DoubleColumnStats extends BasicColumnStats(DOUBLE) {
-  override def initialBounds = (Double.MaxValue, Double.MinValue)
-
-  override def isBelow(row: Row, ordinal: Int) = {
-    lowerBound < columnType.getField(row, ordinal)
-  }
-
-  override def isAbove(row: Row, ordinal: Int) = {
-    columnType.getField(row, ordinal) < upperBound
-  }
-
-  override def contains(row: Row, ordinal: Int) = {
-    val field = columnType.getField(row, ordinal)
-    lowerBound <= field && field <= upperBound
-  }
+private[sql] class LongColumnStats extends ColumnStats {
+  var upper = Long.MinValue
+  var lower = Long.MaxValue
+  var nullCount = 0
 
   override def gatherStats(row: Row, ordinal: Int) {
-    val field = columnType.getField(row, ordinal)
-    if (field > upperBound) _upper = field
-    if (field < lowerBound) _lower = field
-  }
-}
-
-private[sql] class FloatColumnStats extends BasicColumnStats(FLOAT) {
-  override def initialBounds = (Float.MaxValue, Float.MinValue)
-
-  override def isBelow(row: Row, ordinal: Int) = {
-    lowerBound < columnType.getField(row, ordinal)
+    if (!row.isNullAt(ordinal)) {
+      val value = row.getLong(ordinal)
+      if (value > upper) upper = value
+      if (value < lower) lower = value
+    } else {
+      nullCount += 1
+    }
   }
 
-  override def isAbove(row: Row, ordinal: Int) = {
-    columnType.getField(row, ordinal) < upperBound
-  }
+  def collectedStatistics = Row(lower, upper, nullCount)
+}
 
-  override def contains(row: Row, ordinal: Int) = {
-    val field = columnType.getField(row, ordinal)
-    lowerBound <= field && field <= upperBound
-  }
+private[sql] class DoubleColumnStats extends ColumnStats {
+  var upper = Double.MinValue
+  var lower = Double.MaxValue
+  var nullCount = 0
 
   override def gatherStats(row: Row, ordinal: Int) {
-    val field = columnType.getField(row, ordinal)
-    if (field > upperBound) _upper = field
-    if (field < lowerBound) _lower = field
+    if (!row.isNullAt(ordinal)) {
+      val value = row.getDouble(ordinal)
+      if (value > upper) upper = value
+      if (value < lower) lower = value
+    } else {
+      nullCount += 1
+    }
   }
-}
 
-private[sql] object IntColumnStats {
-  val UNINITIALIZED = 0
-  val INITIALIZED = 1
-  val ASCENDING = 2
-  val DESCENDING = 3
-  val UNORDERED = 4
+  def collectedStatistics = Row(lower, upper, nullCount)
 }
 
-/**
- * Statistical information for `Int` columns. More information is collected 
since `Int` is
- * frequently used. Extra information include:
- *
- * - Ordering state (ascending/descending/unordered), may be used to decide 
whether binary search
- *   is applicable when searching elements.
- * - Maximum delta between adjacent elements, may be used to guide the 
`IntDelta` compression
- *   scheme.
- *
- * (This two kinds of information are not used anywhere yet and might be 
removed later.)
- */
-private[sql] class IntColumnStats extends BasicColumnStats(INT) {
-  import IntColumnStats._
-
-  private var orderedState = UNINITIALIZED
-  private var lastValue: Int = _
-  private var _maxDelta: Int = _
-
-  def isAscending = orderedState != DESCENDING && orderedState != UNORDERED
-  def isDescending = orderedState != ASCENDING && orderedState != UNORDERED
-  def isOrdered = isAscending || isDescending
-  def maxDelta = _maxDelta
-
-  override def initialBounds = (Int.MaxValue, Int.MinValue)
+private[sql] class FloatColumnStats extends ColumnStats {
+  var upper = Float.MinValue
+  var lower = Float.MaxValue
+  var nullCount = 0
 
-  override def isBelow(row: Row, ordinal: Int) = {
-    lowerBound < columnType.getField(row, ordinal)
+  override def gatherStats(row: Row, ordinal: Int) {
+    if (!row.isNullAt(ordinal)) {
+      val value = row.getFloat(ordinal)
+      if (value > upper) upper = value
+      if (value < lower) lower = value
+    } else {
+      nullCount += 1
+    }
   }
 
-  override def isAbove(row: Row, ordinal: Int) = {
-    columnType.getField(row, ordinal) < upperBound
-  }
+  def collectedStatistics = Row(lower, upper, nullCount)
+}
 
-  override def contains(row: Row, ordinal: Int) = {
-    val field = columnType.getField(row, ordinal)
-    lowerBound <= field && field <= upperBound
-  }
+private[sql] class IntColumnStats extends ColumnStats {
+  var upper = Int.MinValue
+  var lower = Int.MaxValue
+  var nullCount = 0
 
   override def gatherStats(row: Row, ordinal: Int) {
-    val field = columnType.getField(row, ordinal)
-
-    if (field > upperBound) _upper = field
-    if (field < lowerBound) _lower = field
-
-    orderedState = orderedState match {
-      case UNINITIALIZED =>
-        lastValue = field
-        INITIALIZED
-
-      case INITIALIZED =>
-        // If all the integers in the column are the same, ordered state is 
set to Ascending.
-        // TODO (lian) Confirm whether this is the standard behaviour.
-        val nextState = if (field >= lastValue) ASCENDING else DESCENDING
-        _maxDelta = math.abs(field - lastValue)
-        lastValue = field
-        nextState
-
-      case ASCENDING if field < lastValue =>
-        UNORDERED
-
-      case DESCENDING if field > lastValue =>
-        UNORDERED
-
-      case state @ (ASCENDING | DESCENDING) =>
-        _maxDelta = _maxDelta.max(field - lastValue)
-        lastValue = field
-        state
-
-      case _ =>
-        orderedState
+    if (!row.isNullAt(ordinal)) {
+      val value = row.getInt(ordinal)
+      if (value > upper) upper = value
+      if (value < lower) lower = value
+    } else {
+      nullCount += 1
     }
   }
+
+  def collectedStatistics = Row(lower, upper, nullCount)
 }
 
-private[sql] class StringColumnStats extends BasicColumnStats(STRING) {
-  override def initialBounds = (null, null)
+private[sql] class StringColumnStats extends ColumnStats {
+  var upper: String = null
+  var lower: String = null
+  var nullCount = 0
 
   override def gatherStats(row: Row, ordinal: Int) {
-    val field = columnType.getField(row, ordinal)
-    if ((upperBound eq null) || field.compareTo(upperBound) > 0) _upper = field
-    if ((lowerBound eq null) || field.compareTo(lowerBound) < 0) _lower = field
-  }
-
-  override def contains(row: Row, ordinal: Int) = {
-    (upperBound ne null) && {
-      val field = columnType.getField(row, ordinal)
-      lowerBound.compareTo(field) <= 0 && field.compareTo(upperBound) <= 0
-    }
-  }
-
-  override def isAbove(row: Row, ordinal: Int) = {
-    (upperBound ne null) && {
-      val field = columnType.getField(row, ordinal)
-      field.compareTo(upperBound) < 0
+    if (!row.isNullAt(ordinal)) {
+      val value = row.getString(ordinal)
+      if (upper == null || value.compareTo(upper) > 0) upper = value
+      if (lower == null || value.compareTo(lower) < 0) lower = value
+    } else {
+      nullCount += 1
     }
   }
 
-  override def isBelow(row: Row, ordinal: Int) = {
-    (lowerBound ne null) && {
-      val field = columnType.getField(row, ordinal)
-      lowerBound.compareTo(field) < 0
-    }
-  }
+  def collectedStatistics = Row(lower, upper, nullCount)
 }
 
-private[sql] class TimestampColumnStats extends BasicColumnStats(TIMESTAMP) {
-  override def initialBounds = (null, null)
+private[sql] class TimestampColumnStats extends ColumnStats {
+  var upper: Timestamp = null
+  var lower: Timestamp = null
+  var nullCount = 0
 
   override def gatherStats(row: Row, ordinal: Int) {
-    val field = columnType.getField(row, ordinal)
-    if ((upperBound eq null) || field.compareTo(upperBound) > 0) _upper = field
-    if ((lowerBound eq null) || field.compareTo(lowerBound) < 0) _lower = field
-  }
-
-  override def contains(row: Row, ordinal: Int) = {
-    (upperBound ne null) && {
-      val field = columnType.getField(row, ordinal)
-      lowerBound.compareTo(field) <= 0 && field.compareTo(upperBound) <= 0
+    if (!row.isNullAt(ordinal)) {
+      val value = row(ordinal).asInstanceOf[Timestamp]
+      if (upper == null || value.compareTo(upper) > 0) upper = value
+      if (lower == null || value.compareTo(lower) < 0) lower = value
+    } else {
+      nullCount += 1
     }
   }
 
-  override def isAbove(row: Row, ordinal: Int) = {
-    (lowerBound ne null) && {
-      val field = columnType.getField(row, ordinal)
-      field.compareTo(upperBound) < 0
-    }
-  }
-
-  override def isBelow(row: Row, ordinal: Int) = {
-    (lowerBound ne null) && {
-      val field = columnType.getField(row, ordinal)
-      lowerBound.compareTo(field) < 0
-    }
-  }
+  def collectedStatistics = Row(lower, upper, nullCount)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/248067ad/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
index cb055cd..dc668e7 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
@@ -19,10 +19,12 @@ package org.apache.spark.sql.columnar
 
 import java.nio.ByteBuffer
 
+import org.apache.spark.SparkContext._
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
-import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow}
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.execution.{LeafNode, SparkPlan}
 
@@ -31,23 +33,27 @@ object InMemoryRelation {
     new InMemoryRelation(child.output, useCompression, batchSize, child)()
 }
 
+private[sql] case class CachedBatch(buffers: Array[ByteBuffer], stats: Row)
+
 private[sql] case class InMemoryRelation(
     output: Seq[Attribute],
     useCompression: Boolean,
     batchSize: Int,
     child: SparkPlan)
-    (private var _cachedColumnBuffers: RDD[Array[ByteBuffer]] = null)
+    (private var _cachedColumnBuffers: RDD[CachedBatch] = null)
   extends LogicalPlan with MultiInstanceRelation {
 
   override lazy val statistics =
     Statistics(sizeInBytes = child.sqlContext.defaultSizeInBytes)
 
+  val partitionStatistics = new PartitionStatistics(output)
+
   // If the cached column buffers were not passed in, we calculate them in the 
constructor.
   // As in Spark, the actual work of caching is lazy.
   if (_cachedColumnBuffers == null) {
     val output = child.output
     val cached = child.execute().mapPartitions { baseIterator =>
-      new Iterator[Array[ByteBuffer]] {
+      new Iterator[CachedBatch] {
         def next() = {
           val columnBuilders = output.map { attribute =>
             val columnType = ColumnType(attribute.dataType)
@@ -68,7 +74,10 @@ private[sql] case class InMemoryRelation(
             rowCount += 1
           }
 
-          columnBuilders.map(_.build())
+          val stats = Row.fromSeq(
+            
columnBuilders.map(_.columnStats.collectedStatistics).foldLeft(Seq.empty[Any])(_
 ++ _))
+
+          CachedBatch(columnBuilders.map(_.build()), stats)
         }
 
         def hasNext = baseIterator.hasNext
@@ -79,7 +88,6 @@ private[sql] case class InMemoryRelation(
     _cachedColumnBuffers = cached
   }
 
-
   override def children = Seq.empty
 
   override def newInstance() = {
@@ -96,13 +104,98 @@ private[sql] case class InMemoryRelation(
 
 private[sql] case class InMemoryColumnarTableScan(
     attributes: Seq[Attribute],
+    predicates: Seq[Expression],
     relation: InMemoryRelation)
   extends LeafNode {
 
+  @transient override val sqlContext = relation.child.sqlContext
+
   override def output: Seq[Attribute] = attributes
 
+  // Returned filter predicate should return false iff it is impossible for 
the input expression
+  // to evaluate to `true' based on statistics collected about this partition 
batch.
+  val buildFilter: PartialFunction[Expression, Expression] = {
+    case And(lhs: Expression, rhs: Expression)
+      if buildFilter.isDefinedAt(lhs) && buildFilter.isDefinedAt(rhs) =>
+      buildFilter(lhs) && buildFilter(rhs)
+
+    case Or(lhs: Expression, rhs: Expression)
+      if buildFilter.isDefinedAt(lhs) && buildFilter.isDefinedAt(rhs) =>
+      buildFilter(lhs) || buildFilter(rhs)
+
+    case EqualTo(a: AttributeReference, l: Literal) =>
+      val aStats = relation.partitionStatistics.forAttribute(a)
+      aStats.lowerBound <= l && l <= aStats.upperBound
+
+    case EqualTo(l: Literal, a: AttributeReference) =>
+      val aStats = relation.partitionStatistics.forAttribute(a)
+      aStats.lowerBound <= l && l <= aStats.upperBound
+
+    case LessThan(a: AttributeReference, l: Literal) =>
+      val aStats = relation.partitionStatistics.forAttribute(a)
+      aStats.lowerBound < l
+
+    case LessThan(l: Literal, a: AttributeReference) =>
+      val aStats = relation.partitionStatistics.forAttribute(a)
+      l < aStats.upperBound
+
+    case LessThanOrEqual(a: AttributeReference, l: Literal) =>
+      val aStats = relation.partitionStatistics.forAttribute(a)
+      aStats.lowerBound <= l
+
+    case LessThanOrEqual(l: Literal, a: AttributeReference) =>
+      val aStats = relation.partitionStatistics.forAttribute(a)
+      l <= aStats.upperBound
+
+    case GreaterThan(a: AttributeReference, l: Literal) =>
+      val aStats = relation.partitionStatistics.forAttribute(a)
+      l < aStats.upperBound
+
+    case GreaterThan(l: Literal, a: AttributeReference) =>
+      val aStats = relation.partitionStatistics.forAttribute(a)
+      aStats.lowerBound < l
+
+    case GreaterThanOrEqual(a: AttributeReference, l: Literal) =>
+      val aStats = relation.partitionStatistics.forAttribute(a)
+      l <= aStats.upperBound
+
+    case GreaterThanOrEqual(l: Literal, a: AttributeReference) =>
+      val aStats = relation.partitionStatistics.forAttribute(a)
+      aStats.lowerBound <= l
+  }
+
+  val partitionFilters = {
+    predicates.flatMap { p =>
+      val filter = buildFilter.lift(p)
+      val boundFilter =
+        filter.map(
+          BindReferences.bindReference(
+            _,
+            relation.partitionStatistics.schema,
+            allowFailures = true))
+
+      boundFilter.foreach(_ =>
+        filter.foreach(f => logInfo(s"Predicate $p generates partition filter: 
$f")))
+
+      // If the filter can't be resolved then we are missing required 
statistics.
+      boundFilter.filter(_.resolved)
+    }
+  }
+
+  val readPartitions = sparkContext.accumulator(0)
+  val readBatches = sparkContext.accumulator(0)
+
+  private val inMemoryPartitionPruningEnabled = 
sqlContext.inMemoryPartitionPruning
+
   override def execute() = {
+    readPartitions.setValue(0)
+    readBatches.setValue(0)
+
     relation.cachedColumnBuffers.mapPartitions { iterator =>
+      val partitionFilter = newPredicate(
+        partitionFilters.reduceOption(And).getOrElse(Literal(true)),
+        relation.partitionStatistics.schema)
+
       // Find the ordinals of the requested columns.  If none are requested, 
use the first.
       val requestedColumns = if (attributes.isEmpty) {
         Seq(0)
@@ -110,8 +203,26 @@ private[sql] case class InMemoryColumnarTableScan(
         attributes.map(a => relation.output.indexWhere(_.exprId == a.exprId))
       }
 
-      iterator
-        .map(batch => requestedColumns.map(batch(_)).map(ColumnAccessor(_)))
+      val rows = iterator
+        // Skip pruned batches
+        .filter { cachedBatch =>
+          if (inMemoryPartitionPruningEnabled && 
!partitionFilter(cachedBatch.stats)) {
+            def statsString = relation.partitionStatistics.schema
+              .zip(cachedBatch.stats)
+              .map { case (a, s) => s"${a.name}: $s" }
+              .mkString(", ")
+            logInfo(s"Skipping partition based on stats $statsString")
+            false
+          } else {
+            readBatches += 1
+            true
+          }
+        }
+        // Build column accessors
+        .map { cachedBatch =>
+          requestedColumns.map(cachedBatch.buffers(_)).map(ColumnAccessor(_))
+        }
+        // Extract rows via column accessors
         .flatMap { columnAccessors =>
           val nextRow = new GenericMutableRow(columnAccessors.length)
           new Iterator[Row] {
@@ -127,6 +238,12 @@ private[sql] case class InMemoryColumnarTableScan(
             override def hasNext = columnAccessors.head.hasNext
           }
         }
+
+      if (rows.hasNext) {
+        readPartitions += 1
+      }
+
+      rows
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/248067ad/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala
index f631ee7..a72970e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala
@@ -49,6 +49,7 @@ private[sql] trait NullableColumnBuilder extends 
ColumnBuilder {
   }
 
   abstract override def appendFrom(row: Row, ordinal: Int) {
+    columnStats.gatherStats(row, ordinal)
     if (row.isNullAt(ordinal)) {
       nulls = ColumnBuilder.ensureFreeSpace(nulls, 4)
       nulls.putInt(pos)

http://git-wip-us.apache.org/repos/asf/spark/blob/248067ad/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 8dacb84..7943d6e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -243,8 +243,8 @@ private[sql] abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
         pruneFilterProject(
           projectList,
           filters,
-          identity[Seq[Expression]], // No filters are pushed down.
-          InMemoryColumnarTableScan(_, mem)) :: Nil
+          identity[Seq[Expression]], // All filters still need to be evaluated.
+          InMemoryColumnarTableScan(_,  filters, mem)) :: Nil
       case _ => Nil
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/248067ad/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
index 5f61fb5..cde91ce 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
@@ -19,29 +19,30 @@ package org.apache.spark.sql.columnar
 
 import org.scalatest.FunSuite
 
+import org.apache.spark.sql.catalyst.expressions.Row
 import org.apache.spark.sql.catalyst.types._
 
 class ColumnStatsSuite extends FunSuite {
-  testColumnStats(classOf[BooleanColumnStats],   BOOLEAN)
-  testColumnStats(classOf[ByteColumnStats],      BYTE)
-  testColumnStats(classOf[ShortColumnStats],     SHORT)
-  testColumnStats(classOf[IntColumnStats],       INT)
-  testColumnStats(classOf[LongColumnStats],      LONG)
-  testColumnStats(classOf[FloatColumnStats],     FLOAT)
-  testColumnStats(classOf[DoubleColumnStats],    DOUBLE)
-  testColumnStats(classOf[StringColumnStats],    STRING)
-  testColumnStats(classOf[TimestampColumnStats], TIMESTAMP)
-
-  def testColumnStats[T <: NativeType, U <: NativeColumnStats[T]](
+  testColumnStats(classOf[ByteColumnStats], BYTE, Row(Byte.MaxValue, 
Byte.MinValue, 0))
+  testColumnStats(classOf[ShortColumnStats], SHORT, Row(Short.MaxValue, 
Short.MinValue, 0))
+  testColumnStats(classOf[IntColumnStats], INT, Row(Int.MaxValue, 
Int.MinValue, 0))
+  testColumnStats(classOf[LongColumnStats], LONG, Row(Long.MaxValue, 
Long.MinValue, 0))
+  testColumnStats(classOf[FloatColumnStats], FLOAT, Row(Float.MaxValue, 
Float.MinValue, 0))
+  testColumnStats(classOf[DoubleColumnStats], DOUBLE, Row(Double.MaxValue, 
Double.MinValue, 0))
+  testColumnStats(classOf[StringColumnStats], STRING, Row(null, null, 0))
+  testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, Row(null, null, 0))
+
+  def testColumnStats[T <: NativeType, U <: ColumnStats](
       columnStatsClass: Class[U],
-      columnType: NativeColumnType[T]) {
+      columnType: NativeColumnType[T],
+      initialStatistics: Row) {
 
     val columnStatsName = columnStatsClass.getSimpleName
 
     test(s"$columnStatsName: empty") {
       val columnStats = columnStatsClass.newInstance()
-      assertResult(columnStats.initialBounds, "Wrong initial bounds") {
-        (columnStats.lowerBound, columnStats.upperBound)
+      columnStats.collectedStatistics.zip(initialStatistics).foreach { case 
(actual, expected) =>
+        assert(actual === expected)
       }
     }
 
@@ -49,14 +50,16 @@ class ColumnStatsSuite extends FunSuite {
       import ColumnarTestUtils._
 
       val columnStats = columnStatsClass.newInstance()
-      val rows = Seq.fill(10)(makeRandomRow(columnType))
+      val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ 
Seq.fill(10)(makeNullRow(1))
       rows.foreach(columnStats.gatherStats(_, 0))
 
-      val values = rows.map(_.head.asInstanceOf[T#JvmType])
+      val values = rows.take(10).map(_.head.asInstanceOf[T#JvmType])
       val ordering = 
columnType.dataType.ordering.asInstanceOf[Ordering[T#JvmType]]
+      val stats = columnStats.collectedStatistics
 
-      assertResult(values.min(ordering), "Wrong lower 
bound")(columnStats.lowerBound)
-      assertResult(values.max(ordering), "Wrong upper 
bound")(columnStats.upperBound)
+      assertResult(values.min(ordering), "Wrong lower bound")(stats(0))
+      assertResult(values.max(ordering), "Wrong upper bound")(stats(1))
+      assertResult(10, "Wrong null count")(stats(2))
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/248067ad/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
index dc813fe..a772625 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.types._
 import org.apache.spark.sql.execution.SparkSqlSerializer
 
 class TestNullableColumnBuilder[T <: DataType, JvmType](columnType: 
ColumnType[T, JvmType])
-  extends BasicColumnBuilder[T, JvmType](new NoopColumnStats[T, JvmType], 
columnType)
+  extends BasicColumnBuilder[T, JvmType](new NoopColumnStats, columnType)
   with NullableColumnBuilder
 
 object TestNullableColumnBuilder {

http://git-wip-us.apache.org/repos/asf/spark/blob/248067ad/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
new file mode 100644
index 0000000..5d2fd49
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.columnar
+
+import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.test.TestSQLContext._
+
+case class IntegerData(i: Int)
+
+class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with 
BeforeAndAfter {
+  val originalColumnBatchSize = columnBatchSize
+  val originalInMemoryPartitionPruning = inMemoryPartitionPruning
+
+  override protected def beforeAll() {
+    // Make a table with 5 partitions, 2 batches per partition, 10 elements 
per batch
+    setConf(SQLConf.COLUMN_BATCH_SIZE, "10")
+    val rawData = sparkContext.makeRDD(1 to 100, 5).map(IntegerData)
+    rawData.registerTempTable("intData")
+
+    // Enable in-memory partition pruning
+    setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true")
+  }
+
+  override protected def afterAll() {
+    setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString)
+    setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, 
originalInMemoryPartitionPruning.toString)
+  }
+
+  before {
+    cacheTable("intData")
+  }
+
+  after {
+    uncacheTable("intData")
+  }
+
+  // Comparisons
+  checkBatchPruning("i = 1", Seq(1), 1, 1)
+  checkBatchPruning("1 = i", Seq(1), 1, 1)
+  checkBatchPruning("i < 12", 1 to 11, 1, 2)
+  checkBatchPruning("i <= 11", 1 to 11, 1, 2)
+  checkBatchPruning("i > 88", 89 to 100, 1, 2)
+  checkBatchPruning("i >= 89", 89 to 100, 1, 2)
+  checkBatchPruning("12 > i", 1 to 11, 1, 2)
+  checkBatchPruning("11 >= i", 1 to 11, 1, 2)
+  checkBatchPruning("88 < i", 89 to 100, 1, 2)
+  checkBatchPruning("89 <= i", 89 to 100, 1, 2)
+
+  // Conjunction and disjunction
+  checkBatchPruning("i > 8 AND i <= 21", 9 to 21, 2, 3)
+  checkBatchPruning("i < 2 OR i > 99", Seq(1, 100), 2, 2)
+  checkBatchPruning("i < 2 OR (i > 78 AND i < 92)", Seq(1) ++ (79 to 91), 3, 4)
+
+  // With unsupported predicate
+  checkBatchPruning("i < 12 AND i IS NOT NULL", 1 to 11, 1, 2)
+  checkBatchPruning("NOT (i < 88)", 88 to 100, 5, 10)
+
+  def checkBatchPruning(
+      filter: String,
+      expectedQueryResult: Seq[Int],
+      expectedReadPartitions: Int,
+      expectedReadBatches: Int) {
+
+    test(filter) {
+      val query = sql(s"SELECT * FROM intData WHERE $filter")
+      assertResult(expectedQueryResult.toArray, "Wrong query result") {
+        query.collect().map(_.head).toArray
+      }
+
+      val (readPartitions, readBatches) = 
query.queryExecution.executedPlan.collect {
+        case in: InMemoryColumnarTableScan => (in.readPartitions.value, 
in.readBatches.value)
+      }.head
+
+      assert(readBatches === expectedReadBatches, "Wrong number of read 
batches")
+      assert(readPartitions === expectedReadPartitions, "Wrong number of read 
partitions")
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/248067ad/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala
index 5fba004..e01cc8b 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.columnar.compression
 import org.scalatest.FunSuite
 
 import org.apache.spark.sql.Row
-import org.apache.spark.sql.columnar.{BOOLEAN, BooleanColumnStats}
+import org.apache.spark.sql.columnar.{NoopColumnStats, BOOLEAN}
 import org.apache.spark.sql.columnar.ColumnarTestUtils._
 
 class BooleanBitSetSuite extends FunSuite {
@@ -31,7 +31,7 @@ class BooleanBitSetSuite extends FunSuite {
     // Tests encoder
     // -------------
 
-    val builder = TestCompressibleColumnBuilder(new BooleanColumnStats, 
BOOLEAN, BooleanBitSet)
+    val builder = TestCompressibleColumnBuilder(new NoopColumnStats, BOOLEAN, 
BooleanBitSet)
     val rows = Seq.fill[Row](count)(makeRandomRow(BOOLEAN))
     val values = rows.map(_.head)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/248067ad/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala
index d8ae2a2..d2969d9 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala
@@ -31,7 +31,7 @@ class DictionaryEncodingSuite extends FunSuite {
   testDictionaryEncoding(new StringColumnStats, STRING)
 
   def testDictionaryEncoding[T <: NativeType](
-      columnStats: NativeColumnStats[T],
+      columnStats: ColumnStats,
       columnType: NativeColumnType[T]) {
 
     val typeName = columnType.getClass.getSimpleName.stripSuffix("$")

http://git-wip-us.apache.org/repos/asf/spark/blob/248067ad/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala
index 17619dc..322f447 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala
@@ -29,7 +29,7 @@ class IntegralDeltaSuite extends FunSuite {
   testIntegralDelta(new LongColumnStats, LONG, LongDelta)
 
   def testIntegralDelta[I <: IntegralType](
-      columnStats: NativeColumnStats[I],
+      columnStats: ColumnStats,
       columnType: NativeColumnType[I],
       scheme: IntegralDelta[I]) {
 

http://git-wip-us.apache.org/repos/asf/spark/blob/248067ad/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala
index 40115be..218c09a 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.columnar._
 import org.apache.spark.sql.columnar.ColumnarTestUtils._
 
 class RunLengthEncodingSuite extends FunSuite {
-  testRunLengthEncoding(new BooleanColumnStats, BOOLEAN)
+  testRunLengthEncoding(new NoopColumnStats, BOOLEAN)
   testRunLengthEncoding(new ByteColumnStats,    BYTE)
   testRunLengthEncoding(new ShortColumnStats,   SHORT)
   testRunLengthEncoding(new IntColumnStats,     INT)
@@ -32,7 +32,7 @@ class RunLengthEncodingSuite extends FunSuite {
   testRunLengthEncoding(new StringColumnStats,  STRING)
 
   def testRunLengthEncoding[T <: NativeType](
-      columnStats: NativeColumnStats[T],
+      columnStats: ColumnStats,
       columnType: NativeColumnType[T]) {
 
     val typeName = columnType.getClass.getSimpleName.stripSuffix("$")

http://git-wip-us.apache.org/repos/asf/spark/blob/248067ad/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala
index 72c19fa..7db723d 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.types.NativeType
 import org.apache.spark.sql.columnar._
 
 class TestCompressibleColumnBuilder[T <: NativeType](
-    override val columnStats: NativeColumnStats[T],
+    override val columnStats: ColumnStats,
     override val columnType: NativeColumnType[T],
     override val schemes: Seq[CompressionScheme])
   extends NativeColumnBuilder(columnStats, columnType)
@@ -33,7 +33,7 @@ class TestCompressibleColumnBuilder[T <: NativeType](
 
 object TestCompressibleColumnBuilder {
   def apply[T <: NativeType](
-      columnStats: NativeColumnStats[T],
+      columnStats: ColumnStats,
       columnType: NativeColumnType[T],
       scheme: CompressionScheme) = {
 

http://git-wip-us.apache.org/repos/asf/spark/blob/248067ad/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
 
b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index b589994..ab487d6 100644
--- 
a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ 
b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -35,26 +35,29 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with 
BeforeAndAfter {
 
   private val originalTimeZone = TimeZone.getDefault
   private val originalLocale = Locale.getDefault
-  private val originalUseCompression = TestHive.useCompression
+  private val originalColumnBatchSize = TestHive.columnBatchSize
+  private val originalInMemoryPartitionPruning = 
TestHive.inMemoryPartitionPruning
 
   def testCases = hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") 
-> f)
 
   override def beforeAll() {
-    // Enable in-memory columnar caching
     TestHive.cacheTables = true
     // Timezone is fixed to America/Los_Angeles for those timezone sensitive 
tests (timestamp_*)
     TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles"))
     // Add Locale setting
     Locale.setDefault(Locale.US)
-    // Enable in-memory columnar compression
-    TestHive.setConf(SQLConf.COMPRESS_CACHED, "true")
+    // Set a relatively small column batch size for testing purposes
+    TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, "5")
+    // Enable in-memory partition pruning for testing purposes
+    TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true")
   }
 
   override def afterAll() {
     TestHive.cacheTables = false
     TimeZone.setDefault(originalTimeZone)
     Locale.setDefault(originalLocale)
-    TestHive.setConf(SQLConf.COMPRESS_CACHED, originalUseCompression.toString)
+    TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, 
originalColumnBatchSize.toString)
+    TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, 
originalInMemoryPartitionPruning.toString)
   }
 
   /** A list of tests deemed out of scope currently and thus completely 
disregarded. */


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to