Github user JoshRosen commented on a diff in the pull request:

    https://github.com/apache/spark/pull/14907#discussion_r77099804
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
 ---
    @@ -322,195 +237,15 @@ private[jdbc] class JDBCRDD(
         }
       }
     
    -  // A `JDBCValueGetter` is responsible for getting a value from 
`ResultSet` into a field
    -  // for `MutableRow`. The last argument `Int` means the index for the 
value to be set in
    -  // the row and also used for the value in `ResultSet`.
    -  private type JDBCValueGetter = (ResultSet, MutableRow, Int) => Unit
    -
    -  /**
    -   * Creates `JDBCValueGetter`s according to [[StructType]], which can set
    -   * each value from `ResultSet` to each field of [[MutableRow]] correctly.
    -   */
    -  def makeGetters(schema: StructType): Array[JDBCValueGetter] =
    -    schema.fields.map(sf => makeGetter(sf.dataType, sf.metadata))
    -
    -  private def makeGetter(dt: DataType, metadata: Metadata): 
JDBCValueGetter = dt match {
    -    case BooleanType =>
    -      (rs: ResultSet, row: MutableRow, pos: Int) =>
    -        row.setBoolean(pos, rs.getBoolean(pos + 1))
    -
    -    case DateType =>
    -      (rs: ResultSet, row: MutableRow, pos: Int) =>
    -        // DateTimeUtils.fromJavaDate does not handle null value, so we 
need to check it.
    -        val dateVal = rs.getDate(pos + 1)
    -        if (dateVal != null) {
    -          row.setInt(pos, DateTimeUtils.fromJavaDate(dateVal))
    -        } else {
    -          row.update(pos, null)
    -        }
    -
    -    // When connecting with Oracle DB through JDBC, the precision and 
scale of BigDecimal
    -    // object returned by ResultSet.getBigDecimal is not correctly matched 
to the table
    -    // schema reported by ResultSetMetaData.getPrecision and 
ResultSetMetaData.getScale.
    -    // If inserting values like 19999 into a column with NUMBER(12, 2) 
type, you get through
    -    // a BigDecimal object with scale as 0. But the dataframe schema has 
correct type as
    -    // DecimalType(12, 2). Thus, after saving the dataframe into parquet 
file and then
    -    // retrieve it, you will get wrong result 199.99.
    -    // So it is needed to set precision and scale for Decimal based on 
JDBC metadata.
    -    case DecimalType.Fixed(p, s) =>
    -      (rs: ResultSet, row: MutableRow, pos: Int) =>
    -        val decimal =
    -          nullSafeConvert[java.math.BigDecimal](rs.getBigDecimal(pos + 1), 
d => Decimal(d, p, s))
    -        row.update(pos, decimal)
    -
    -    case DoubleType =>
    -      (rs: ResultSet, row: MutableRow, pos: Int) =>
    -        row.setDouble(pos, rs.getDouble(pos + 1))
    -
    -    case FloatType =>
    -      (rs: ResultSet, row: MutableRow, pos: Int) =>
    -        row.setFloat(pos, rs.getFloat(pos + 1))
    -
    -    case IntegerType =>
    -      (rs: ResultSet, row: MutableRow, pos: Int) =>
    -        row.setInt(pos, rs.getInt(pos + 1))
    -
    -    case LongType if metadata.contains("binarylong") =>
    -      (rs: ResultSet, row: MutableRow, pos: Int) =>
    -        val bytes = rs.getBytes(pos + 1)
    -        var ans = 0L
    -        var j = 0
    -        while (j < bytes.size) {
    -          ans = 256 * ans + (255 & bytes(j))
    -          j = j + 1
    -        }
    -        row.setLong(pos, ans)
    -
    -    case LongType =>
    -      (rs: ResultSet, row: MutableRow, pos: Int) =>
    -        row.setLong(pos, rs.getLong(pos + 1))
    -
    -    case ShortType =>
    -      (rs: ResultSet, row: MutableRow, pos: Int) =>
    -        row.setShort(pos, rs.getShort(pos + 1))
    -
    -    case StringType =>
    -      (rs: ResultSet, row: MutableRow, pos: Int) =>
    -        // TODO(davies): use getBytes for better performance, if the 
encoding is UTF-8
    -        row.update(pos, UTF8String.fromString(rs.getString(pos + 1)))
    -
    -    case TimestampType =>
    -      (rs: ResultSet, row: MutableRow, pos: Int) =>
    -        val t = rs.getTimestamp(pos + 1)
    -        if (t != null) {
    -          row.setLong(pos, DateTimeUtils.fromJavaTimestamp(t))
    -        } else {
    -          row.update(pos, null)
    -        }
    -
    -    case BinaryType =>
    -      (rs: ResultSet, row: MutableRow, pos: Int) =>
    -        row.update(pos, rs.getBytes(pos + 1))
    -
    -    case ArrayType(et, _) =>
    -      val elementConversion = et match {
    -        case TimestampType =>
    -          (array: Object) =>
    -            array.asInstanceOf[Array[java.sql.Timestamp]].map { timestamp 
=>
    -              nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp)
    -            }
    -
    -        case StringType =>
    -          (array: Object) =>
    -            array.asInstanceOf[Array[java.lang.String]]
    -              .map(UTF8String.fromString)
    -
    -        case DateType =>
    -          (array: Object) =>
    -            array.asInstanceOf[Array[java.sql.Date]].map { date =>
    -              nullSafeConvert(date, DateTimeUtils.fromJavaDate)
    -            }
    -
    -        case dt: DecimalType =>
    -          (array: Object) =>
    -            array.asInstanceOf[Array[java.math.BigDecimal]].map { decimal 
=>
    -              nullSafeConvert[java.math.BigDecimal](
    -                decimal, d => Decimal(d, dt.precision, dt.scale))
    -            }
    -
    -        case LongType if metadata.contains("binarylong") =>
    -          throw new IllegalArgumentException(s"Unsupported array element " 
+
    -            s"type ${dt.simpleString} based on binary")
    -
    -        case ArrayType(_, _) =>
    -          throw new IllegalArgumentException("Nested arrays unsupported")
    -
    -        case _ => (array: Object) => array.asInstanceOf[Array[Any]]
    -      }
    -
    -      (rs: ResultSet, row: MutableRow, pos: Int) =>
    -        val array = nullSafeConvert[Object](
    -          rs.getArray(pos + 1).getArray,
    -          array => new GenericArrayData(elementConversion.apply(array)))
    -        row.update(pos, array)
    -
    -    case _ => throw new IllegalArgumentException(s"Unsupported type 
${dt.simpleString}")
    -  }
    -
       /**
        * Runs the SQL query against the JDBC driver.
        *
        */
    -  override def compute(thePart: Partition, context: TaskContext): 
Iterator[InternalRow] =
    -    new Iterator[InternalRow] {
    +  override def compute(thePart: Partition, context: TaskContext): 
Iterator[InternalRow] = {
         var closed = false
    -    var finished = false
    -    var gotNext = false
    -    var nextValue: InternalRow = null
    -
    -    context.addTaskCompletionListener{ context => close() }
    -    val inputMetrics = context.taskMetrics().inputMetrics
    -    val part = thePart.asInstanceOf[JDBCPartition]
    -    val conn = getConnection()
    -    val dialect = JdbcDialects.get(url)
    -    import scala.collection.JavaConverters._
    -    dialect.beforeFetch(conn, properties.asScala.toMap)
    -
    -    // H2's JDBC driver does not support the setSchema() method.  We pass a
    -    // fully-qualified table name in the SELECT statement.  I don't know 
how to
    -    // talk about a table in a completely portable way.
    -
    -    val myWhereClause = getWhereClause(part)
    -
    -    val sqlText = s"SELECT $columnList FROM $fqTable $myWhereClause"
    -    val stmt = conn.prepareStatement(sqlText,
    -        ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
    -    val fetchSize = 
properties.getProperty(JdbcUtils.JDBC_BATCH_FETCH_SIZE, "0").toInt
    -    require(fetchSize >= 0,
    -      s"Invalid value `${fetchSize.toString}` for parameter " +
    -      s"`${JdbcUtils.JDBC_BATCH_FETCH_SIZE}`. The minimum value is 0. When 
the value is 0, " +
    -      "the JDBC driver ignores the value and does the estimates.")
    -    stmt.setFetchSize(fetchSize)
    -    val rs = stmt.executeQuery()
    -
    -    val getters: Array[JDBCValueGetter] = makeGetters(schema)
    -    val mutableRow = new SpecificMutableRow(schema.fields.map(x => 
x.dataType))
    -
    -    def getNext(): InternalRow = {
    -      if (rs.next()) {
    -        inputMetrics.incRecordsRead(1)
    -        var i = 0
    -        while (i < getters.length) {
    -          getters(i).apply(rs, mutableRow, i)
    -          if (rs.wasNull) mutableRow.setNullAt(i)
    -          i = i + 1
    -        }
    -        mutableRow
    -      } else {
    -        finished = true
    -        null.asInstanceOf[InternalRow]
    -      }
    -    }
    +    var rs: ResultSet = null
    +    var stmt: PreparedStatement = null
    +    var conn: Connection = null
     
         def close() {
    --- End diff --
    
    Here, `close()` tears down the result set, statement, and connection. While 
in principle closing the connection should be sufficient to close the statement 
(which should be sufficient to close the result set), not all JDBC drivers 
adhere to this contract so we can't simplify this code at all.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to