Github user cloud-fan commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21762#discussion_r202502621
  
    --- Diff: 
external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala 
---
    @@ -0,0 +1,348 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.sql.avro
    +
    +import java.nio.ByteBuffer
    +
    +import scala.collection.JavaConverters._
    +import scala.collection.mutable.ArrayBuffer
    +
    +import org.apache.avro.{Schema, SchemaBuilder}
    +import org.apache.avro.Schema.Type._
    +import org.apache.avro.generic._
    +import org.apache.avro.util.Utf8
    +
    +import org.apache.spark.sql.catalyst.InternalRow
    +import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, 
UnsafeArrayData}
    +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, 
DateTimeUtils, GenericArrayData}
    +import org.apache.spark.sql.types._
    +import org.apache.spark.unsafe.types.UTF8String
    +
    +/**
    + * A deserializer to deserialize data in avro format to data in catalyst 
format.
    + */
    +class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) {
    +  private val converter: Any => Any = rootCatalystType match {
    +    // A shortcut for empty schema.
    +    case st: StructType if st.isEmpty =>
    +      (data: Any) => InternalRow.empty
    +
    +    case st: StructType =>
    +      val resultRow = new SpecificInternalRow(st.map(_.dataType))
    +      val fieldUpdater = new RowUpdater(resultRow)
    +      val writer = getRecordWriter(rootAvroType, st, Nil)
    +      (data: Any) => {
    +        val record = data.asInstanceOf[GenericRecord]
    +        writer(fieldUpdater, record)
    +        resultRow
    +      }
    +
    +    case _ =>
    +      val tmpRow = new SpecificInternalRow(Seq(rootCatalystType))
    +      val fieldUpdater = new RowUpdater(tmpRow)
    +      val writer = newWriter(rootAvroType, rootCatalystType, Nil)
    +      (data: Any) => {
    +        writer(fieldUpdater, 0, data)
    +        tmpRow.get(0, rootCatalystType)
    +      }
    +  }
    +
    +  def deserialize(data: Any): Any = converter(data)
    +
    +  /**
    +   * Creates a writer to writer avro values to Catalyst values at the 
given ordinal with the given
    +   * updater.
    +   */
    +  private def newWriter(
    +      avroType: Schema,
    +      catalystType: DataType,
    +      path: List[String]): (CatalystDataUpdater, Int, Any) => Unit =
    +    (avroType.getType, catalystType) match {
    +      case (NULL, NullType) => (updater, ordinal, _) =>
    +        updater.setNullAt(ordinal)
    +
    +      // TODO: we can avoid boxing if future version of avro provide 
primitive accessors.
    +      case (BOOLEAN, BooleanType) => (updater, ordinal, value) =>
    +        updater.setBoolean(ordinal, value.asInstanceOf[Boolean])
    +
    +      case (INT, IntegerType) => (updater, ordinal, value) =>
    +        updater.setInt(ordinal, value.asInstanceOf[Int])
    +
    +      case (LONG, LongType) => (updater, ordinal, value) =>
    +        updater.setLong(ordinal, value.asInstanceOf[Long])
    +
    +      case (LONG, TimestampType) => (updater, ordinal, value) =>
    +        updater.setLong(ordinal, value.asInstanceOf[Long] * 1000)
    +
    +      case (LONG, DateType) => (updater, ordinal, value) =>
    +        updater.setInt(ordinal, (value.asInstanceOf[Long] / 
DateTimeUtils.MILLIS_PER_DAY).toInt)
    +
    +      case (FLOAT, FloatType) => (updater, ordinal, value) =>
    +        updater.setFloat(ordinal, value.asInstanceOf[Float])
    +
    +      case (DOUBLE, DoubleType) => (updater, ordinal, value) =>
    +        updater.setDouble(ordinal, value.asInstanceOf[Double])
    +
    +      case (STRING, StringType) => (updater, ordinal, value) =>
    +        val str = value match {
    +          case s: String => UTF8String.fromString(s)
    +          case s: Utf8 =>
    +            val bytes = new Array[Byte](s.getByteLength)
    +            System.arraycopy(s.getBytes, 0, bytes, 0, s.getByteLength)
    +            UTF8String.fromBytes(bytes)
    +        }
    +        updater.set(ordinal, str)
    +
    +      case (ENUM, StringType) => (updater, ordinal, value) =>
    +        updater.set(ordinal, UTF8String.fromString(value.toString))
    +
    +      case (FIXED, BinaryType) => (updater, ordinal, value) =>
    +        updater.set(ordinal, 
value.asInstanceOf[GenericFixed].bytes().clone())
    +
    +      case (BYTES, BinaryType) => (updater, ordinal, value) =>
    +        val bytes = value match {
    +          case b: ByteBuffer =>
    +            val bytes = new Array[Byte](b.remaining)
    +            b.get(bytes)
    +            bytes
    +          case b: Array[Byte] => b
    +          case other => throw new RuntimeException(s"$other is not a valid 
avro binary.")
    +
    +        }
    +        updater.set(ordinal, bytes)
    +
    +      case (RECORD, st: StructType) =>
    +        val writeRecord = getRecordWriter(avroType, st, path)
    +        (updater, ordinal, value) =>
    +          val row = new SpecificInternalRow(st)
    +          writeRecord(new RowUpdater(row), 
value.asInstanceOf[GenericRecord])
    +          updater.set(ordinal, row)
    +
    +      case (ARRAY, ArrayType(elementType, containsNull)) =>
    +        val elementWriter = newWriter(avroType.getElementType, 
elementType, path)
    +        (updater, ordinal, value) =>
    +          val array = value.asInstanceOf[GenericData.Array[Any]]
    +          val len = array.size()
    +          val result = createArrayData(elementType, len)
    +          val elementUpdater = new ArrayDataUpdater(result)
    +
    +          var i = 0
    +          while (i < len) {
    +            val element = array.get(i)
    +            if (element == null) {
    +              if (!containsNull) {
    +                throw new RuntimeException(s"Array value at path 
${path.mkString(".")} is not " +
    +                  "allowed to be null")
    +              } else {
    +                elementUpdater.setNullAt(i)
    +              }
    +            } else {
    +              elementWriter(elementUpdater, i, element)
    +            }
    +            i += 1
    +          }
    +
    +          updater.set(ordinal, result)
    +
    +      case (MAP, MapType(keyType, valueType, valueContainsNull)) if 
keyType == StringType =>
    +        val keyWriter = newWriter(SchemaBuilder.builder().stringType(), 
StringType, path)
    +        val valueWriter = newWriter(avroType.getValueType, valueType, path)
    +        (updater, ordinal, value) =>
    +          val map = value.asInstanceOf[java.util.Map[AnyRef, AnyRef]]
    +          val keyArray = createArrayData(keyType, map.size())
    +          val keyUpdater = new ArrayDataUpdater(keyArray)
    +          val valueArray = createArrayData(valueType, map.size())
    +          val valueUpdater = new ArrayDataUpdater(valueArray)
    +          val iter = map.entrySet().iterator()
    +          var i = 0
    +          while (iter.hasNext) {
    +            val entry = iter.next()
    +            assert(entry.getKey != null)
    +            keyWriter(keyUpdater, i, entry.getKey)
    +            if (entry.getValue == null) {
    +              if (!valueContainsNull) {
    +                throw new RuntimeException(s"Map value at path 
${path.mkString(".")} is not " +
    +                  "allowed to be null")
    +              } else {
    +                valueUpdater.setNullAt(i)
    +              }
    +            } else {
    +              valueWriter(valueUpdater, i, entry.getValue)
    +            }
    +            i += 1
    +          }
    +
    +          updater.set(ordinal, new ArrayBasedMapData(keyArray, valueArray))
    +
    +      case (UNION, _) =>
    +        val allTypes = avroType.getTypes.asScala
    +        val nonNullTypes = allTypes.filter(_.getType != NULL)
    +        if (nonNullTypes.nonEmpty) {
    +          if (nonNullTypes.length == 1) {
    +            newWriter(nonNullTypes.head, catalystType, path)
    +          } else {
    +            nonNullTypes.map(_.getType) match {
    +              case Seq(a, b) if Set(a, b) == Set(INT, LONG) && 
catalystType == LongType =>
    +                (updater, ordinal, value) => value match {
    +                  case null => updater.setNullAt(ordinal)
    +                  case l: java.lang.Long => updater.setLong(ordinal, l)
    +                  case i: java.lang.Integer => updater.setLong(ordinal, 
i.longValue())
    +                }
    +
    +              case Seq(a, b) if Set(a, b) == Set(FLOAT, DOUBLE) && 
catalystType == DoubleType =>
    +                (updater, ordinal, value) => value match {
    +                  case null => updater.setNullAt(ordinal)
    +                  case d: java.lang.Double => updater.setDouble(ordinal, d)
    +                  case f: java.lang.Float => updater.setDouble(ordinal, 
f.doubleValue())
    +                }
    +
    +              case _ =>
    +                catalystType match {
    +                  case st: StructType if st.length == nonNullTypes.size =>
    +                    val fieldWriters = nonNullTypes.zip(st.fields).map {
    +                      case (schema, field) => newWriter(schema, 
field.dataType, path :+ field.name)
    +                    }.toArray
    +                    (updater, ordinal, value) => {
    +                      val row = new SpecificInternalRow(st)
    +                      val fieldUpdater = new RowUpdater(row)
    +                      val i = GenericData.get().resolveUnion(avroType, 
value)
    +                      fieldWriters(i)(fieldUpdater, i, value)
    +                      updater.set(ordinal, row)
    +                    }
    +
    +                  case _ =>
    +                    throw new IncompatibleSchemaException(
    +                      s"Cannot convert Avro to catalyst because schema at 
path " +
    +                        s"${path.mkString(".")} is not compatible " +
    +                        s"(avroType = $avroType, sqlType = 
$catalystType).\n" +
    +                        s"Source Avro schema: $rootAvroType.\n" +
    +                        s"Target Catalyst type: $rootCatalystType")
    +                }
    +            }
    +          }
    +        } else {
    +          (updater, ordinal, value) => updater.setNullAt(ordinal)
    +        }
    +
    +      case _ =>
    +        throw new IncompatibleSchemaException(
    +          s"Cannot convert Avro to catalyst because schema at path 
${path.mkString(".")} " +
    +            s"is not compatible (avroType = $avroType, sqlType = 
$catalystType).\n" +
    +            s"Source Avro schema: $rootAvroType.\n" +
    +            s"Target Catalyst type: $rootCatalystType")
    +    }
    +
    +  private def getRecordWriter(
    +      avroType: Schema,
    +      sqlType: StructType,
    +      path: List[String]): (CatalystDataUpdater, GenericRecord) => Unit = {
    +    val validFieldIndexes = ArrayBuffer.empty[Int]
    +    val fieldWriters = ArrayBuffer.empty[(CatalystDataUpdater, Any) => 
Unit]
    +
    +    val length = sqlType.length
    +    var i = 0
    +    while (i < length) {
    +      val sqlField = sqlType.fields(i)
    +      val avroField = avroType.getField(sqlField.name)
    +      if (avroField != null) {
    +        validFieldIndexes += avroField.pos()
    +
    +        val baseWriter = newWriter(avroField.schema(), sqlField.dataType, 
path :+ sqlField.name)
    +        val ordinal = i
    +        val fieldWriter = (fieldUpdater: CatalystDataUpdater, value: Any) 
=> {
    +          if (value == null) {
    +            fieldUpdater.setNullAt(ordinal)
    +          } else {
    +            baseWriter(fieldUpdater, ordinal, value)
    +          }
    +        }
    +        fieldWriters += fieldWriter
    +      } else if (!sqlField.nullable) {
    +        throw new IncompatibleSchemaException(
    +          s"""
    +             |Cannot find non-nullable field 
${path.mkString(".")}.${sqlField.name} in Avro schema.
    +             |Source Avro schema: $rootAvroType.
    +             |Target Catalyst type: $rootCatalystType.
    +           """.stripMargin)
    +      }
    +      i += 1
    +    }
    +
    +    (fieldUpdater, record) => {
    +      var i = 0
    +      while (i < validFieldIndexes.length) {
    +        fieldWriters(i)(fieldUpdater, record.get(validFieldIndexes(i)))
    +        i += 1
    +      }
    +    }
    +  }
    +
    +  private def createArrayData(elementType: DataType, length: Int): 
ArrayData = elementType match {
    +    case BooleanType => UnsafeArrayData.fromPrimitiveArray(new 
Array[Boolean](length))
    +    case ByteType => UnsafeArrayData.fromPrimitiveArray(new 
Array[Byte](length))
    +    case ShortType => UnsafeArrayData.fromPrimitiveArray(new 
Array[Short](length))
    +    case IntegerType => UnsafeArrayData.fromPrimitiveArray(new 
Array[Int](length))
    +    case LongType => UnsafeArrayData.fromPrimitiveArray(new 
Array[Long](length))
    +    case FloatType => UnsafeArrayData.fromPrimitiveArray(new 
Array[Float](length))
    +    case DoubleType => UnsafeArrayData.fromPrimitiveArray(new 
Array[Double](length))
    +    case _ => new GenericArrayData(new Array[Any](length))
    +  }
    +
    +  /**
    +   * A base interface for updating values inside catalyst data structure 
like `InternalRow` and
    +   * `ArrayData`.
    +   */
    +  sealed trait CatalystDataUpdater {
    +    def set(ordinal: Int, value: Any): Unit
    +
    +    def setNullAt(ordinal: Int): Unit = set(ordinal, null)
    +    def setBoolean(ordinal: Int, value: Boolean): Unit = set(ordinal, 
value)
    +    def setByte(ordinal: Int, value: Byte): Unit = set(ordinal, value)
    +    def setShort(ordinal: Int, value: Short): Unit = set(ordinal, value)
    +    def setInt(ordinal: Int, value: Int): Unit = set(ordinal, value)
    +    def setLong(ordinal: Int, value: Long): Unit = set(ordinal, value)
    +    def setDouble(ordinal: Int, value: Double): Unit = set(ordinal, value)
    +    def setFloat(ordinal: Int, value: Float): Unit = set(ordinal, value)
    --- End diff --
    
    seems we don't need these default implementation


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to