Repository: spark Updated Branches: refs/heads/master c19680be1 -> d4107196d
[SPARK-18004][SQL] Make sure the date or timestamp related predicate can be pushed down to Oracle correctly ## What changes were proposed in this pull request? Move `compileValue` method in JDBCRDD to JdbcDialect, and override the `compileValue` method in OracleDialect to rewrite the Oracle-specific timestamp and date literals in where clause. ## How was this patch tested? An integration test has been added. Author: Rui Zha <zrdt...@gmail.com> Author: Zharui <zrdt...@gmail.com> Closes #18451 from SharpRay/extend-compileValue-to-dialects. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d4107196 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d4107196 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d4107196 Branch: refs/heads/master Commit: d4107196d59638845bd19da6aab074424d90ddaf Parents: c19680b Author: Rui Zha <zrdt...@gmail.com> Authored: Sun Jul 2 17:37:47 2017 -0700 Committer: gatorsmile <gatorsm...@gmail.com> Committed: Sun Jul 2 17:37:47 2017 -0700 ---------------------------------------------------------------------- .../spark/sql/jdbc/OracleIntegrationSuite.scala | 45 ++++++++++++++++++++ .../execution/datasources/jdbc/JDBCRDD.scala | 35 +++++---------- .../apache/spark/sql/jdbc/JdbcDialects.scala | 27 +++++++++++- .../apache/spark/sql/jdbc/OracleDialect.scala | 15 ++++++- 4 files changed, 95 insertions(+), 27 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/d4107196/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala ---------------------------------------------------------------------- diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala index b2f0969..e14810a 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala @@ -223,4 +223,49 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo val types = rows(0).toSeq.map(x => x.getClass.toString) assert(types(1).equals("class java.sql.Timestamp")) } + + test("SPARK-18004: Make sure date or timestamp related predicate is pushed down correctly") { + val props = new Properties() + props.put("oracle.jdbc.mapDateToTimestamp", "false") + + val schema = StructType(Seq( + StructField("date_type", DateType, true), + StructField("timestamp_type", TimestampType, true) + )) + + val tableName = "test_date_timestamp_pushdown" + val dateVal = Date.valueOf("2017-06-22") + val timestampVal = Timestamp.valueOf("2017-06-22 21:30:07") + + val data = spark.sparkContext.parallelize(Seq( + Row(dateVal, timestampVal) + )) + + val dfWrite = spark.createDataFrame(data, schema) + dfWrite.write.jdbc(jdbcUrl, tableName, props) + + val dfRead = spark.read.jdbc(jdbcUrl, tableName, props) + + val millis = System.currentTimeMillis() + val dt = new java.sql.Date(millis) + val ts = new java.sql.Timestamp(millis) + + // Query Oracle table with date and timestamp predicates + // which should be pushed down to Oracle. + val df = dfRead.filter(dfRead.col("date_type").lt(dt)) + .filter(dfRead.col("timestamp_type").lt(ts)) + + val metadata = df.queryExecution.sparkPlan.metadata + // The "PushedFilters" part should be exist in Datafrome's + // physical plan and the existence of right literals in + // "PushedFilters" is used to prove that the predicates + // pushing down have been effective. + assert(metadata.get("PushedFilters").ne(None)) + assert(metadata("PushedFilters").contains(dt.toString)) + assert(metadata("PushedFilters").contains(ts.toString)) + + val row = df.collect()(0) + assert(row.getDate(0).equals(dateVal)) + assert(row.getTimestamp(1).equals(timestampVal)) + } } http://git-wip-us.apache.org/repos/asf/spark/blob/d4107196/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 2bdc432..0f53b5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -17,12 +17,10 @@ package org.apache.spark.sql.execution.datasources.jdbc -import java.sql.{Connection, Date, PreparedStatement, ResultSet, SQLException, Timestamp} +import java.sql.{Connection, PreparedStatement, ResultSet, SQLException} import scala.util.control.NonFatal -import org.apache.commons.lang3.StringUtils - import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD @@ -87,20 +85,6 @@ object JDBCRDD extends Logging { } /** - * Converts value to SQL expression. - */ - private def compileValue(value: Any): Any = value match { - case stringValue: String => s"'${escapeSql(stringValue)}'" - case timestampValue: Timestamp => "'" + timestampValue + "'" - case dateValue: Date => "'" + dateValue + "'" - case arrayValue: Array[Any] => arrayValue.map(compileValue).mkString(", ") - case _ => value - } - - private def escapeSql(value: String): String = - if (value == null) null else StringUtils.replace(value, "'", "''") - - /** * Turns a single Filter into a String representing a SQL expression. * Returns None for an unhandled filter. */ @@ -108,15 +92,16 @@ object JDBCRDD extends Logging { def quote(colName: String): String = dialect.quoteIdentifier(colName) Option(f match { - case EqualTo(attr, value) => s"${quote(attr)} = ${compileValue(value)}" + case EqualTo(attr, value) => s"${quote(attr)} = ${dialect.compileValue(value)}" case EqualNullSafe(attr, value) => val col = quote(attr) - s"(NOT ($col != ${compileValue(value)} OR $col IS NULL OR " + - s"${compileValue(value)} IS NULL) OR ($col IS NULL AND ${compileValue(value)} IS NULL))" - case LessThan(attr, value) => s"${quote(attr)} < ${compileValue(value)}" - case GreaterThan(attr, value) => s"${quote(attr)} > ${compileValue(value)}" - case LessThanOrEqual(attr, value) => s"${quote(attr)} <= ${compileValue(value)}" - case GreaterThanOrEqual(attr, value) => s"${quote(attr)} >= ${compileValue(value)}" + s"(NOT ($col != ${dialect.compileValue(value)} OR $col IS NULL OR " + + s"${dialect.compileValue(value)} IS NULL) OR " + + s"($col IS NULL AND ${dialect.compileValue(value)} IS NULL))" + case LessThan(attr, value) => s"${quote(attr)} < ${dialect.compileValue(value)}" + case GreaterThan(attr, value) => s"${quote(attr)} > ${dialect.compileValue(value)}" + case LessThanOrEqual(attr, value) => s"${quote(attr)} <= ${dialect.compileValue(value)}" + case GreaterThanOrEqual(attr, value) => s"${quote(attr)} >= ${dialect.compileValue(value)}" case IsNull(attr) => s"${quote(attr)} IS NULL" case IsNotNull(attr) => s"${quote(attr)} IS NOT NULL" case StringStartsWith(attr, value) => s"${quote(attr)} LIKE '${value}%'" @@ -124,7 +109,7 @@ object JDBCRDD extends Logging { case StringContains(attr, value) => s"${quote(attr)} LIKE '%${value}%'" case In(attr, value) if value.isEmpty => s"CASE WHEN ${quote(attr)} IS NULL THEN NULL ELSE FALSE END" - case In(attr, value) => s"${quote(attr)} IN (${compileValue(value)})" + case In(attr, value) => s"${quote(attr)} IN (${dialect.compileValue(value)})" case Not(f) => compileFilter(f, dialect).map(p => s"(NOT ($p))").getOrElse(null) case Or(f1, f2) => // We can't compile Or filter unless both sub-filters are compiled successfully. http://git-wip-us.apache.org/repos/asf/spark/blob/d4107196/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index a86a86d..7c38ed6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.jdbc -import java.sql.Connection +import java.sql.{Connection, Date, Timestamp} + +import org.apache.commons.lang3.StringUtils import org.apache.spark.annotation.{DeveloperApi, InterfaceStability, Since} import org.apache.spark.sql.types._ @@ -124,6 +126,29 @@ abstract class JdbcDialect extends Serializable { } /** + * Escape special characters in SQL string literals. + * @param value The string to be escaped. + * @return Escaped string. + */ + @Since("2.3.0") + protected[jdbc] def escapeSql(value: String): String = + if (value == null) null else StringUtils.replace(value, "'", "''") + + /** + * Converts value to SQL expression. + * @param value The value to be converted. + * @return Converted value. + */ + @Since("2.3.0") + def compileValue(value: Any): Any = value match { + case stringValue: String => s"'${escapeSql(stringValue)}'" + case timestampValue: Timestamp => "'" + timestampValue + "'" + case dateValue: Date => "'" + dateValue + "'" + case arrayValue: Array[Any] => arrayValue.map(compileValue).mkString(", ") + case _ => value + } + + /** * Return Some[true] iff `TRUNCATE TABLE` causes cascading default. * Some[true] : TRUNCATE TABLE causes cascading. * Some[false] : TRUNCATE TABLE does not cause cascading. http://git-wip-us.apache.org/repos/asf/spark/blob/d4107196/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index 20e634c..3b44c1d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.jdbc -import java.sql.Types +import java.sql.{Date, Timestamp, Types} import org.apache.spark.sql.types._ @@ -64,5 +64,18 @@ private case object OracleDialect extends JdbcDialect { case _ => None } + override def compileValue(value: Any): Any = value match { + // The JDBC drivers support date literals in SQL statements written in the + // format: {d 'yyyy-mm-dd'} and timestamp literals in SQL statements written + // in the format: {ts 'yyyy-mm-dd hh:mm:ss.f...'}. For details, see + // 'Oracle Database JDBC Developerâs Guide and Reference, 11g Release 1 (11.1)' + // Appendix A Reference Information. + case stringValue: String => s"'${escapeSql(stringValue)}'" + case timestampValue: Timestamp => "{ts '" + timestampValue + "'}" + case dateValue: Date => "{d '" + dateValue + "'}" + case arrayValue: Array[Any] => arrayValue.map(compileValue).mkString(", ") + case _ => value + } + override def isCascadingTruncateTable(): Option[Boolean] = Some(false) } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org