Repository: spark
Updated Branches:
  refs/heads/branch-1.5 5dd0c5cd6 -> 8e32db9a5


[SPARK-9182] [SQL] Filters are not passed through to jdbc source

This PR fixes unable to push filter down to JDBC source caused by `Cast` during 
pattern matching.

While we are comparing columns of different type, there's a big chance we need 
a cast on the column, therefore not match the pattern directly on Attribute and 
would fail to push down.

Author: Yijie Shen <henry.yijies...@gmail.com>

Closes #8049 from yjshen/jdbc_pushdown.

(cherry picked from commit 9d0822455ddc8d765440d58c463367a4d67ef456)
Signed-off-by: Cheng Lian <l...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8e32db9a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8e32db9a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8e32db9a

Branch: refs/heads/branch-1.5
Commit: 8e32db9a5b082f45d161ba1fc88732a6ba166ac1
Parents: 5dd0c5c
Author: Yijie Shen <henry.yijies...@gmail.com>
Authored: Wed Aug 12 19:54:00 2015 +0800
Committer: Cheng Lian <l...@databricks.com>
Committed: Wed Aug 12 19:54:31 2015 +0800

----------------------------------------------------------------------
 .../datasources/DataSourceStrategy.scala        | 30 +++++++++++++++--
 .../execution/datasources/jdbc/JDBCRDD.scala    |  2 +-
 .../org/apache/spark/sql/jdbc/JDBCSuite.scala   | 34 ++++++++++++++++++++
 3 files changed, 63 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8e32db9a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index 2a4c40d..9eea2b0 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -20,13 +20,13 @@ package org.apache.spark.sql.execution.datasources
 import org.apache.spark.{Logging, TaskContext}
 import org.apache.spark.deploy.SparkHadoopUtil
 import org.apache.spark.rdd.{MapPartitionsRDD, RDD, UnionRDD}
-import org.apache.spark.sql.catalyst.{InternalRow, expressions}
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, 
expressions}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.planning.PhysicalOperation
 import org.apache.spark.sql.catalyst.plans.logical
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.sources._
-import org.apache.spark.sql.types.{StringType, StructType}
+import org.apache.spark.sql.types.{TimestampType, DateType, StringType, 
StructType}
 import org.apache.spark.sql.{SaveMode, Strategy, execution, sources, _}
 import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.{SerializableConfiguration, Utils}
@@ -343,11 +343,17 @@ private[sql] object DataSourceStrategy extends Strategy 
with Logging {
    * and convert them.
    */
   protected[sql] def selectFilters(filters: Seq[Expression]) = {
+    import CatalystTypeConverters._
+
     def translate(predicate: Expression): Option[Filter] = predicate match {
       case expressions.EqualTo(a: Attribute, Literal(v, _)) =>
         Some(sources.EqualTo(a.name, v))
       case expressions.EqualTo(Literal(v, _), a: Attribute) =>
         Some(sources.EqualTo(a.name, v))
+      case expressions.EqualTo(Cast(a: Attribute, _), l: Literal) =>
+        Some(sources.EqualTo(a.name, convertToScala(Cast(l, 
a.dataType).eval(), a.dataType)))
+      case expressions.EqualTo(l: Literal, Cast(a: Attribute, _)) =>
+        Some(sources.EqualTo(a.name, convertToScala(Cast(l, 
a.dataType).eval(), a.dataType)))
 
       case expressions.EqualNullSafe(a: Attribute, Literal(v, _)) =>
         Some(sources.EqualNullSafe(a.name, v))
@@ -358,21 +364,41 @@ private[sql] object DataSourceStrategy extends Strategy 
with Logging {
         Some(sources.GreaterThan(a.name, v))
       case expressions.GreaterThan(Literal(v, _), a: Attribute) =>
         Some(sources.LessThan(a.name, v))
+      case expressions.GreaterThan(Cast(a: Attribute, _), l: Literal) =>
+        Some(sources.GreaterThan(a.name, convertToScala(Cast(l, 
a.dataType).eval(), a.dataType)))
+      case expressions.GreaterThan(l: Literal, Cast(a: Attribute, _)) =>
+        Some(sources.LessThan(a.name, convertToScala(Cast(l, 
a.dataType).eval(), a.dataType)))
 
       case expressions.LessThan(a: Attribute, Literal(v, _)) =>
         Some(sources.LessThan(a.name, v))
       case expressions.LessThan(Literal(v, _), a: Attribute) =>
         Some(sources.GreaterThan(a.name, v))
+      case expressions.LessThan(Cast(a: Attribute, _), l: Literal) =>
+        Some(sources.LessThan(a.name, convertToScala(Cast(l, 
a.dataType).eval(), a.dataType)))
+      case expressions.LessThan(l: Literal, Cast(a: Attribute, _)) =>
+        Some(sources.GreaterThan(a.name, convertToScala(Cast(l, 
a.dataType).eval(), a.dataType)))
 
       case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, _)) =>
         Some(sources.GreaterThanOrEqual(a.name, v))
       case expressions.GreaterThanOrEqual(Literal(v, _), a: Attribute) =>
         Some(sources.LessThanOrEqual(a.name, v))
+      case expressions.GreaterThanOrEqual(Cast(a: Attribute, _), l: Literal) =>
+        Some(sources.GreaterThanOrEqual(a.name,
+          convertToScala(Cast(l, a.dataType).eval(), a.dataType)))
+      case expressions.GreaterThanOrEqual(l: Literal, Cast(a: Attribute, _)) =>
+        Some(sources.LessThanOrEqual(a.name,
+          convertToScala(Cast(l, a.dataType).eval(), a.dataType)))
 
       case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) =>
         Some(sources.LessThanOrEqual(a.name, v))
       case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) =>
         Some(sources.GreaterThanOrEqual(a.name, v))
+      case expressions.LessThanOrEqual(Cast(a: Attribute, _), l: Literal) =>
+        Some(sources.LessThanOrEqual(a.name,
+          convertToScala(Cast(l, a.dataType).eval(), a.dataType)))
+      case expressions.LessThanOrEqual(l: Literal, Cast(a: Attribute, _)) =>
+        Some(sources.GreaterThanOrEqual(a.name,
+          convertToScala(Cast(l, a.dataType).eval(), a.dataType)))
 
       case expressions.InSet(a: Attribute, set) =>
         Some(sources.In(a.name, set.toArray))

http://git-wip-us.apache.org/repos/asf/spark/blob/8e32db9a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index 8eab6a0..281943e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -284,7 +284,7 @@ private[sql] class JDBCRDD(
   /**
    * `filters`, but as a WHERE clause suitable for injection into a SQL query.
    */
-  private val filterWhereClause: String = {
+  val filterWhereClause: String = {
     val filterStrings = filters map compileFilter filter (_ != null)
     if (filterStrings.size > 0) {
       val sb = new StringBuilder("WHERE ")

http://git-wip-us.apache.org/repos/asf/spark/blob/8e32db9a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index 42f2449..b9cfae5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -25,6 +25,8 @@ import org.h2.jdbc.JdbcSQLException
 import org.scalatest.BeforeAndAfter
 
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD
+import org.apache.spark.sql.execution.PhysicalRDD
 import org.apache.spark.sql.types._
 import org.apache.spark.util.Utils
 
@@ -148,6 +150,18 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
         |OPTIONS (url '$url', dbtable 'TEST.FLTTYPES', user 'testUser', 
password 'testPass')
       """.stripMargin.replaceAll("\n", " "))
 
+    conn.prepareStatement("create table test.decimals (a DECIMAL(7, 2), b 
DECIMAL(4, 0))").
+      executeUpdate()
+    conn.prepareStatement("insert into test.decimals values (12345.67, 
1234)").executeUpdate()
+    conn.prepareStatement("insert into test.decimals values (34567.89, 
1428)").executeUpdate()
+    conn.commit()
+    sql(
+      s"""
+         |CREATE TEMPORARY TABLE decimals
+         |USING org.apache.spark.sql.jdbc
+         |OPTIONS (url '$url', dbtable 'TEST.DECIMALS', user 'testUser', 
password 'testPass')
+      """.stripMargin.replaceAll("\n", " "))
+
     conn.prepareStatement(
       s"""
         |create table test.nulltypes (a INT, b BOOLEAN, c TINYINT, d 
BINARY(20), e VARCHAR(20),
@@ -445,4 +459,24 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
     assert(agg.getCatalystType(1, "", 1, null) === Some(StringType))
   }
 
+  test("SPARK-9182: filters are not passed through to jdbc source") {
+    def checkPushedFilter(query: String, filterStr: String): Unit = {
+      val rddOpt = sql(query).queryExecution.executedPlan.collectFirst {
+        case PhysicalRDD(_, rdd: JDBCRDD, _) => rdd
+      }
+      assert(rddOpt.isDefined)
+      val pushedFilterStr = rddOpt.get.filterWhereClause
+      assert(pushedFilterStr.contains(filterStr),
+        s"Expected to push [$filterStr], actually we pushed 
[$pushedFilterStr]")
+    }
+
+    checkPushedFilter("select * from foobar where NAME = 'fred'", "NAME = 
'fred'")
+    checkPushedFilter("select * from inttypes where A > '15'", "A > 15")
+    checkPushedFilter("select * from inttypes where C <= 20", "C <= 20")
+    checkPushedFilter("select * from decimals where A > 1000", "A > 1000.00")
+    checkPushedFilter("select * from decimals where A > 1000 AND A < 2000",
+      "A > 1000.00 AND A < 2000.00")
+    checkPushedFilter("select * from decimals where A = 2000 AND B > 20", "A = 
2000.00 AND B > 20")
+    checkPushedFilter("select * from timetypes where B > '1998-09-10'", "B > 
1998-09-10")
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to