Github user cloud-fan commented on a diff in the pull request:

    https://github.com/apache/spark/pull/14313#discussion_r72008447
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
 ---
    @@ -322,46 +322,133 @@ private[sql] class JDBCRDD(
         }
       }
     
    -  // Each JDBC-to-Catalyst conversion corresponds to a tag defined here so 
that
    -  // we don't have to potentially poke around in the Metadata once for 
every
    -  // row.
    -  // Is there a better way to do this?  I'd rather be using a type that
    -  // contains only the tags I define.
    -  abstract class JDBCConversion
    -  case object BooleanConversion extends JDBCConversion
    -  case object DateConversion extends JDBCConversion
    -  case class  DecimalConversion(precision: Int, scale: Int) extends 
JDBCConversion
    -  case object DoubleConversion extends JDBCConversion
    -  case object FloatConversion extends JDBCConversion
    -  case object IntegerConversion extends JDBCConversion
    -  case object LongConversion extends JDBCConversion
    -  case object BinaryLongConversion extends JDBCConversion
    -  case object StringConversion extends JDBCConversion
    -  case object TimestampConversion extends JDBCConversion
    -  case object BinaryConversion extends JDBCConversion
    -  case class ArrayConversion(elementConversion: JDBCConversion) extends 
JDBCConversion
    +  // A `JDBCConversion` is responsible for converting and setting a value 
from `ResultSet`
    +  // into a field for `MutableRow`. The last argument `Int` means the 
index for the
    +  // value to be set in the row and also used for the value to retrieve 
from `ResultSet`.
    +  private type JDBCConversion = (ResultSet, MutableRow, Int) => Unit
     
       /**
    -   * Maps a StructType to a type tag list.
    +   * Maps a StructType to conversions for each type.
        */
       def getConversions(schema: StructType): Array[JDBCConversion] =
         schema.fields.map(sf => getConversions(sf.dataType, sf.metadata))
     
       private def getConversions(dt: DataType, metadata: Metadata): 
JDBCConversion = dt match {
    -    case BooleanType => BooleanConversion
    -    case DateType => DateConversion
    -    case DecimalType.Fixed(p, s) => DecimalConversion(p, s)
    -    case DoubleType => DoubleConversion
    -    case FloatType => FloatConversion
    -    case IntegerType => IntegerConversion
    -    case LongType => if (metadata.contains("binarylong")) 
BinaryLongConversion else LongConversion
    -    case StringType => StringConversion
    -    case TimestampType => TimestampConversion
    -    case BinaryType => BinaryConversion
    -    case ArrayType(et, _) => ArrayConversion(getConversions(et, metadata))
    +    case BooleanType =>
    +      (rs: ResultSet, row: MutableRow, pos: Int) =>
    +        row.setBoolean(pos, rs.getBoolean(pos + 1))
    +
    +    case DateType =>
    +      (rs: ResultSet, row: MutableRow, pos: Int) =>
    +        // DateTimeUtils.fromJavaDate does not handle null value, so we 
need to check it.
    +        val dateVal = rs.getDate(pos + 1)
    +        if (dateVal != null) {
    +          row.setInt(pos, DateTimeUtils.fromJavaDate(dateVal))
    +        } else {
    +          row.update(pos, null)
    +        }
    +
    +    // When connecting with Oracle DB through JDBC, the precision and 
scale of BigDecimal
    +    // object returned by ResultSet.getBigDecimal is not correctly matched 
to the table
    +    // schema reported by ResultSetMetaData.getPrecision and 
ResultSetMetaData.getScale.
    +    // If inserting values like 19999 into a column with NUMBER(12, 2) 
type, you get through
    +    // a BigDecimal object with scale as 0. But the dataframe schema has 
correct type as
    +    // DecimalType(12, 2). Thus, after saving the dataframe into parquet 
file and then
    +    // retrieve it, you will get wrong result 199.99.
    +    // So it is needed to set precision and scale for Decimal based on 
JDBC metadata.
    +    case DecimalType.Fixed(p, s) =>
    +      (rs: ResultSet, row: MutableRow, pos: Int) =>
    +        val decimal =
    +          nullSafeConvert[java.math.BigDecimal](rs.getBigDecimal(pos + 1), 
d => Decimal(d, p, s))
    +        row.update(pos, decimal)
    +
    +    case DoubleType =>
    +      (rs: ResultSet, row: MutableRow, pos: Int) =>
    +        row.setDouble(pos, rs.getDouble(pos + 1))
    +
    +    case FloatType =>
    +      (rs: ResultSet, row: MutableRow, pos: Int) =>
    +        row.setFloat(pos, rs.getFloat(pos + 1))
    +
    +    case IntegerType =>
    +      (rs: ResultSet, row: MutableRow, pos: Int) =>
    +        row.setInt(pos, rs.getInt(pos + 1))
    +
    +    case LongType if metadata.contains("binarylong") =>
    +      (rs: ResultSet, row: MutableRow, pos: Int) =>
    +        val bytes = rs.getBytes(pos + 1)
    +        var ans = 0L
    +        var j = 0
    +        while (j < bytes.size) {
    +          ans = 256 * ans + (255 & bytes(j))
    +          j = j + 1
    +        }
    +        row.setLong(pos, ans)
    +
    +    case LongType =>
    +      (rs: ResultSet, row: MutableRow, pos: Int) =>
    +        row.setLong(pos, rs.getLong(pos + 1))
    +
    +    case StringType =>
    +      (rs: ResultSet, row: MutableRow, pos: Int) =>
    +        // TODO(davies): use getBytes for better performance, if the 
encoding is UTF-8
    +        row.update(pos, UTF8String.fromString(rs.getString(pos + 1)))
    +
    +    case TimestampType =>
    +      (rs: ResultSet, row: MutableRow, pos: Int) =>
    +        val t = rs.getTimestamp(pos + 1)
    +        if (t != null) {
    +          row.setLong(pos, DateTimeUtils.fromJavaTimestamp(t))
    +        } else {
    +          row.update(pos, null)
    +        }
    +
    +    case BinaryType =>
    +      (rs: ResultSet, row: MutableRow, pos: Int) =>
    +        row.update(pos, rs.getBytes(pos + 1))
    +
    +    case ArrayType(et, _) =>
    +      val elementConversion = getArrayElementConversion(et, metadata)
    +      (rs: ResultSet, row: MutableRow, pos: Int) =>
    +        row.update(pos, nullSafeConvert(rs.getArray(pos + 1).getArray, 
elementConversion))
    +
         case _ => throw new IllegalArgumentException(s"Unsupported type 
${dt.simpleString}")
       }
     
    +  private def getArrayElementConversion(dt: DataType, metadata: Metadata) 
= dt match {
    --- End diff --
    
    inline this method?


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to