Github user cloud-fan commented on a diff in the pull request:

    https://github.com/apache/spark/pull/22749#discussion_r226301402
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
 ---
    @@ -212,21 +183,88 @@ object ExpressionEncoder {
      * A generic encoder for JVM objects that uses Catalyst Expressions for a 
`serializer`
      * and a `deserializer`.
      *
    - * @param schema The schema after converting `T` to a Spark SQL row.
    - * @param serializer A set of expressions, one for each top-level field 
that can be used to
    - *                   extract the values from a raw object into an 
[[InternalRow]].
    - * @param deserializer An expression that will construct an object given 
an [[InternalRow]].
    + * @param objSerializer An expression that can be used to encode a raw 
object to corresponding
    + *                   Spark SQL representation that can be a primitive 
column, array, map or a
    + *                   struct. This represents how Spark SQL generally 
serializes an object of
    + *                   type `T`.
    + * @param objDeserializer An expression that will construct an object 
given a Spark SQL
    + *                        representation. This represents how Spark SQL 
generally deserializes
    + *                        a serialized value in Spark SQL representation 
back to an object of
    + *                        type `T`.
      * @param clsTag A classtag for `T`.
      */
     case class ExpressionEncoder[T](
    -    schema: StructType,
    -    flat: Boolean,
    -    serializer: Seq[Expression],
    -    deserializer: Expression,
    +    objSerializer: Expression,
    +    objDeserializer: Expression,
         clsTag: ClassTag[T])
       extends Encoder[T] {
     
    -  if (flat) require(serializer.size == 1)
    +  /**
    +   * A set of expressions, one for each top-level field that can be used to
    +   * extract the values from a raw object into an [[InternalRow]]:
    +   * 1. If `serializer` encodes a raw object to a struct, we directly use 
the `serializer`.
    +   * 2. For other cases, we create a struct to wrap the `serializer`.
    +   */
    +  val serializer: Seq[NamedExpression] = {
    +    val serializedAsStruct = 
objSerializer.dataType.isInstanceOf[StructType]
    +    val clsName = Utils.getSimpleName(clsTag.runtimeClass)
    +
    +    if (serializedAsStruct) {
    +      val nullSafeSerializer = objSerializer.transformUp {
    +        case r: BoundReference =>
    +          // For input object of Product type, we can't encode it to row 
if it's null, as Spark SQL
    +          // doesn't allow top-level row to be null, only its columns can 
be null.
    +          AssertNotNull(r, Seq("top level Product or row object"))
    +      }
    +      nullSafeSerializer match {
    +        case If(_, _, s: CreateNamedStruct) => s
    +        case s: CreateNamedStruct => s
    +        case _ =>
    +          throw new RuntimeException(s"class $clsName has unexpected 
serializer: $objSerializer")
    +      }
    +    } else {
    +      // For other input objects like primitive, array, map, etc., we 
construct a struct to wrap
    +      // the serializer which is a column of an row.
    +      CreateNamedStruct(Literal("value") :: objSerializer :: Nil)
    +    }
    +  }.flatten
    +
    +  /**
    +   * Returns an expression that can be used to deserialize an input row to 
an object of type `T`
    +   * with a compatible schema. Fields of the row will be extracted using 
`UnresolvedAttribute`.
    +   * of the same name as the constructor arguments.
    +   *
    +   * For complex objects that are encoded to structs, Fields of the struct 
will be extracted using
    +   * `GetColumnByOrdinal` with corresponding ordinal.
    +   */
    +  val deserializer: Expression = {
    +    val serializedAsStruct = 
objSerializer.dataType.isInstanceOf[StructType]
    +
    +    if (serializedAsStruct) {
    +      // We serialized this kind of objects to root-level row. The input 
of general deserializer
    +      // is a `GetColumnByOrdinal(0)` expression to extract first column 
of a row. We need to
    +      // transform attributes accessors.
    +      objDeserializer.transform {
    +        case UnresolvedExtractValue(GetColumnByOrdinal(0, _),
    +            Literal(part: UTF8String, StringType)) =>
    +          UnresolvedAttribute.quoted(part.toString)
    +        case GetStructField(GetColumnByOrdinal(0, dt), ordinal, _) =>
    +          GetColumnByOrdinal(ordinal, dt)
    +        case If(IsNull(GetColumnByOrdinal(0, _)), _, n: NewInstance) => n
    +        case If(IsNull(GetColumnByOrdinal(0, _)), _, i: 
InitializeJavaBean) => i
    +      }
    +    } else {
    +      // For other input objects like primitive, array, map, etc., we 
deserialize the first column
    +      // of a row to the object.
    +      objDeserializer
    +    }
    +  }
    +
    +  // The schema after converting `T` to a Spark SQL row. This schema is 
dependent on the given
    +  // serialier.
    +  val schema: StructType = StructType(serializer.map { s =>
    +    StructField(s.name, s.dataType, s.nullable)
    --- End diff --
    
    nvm, serializer don't need analysis


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to