Github user cloud-fan commented on a diff in the pull request: https://github.com/apache/spark/pull/22749#discussion_r226301402 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala --- @@ -212,21 +183,88 @@ object ExpressionEncoder { * A generic encoder for JVM objects that uses Catalyst Expressions for a `serializer` * and a `deserializer`. * - * @param schema The schema after converting `T` to a Spark SQL row. - * @param serializer A set of expressions, one for each top-level field that can be used to - * extract the values from a raw object into an [[InternalRow]]. - * @param deserializer An expression that will construct an object given an [[InternalRow]]. + * @param objSerializer An expression that can be used to encode a raw object to corresponding + * Spark SQL representation that can be a primitive column, array, map or a + * struct. This represents how Spark SQL generally serializes an object of + * type `T`. + * @param objDeserializer An expression that will construct an object given a Spark SQL + * representation. This represents how Spark SQL generally deserializes + * a serialized value in Spark SQL representation back to an object of + * type `T`. * @param clsTag A classtag for `T`. */ case class ExpressionEncoder[T]( - schema: StructType, - flat: Boolean, - serializer: Seq[Expression], - deserializer: Expression, + objSerializer: Expression, + objDeserializer: Expression, clsTag: ClassTag[T]) extends Encoder[T] { - if (flat) require(serializer.size == 1) + /** + * A set of expressions, one for each top-level field that can be used to + * extract the values from a raw object into an [[InternalRow]]: + * 1. If `serializer` encodes a raw object to a struct, we directly use the `serializer`. + * 2. For other cases, we create a struct to wrap the `serializer`. + */ + val serializer: Seq[NamedExpression] = { + val serializedAsStruct = objSerializer.dataType.isInstanceOf[StructType] + val clsName = Utils.getSimpleName(clsTag.runtimeClass) + + if (serializedAsStruct) { + val nullSafeSerializer = objSerializer.transformUp { + case r: BoundReference => + // For input object of Product type, we can't encode it to row if it's null, as Spark SQL + // doesn't allow top-level row to be null, only its columns can be null. + AssertNotNull(r, Seq("top level Product or row object")) + } + nullSafeSerializer match { + case If(_, _, s: CreateNamedStruct) => s + case s: CreateNamedStruct => s + case _ => + throw new RuntimeException(s"class $clsName has unexpected serializer: $objSerializer") + } + } else { + // For other input objects like primitive, array, map, etc., we construct a struct to wrap + // the serializer which is a column of an row. + CreateNamedStruct(Literal("value") :: objSerializer :: Nil) + } + }.flatten + + /** + * Returns an expression that can be used to deserialize an input row to an object of type `T` + * with a compatible schema. Fields of the row will be extracted using `UnresolvedAttribute`. + * of the same name as the constructor arguments. + * + * For complex objects that are encoded to structs, Fields of the struct will be extracted using + * `GetColumnByOrdinal` with corresponding ordinal. + */ + val deserializer: Expression = { + val serializedAsStruct = objSerializer.dataType.isInstanceOf[StructType] + + if (serializedAsStruct) { + // We serialized this kind of objects to root-level row. The input of general deserializer + // is a `GetColumnByOrdinal(0)` expression to extract first column of a row. We need to + // transform attributes accessors. + objDeserializer.transform { + case UnresolvedExtractValue(GetColumnByOrdinal(0, _), + Literal(part: UTF8String, StringType)) => + UnresolvedAttribute.quoted(part.toString) + case GetStructField(GetColumnByOrdinal(0, dt), ordinal, _) => + GetColumnByOrdinal(ordinal, dt) + case If(IsNull(GetColumnByOrdinal(0, _)), _, n: NewInstance) => n + case If(IsNull(GetColumnByOrdinal(0, _)), _, i: InitializeJavaBean) => i + } + } else { + // For other input objects like primitive, array, map, etc., we deserialize the first column + // of a row to the object. + objDeserializer + } + } + + // The schema after converting `T` to a Spark SQL row. This schema is dependent on the given + // serialier. + val schema: StructType = StructType(serializer.map { s => + StructField(s.name, s.dataType, s.nullable) --- End diff -- nvm, serializer don't need analysis
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org