Repository: spark
Updated Branches:
  refs/heads/master a5849ad9a -> f596ebe4d


[SPARK-24327][SQL] Verify and normalize a partition column name based on the 
JDBC resolved schema

## What changes were proposed in this pull request?
This pr modified JDBC datasource code to verify and normalize a partition 
column based on the JDBC resolved schema before building `JDBCRelation`.

Closes #20370

## How was this patch tested?
Added tests in `JDBCSuite`.

Author: Takeshi Yamamuro <yamam...@apache.org>

Closes #21379 from maropu/SPARK-24327.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f596ebe4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f596ebe4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f596ebe4

Branch: refs/heads/master
Commit: f596ebe4d3170590b6fce34c179e51ee80c965d3
Parents: a5849ad
Author: Takeshi Yamamuro <yamam...@apache.org>
Authored: Sun Jun 24 23:14:42 2018 -0700
Committer: Xiao Li <gatorsm...@gmail.com>
Committed: Sun Jun 24 23:14:42 2018 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/util/Utils.scala     |  2 +-
 .../datasources/jdbc/JDBCRelation.scala         | 76 ++++++++++++++++----
 .../datasources/jdbc/JdbcRelationProvider.scala |  6 +-
 .../org/apache/spark/sql/jdbc/JDBCSuite.scala   | 51 ++++++++++++-
 4 files changed, 118 insertions(+), 17 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f596ebe4/core/src/main/scala/org/apache/spark/util/Utils.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala 
b/core/src/main/scala/org/apache/spark/util/Utils.scala
index c139db4..a6fd363 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -100,7 +100,7 @@ private[spark] object Utils extends Logging {
    */
   val DEFAULT_MAX_TO_STRING_FIELDS = 25
 
-  private def maxNumToStringFields = {
+  private[spark] def maxNumToStringFields = {
     if (SparkEnv.get != null) {
       SparkEnv.get.conf.getInt("spark.debug.maxToStringFields", 
DEFAULT_MAX_TO_STRING_FIELDS)
     } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/f596ebe4/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
index b23e5a7..b84543c 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
@@ -22,10 +22,12 @@ import scala.collection.mutable.ArrayBuffer
 import org.apache.spark.Partition
 import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, 
SQLContext}
+import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode, 
SparkSession, SQLContext}
+import org.apache.spark.sql.catalyst.analysis._
 import org.apache.spark.sql.jdbc.JdbcDialects
 import org.apache.spark.sql.sources._
 import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.Utils
 
 /**
  * Instructions on how to partition the table among workers.
@@ -48,10 +50,17 @@ private[sql] object JDBCRelation extends Logging {
    * Null value predicate is added to the first partition where clause to 
include
    * the rows with null value for the partitions column.
    *
+   * @param schema resolved schema of a JDBC table
    * @param partitioning partition information to generate the where clause 
for each partition
+   * @param resolver function used to determine if two identifiers are equal
+   * @param jdbcOptions JDBC options that contains url
    * @return an array of partitions with where clause for each partition
    */
-  def columnPartition(partitioning: JDBCPartitioningInfo): Array[Partition] = {
+  def columnPartition(
+      schema: StructType,
+      partitioning: JDBCPartitioningInfo,
+      resolver: Resolver,
+      jdbcOptions: JDBCOptions): Array[Partition] = {
     if (partitioning == null || partitioning.numPartitions <= 1 ||
       partitioning.lowerBound == partitioning.upperBound) {
       return Array[Partition](JDBCPartition(null, 0))
@@ -78,7 +87,10 @@ private[sql] object JDBCRelation extends Logging {
     // Overflow and silliness can happen if you subtract then divide.
     // Here we get a little roundoff, but that's (hopefully) OK.
     val stride: Long = upperBound / numPartitions - lowerBound / numPartitions
-    val column = partitioning.column
+
+    val column = verifyAndGetNormalizedColumnName(
+      schema, partitioning.column, resolver, jdbcOptions)
+
     var i: Int = 0
     var currentValue: Long = lowerBound
     val ans = new ArrayBuffer[Partition]()
@@ -99,10 +111,57 @@ private[sql] object JDBCRelation extends Logging {
     }
     ans.toArray
   }
+
+  // Verify column name based on the JDBC resolved schema
+  private def verifyAndGetNormalizedColumnName(
+      schema: StructType,
+      columnName: String,
+      resolver: Resolver,
+      jdbcOptions: JDBCOptions): String = {
+    val dialect = JdbcDialects.get(jdbcOptions.url)
+    schema.map(_.name).find { fieldName =>
+      resolver(fieldName, columnName) ||
+        resolver(dialect.quoteIdentifier(fieldName), columnName)
+    }.map(dialect.quoteIdentifier).getOrElse {
+      throw new AnalysisException(s"User-defined partition column $columnName 
not " +
+        s"found in the JDBC relation: 
${schema.simpleString(Utils.maxNumToStringFields)}")
+    }
+  }
+
+  /**
+   * Takes a (schema, table) specification and returns the table's Catalyst 
schema.
+   * If `customSchema` defined in the JDBC options, replaces the schema's 
dataType with the
+   * custom schema's type.
+   *
+   * @param resolver function used to determine if two identifiers are equal
+   * @param jdbcOptions JDBC options that contains url, table and other 
information.
+   * @return resolved Catalyst schema of a JDBC table
+   */
+  def getSchema(resolver: Resolver, jdbcOptions: JDBCOptions): StructType = {
+    val tableSchema = JDBCRDD.resolveTable(jdbcOptions)
+    jdbcOptions.customSchema match {
+      case Some(customSchema) => JdbcUtils.getCustomSchema(
+        tableSchema, customSchema, resolver)
+      case None => tableSchema
+    }
+  }
+
+  /**
+   * Resolves a Catalyst schema of a JDBC table and returns [[JDBCRelation]] 
with the schema.
+   */
+  def apply(
+      parts: Array[Partition],
+      jdbcOptions: JDBCOptions)(
+      sparkSession: SparkSession): JDBCRelation = {
+    val schema = 
JDBCRelation.getSchema(sparkSession.sessionState.conf.resolver, jdbcOptions)
+    JDBCRelation(schema, parts, jdbcOptions)(sparkSession)
+  }
 }
 
 private[sql] case class JDBCRelation(
-    parts: Array[Partition], jdbcOptions: JDBCOptions)(@transient val 
sparkSession: SparkSession)
+    override val schema: StructType,
+    parts: Array[Partition],
+    jdbcOptions: JDBCOptions)(@transient val sparkSession: SparkSession)
   extends BaseRelation
   with PrunedFilteredScan
   with InsertableRelation {
@@ -111,15 +170,6 @@ private[sql] case class JDBCRelation(
 
   override val needConversion: Boolean = false
 
-  override val schema: StructType = {
-    val tableSchema = JDBCRDD.resolveTable(jdbcOptions)
-    jdbcOptions.customSchema match {
-      case Some(customSchema) => JdbcUtils.getCustomSchema(
-        tableSchema, customSchema, sparkSession.sessionState.conf.resolver)
-      case None => tableSchema
-    }
-  }
-
   // Check if JDBCRDD.compileFilter can accept input filters
   override def unhandledFilters(filters: Array[Filter]): Array[Filter] = {
     filters.filter(JDBCRDD.compileFilter(_, 
JdbcDialects.get(jdbcOptions.url)).isEmpty)

http://git-wip-us.apache.org/repos/asf/spark/blob/f596ebe4/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
index f8c5677..2b488bb 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
@@ -48,8 +48,10 @@ class JdbcRelationProvider extends CreatableRelationProvider
       JDBCPartitioningInfo(
         partitionColumn.get, lowerBound.get, upperBound.get, numPartitions.get)
     }
-    val parts = JDBCRelation.columnPartition(partitionInfo)
-    JDBCRelation(parts, jdbcOptions)(sqlContext.sparkSession)
+    val resolver = sqlContext.conf.resolver
+    val schema = JDBCRelation.getSchema(resolver, jdbcOptions)
+    val parts = JDBCRelation.columnPartition(schema, partitionInfo, resolver, 
jdbcOptions)
+    JDBCRelation(schema, parts, jdbcOptions)(sqlContext.sparkSession)
   }
 
   override def createRelation(

http://git-wip-us.apache.org/repos/asf/spark/blob/f596ebe4/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index bc2aca6..6ea61f0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -31,8 +31,9 @@ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
 import org.apache.spark.sql.execution.DataSourceScanExec
 import org.apache.spark.sql.execution.command.ExplainCommand
 import org.apache.spark.sql.execution.datasources.LogicalRelation
-import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, 
JDBCRelation, JdbcUtils}
+import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, 
JDBCPartition, JDBCRDD, JDBCRelation, JdbcUtils}
 import org.apache.spark.sql.execution.metric.InputOutputMetricsHelper
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.sources._
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
@@ -238,6 +239,11 @@ class JDBCSuite extends SparkFunSuite
         |OPTIONS (url '$url', dbtable 'TEST."mixedCaseCols"', user 'testUser', 
password 'testPass')
        """.stripMargin.replaceAll("\n", " "))
 
+    conn.prepareStatement("CREATE TABLE test.partition (THEID INTEGER, `THE 
ID` INTEGER) " +
+      "AS SELECT 1, 1")
+      .executeUpdate()
+    conn.commit()
+
     // Untested: IDENTITY, OTHER, UUID, ARRAY, and GEOMETRY types.
   }
 
@@ -1206,4 +1212,47 @@ class JDBCSuite extends SparkFunSuite
     }.getMessage
     assert(errMsg.contains("Statement was canceled or the session timed out"))
   }
+
+  test("SPARK-24327 verify and normalize a partition column based on a JDBC 
resolved schema") {
+    def testJdbcParitionColumn(partColName: String, expectedColumnName: 
String): Unit = {
+      val df = spark.read.format("jdbc")
+        .option("url", urlWithUserAndPass)
+        .option("dbtable", "TEST.PARTITION")
+        .option("partitionColumn", partColName)
+        .option("lowerBound", 1)
+        .option("upperBound", 4)
+        .option("numPartitions", 3)
+        .load()
+
+      val quotedPrtColName = testH2Dialect.quoteIdentifier(expectedColumnName)
+      df.logicalPlan match {
+        case LogicalRelation(JDBCRelation(_, parts, _), _, _, _) =>
+          val whereClauses = 
parts.map(_.asInstanceOf[JDBCPartition].whereClause).toSet
+          assert(whereClauses === Set(
+            s"$quotedPrtColName < 2 or $quotedPrtColName is null",
+            s"$quotedPrtColName >= 2 AND $quotedPrtColName < 3",
+            s"$quotedPrtColName >= 3"))
+      }
+    }
+
+    testJdbcParitionColumn("THEID", "THEID")
+    testJdbcParitionColumn("\"THEID\"", "THEID")
+    withSQLConf("spark.sql.caseSensitive" -> "false") {
+      testJdbcParitionColumn("ThEiD", "THEID")
+    }
+    testJdbcParitionColumn("THE ID", "THE ID")
+
+    def testIncorrectJdbcPartitionColumn(partColName: String): Unit = {
+      val errMsg = intercept[AnalysisException] {
+        testJdbcParitionColumn(partColName, "THEID")
+      }.getMessage
+      assert(errMsg.contains(s"User-defined partition column $partColName not 
found " +
+        "in the JDBC relation:"))
+    }
+
+    testIncorrectJdbcPartitionColumn("NoExistingColumn")
+    withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
+      testIncorrectJdbcPartitionColumn(testH2Dialect.quoteIdentifier("ThEiD"))
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to