Repository: spark
Updated Branches:
  refs/heads/branch-1.6 b5a1f564a -> 7f37c1e45


[SPARK-12579][SQL] Force user-specified JDBC driver to take precedence

Spark SQL's JDBC data source allows users to specify an explicit JDBC driver to 
load (using the `driver` argument), but in the current code it's possible that 
the user-specified driver will not be used when it comes time to actually 
create a JDBC connection.

In a nutshell, the problem is that you might have multiple JDBC drivers on the 
classpath that claim to be able to handle the same subprotocol, so simply 
registering the user-provided driver class with the our `DriverRegistry` and 
JDBC's `DriverManager` is not sufficient to ensure that it's actually used when 
creating the JDBC connection.

This patch addresses this issue by first registering the user-specified driver 
with the DriverManager, then iterating over the driver manager's loaded drivers 
in order to obtain the correct driver and use it to create a connection 
(previously, we just called `DriverManager.getConnection()` directly).

If a user did not specify a JDBC driver to use, then we call 
`DriverManager.getDriver` to figure out the class of the driver to use, then 
pass that class's name to executors; this guards against corner-case bugs in 
situations where the driver and executor JVMs might have different sets of JDBC 
drivers on their classpaths (previously, there was the (rare) potential for 
`DriverManager.getConnection()` to use different drivers on the driver and 
executors if the user had not explicitly specified a JDBC driver class and the 
classpaths were different).

This patch is inspired by a similar patch that I made to the `spark-redshift` 
library (https://github.com/databricks/spark-redshift/pull/143), which contains 
its own modified fork of some of Spark's JDBC data source code (for 
cross-Spark-version compatibility reasons).

Author: Josh Rosen <joshro...@databricks.com>

Closes #10519 from JoshRosen/jdbc-driver-precedence.

(cherry picked from commit 6c83d938cc61bd5fabaf2157fcc3936364a83f02)
Signed-off-by: Yin Huai <yh...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7f37c1e4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7f37c1e4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7f37c1e4

Branch: refs/heads/branch-1.6
Commit: 7f37c1e45d52b7823d566349e2be21366d73651f
Parents: b5a1f56
Author: Josh Rosen <joshro...@databricks.com>
Authored: Mon Jan 4 10:39:42 2016 -0800
Committer: Yin Huai <yh...@databricks.com>
Committed: Mon Jan 4 10:46:45 2016 -0800

----------------------------------------------------------------------
 docs/sql-programming-guide.md                   |  4 +--
 .../org/apache/spark/sql/DataFrameWriter.scala  |  2 +-
 .../datasources/jdbc/DefaultSource.scala        |  3 --
 .../datasources/jdbc/DriverRegistry.scala       |  5 ---
 .../execution/datasources/jdbc/JDBCRDD.scala    | 33 +++---------------
 .../datasources/jdbc/JDBCRelation.scala         |  2 --
 .../execution/datasources/jdbc/JdbcUtils.scala  | 35 ++++++++++++++++----
 7 files changed, 34 insertions(+), 50 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7f37c1e4/docs/sql-programming-guide.md
----------------------------------------------------------------------
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index 3f9a831..b058833 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -1895,9 +1895,7 @@ the Data Sources API. The following options are supported:
   <tr>
     <td><code>driver</code></td>
     <td>
-      The class name of the JDBC driver needed to connect to this URL. This 
class will be loaded
-      on the master and workers before running an JDBC commands to allow the 
driver to
-      register itself with the JDBC subsystem.
+      The class name of the JDBC driver to use to connect to this URL.
     </td>
   </tr>
   

http://git-wip-us.apache.org/repos/asf/spark/blob/7f37c1e4/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index ab36253..9f59c0f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -275,7 +275,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
     }
     // connectionProperties should override settings in extraOptions
     props.putAll(connectionProperties)
-    val conn = JdbcUtils.createConnection(url, props)
+    val conn = JdbcUtils.createConnectionFactory(url, props)()
 
     try {
       var tableExists = JdbcUtils.tableExists(conn, url, table)

http://git-wip-us.apache.org/repos/asf/spark/blob/7f37c1e4/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala
index f522303..5ae6cff 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala
@@ -31,15 +31,12 @@ class DefaultSource extends RelationProvider with 
DataSourceRegister {
       sqlContext: SQLContext,
       parameters: Map[String, String]): BaseRelation = {
     val url = parameters.getOrElse("url", sys.error("Option 'url' not 
specified"))
-    val driver = parameters.getOrElse("driver", null)
     val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' 
not specified"))
     val partitionColumn = parameters.getOrElse("partitionColumn", null)
     val lowerBound = parameters.getOrElse("lowerBound", null)
     val upperBound = parameters.getOrElse("upperBound", null)
     val numPartitions = parameters.getOrElse("numPartitions", null)
 
-    if (driver != null) DriverRegistry.register(driver)
-
     if (partitionColumn != null
       && (lowerBound == null || upperBound == null || numPartitions == null)) {
       sys.error("Partitioning incompletely specified")

http://git-wip-us.apache.org/repos/asf/spark/blob/7f37c1e4/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala
index 7ccd61e..65af397 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala
@@ -51,10 +51,5 @@ object DriverRegistry extends Logging {
       }
     }
   }
-
-  def getDriverClassName(url: String): String = DriverManager.getDriver(url) 
match {
-    case wrapper: DriverWrapper => wrapper.wrapped.getClass.getCanonicalName
-    case driver => driver.getClass.getCanonicalName
-  }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/7f37c1e4/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index c2f2a31..fad482b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.execution.datasources.jdbc
 
-import java.sql.{Connection, Date, DriverManager, ResultSet, 
ResultSetMetaData, SQLException, Timestamp}
+import java.sql.{Connection, Date, ResultSet, ResultSetMetaData, SQLException, 
Timestamp}
 import java.util.Properties
 
 import org.apache.commons.lang3.StringUtils
@@ -39,7 +39,6 @@ private[sql] case class JDBCPartition(whereClause: String, 
idx: Int) extends Par
   override def index: Int = idx
 }
 
-
 private[sql] object JDBCRDD extends Logging {
 
   /**
@@ -118,7 +117,7 @@ private[sql] object JDBCRDD extends Logging {
    */
   def resolveTable(url: String, table: String, properties: Properties): 
StructType = {
     val dialect = JdbcDialects.get(url)
-    val conn: Connection = getConnector(properties.getProperty("driver"), url, 
properties)()
+    val conn: Connection = JdbcUtils.createConnectionFactory(url, properties)()
     try {
       val statement = conn.prepareStatement(s"SELECT * FROM $table WHERE 1=0")
       try {
@@ -170,36 +169,13 @@ private[sql] object JDBCRDD extends Logging {
     new StructType(columns map { name => fieldMap(name) })
   }
 
-  /**
-   * Given a driver string and an url, return a function that loads the
-   * specified driver string then returns a connection to the JDBC url.
-   * getConnector is run on the driver code, while the function it returns
-   * is run on the executor.
-   *
-   * @param driver - The class name of the JDBC driver for the given url, or 
null if the class name
-   *                 is not necessary.
-   * @param url - The JDBC url to connect to.
-   *
-   * @return A function that loads the driver and connects to the url.
-   */
-  def getConnector(driver: String, url: String, properties: Properties): () => 
Connection = {
-    () => {
-      try {
-        if (driver != null) DriverRegistry.register(driver)
-      } catch {
-        case e: ClassNotFoundException =>
-          logWarning(s"Couldn't find class $driver", e)
-      }
-      DriverManager.getConnection(url, properties)
-    }
-  }
+
 
   /**
    * Build and return JDBCRDD from the given information.
    *
    * @param sc - Your SparkContext.
    * @param schema - The Catalyst schema of the underlying database table.
-   * @param driver - The class name of the JDBC driver for the given url.
    * @param url - The JDBC url to connect to.
    * @param fqTable - The fully-qualified table name (or paren'd SQL query) to 
use.
    * @param requiredColumns - The names of the columns to SELECT.
@@ -212,7 +188,6 @@ private[sql] object JDBCRDD extends Logging {
   def scanTable(
       sc: SparkContext,
       schema: StructType,
-      driver: String,
       url: String,
       properties: Properties,
       fqTable: String,
@@ -223,7 +198,7 @@ private[sql] object JDBCRDD extends Logging {
     val quotedColumns = requiredColumns.map(colName => 
dialect.quoteIdentifier(colName))
     new JDBCRDD(
       sc,
-      getConnector(driver, url, properties),
+      JdbcUtils.createConnectionFactory(url, properties),
       pruneSchema(schema, requiredColumns),
       fqTable,
       quotedColumns,

http://git-wip-us.apache.org/repos/asf/spark/blob/7f37c1e4/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
index f9300dc..375266f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
@@ -91,12 +91,10 @@ private[sql] case class JDBCRelation(
   override val schema: StructType = JDBCRDD.resolveTable(url, table, 
properties)
 
   override def buildScan(requiredColumns: Array[String], filters: 
Array[Filter]): RDD[Row] = {
-    val driver: String = DriverRegistry.getDriverClassName(url)
     // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
     JDBCRDD.scanTable(
       sqlContext.sparkContext,
       schema,
-      driver,
       url,
       properties,
       table,

http://git-wip-us.apache.org/repos/asf/spark/blob/7f37c1e4/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
index 46f2670..10f6506 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -17,9 +17,10 @@
 
 package org.apache.spark.sql.execution.datasources.jdbc
 
-import java.sql.{Connection, PreparedStatement}
+import java.sql.{Connection, Driver, DriverManager, PreparedStatement}
 import java.util.Properties
 
+import scala.collection.JavaConverters._
 import scala.util.Try
 import scala.util.control.NonFatal
 
@@ -34,10 +35,31 @@ import org.apache.spark.sql.{DataFrame, Row}
 object JdbcUtils extends Logging {
 
   /**
-   * Establishes a JDBC connection.
+   * Returns a factory for creating connections to the given JDBC URL.
+   *
+   * @param url the JDBC url to connect to.
+   * @param properties JDBC connection properties.
    */
-  def createConnection(url: String, connectionProperties: Properties): 
Connection = {
-    JDBCRDD.getConnector(connectionProperties.getProperty("driver"), url, 
connectionProperties)()
+  def createConnectionFactory(url: String, properties: Properties): () => 
Connection = {
+    val userSpecifiedDriverClass = Option(properties.getProperty("driver"))
+    userSpecifiedDriverClass.foreach(DriverRegistry.register)
+    // Performing this part of the logic on the driver guards against the 
corner-case where the
+    // driver returned for a URL is different on the driver and executors due 
to classpath
+    // differences.
+    val driverClass: String = userSpecifiedDriverClass.getOrElse {
+      DriverManager.getDriver(url).getClass.getCanonicalName
+    }
+    () => {
+      userSpecifiedDriverClass.foreach(DriverRegistry.register)
+      val driver: Driver = DriverManager.getDrivers.asScala.collectFirst {
+        case d: DriverWrapper if d.wrapped.getClass.getCanonicalName == 
driverClass => d
+        case d if d.getClass.getCanonicalName == driverClass => d
+      }.getOrElse {
+        throw new IllegalStateException(
+          s"Did not find registered driver with class $driverClass")
+      }
+      driver.connect(url, properties)
+    }
   }
 
   /**
@@ -242,15 +264,14 @@ object JdbcUtils extends Logging {
       df: DataFrame,
       url: String,
       table: String,
-      properties: Properties = new Properties()) {
+      properties: Properties) {
     val dialect = JdbcDialects.get(url)
     val nullTypes: Array[Int] = df.schema.fields.map { field =>
       getJdbcType(field.dataType, dialect).jdbcNullType
     }
 
     val rddSchema = df.schema
-    val driver: String = DriverRegistry.getDriverClassName(url)
-    val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, 
properties)
+    val getConnection: () => Connection = createConnectionFactory(url, 
properties)
     val batchSize = properties.getProperty("batchsize", "1000").toInt
     df.foreachPartition { iterator =>
       savePartition(getConnection, table, iterator, rddSchema, nullTypes, 
batchSize, dialect)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to