Repository: spark
Updated Branches:
  refs/heads/master c19680be1 -> d4107196d


[SPARK-18004][SQL] Make sure the date or timestamp related predicate can be 
pushed down to Oracle correctly

## What changes were proposed in this pull request?

Move `compileValue` method in JDBCRDD to JdbcDialect, and override the 
`compileValue` method in OracleDialect to rewrite the Oracle-specific timestamp 
and date literals in where clause.

## How was this patch tested?

An integration test has been added.

Author: Rui Zha <zrdt...@gmail.com>
Author: Zharui <zrdt...@gmail.com>

Closes #18451 from SharpRay/extend-compileValue-to-dialects.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d4107196
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d4107196
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d4107196

Branch: refs/heads/master
Commit: d4107196d59638845bd19da6aab074424d90ddaf
Parents: c19680b
Author: Rui Zha <zrdt...@gmail.com>
Authored: Sun Jul 2 17:37:47 2017 -0700
Committer: gatorsmile <gatorsm...@gmail.com>
Committed: Sun Jul 2 17:37:47 2017 -0700

----------------------------------------------------------------------
 .../spark/sql/jdbc/OracleIntegrationSuite.scala | 45 ++++++++++++++++++++
 .../execution/datasources/jdbc/JDBCRDD.scala    | 35 +++++----------
 .../apache/spark/sql/jdbc/JdbcDialects.scala    | 27 +++++++++++-
 .../apache/spark/sql/jdbc/OracleDialect.scala   | 15 ++++++-
 4 files changed, 95 insertions(+), 27 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d4107196/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala
----------------------------------------------------------------------
diff --git 
a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala
 
b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala
index b2f0969..e14810a 100644
--- 
a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala
+++ 
b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala
@@ -223,4 +223,49 @@ class OracleIntegrationSuite extends 
DockerJDBCIntegrationSuite with SharedSQLCo
     val types = rows(0).toSeq.map(x => x.getClass.toString)
     assert(types(1).equals("class java.sql.Timestamp"))
   }
+
+  test("SPARK-18004: Make sure date or timestamp related predicate is pushed 
down correctly") {
+    val props = new Properties()
+    props.put("oracle.jdbc.mapDateToTimestamp", "false")
+
+    val schema = StructType(Seq(
+      StructField("date_type", DateType, true),
+      StructField("timestamp_type", TimestampType, true)
+    ))
+
+    val tableName = "test_date_timestamp_pushdown"
+    val dateVal = Date.valueOf("2017-06-22")
+    val timestampVal = Timestamp.valueOf("2017-06-22 21:30:07")
+
+    val data = spark.sparkContext.parallelize(Seq(
+      Row(dateVal, timestampVal)
+    ))
+
+    val dfWrite = spark.createDataFrame(data, schema)
+    dfWrite.write.jdbc(jdbcUrl, tableName, props)
+
+    val dfRead = spark.read.jdbc(jdbcUrl, tableName, props)
+
+    val millis = System.currentTimeMillis()
+    val dt = new java.sql.Date(millis)
+    val ts = new java.sql.Timestamp(millis)
+
+    // Query Oracle table with date and timestamp predicates
+    // which should be pushed down to Oracle.
+    val df = dfRead.filter(dfRead.col("date_type").lt(dt))
+      .filter(dfRead.col("timestamp_type").lt(ts))
+
+    val metadata = df.queryExecution.sparkPlan.metadata
+    // The "PushedFilters" part should be exist in Datafrome's
+    // physical plan and the existence of right literals in
+    // "PushedFilters" is used to prove that the predicates
+    // pushing down have been effective.
+    assert(metadata.get("PushedFilters").ne(None))
+    assert(metadata("PushedFilters").contains(dt.toString))
+    assert(metadata("PushedFilters").contains(ts.toString))
+
+    val row = df.collect()(0)
+    assert(row.getDate(0).equals(dateVal))
+    assert(row.getTimestamp(1).equals(timestampVal))
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/d4107196/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index 2bdc432..0f53b5c 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -17,12 +17,10 @@
 
 package org.apache.spark.sql.execution.datasources.jdbc
 
-import java.sql.{Connection, Date, PreparedStatement, ResultSet, SQLException, 
Timestamp}
+import java.sql.{Connection, PreparedStatement, ResultSet, SQLException}
 
 import scala.util.control.NonFatal
 
-import org.apache.commons.lang3.StringUtils
-
 import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, 
TaskContext}
 import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.RDD
@@ -87,20 +85,6 @@ object JDBCRDD extends Logging {
   }
 
   /**
-   * Converts value to SQL expression.
-   */
-  private def compileValue(value: Any): Any = value match {
-    case stringValue: String => s"'${escapeSql(stringValue)}'"
-    case timestampValue: Timestamp => "'" + timestampValue + "'"
-    case dateValue: Date => "'" + dateValue + "'"
-    case arrayValue: Array[Any] => arrayValue.map(compileValue).mkString(", ")
-    case _ => value
-  }
-
-  private def escapeSql(value: String): String =
-    if (value == null) null else StringUtils.replace(value, "'", "''")
-
-  /**
    * Turns a single Filter into a String representing a SQL expression.
    * Returns None for an unhandled filter.
    */
@@ -108,15 +92,16 @@ object JDBCRDD extends Logging {
     def quote(colName: String): String = dialect.quoteIdentifier(colName)
 
     Option(f match {
-      case EqualTo(attr, value) => s"${quote(attr)} = ${compileValue(value)}"
+      case EqualTo(attr, value) => s"${quote(attr)} = 
${dialect.compileValue(value)}"
       case EqualNullSafe(attr, value) =>
         val col = quote(attr)
-        s"(NOT ($col != ${compileValue(value)} OR $col IS NULL OR " +
-          s"${compileValue(value)} IS NULL) OR ($col IS NULL AND 
${compileValue(value)} IS NULL))"
-      case LessThan(attr, value) => s"${quote(attr)} < ${compileValue(value)}"
-      case GreaterThan(attr, value) => s"${quote(attr)} > 
${compileValue(value)}"
-      case LessThanOrEqual(attr, value) => s"${quote(attr)} <= 
${compileValue(value)}"
-      case GreaterThanOrEqual(attr, value) => s"${quote(attr)} >= 
${compileValue(value)}"
+        s"(NOT ($col != ${dialect.compileValue(value)} OR $col IS NULL OR " +
+          s"${dialect.compileValue(value)} IS NULL) OR " +
+          s"($col IS NULL AND ${dialect.compileValue(value)} IS NULL))"
+      case LessThan(attr, value) => s"${quote(attr)} < 
${dialect.compileValue(value)}"
+      case GreaterThan(attr, value) => s"${quote(attr)} > 
${dialect.compileValue(value)}"
+      case LessThanOrEqual(attr, value) => s"${quote(attr)} <= 
${dialect.compileValue(value)}"
+      case GreaterThanOrEqual(attr, value) => s"${quote(attr)} >= 
${dialect.compileValue(value)}"
       case IsNull(attr) => s"${quote(attr)} IS NULL"
       case IsNotNull(attr) => s"${quote(attr)} IS NOT NULL"
       case StringStartsWith(attr, value) => s"${quote(attr)} LIKE '${value}%'"
@@ -124,7 +109,7 @@ object JDBCRDD extends Logging {
       case StringContains(attr, value) => s"${quote(attr)} LIKE '%${value}%'"
       case In(attr, value) if value.isEmpty =>
         s"CASE WHEN ${quote(attr)} IS NULL THEN NULL ELSE FALSE END"
-      case In(attr, value) => s"${quote(attr)} IN (${compileValue(value)})"
+      case In(attr, value) => s"${quote(attr)} IN 
(${dialect.compileValue(value)})"
       case Not(f) => compileFilter(f, dialect).map(p => s"(NOT 
($p))").getOrElse(null)
       case Or(f1, f2) =>
         // We can't compile Or filter unless both sub-filters are compiled 
successfully.

http://git-wip-us.apache.org/repos/asf/spark/blob/d4107196/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
index a86a86d..7c38ed6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
@@ -17,7 +17,9 @@
 
 package org.apache.spark.sql.jdbc
 
-import java.sql.Connection
+import java.sql.{Connection, Date, Timestamp}
+
+import org.apache.commons.lang3.StringUtils
 
 import org.apache.spark.annotation.{DeveloperApi, InterfaceStability, Since}
 import org.apache.spark.sql.types._
@@ -124,6 +126,29 @@ abstract class JdbcDialect extends Serializable {
   }
 
   /**
+   * Escape special characters in SQL string literals.
+   * @param value The string to be escaped.
+   * @return Escaped string.
+   */
+  @Since("2.3.0")
+  protected[jdbc] def escapeSql(value: String): String =
+    if (value == null) null else StringUtils.replace(value, "'", "''")
+
+  /**
+   * Converts value to SQL expression.
+   * @param value The value to be converted.
+   * @return Converted value.
+   */
+  @Since("2.3.0")
+  def compileValue(value: Any): Any = value match {
+    case stringValue: String => s"'${escapeSql(stringValue)}'"
+    case timestampValue: Timestamp => "'" + timestampValue + "'"
+    case dateValue: Date => "'" + dateValue + "'"
+    case arrayValue: Array[Any] => arrayValue.map(compileValue).mkString(", ")
+    case _ => value
+  }
+
+  /**
    * Return Some[true] iff `TRUNCATE TABLE` causes cascading default.
    * Some[true] : TRUNCATE TABLE causes cascading.
    * Some[false] : TRUNCATE TABLE does not cause cascading.

http://git-wip-us.apache.org/repos/asf/spark/blob/d4107196/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala
index 20e634c..3b44c1d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.jdbc
 
-import java.sql.Types
+import java.sql.{Date, Timestamp, Types}
 
 import org.apache.spark.sql.types._
 
@@ -64,5 +64,18 @@ private case object OracleDialect extends JdbcDialect {
     case _ => None
   }
 
+  override def compileValue(value: Any): Any = value match {
+    // The JDBC drivers support date literals in SQL statements written in the
+    // format: {d 'yyyy-mm-dd'} and timestamp literals in SQL statements 
written
+    // in the format: {ts 'yyyy-mm-dd hh:mm:ss.f...'}. For details, see
+    // 'Oracle Database JDBC Developer’s Guide and Reference, 11g Release 1 
(11.1)'
+    // Appendix A Reference Information.
+    case stringValue: String => s"'${escapeSql(stringValue)}'"
+    case timestampValue: Timestamp => "{ts '" + timestampValue + "'}"
+    case dateValue: Date => "{d '" + dateValue + "'}"
+    case arrayValue: Array[Any] => arrayValue.map(compileValue).mkString(", ")
+    case _ => value
+  }
+
   override def isCascadingTruncateTable(): Option[Boolean] = Some(false)
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to