Repository: spark
Updated Branches:
  refs/heads/master 574141a29 -> c048929c6


[SPARK-10978][SQL][FOLLOW-UP] More comprehensive tests for PR #9399

This PR adds test cases that test various column pruning and filter push-down 
cases.

Author: Cheng Lian <l...@databricks.com>

Closes #9468 from liancheng/spark-10978.follow-up.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c048929c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c048929c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c048929c

Branch: refs/heads/master
Commit: c048929c6a9f7ce57f384037cd6c0bf5751c447a
Parents: 574141a
Author: Cheng Lian <l...@databricks.com>
Authored: Fri Nov 6 11:11:36 2015 -0800
Committer: Yin Huai <yh...@databricks.com>
Committed: Fri Nov 6 11:11:36 2015 -0800

----------------------------------------------------------------------
 .../spark/sql/sources/FilteredScanSuite.scala   |  21 +-
 .../SimpleTextHadoopFsRelationSuite.scala       | 335 +++++++++++++++++--
 .../spark/sql/sources/SimpleTextRelation.scala  |  11 +
 3 files changed, 321 insertions(+), 46 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c048929c/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
index 7541e72..2cad964 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
@@ -17,16 +17,15 @@
 
 package org.apache.spark.sql.sources
 
-import org.apache.spark.sql.execution.datasources.LogicalRelation
-
 import scala.language.existentials
 
 import org.apache.spark.rdd.RDD
-import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.expressions.PredicateHelper
+import org.apache.spark.sql.execution.datasources.LogicalRelation
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
-
+import org.apache.spark.unsafe.types.UTF8String
 
 class FilteredScanSource extends RelationProvider {
   override def createRelation(
@@ -130,7 +129,7 @@ object ColumnsRequired {
   var set: Set[String] = Set.empty
 }
 
-class FilteredScanSuite extends DataSourceTest with SharedSQLContext {
+class FilteredScanSuite extends DataSourceTest with SharedSQLContext with 
PredicateHelper {
   protected override lazy val sql = caseInsensitiveContext.sql _
 
   override def beforeAll(): Unit = {
@@ -144,9 +143,6 @@ class FilteredScanSuite extends DataSourceTest with 
SharedSQLContext {
         |  to '10'
         |)
       """.stripMargin)
-
-    // UDF for testing filter push-down
-    caseInsensitiveContext.udf.register("udf_gt3", (_: Int) > 3)
   }
 
   sqlTest(
@@ -276,14 +272,15 @@ class FilteredScanSuite extends DataSourceTest with 
SharedSQLContext {
   testPushDown("SELECT c FROM oneToTenFiltered WHERE c = 'aaaaaAAAAA'", 1, 
Set("c"))
   testPushDown("SELECT c FROM oneToTenFiltered WHERE c IN ('aaaaaAAAAA', 
'foo')", 1, Set("c"))
 
-  // Columns only referenced by UDF filter must be required, as UDF filters 
can't be pushed down.
-  testPushDown("SELECT c FROM oneToTenFiltered WHERE udf_gt3(A)", 10, Set("a", 
"c"))
+  // Filters referencing multiple columns are not convertible, all referenced 
columns must be
+  // required.
+  testPushDown("SELECT c FROM oneToTenFiltered WHERE A + b > 9", 10, Set("a", 
"b", "c"))
 
-  // A query with an unconvertible filter, an unhandled filter, and a handled 
filter.
+  // A query with an inconvertible filter, an unhandled filter, and a handled 
filter.
   testPushDown(
     """SELECT a
       |  FROM oneToTenFiltered
-      | WHERE udf_gt3(b)
+      | WHERE a + b > 9
       |   AND b < 16
       |   AND c IN ('bbbbbBBBBB', 'cccccCCCCC', 'dddddDDDDD', 'foo')
     """.stripMargin.split("\n").map(_.trim).mkString(" "), 3, Set("a", "b"))

http://git-wip-us.apache.org/repos/asf/spark/blob/c048929c/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala
index d945408..9251a69 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala
@@ -17,15 +17,21 @@
 
 package org.apache.spark.sql.sources
 
+import java.io.File
+
 import org.apache.hadoop.fs.Path
+
+import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.deploy.SparkHadoopUtil
-import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.CatalystTypeConverters
-import org.apache.spark.sql.execution.PhysicalRDD
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, 
PredicateHelper}
+import org.apache.spark.sql.execution.{LogicalRDD, PhysicalRDD}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types._
+import org.apache.spark.sql.{Column, DataFrame, Row, execution}
+import org.apache.spark.util.Utils
 
-class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest {
+class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest with 
PredicateHelper {
   import testImplicits._
 
   override val dataSourceName: String = 
classOf[SimpleTextSource].getCanonicalName
@@ -70,43 +76,304 @@ class SimpleTextHadoopFsRelationSuite extends 
HadoopFsRelationTest {
     }
   }
 
-  private val writer = testDF.write.option("dataSchema", 
dataSchema.json).format(dataSourceName)
-  private val reader = sqlContext.read.option("dataSchema", 
dataSchema.json).format(dataSourceName)
-
-  test("unhandledFilters") {
-    withTempPath { dir =>
-
-      val path = dir.getCanonicalPath
-      writer.save(s"$path/p=0")
-      writer.save(s"$path/p=1")
-
-      val isOdd = udf((_: Int) % 2 == 1)
-      val df = reader.load(path)
-        .filter(
-          // This filter is inconvertible
-          isOdd('a) &&
-            // This filter is convertible but unhandled
-            'a > 1 &&
-            // This filter is convertible and handled
-            'b > "val_1" &&
-            // This filter references a partiiton column, won't be pushed down
-            'p === 1
-        ).select('a, 'p)
-      val rawScan = df.queryExecution.executedPlan collect {
+  private var tempPath: File = _
+
+  private var partitionedDF: DataFrame = _
+
+  private val partitionedDataSchema: StructType = StructType('a.int :: 'b.int 
:: 'c.string :: Nil)
+
+  protected override def beforeAll(): Unit = {
+    this.tempPath = Utils.createTempDir()
+
+    val df = sqlContext.range(10).select(
+      'id cast IntegerType as 'a,
+      ('id cast IntegerType) * 2 as 'b,
+      concat(lit("val_"), 'id) as 'c
+    )
+
+    partitionedWriter(df).save(s"${tempPath.getCanonicalPath}/p=0")
+    partitionedWriter(df).save(s"${tempPath.getCanonicalPath}/p=1")
+
+    partitionedDF = partitionedReader.load(tempPath.getCanonicalPath)
+  }
+
+  override protected def afterAll(): Unit = {
+    Utils.deleteRecursively(tempPath)
+  }
+
+  private def partitionedWriter(df: DataFrame) =
+    df.write.option("dataSchema", 
partitionedDataSchema.json).format(dataSourceName)
+
+  private def partitionedReader =
+    sqlContext.read.option("dataSchema", 
partitionedDataSchema.json).format(dataSourceName)
+
+  /**
+   * Constructs test cases that test column pruning and filter push-down.
+   *
+   * For filter push-down, the following filters are not pushed-down.
+   *
+   * 1. Partitioning filters don't participate filter push-down, they are 
handled separately in
+   *    `DataSourceStrategy`
+   *
+   * 2. Catalyst filter `Expression`s that cannot be converted to data source 
`Filter`s are not
+   *    pushed down (e.g. UDF and filters referencing multiple columns).
+   *
+   * 3. Catalyst filter `Expression`s that can be converted to data source 
`Filter`s but cannot be
+   *    handled by the underlying data source are not pushed down (e.g. 
returned from
+   *    `BaseRelation.unhandledFilters()`).
+   *
+   *    Note that for [[SimpleTextRelation]], all data source [[Filter]]s 
other than [[GreaterThan]]
+   *    are unhandled.  We made this assumption in 
[[SimpleTextRelation.unhandledFilters()]] only
+   *    for testing purposes.
+   *
+   * @param projections Projection list of the query
+   * @param filter Filter condition of the query
+   * @param requiredColumns Expected names of required columns
+   * @param pushedFilters Expected data source [[Filter]]s that are pushed down
+   * @param inconvertibleFilters Expected Catalyst filter [[Expression]]s that 
cannot be converted
+   *        to data source [[Filter]]s
+   * @param unhandledFilters Expected Catalyst flter [[Expression]]s that can 
be converted to data
+   *        source [[Filter]]s but cannot be handled by the data source 
relation
+   * @param partitioningFilters Expected Catalyst filter [[Expression]]s that 
reference partition
+   *        columns
+   * @param expectedRawScanAnswer Expected query result of the raw table scan 
returned by the data
+   *        source relation
+   * @param expectedAnswer Expected query result of the full query
+   */
+  def testPruningAndFiltering(
+      projections: Seq[Column],
+      filter: Column,
+      requiredColumns: Seq[String],
+      pushedFilters: Seq[Filter],
+      inconvertibleFilters: Seq[Column],
+      unhandledFilters: Seq[Column],
+      partitioningFilters: Seq[Column])(
+      expectedRawScanAnswer: => Seq[Row])(
+      expectedAnswer: => Seq[Row]): Unit = {
+    test(s"pruning and filtering: df.select(${projections.mkString(", 
")}).where($filter)") {
+      val df = partitionedDF.where(filter).select(projections: _*)
+      val queryExecution = df.queryExecution
+      val executedPlan = queryExecution.executedPlan
+
+      val rawScan = executedPlan.collect {
         case p: PhysicalRDD => p
       } match {
-        case Seq(p) => p
+        case Seq(scan) => scan
+        case _ => fail(s"More than one PhysicalRDD found\n$queryExecution")
       }
 
-      val outputSchema = new StructType().add("a", IntegerType).add("p", 
IntegerType)
+      markup("Checking raw scan answer")
+      checkAnswer(
+        DataFrame(sqlContext, LogicalRDD(rawScan.output, 
rawScan.rdd)(sqlContext)),
+        expectedRawScanAnswer)
 
-      assertResult(Set((2, 1), (3, 1))) {
-        rawScan.execute().collect()
-          .map { CatalystTypeConverters.convertToScala(_, outputSchema) }
-          .map { case Row(a, p) => (a, p) }.toSet
+      markup("Checking full query answer")
+      checkAnswer(df, expectedAnswer)
+
+      markup("Checking required columns")
+      assert(requiredColumns === SimpleTextRelation.requiredColumns)
+
+      val nonPushedFilters = {
+        val boundFilters = executedPlan.collect {
+          case f: execution.Filter => f
+        } match {
+          case Nil => Nil
+          case Seq(f) => splitConjunctivePredicates(f.condition)
+          case _ => fail(s"More than one PhysicalRDD found\n$queryExecution")
+        }
+
+        // Unbound these bound filters so that we can easily compare them with 
expected results.
+        boundFilters.map {
+          _.transform { case a: AttributeReference => 
UnresolvedAttribute(a.name) }
+        }.toSet
       }
 
-      checkAnswer(df, Row(3, 1))
+      markup("Checking pushed filters")
+      assert(SimpleTextRelation.pushedFilters === pushedFilters.toSet)
+
+      val expectedInconvertibleFilters = inconvertibleFilters.map(_.expr).toSet
+      val expectedUnhandledFilters = unhandledFilters.map(_.expr).toSet
+      val expectedPartitioningFilters = partitioningFilters.map(_.expr).toSet
+
+      markup("Checking unhandled and inconvertible filters")
+      assert(expectedInconvertibleFilters ++ expectedUnhandledFilters === 
nonPushedFilters)
+
+      markup("Checking partitioning filters")
+      val actualPartitioningFilters = 
splitConjunctivePredicates(filter.expr).filter {
+        _.references.contains(UnresolvedAttribute("p"))
+      }.toSet
+
+      // Partitioning filters are handled separately and don't participate 
filter push-down. So they
+      // shouldn't be part of non-pushed filters.
+      assert(expectedPartitioningFilters.intersect(nonPushedFilters).isEmpty)
+      assert(expectedPartitioningFilters === actualPartitioningFilters)
     }
   }
+
+  testPruningAndFiltering(
+    projections = Seq('*),
+    filter = 'p > 0,
+    requiredColumns = Seq("a", "b", "c"),
+    pushedFilters = Nil,
+    inconvertibleFilters = Nil,
+    unhandledFilters = Nil,
+    partitioningFilters = Seq('p > 0)
+  ) {
+    Seq(
+      Row(0, 0, "val_0", 1),
+      Row(1, 2, "val_1", 1),
+      Row(2, 4, "val_2", 1),
+      Row(3, 6, "val_3", 1),
+      Row(4, 8, "val_4", 1),
+      Row(5, 10, "val_5", 1),
+      Row(6, 12, "val_6", 1),
+      Row(7, 14, "val_7", 1),
+      Row(8, 16, "val_8", 1),
+      Row(9, 18, "val_9", 1))
+  } {
+    Seq(
+      Row(0, 0, "val_0", 1),
+      Row(1, 2, "val_1", 1),
+      Row(2, 4, "val_2", 1),
+      Row(3, 6, "val_3", 1),
+      Row(4, 8, "val_4", 1),
+      Row(5, 10, "val_5", 1),
+      Row(6, 12, "val_6", 1),
+      Row(7, 14, "val_7", 1),
+      Row(8, 16, "val_8", 1),
+      Row(9, 18, "val_9", 1))
+  }
+
+  testPruningAndFiltering(
+    projections = Seq('c, 'p),
+    filter = 'a < 3 && 'p > 0,
+    requiredColumns = Seq("c", "a"),
+    pushedFilters = Nil,
+    inconvertibleFilters = Nil,
+    unhandledFilters = Seq('a < 3),
+    partitioningFilters = Seq('p > 0)
+  ) {
+    Seq(
+      Row("val_0", 1, 0),
+      Row("val_1", 1, 1),
+      Row("val_2", 1, 2),
+      Row("val_3", 1, 3),
+      Row("val_4", 1, 4),
+      Row("val_5", 1, 5),
+      Row("val_6", 1, 6),
+      Row("val_7", 1, 7),
+      Row("val_8", 1, 8),
+      Row("val_9", 1, 9))
+  } {
+    Seq(
+      Row("val_0", 1),
+      Row("val_1", 1),
+      Row("val_2", 1))
+  }
+
+  testPruningAndFiltering(
+    projections = Seq('*),
+    filter = 'a > 8,
+    requiredColumns = Seq("a", "b", "c"),
+    pushedFilters = Seq(GreaterThan("a", 8)),
+    inconvertibleFilters = Nil,
+    unhandledFilters = Nil,
+    partitioningFilters = Nil
+  ) {
+    Seq(
+      Row(9, 18, "val_9", 0),
+      Row(9, 18, "val_9", 1))
+  } {
+    Seq(
+      Row(9, 18, "val_9", 0),
+      Row(9, 18, "val_9", 1))
+  }
+
+  testPruningAndFiltering(
+    projections = Seq('b, 'p),
+    filter = 'a > 8,
+    requiredColumns = Seq("b"),
+    pushedFilters = Seq(GreaterThan("a", 8)),
+    inconvertibleFilters = Nil,
+    unhandledFilters = Nil,
+    partitioningFilters = Nil
+  ) {
+    Seq(
+      Row(18, 0),
+      Row(18, 1))
+  } {
+    Seq(
+      Row(18, 0),
+      Row(18, 1))
+  }
+
+  testPruningAndFiltering(
+    projections = Seq('b, 'p),
+    filter = 'a > 8 && 'p > 0,
+    requiredColumns = Seq("b"),
+    pushedFilters = Seq(GreaterThan("a", 8)),
+    inconvertibleFilters = Nil,
+    unhandledFilters = Nil,
+    partitioningFilters = Seq('p > 0)
+  ) {
+    Seq(
+      Row(18, 1))
+  } {
+    Seq(
+      Row(18, 1))
+  }
+
+  testPruningAndFiltering(
+    projections = Seq('b, 'p),
+    filter = 'c > "val_7" && 'b < 18 && 'p > 0,
+    requiredColumns = Seq("b"),
+    pushedFilters = Seq(GreaterThan("c", "val_7")),
+    inconvertibleFilters = Nil,
+    unhandledFilters = Seq('b < 18),
+    partitioningFilters = Seq('p > 0)
+  ) {
+    Seq(
+      Row(16, 1),
+      Row(18, 1))
+  } {
+    Seq(
+      Row(16, 1))
+  }
+
+  testPruningAndFiltering(
+    projections = Seq('b, 'p),
+    filter = 'a % 2 === 0 && 'c > "val_7" && 'b < 18 && 'p > 0,
+    requiredColumns = Seq("b", "a"),
+    pushedFilters = Seq(GreaterThan("c", "val_7")),
+    inconvertibleFilters = Seq('a % 2 === 0),
+    unhandledFilters = Seq('b < 18),
+    partitioningFilters = Seq('p > 0)
+  ) {
+    Seq(
+      Row(16, 1, 8),
+      Row(18, 1, 9))
+  } {
+    Seq(
+      Row(16, 1))
+  }
+
+  testPruningAndFiltering(
+    projections = Seq('b, 'p),
+    filter = 'a > 7 && 'a < 9,
+    requiredColumns = Seq("b", "a"),
+    pushedFilters = Seq(GreaterThan("a", 7)),
+    inconvertibleFilters = Nil,
+    unhandledFilters = Seq('a < 9),
+    partitioningFilters = Nil
+  ) {
+    Seq(
+      Row(16, 0, 8),
+      Row(16, 1, 8),
+      Row(18, 0, 9),
+      Row(18, 1, 9))
+  } {
+    Seq(
+      Row(16, 0),
+      Row(16, 1))
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/c048929c/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala 
b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
index da09e1b..bdc48a3 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
@@ -128,6 +128,9 @@ class SimpleTextRelation(
       filters: Array[Filter],
       inputFiles: Array[FileStatus]): RDD[Row] = {
 
+    SimpleTextRelation.requiredColumns = requiredColumns
+    SimpleTextRelation.pushedFilters = filters.toSet
+
     val fields = this.dataSchema.map(_.dataType)
     val inputAttributes = this.dataSchema.toAttributes
     val outputAttributes = requiredColumns.flatMap(name => 
inputAttributes.find(_.name == name))
@@ -191,6 +194,14 @@ class SimpleTextRelation(
   }
 }
 
+object SimpleTextRelation {
+  // Used to test column pruning
+  var requiredColumns: Seq[String] = Nil
+
+  // Used to test filter push-down
+  var pushedFilters: Set[Filter] = Set.empty
+}
+
 /**
  * A simple example [[HadoopFsRelationProvider]].
  */


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to