Repository: spark
Updated Branches:
  refs/heads/master 03c27435a -> 3b2b785ec


[SPARK-16675][SQL] Avoid per-record type dispatch in JDBC when writing

## What changes were proposed in this pull request?

Currently, `JdbcUtils.savePartition` is doing type-based dispatch for each row 
to write appropriate values.

So, appropriate setters for `PreparedStatement` can be created first according 
to the schema, and then apply them to each row. This approach is similar with 
`CatalystWriteSupport`.

This PR simply make the setters to avoid this.

## How was this patch tested?

Existing tests should cover this.

Author: hyukjinkwon <gurwls...@gmail.com>

Closes #14323 from HyukjinKwon/SPARK-16675.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3b2b785e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3b2b785e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3b2b785e

Branch: refs/heads/master
Commit: 3b2b785ece4394ca332377647a6305ea493f411b
Parents: 03c2743
Author: hyukjinkwon <gurwls...@gmail.com>
Authored: Tue Jul 26 17:14:58 2016 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Tue Jul 26 17:14:58 2016 +0800

----------------------------------------------------------------------
 .../execution/datasources/jdbc/JDBCRDD.scala    |  22 ++--
 .../execution/datasources/jdbc/JdbcUtils.scala  | 102 ++++++++++++++-----
 2 files changed, 88 insertions(+), 36 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3b2b785e/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index 4c98430..e267e77 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -322,19 +322,19 @@ private[sql] class JDBCRDD(
     }
   }
 
-  // A `JDBCValueSetter` is responsible for converting and setting a value 
from `ResultSet`
-  // into a field for `MutableRow`. The last argument `Int` means the index 
for the
-  // value to be set in the row and also used for the value to retrieve from 
`ResultSet`.
-  private type JDBCValueSetter = (ResultSet, MutableRow, Int) => Unit
+  // A `JDBCValueGetter` is responsible for getting a value from `ResultSet` 
into a field
+  // for `MutableRow`. The last argument `Int` means the index for the value 
to be set in
+  // the row and also used for the value in `ResultSet`.
+  private type JDBCValueGetter = (ResultSet, MutableRow, Int) => Unit
 
   /**
-   * Creates `JDBCValueSetter`s according to [[StructType]], which can set
+   * Creates `JDBCValueGetter`s according to [[StructType]], which can set
    * each value from `ResultSet` to each field of [[MutableRow]] correctly.
    */
-  def makeSetters(schema: StructType): Array[JDBCValueSetter] =
-    schema.fields.map(sf => makeSetter(sf.dataType, sf.metadata))
+  def makeGetters(schema: StructType): Array[JDBCValueGetter] =
+    schema.fields.map(sf => makeGetter(sf.dataType, sf.metadata))
 
-  private def makeSetter(dt: DataType, metadata: Metadata): JDBCValueSetter = 
dt match {
+  private def makeGetter(dt: DataType, metadata: Metadata): JDBCValueGetter = 
dt match {
     case BooleanType =>
       (rs: ResultSet, row: MutableRow, pos: Int) =>
         row.setBoolean(pos, rs.getBoolean(pos + 1))
@@ -489,15 +489,15 @@ private[sql] class JDBCRDD(
     stmt.setFetchSize(fetchSize)
     val rs = stmt.executeQuery()
 
-    val setters: Array[JDBCValueSetter] = makeSetters(schema)
+    val getters: Array[JDBCValueGetter] = makeGetters(schema)
     val mutableRow = new SpecificMutableRow(schema.fields.map(x => x.dataType))
 
     def getNext(): InternalRow = {
       if (rs.next()) {
         inputMetrics.incRecordsRead(1)
         var i = 0
-        while (i < setters.length) {
-          setters(i).apply(rs, mutableRow, i)
+        while (i < getters.length) {
+          getters(i).apply(rs, mutableRow, i)
           if (rs.wasNull) mutableRow.setNullAt(i)
           i = i + 1
         }

http://git-wip-us.apache.org/repos/asf/spark/blob/3b2b785e/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
index cb474cb..81d38e3 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -154,6 +154,79 @@ object JdbcUtils extends Logging {
       throw new IllegalArgumentException(s"Can't get JDBC type for 
${dt.simpleString}"))
   }
 
+  // A `JDBCValueSetter` is responsible for setting a value from `Row` into a 
field for
+  // `PreparedStatement`. The last argument `Int` means the index for the 
value to be set
+  // in the SQL statement and also used for the value in `Row`.
+  private type JDBCValueSetter = (PreparedStatement, Row, Int) => Unit
+
+  private def makeSetter(
+      conn: Connection,
+      dialect: JdbcDialect,
+      dataType: DataType): JDBCValueSetter = dataType match {
+    case IntegerType =>
+      (stmt: PreparedStatement, row: Row, pos: Int) =>
+        stmt.setInt(pos + 1, row.getInt(pos))
+
+    case LongType =>
+      (stmt: PreparedStatement, row: Row, pos: Int) =>
+        stmt.setLong(pos + 1, row.getLong(pos))
+
+    case DoubleType =>
+      (stmt: PreparedStatement, row: Row, pos: Int) =>
+        stmt.setDouble(pos + 1, row.getDouble(pos))
+
+    case FloatType =>
+      (stmt: PreparedStatement, row: Row, pos: Int) =>
+        stmt.setFloat(pos + 1, row.getFloat(pos))
+
+    case ShortType =>
+      (stmt: PreparedStatement, row: Row, pos: Int) =>
+        stmt.setInt(pos + 1, row.getShort(pos))
+
+    case ByteType =>
+      (stmt: PreparedStatement, row: Row, pos: Int) =>
+        stmt.setInt(pos + 1, row.getByte(pos))
+
+    case BooleanType =>
+      (stmt: PreparedStatement, row: Row, pos: Int) =>
+        stmt.setBoolean(pos + 1, row.getBoolean(pos))
+
+    case StringType =>
+      (stmt: PreparedStatement, row: Row, pos: Int) =>
+        stmt.setString(pos + 1, row.getString(pos))
+
+    case BinaryType =>
+      (stmt: PreparedStatement, row: Row, pos: Int) =>
+        stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos))
+
+    case TimestampType =>
+      (stmt: PreparedStatement, row: Row, pos: Int) =>
+        stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos))
+
+    case DateType =>
+      (stmt: PreparedStatement, row: Row, pos: Int) =>
+        stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos))
+
+    case t: DecimalType =>
+      (stmt: PreparedStatement, row: Row, pos: Int) =>
+        stmt.setBigDecimal(pos + 1, row.getDecimal(pos))
+
+    case ArrayType(et, _) =>
+      // remove type length parameters from end of type name
+      val typeName = getJdbcType(et, dialect).databaseTypeDefinition
+        .toLowerCase.split("\\(")(0)
+      (stmt: PreparedStatement, row: Row, pos: Int) =>
+        val array = conn.createArrayOf(
+          typeName,
+          row.getSeq[AnyRef](pos).toArray)
+        stmt.setArray(pos + 1, array)
+
+    case _ =>
+      (_: PreparedStatement, _: Row, pos: Int) =>
+        throw new IllegalArgumentException(
+          s"Can't translate non-null value for field $pos")
+  }
+
   /**
    * Saves a partition of a DataFrame to the JDBC database.  This is done in
    * a single database transaction (unless isolation level is "NONE")
@@ -215,6 +288,9 @@ object JdbcUtils extends Logging {
         conn.setTransactionIsolation(finalIsolationLevel)
       }
       val stmt = insertStatement(conn, table, rddSchema, dialect)
+      val setters: Array[JDBCValueSetter] = rddSchema.fields.map(_.dataType)
+          .map(makeSetter(conn, dialect, _)).toArray
+
       try {
         var rowCount = 0
         while (iterator.hasNext) {
@@ -225,30 +301,7 @@ object JdbcUtils extends Logging {
             if (row.isNullAt(i)) {
               stmt.setNull(i + 1, nullTypes(i))
             } else {
-              rddSchema.fields(i).dataType match {
-                case IntegerType => stmt.setInt(i + 1, row.getInt(i))
-                case LongType => stmt.setLong(i + 1, row.getLong(i))
-                case DoubleType => stmt.setDouble(i + 1, row.getDouble(i))
-                case FloatType => stmt.setFloat(i + 1, row.getFloat(i))
-                case ShortType => stmt.setInt(i + 1, row.getShort(i))
-                case ByteType => stmt.setInt(i + 1, row.getByte(i))
-                case BooleanType => stmt.setBoolean(i + 1, row.getBoolean(i))
-                case StringType => stmt.setString(i + 1, row.getString(i))
-                case BinaryType => stmt.setBytes(i + 1, 
row.getAs[Array[Byte]](i))
-                case TimestampType => stmt.setTimestamp(i + 1, 
row.getAs[java.sql.Timestamp](i))
-                case DateType => stmt.setDate(i + 1, 
row.getAs[java.sql.Date](i))
-                case t: DecimalType => stmt.setBigDecimal(i + 1, 
row.getDecimal(i))
-                case ArrayType(et, _) =>
-                  // remove type length parameters from end of type name
-                  val typeName = getJdbcType(et, 
dialect).databaseTypeDefinition
-                    .toLowerCase.split("\\(")(0)
-                  val array = conn.createArrayOf(
-                    typeName,
-                    row.getSeq[AnyRef](i).toArray)
-                  stmt.setArray(i + 1, array)
-                case _ => throw new IllegalArgumentException(
-                  s"Can't translate non-null value for field $i")
-              }
+              setters(i).apply(stmt, row, i)
             }
             i = i + 1
           }
@@ -333,5 +386,4 @@ object JdbcUtils extends Logging {
       getConnection, table, iterator, rddSchema, nullTypes, batchSize, 
dialect, isolationLevel)
     )
   }
-
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to