Repository: spark
Updated Branches:
  refs/heads/master b2e4b314d -> ebf8b0b48


[SPARK-10978][SQL] Allow data sources to eliminate filters

This PR adds a new method `unhandledFilters` to `BaseRelation`. Data sources 
which implement this method properly may avoid the overhead of defensive 
filtering done by Spark SQL.

Author: Cheng Lian <l...@databricks.com>

Closes #9399 from liancheng/spark-10978.unhandled-filters.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ebf8b0b4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ebf8b0b4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ebf8b0b4

Branch: refs/heads/master
Commit: ebf8b0b48deaad64f7ca27051caee763451e2623
Parents: b2e4b31
Author: Cheng Lian <l...@databricks.com>
Authored: Tue Nov 3 10:07:45 2015 -0800
Committer: Yin Huai <yh...@databricks.com>
Committed: Tue Nov 3 10:07:45 2015 -0800

----------------------------------------------------------------------
 .../datasources/DataSourceStrategy.scala        | 131 +++++++++++++++----
 .../apache/spark/sql/sources/interfaces.scala   |   9 ++
 .../parquet/ParquetFilterSuite.scala            |   2 +-
 .../spark/sql/sources/FilteredScanSuite.scala   | 129 +++++++++++++-----
 .../SimpleTextHadoopFsRelationSuite.scala       |  47 ++++++-
 .../spark/sql/sources/SimpleTextRelation.scala  |  65 ++++++++-
 6 files changed, 315 insertions(+), 68 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ebf8b0b4/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index 6585986..7265d6a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -43,7 +43,8 @@ private[sql] object DataSourceStrategy extends Strategy with 
Logging {
         l,
         projects,
         filters,
-        (a, f) => toCatalystRDD(l, a, t.buildScan(a, f))) :: Nil
+        (requestedColumns, allPredicates, _) =>
+          toCatalystRDD(l, requestedColumns, t.buildScan(requestedColumns, 
allPredicates))) :: Nil
 
     case PhysicalOperation(projects, filters, l @ LogicalRelation(t: 
PrunedFilteredScan, _)) =>
       pruneFilterProject(
@@ -266,47 +267,81 @@ private[sql] object DataSourceStrategy extends Strategy 
with Logging {
       relation,
       projects,
       filterPredicates,
-      (requestedColumns, pushedFilters) => {
-        scanBuilder(requestedColumns, selectFilters(pushedFilters).toArray)
+      (requestedColumns, _, pushedFilters) => {
+        scanBuilder(requestedColumns, pushedFilters.toArray)
       })
   }
 
-  // Based on Catalyst expressions.
+  // Based on Catalyst expressions. The `scanBuilder` function accepts three 
arguments:
+  //
+  //  1. A `Seq[Attribute]`, containing all required column attributes. Used 
to handle relation
+  //     traits that support column pruning (e.g. `PrunedScan` and 
`PrunedFilteredScan`).
+  //
+  //  2. A `Seq[Expression]`, containing all gathered Catalyst filter 
expressions, only used for
+  //     `CatalystScan`.
+  //
+  //  3. A `Seq[Filter]`, containing all data source `Filter`s that are 
converted from (possibly a
+  //     subset of) Catalyst filter expressions and can be handled by 
`relation`.  Used to handle
+  //     relation traits (`CatalystScan` excluded) that support filter 
push-down (e.g.
+  //     `PrunedFilteredScan` and `HadoopFsRelation`).
+  //
+  // Note that 2 and 3 shouldn't be used together.
   protected def pruneFilterProjectRaw(
-      relation: LogicalRelation,
-      projects: Seq[NamedExpression],
-      filterPredicates: Seq[Expression],
-      scanBuilder: (Seq[Attribute], Seq[Expression]) => RDD[InternalRow]) = {
+    relation: LogicalRelation,
+    projects: Seq[NamedExpression],
+    filterPredicates: Seq[Expression],
+    scanBuilder: (Seq[Attribute], Seq[Expression], Seq[Filter]) => 
RDD[InternalRow]) = {
 
     val projectSet = AttributeSet(projects.flatMap(_.references))
     val filterSet = AttributeSet(filterPredicates.flatMap(_.references))
-    val filterCondition = filterPredicates.reduceLeftOption(expressions.And)
 
-    val pushedFilters = filterPredicates.map { _ transform {
+    val candidatePredicates = filterPredicates.map { _ transform {
       case a: AttributeReference => relation.attributeMap(a) // Match original 
case of attributes.
     }}
 
+    val (unhandledPredicates, pushedFilters) =
+      selectFilters(relation.relation, candidatePredicates)
+
+    // A set of column attributes that are only referenced by pushed down 
filters.  We can eliminate
+    // them from requested columns.
+    val handledSet = {
+      val handledPredicates = 
filterPredicates.filterNot(unhandledPredicates.contains)
+      val unhandledSet = 
AttributeSet(unhandledPredicates.flatMap(_.references))
+      AttributeSet(handledPredicates.flatMap(_.references)) --
+        (projectSet ++ unhandledSet).map(relation.attributeMap)
+    }
+
+    // Combines all Catalyst filter `Expression`s that are either not 
convertible to data source
+    // `Filter`s or cannot be handled by `relation`.
+    val filterCondition = unhandledPredicates.reduceLeftOption(expressions.And)
+
     if (projects.map(_.toAttribute) == projects &&
         projectSet.size == projects.size &&
         filterSet.subsetOf(projectSet)) {
       // When it is possible to just use column pruning to get the right 
projection and
       // when the columns of this projection are enough to evaluate all filter 
conditions,
       // just do a scan followed by a filter, with no extra project.
-      val requestedColumns =
-        projects.asInstanceOf[Seq[Attribute]] // Safe due to if above.
-          .map(relation.attributeMap)            // Match original case of 
attributes.
+      val requestedColumns = projects
+        // Safe due to if above.
+        .asInstanceOf[Seq[Attribute]]
+        // Match original case of attributes.
+        .map(relation.attributeMap)
+        // Don't request columns that are only referenced by pushed filters.
+        .filterNot(handledSet.contains)
 
       val scan = execution.PhysicalRDD.createFromDataSource(
         projects.map(_.toAttribute),
-        scanBuilder(requestedColumns, pushedFilters),
+        scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
         relation.relation)
       filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)
     } else {
-      val requestedColumns = (projectSet ++ 
filterSet).map(relation.attributeMap).toSeq
+      // Don't request columns that are only referenced by pushed filters.
+      val requestedColumns =
+        (projectSet ++ filterSet -- 
handledSet).map(relation.attributeMap).toSeq
 
       val scan = execution.PhysicalRDD.createFromDataSource(
         requestedColumns,
-        scanBuilder(requestedColumns, pushedFilters),
+        scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
         relation.relation)
       execution.Project(projects, filterCondition.map(execution.Filter(_, 
scan)).getOrElse(scan))
     }
@@ -334,11 +369,12 @@ private[sql] object DataSourceStrategy extends Strategy 
with Logging {
   }
 
   /**
-   * Selects Catalyst predicate [[Expression]]s which are convertible into 
data source [[Filter]]s,
-   * and convert them.
+   * Tries to translate a Catalyst [[Expression]] into data source [[Filter]].
+   *
+   * @return a `Some[Filter]` if the input [[Expression]] is convertible, 
otherwise a `None`.
    */
-  protected[sql] def selectFilters(filters: Seq[Expression]) = {
-    def translate(predicate: Expression): Option[Filter] = predicate match {
+  protected[sql] def translateFilter(predicate: Expression): Option[Filter] = {
+    predicate match {
       case expressions.EqualTo(a: Attribute, Literal(v, t)) =>
         Some(sources.EqualTo(a.name, convertToScala(v, t)))
       case expressions.EqualTo(Literal(v, t), a: Attribute) =>
@@ -387,16 +423,16 @@ private[sql] object DataSourceStrategy extends Strategy 
with Logging {
         Some(sources.IsNotNull(a.name))
 
       case expressions.And(left, right) =>
-        (translate(left) ++ translate(right)).reduceOption(sources.And)
+        (translateFilter(left) ++ 
translateFilter(right)).reduceOption(sources.And)
 
       case expressions.Or(left, right) =>
         for {
-          leftFilter <- translate(left)
-          rightFilter <- translate(right)
+          leftFilter <- translateFilter(left)
+          rightFilter <- translateFilter(right)
         } yield sources.Or(leftFilter, rightFilter)
 
       case expressions.Not(child) =>
-        translate(child).map(sources.Not)
+        translateFilter(child).map(sources.Not)
 
       case expressions.StartsWith(a: Attribute, Literal(v: UTF8String, 
StringType)) =>
         Some(sources.StringStartsWith(a.name, v.toString))
@@ -409,7 +445,52 @@ private[sql] object DataSourceStrategy extends Strategy 
with Logging {
 
       case _ => None
     }
+  }
+
+  /**
+   * Selects Catalyst predicate [[Expression]]s which are convertible into 
data source [[Filter]]s
+   * and can be handled by `relation`.
+   *
+   * @return A pair of `Seq[Expression]` and `Seq[Filter]`. The first element 
contains all Catalyst
+   *         predicate [[Expression]]s that are either not convertible or 
cannot be handled by
+   *         `relation`. The second element contains all converted data source 
[[Filter]]s that can
+   *        be handled by `relation`.
+   */
+  protected[sql] def selectFilters(
+    relation: BaseRelation,
+    predicates: Seq[Expression]): (Seq[Expression], Seq[Filter]) = {
+
+    // For conciseness, all Catalyst filter expressions of type 
`expressions.Expression` below are
+    // called `predicate`s, while all data source filters of type 
`sources.Filter` are simply called
+    // `filter`s.
+
+    val translated: Seq[(Expression, Filter)] =
+      for {
+        predicate <- predicates
+        filter <- translateFilter(predicate)
+      } yield predicate -> filter
+
+    // A map from original Catalyst expressions to corresponding translated 
data source filters.
+    val translatedMap: Map[Expression, Filter] = translated.toMap
+
+    // Catalyst predicate expressions that cannot be translated to data source 
filters.
+    val unrecognizedPredicates = predicates.filterNot(translatedMap.contains)
+
+    // Data source filters that cannot be handled by `relation`
+    val unhandledFilters = 
relation.unhandledFilters(translatedMap.values.toArray).toSet
+
+    val (unhandled, handled) = translated.partition {
+      case (predicate, filter) =>
+        unhandledFilters.contains(filter)
+    }
+
+    // Catalyst predicate expressions that can be translated to data source 
filters, but cannot be
+    // handled by `relation`.
+    val (unhandledPredicates, _) = unhandled.unzip
+
+    // Translated data source filters that can be handled by `relation`
+    val (_, handledFilters) = handled.unzip
 
-    filters.flatMap(translate)
+    (unrecognizedPredicates ++ unhandledPredicates, handledFilters)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/ebf8b0b4/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
index 7a55351..e296d63 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
@@ -233,6 +233,15 @@ abstract class BaseRelation {
    * @since 1.4.0
    */
   def needConversion: Boolean = true
+
+  /**
+   * Given an array of [[Filter]]s, returns an array of [[Filter]]s that this 
data source relation
+   * cannot handle.  Spark SQL will apply all returned [[Filter]]s against 
rows returned by this
+   * data source relation.
+   *
+   * @since 1.6.0
+   */
+  def unhandledFilters(filters: Array[Filter]): Array[Filter] = filters
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/ebf8b0b4/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
index f88ddc7..c24c9f0 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
@@ -59,7 +59,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest 
with SharedSQLContex
       }.flatten
       assert(analyzedPredicate.nonEmpty)
 
-      val selectedFilters = DataSourceStrategy.selectFilters(analyzedPredicate)
+      val selectedFilters = 
analyzedPredicate.flatMap(DataSourceStrategy.translateFilter)
       assert(selectedFilters.nonEmpty)
 
       selectedFilters.foreach { pred =>

http://git-wip-us.apache.org/repos/asf/spark/blob/ebf8b0b4/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
index 68ce37c..7541e72 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.sources
 
+import org.apache.spark.sql.execution.datasources.LogicalRelation
+
 import scala.language.existentials
 
 import org.apache.spark.rdd.RDD
@@ -44,16 +46,39 @@ case class SimpleFilteredScan(from: Int, to: 
Int)(@transient val sqlContext: SQL
       StructField("b", IntegerType, nullable = false) ::
       StructField("c", StringType, nullable = false) :: Nil)
 
+  override def unhandledFilters(filters: Array[Filter]): Array[Filter] = {
+    def unhandled(filter: Filter): Boolean = {
+      filter match {
+        case EqualTo(col, v) => col == "b"
+        case EqualNullSafe(col, v) => col == "b"
+        case LessThan(col, v: Int) => col == "b"
+        case LessThanOrEqual(col, v: Int) => col == "b"
+        case GreaterThan(col, v: Int) => col == "b"
+        case GreaterThanOrEqual(col, v: Int) => col == "b"
+        case In(col, values) => col == "b"
+        case IsNull(col) => col == "b"
+        case IsNotNull(col) => col == "b"
+        case Not(pred) => unhandled(pred)
+        case And(left, right) => unhandled(left) || unhandled(right)
+        case Or(left, right) => unhandled(left) || unhandled(right)
+        case _ => false
+      }
+    }
+
+    filters.filter(unhandled)
+  }
+
   override def buildScan(requiredColumns: Array[String], filters: 
Array[Filter]): RDD[Row] = {
     val rowBuilders = requiredColumns.map {
       case "a" => (i: Int) => Seq(i)
       case "b" => (i: Int) => Seq(i * 2)
       case "c" => (i: Int) =>
         val c = (i - 1 + 'a').toChar.toString
-        Seq(c * 5 + c.toUpperCase() * 5)
+        Seq(c * 5 + c.toUpperCase * 5)
     }
 
     FiltersPushed.list = filters
+    ColumnsRequired.set = requiredColumns.toSet
 
     // Predicate test on integer column
     def translateFilterOnA(filter: Filter): Int => Boolean = filter match {
@@ -86,9 +111,8 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient 
val sqlContext: SQL
     }
 
     def eval(a: Int) = {
-      val c = (a - 1 + 'a').toChar.toString * 5 + (a - 1 + 
'a').toChar.toString.toUpperCase() * 5
-      !filters.map(translateFilterOnA(_)(a)).contains(false) &&
-        !filters.map(translateFilterOnC(_)(c)).contains(false)
+      val c = (a - 1 + 'a').toChar.toString * 5 + (a - 1 + 
'a').toChar.toString.toUpperCase * 5
+      filters.forall(translateFilterOnA(_)(a)) && 
filters.forall(translateFilterOnC(_)(c))
     }
 
     sqlContext.sparkContext.parallelize(from to to).filter(eval).map(i =>
@@ -101,6 +125,11 @@ object FiltersPushed {
   var list: Seq[Filter] = Nil
 }
 
+// Used together with `SimpleFilteredScan` to check pushed columns.
+object ColumnsRequired {
+  var set: Set[String] = Set.empty
+}
+
 class FilteredScanSuite extends DataSourceTest with SharedSQLContext {
   protected override lazy val sql = caseInsensitiveContext.sql _
 
@@ -115,12 +144,15 @@ class FilteredScanSuite extends DataSourceTest with 
SharedSQLContext {
         |  to '10'
         |)
       """.stripMargin)
+
+    // UDF for testing filter push-down
+    caseInsensitiveContext.udf.register("udf_gt3", (_: Int) > 3)
   }
 
   sqlTest(
     "SELECT * FROM oneToTenFiltered",
     (1 to 10).map(i => Row(i, i * 2, (i - 1 + 'a').toChar.toString * 5
-      + (i - 1 + 'a').toChar.toString.toUpperCase() * 5)).toSeq)
+      + (i - 1 + 'a').toChar.toString.toUpperCase * 5)).toSeq)
 
   sqlTest(
     "SELECT a, b FROM oneToTenFiltered",
@@ -202,49 +234,64 @@ class FilteredScanSuite extends DataSourceTest with 
SharedSQLContext {
     "SELECT a, b, c FROM oneToTenFiltered WHERE c like '%eE%'",
     Seq(Row(5, 5 * 2, "e" * 5 + "E" * 5)))
 
-  testPushDown("SELECT * FROM oneToTenFiltered WHERE A = 1", 1)
-  testPushDown("SELECT a FROM oneToTenFiltered WHERE A = 1", 1)
-  testPushDown("SELECT b FROM oneToTenFiltered WHERE A = 1", 1)
-  testPushDown("SELECT a, b FROM oneToTenFiltered WHERE A = 1", 1)
-  testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 1", 1)
-  testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 = a", 1)
+  testPushDown("SELECT * FROM oneToTenFiltered WHERE A = 1", 1, Set("a", "b", 
"c"))
+  testPushDown("SELECT a FROM oneToTenFiltered WHERE A = 1", 1, Set("a"))
+  testPushDown("SELECT b FROM oneToTenFiltered WHERE A = 1", 1, Set("b"))
+  testPushDown("SELECT a, b FROM oneToTenFiltered WHERE A = 1", 1, Set("a", 
"b"))
+  testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 1", 1, Set("a", "b", 
"c"))
+  testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 = a", 1, Set("a", "b", 
"c"))
+
+  testPushDown("SELECT * FROM oneToTenFiltered WHERE a > 1", 9, Set("a", "b", 
"c"))
+  testPushDown("SELECT * FROM oneToTenFiltered WHERE a >= 2", 9, Set("a", "b", 
"c"))
 
-  testPushDown("SELECT * FROM oneToTenFiltered WHERE a > 1", 9)
-  testPushDown("SELECT * FROM oneToTenFiltered WHERE a >= 2", 9)
+  testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 < a", 9, Set("a", "b", 
"c"))
+  testPushDown("SELECT * FROM oneToTenFiltered WHERE 2 <= a", 9, Set("a", "b", 
"c"))
 
-  testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 < a", 9)
-  testPushDown("SELECT * FROM oneToTenFiltered WHERE 2 <= a", 9)
+  testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 > a", 0, Set("a", "b", 
"c"))
+  testPushDown("SELECT * FROM oneToTenFiltered WHERE 2 >= a", 2, Set("a", "b", 
"c"))
 
-  testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 > a", 0)
-  testPushDown("SELECT * FROM oneToTenFiltered WHERE 2 >= a", 2)
+  testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 1", 0, Set("a", "b", 
"c"))
+  testPushDown("SELECT * FROM oneToTenFiltered WHERE a <= 2", 2, Set("a", "b", 
"c"))
 
-  testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 1", 0)
-  testPushDown("SELECT * FROM oneToTenFiltered WHERE a <= 2", 2)
+  testPushDown("SELECT * FROM oneToTenFiltered WHERE a > 1 AND a < 10", 8, 
Set("a", "b", "c"))
 
-  testPushDown("SELECT * FROM oneToTenFiltered WHERE a > 1 AND a < 10", 8)
+  testPushDown("SELECT * FROM oneToTenFiltered WHERE a IN (1,3,5)", 3, 
Set("a", "b", "c"))
 
-  testPushDown("SELECT * FROM oneToTenFiltered WHERE a IN (1,3,5)", 3)
+  testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 20", 0, Set("a", "b", 
"c"))
+  testPushDown("SELECT * FROM oneToTenFiltered WHERE b = 1", 10, Set("a", "b", 
"c"))
 
-  testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 20", 0)
-  testPushDown("SELECT * FROM oneToTenFiltered WHERE b = 1", 10)
+  testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 5 AND a > 1", 3, 
Set("a", "b", "c"))
+  testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 3 OR a > 8", 4, 
Set("a", "b", "c"))
+  testPushDown("SELECT * FROM oneToTenFiltered WHERE NOT (a < 6)", 5, Set("a", 
"b", "c"))
 
-  testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 5 AND a > 1", 3)
-  testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 3 OR a > 8", 4)
-  testPushDown("SELECT * FROM oneToTenFiltered WHERE NOT (a < 6)", 5)
+  testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like 'c%'", 1, 
Set("a", "b", "c"))
+  testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like 'C%'", 0, 
Set("a", "b", "c"))
 
-  testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like 'c%'", 1)
-  testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like 'C%'", 0)
+  testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%D'", 1, 
Set("a", "b", "c"))
+  testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%d'", 0, 
Set("a", "b", "c"))
 
-  testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%D'", 1)
-  testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%d'", 0)
+  testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%eE%'", 1, 
Set("a", "b", "c"))
+  testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%Ee%'", 0, 
Set("a", "b", "c"))
 
-  testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%eE%'", 1)
-  testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%Ee%'", 0)
+  testPushDown("SELECT c FROM oneToTenFiltered WHERE c = 'aaaaaAAAAA'", 1, 
Set("c"))
+  testPushDown("SELECT c FROM oneToTenFiltered WHERE c IN ('aaaaaAAAAA', 
'foo')", 1, Set("c"))
 
-  testPushDown("SELECT c FROM oneToTenFiltered WHERE c = 'aaaaaAAAAA'", 1)
-  testPushDown("SELECT c FROM oneToTenFiltered WHERE c IN ('aaaaaAAAAA', 
'foo')", 1)
+  // Columns only referenced by UDF filter must be required, as UDF filters 
can't be pushed down.
+  testPushDown("SELECT c FROM oneToTenFiltered WHERE udf_gt3(A)", 10, Set("a", 
"c"))
 
-  def testPushDown(sqlString: String, expectedCount: Int): Unit = {
+  // A query with an unconvertible filter, an unhandled filter, and a handled 
filter.
+  testPushDown(
+    """SELECT a
+      |  FROM oneToTenFiltered
+      | WHERE udf_gt3(b)
+      |   AND b < 16
+      |   AND c IN ('bbbbbBBBBB', 'cccccCCCCC', 'dddddDDDDD', 'foo')
+    """.stripMargin.split("\n").map(_.trim).mkString(" "), 3, Set("a", "b"))
+
+  def testPushDown(
+      sqlString: String,
+      expectedCount: Int,
+      requiredColumnNames: Set[String]): Unit = {
     test(s"PushDown Returns $expectedCount: $sqlString") {
       val queryExecution = sql(sqlString).queryExecution
       val rawPlan = queryExecution.executedPlan.collect {
@@ -254,6 +301,17 @@ class FilteredScanSuite extends DataSourceTest with 
SharedSQLContext {
         case _ => fail(s"More than one PhysicalRDD found\n$queryExecution")
       }
       val rawCount = rawPlan.execute().count()
+      assert(ColumnsRequired.set === requiredColumnNames)
+
+      assert {
+        val table = caseInsensitiveContext.table("oneToTenFiltered")
+        val relation = table.queryExecution.logical.collectFirst {
+          case LogicalRelation(r, _) => r
+        }.get
+
+        // `relation` should be able to handle all pushed filters
+        relation.unhandledFilters(FiltersPushed.list.toArray).isEmpty
+      }
 
       if (rawCount != expectedCount) {
         fail(
@@ -264,4 +322,3 @@ class FilteredScanSuite extends DataSourceTest with 
SharedSQLContext {
     }
   }
 }
-

http://git-wip-us.apache.org/repos/asf/spark/blob/ebf8b0b4/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala
index a3a1244..d945408 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala
@@ -18,11 +18,16 @@
 package org.apache.spark.sql.sources
 
 import org.apache.hadoop.fs.Path
-
 import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.CatalystTypeConverters
+import org.apache.spark.sql.execution.PhysicalRDD
+import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types._
 
 class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest {
+  import testImplicits._
+
   override val dataSourceName: String = 
classOf[SimpleTextSource].getCanonicalName
 
   // We have a very limited number of supported types at here since it is just 
for a
@@ -64,4 +69,44 @@ class SimpleTextHadoopFsRelationSuite extends 
HadoopFsRelationTest {
           .load(file.getCanonicalPath))
     }
   }
+
+  private val writer = testDF.write.option("dataSchema", 
dataSchema.json).format(dataSourceName)
+  private val reader = sqlContext.read.option("dataSchema", 
dataSchema.json).format(dataSourceName)
+
+  test("unhandledFilters") {
+    withTempPath { dir =>
+
+      val path = dir.getCanonicalPath
+      writer.save(s"$path/p=0")
+      writer.save(s"$path/p=1")
+
+      val isOdd = udf((_: Int) % 2 == 1)
+      val df = reader.load(path)
+        .filter(
+          // This filter is inconvertible
+          isOdd('a) &&
+            // This filter is convertible but unhandled
+            'a > 1 &&
+            // This filter is convertible and handled
+            'b > "val_1" &&
+            // This filter references a partiiton column, won't be pushed down
+            'p === 1
+        ).select('a, 'p)
+      val rawScan = df.queryExecution.executedPlan collect {
+        case p: PhysicalRDD => p
+      } match {
+        case Seq(p) => p
+      }
+
+      val outputSchema = new StructType().add("a", IntegerType).add("p", 
IntegerType)
+
+      assertResult(Set((2, 1), (3, 1))) {
+        rawScan.execute().collect()
+          .map { CatalystTypeConverters.convertToScala(_, outputSchema) }
+          .map { case Row(a, p) => (a, p) }.toSet
+      }
+
+      checkAnswer(df, Row(3, 1))
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/ebf8b0b4/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala 
b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
index aeaaa3e..da09e1b 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
@@ -18,7 +18,6 @@
 package org.apache.spark.sql.sources
 
 import java.text.NumberFormat
-import java.util.UUID
 
 import com.google.common.base.Objects
 import org.apache.hadoop.fs.{FileStatus, Path}
@@ -26,12 +25,12 @@ import org.apache.hadoop.io.{NullWritable, Text}
 import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, 
TextOutputFormat}
 import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext}
 
-import org.apache.spark.rdd.RDD
 import org.apache.spark.deploy.SparkHadoopUtil
-import org.apache.spark.sql.catalyst.CatalystTypeConverters
-import org.apache.spark.sql.catalyst.expressions.{Cast, Literal}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, expressions}
 import org.apache.spark.sql.types.{DataType, StructType}
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.{Row, SQLContext, sources}
 
 /**
  * A simple example [[HadoopFsRelationProvider]].
@@ -124,6 +123,53 @@ class SimpleTextRelation(
     }
   }
 
+  override def buildScan(
+      requiredColumns: Array[String],
+      filters: Array[Filter],
+      inputFiles: Array[FileStatus]): RDD[Row] = {
+
+    val fields = this.dataSchema.map(_.dataType)
+    val inputAttributes = this.dataSchema.toAttributes
+    val outputAttributes = requiredColumns.flatMap(name => 
inputAttributes.find(_.name == name))
+    val dataSchema = this.dataSchema
+
+    val inputPaths = inputFiles.map(_.getPath).mkString(",")
+    sparkContext.textFile(inputPaths).mapPartitions { iterator =>
+      // Constructs a filter predicate to simulate filter push-down
+      val predicate = {
+        val filterCondition: Expression = filters.collect {
+          // According to `unhandledFilters`, `SimpleTextRelation` only 
handles `GreaterThan` filter
+          case sources.GreaterThan(column, value) =>
+            val dataType = dataSchema(column).dataType
+            val literal = Literal.create(value, dataType)
+            val attribute = inputAttributes.find(_.name == column).get
+            expressions.GreaterThan(attribute, literal)
+        }.reduceOption(expressions.And).getOrElse(Literal(true))
+        InterpretedPredicate.create(filterCondition, inputAttributes)
+      }
+
+      // Uses a simple projection to simulate column pruning
+      val projection = new InterpretedMutableProjection(outputAttributes, 
inputAttributes)
+      val toScala = {
+        val requiredSchema = StructType.fromAttributes(outputAttributes)
+        CatalystTypeConverters.createToScalaConverter(requiredSchema)
+      }
+
+      iterator.map { record =>
+        new GenericInternalRow(record.split(",", -1).zip(fields).map {
+          case (v, dataType) =>
+            val value = if (v == "") null else v
+            // `Cast`ed values are always of internal types (e.g. UTF8String 
instead of String)
+            Cast(Literal(value), dataType).eval()
+        })
+      }.filter { row =>
+        predicate(row)
+      }.map { row =>
+        toScala(projection(row)).asInstanceOf[Row]
+      }
+    }
+  }
+
   override def prepareJobForWrite(job: Job): OutputWriterFactory = new 
OutputWriterFactory {
     job.setOutputFormatClass(classOf[TextOutputFormat[_, _]])
 
@@ -134,6 +180,15 @@ class SimpleTextRelation(
       new SimpleTextOutputWriter(path, context)
     }
   }
+
+  // `SimpleTextRelation` only handles `GreaterThan` filter.  This is used to 
test filter push-down
+  // and `BaseRelation.unhandledFilters()`.
+  override def unhandledFilters(filters: Array[Filter]): Array[Filter] = {
+    filters.filter {
+      case _: GreaterThan => false
+      case _ => true
+    }
+  }
 }
 
 /**


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to