This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new ddf2da74f52 [SPARK-41993][SQL] Move RowEncoder to AgnosticEncoders ddf2da74f52 is described below commit ddf2da74f527ee00af99fe3928015149f9477734 Author: Herman van Hovell <her...@databricks.com> AuthorDate: Tue Jan 17 10:52:28 2023 -0800 [SPARK-41993][SQL] Move RowEncoder to AgnosticEncoders ### What changes were proposed in this pull request? This PR makes `RowEncoder` produce an `AgnosticEncoder`. The expression generation for these encoders is moved to `ScalaReflection` (this will be moved out in a subsequent PR). The generated serializer and deserializer expressions will slightly change for both schema and type based encoders. These are not semantically different from the old expressions. Concretely the following changes have been introduced: - There is more type validation in maps/arrays/seqs for type based encoders. This should be a positive change, since it disallows users to pass wrong data through erasure hackd. - Array/Seq serialization is a bit more strict. In the old scenario it was possible to pass in sequences/arrays with the wrong type and/or nullability. ### Why are the changes needed? For the Spark Connect Scala Client we also want to be able to use `Row` based results. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? This is a refactoring so mostly existing tests. I have added test to the catalyst tests that triggered failures downstream (typed arrays in `WrappedArray` & `Seq[_]` change in Scala 2.13). Closes #39627 from hvanhovell/SPARK-41993-v2. Authored-by: Herman van Hovell <her...@databricks.com> Signed-off-by: Dongjoon Hyun <dongj...@apache.org> --- .../spark/sql/catalyst/JavaTypeInference.scala | 4 +- .../spark/sql/catalyst/ScalaReflection.scala | 317 ++++++++++++------ .../spark/sql/catalyst/SerializerBuildHelper.scala | 25 +- .../sql/catalyst/encoders/AgnosticEncoder.scala | 128 ++++++-- .../sql/catalyst/encoders/ExpressionEncoder.scala | 5 +- .../spark/sql/catalyst/encoders/RowEncoder.scala | 354 ++++----------------- .../sql/catalyst/expressions/objects/objects.scala | 87 +++-- .../spark/sql/catalyst/ScalaReflectionSuite.scala | 9 +- .../catalyst/encoders/ExpressionEncoderSuite.scala | 2 + .../sql/catalyst/encoders/RowEncoderSuite.scala | 24 ++ .../catalyst/expressions/CodeGenerationSuite.scala | 2 +- .../expressions/ObjectExpressionsSuite.scala | 9 +- 12 files changed, 462 insertions(+), 504 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 827807055ce..81f363dda36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -423,10 +423,10 @@ object JavaTypeInference { case c if c == classOf[java.time.Period] => createSerializerForJavaPeriod(inputObject) case c if c == classOf[java.math.BigInteger] => - createSerializerForJavaBigInteger(inputObject) + createSerializerForBigInteger(inputObject) case c if c == classOf[java.math.BigDecimal] => - createSerializerForJavaBigDecimal(inputObject) + createSerializerForBigDecimal(inputObject) case c if c == classOf[java.lang.Boolean] => createSerializerForBoolean(inputObject) case c if c == classOf[java.lang.Byte] => createSerializerForByte(inputObject) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index e02e42cea1a..42208cd1098 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst import javax.lang.model.SourceVersion import scala.annotation.tailrec +import scala.language.existentials import scala.reflect.ClassTag import scala.reflect.internal.Symbols import scala.util.{Failure, Success} @@ -27,12 +28,13 @@ import scala.util.{Failure, Success} import org.apache.commons.lang3.reflect.ConstructorUtils import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.{expressions => exprs} import org.apache.spark.sql.catalyst.DeserializerBuildHelper._ import org.apache.spark.sql.catalyst.SerializerBuildHelper._ import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ -import org.apache.spark.sql.catalyst.expressions.{Expression, _} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.errors.QueryExecutionErrors @@ -82,12 +84,24 @@ object ScalaReflection extends ScalaReflection { } } - // TODO this name is slightly misleading. This returns the input - // data type we expect to see during serialization. - private[catalyst] def dataTypeFor(enc: AgnosticEncoder[_]): DataType = { + /** + * Return the data type we expect to see when deserializing a value with encoder `enc`. + */ + private[catalyst] def externalDataTypeFor(enc: AgnosticEncoder[_]): DataType = { + externalDataTypeFor(enc, lenientSerialization = false) + } + + private[catalyst] def lenientExternalDataTypeFor(enc: AgnosticEncoder[_]): DataType = + externalDataTypeFor(enc, enc.lenientSerialization) + + private def externalDataTypeFor( + enc: AgnosticEncoder[_], + lenientSerialization: Boolean): DataType = { // DataType can be native. if (isNativeEncoder(enc)) { enc.dataType + } else if (lenientSerialization) { + ObjectType(classOf[java.lang.Object]) } else { ObjectType(enc.clsTag.runtimeClass) } @@ -123,7 +137,7 @@ object ScalaReflection extends ScalaReflection { case NullEncoder => true case CalendarIntervalEncoder => true case BinaryEncoder => true - case SparkDecimalEncoder => true + case _: SparkDecimalEncoder => true case _ => false } @@ -155,11 +169,19 @@ object ScalaReflection extends ScalaReflection { val walkedTypePath = WalkedTypePath().recordRoot(enc.clsTag.runtimeClass.getName) // Assumes we are deserializing the first column of a row. val input = GetColumnByOrdinal(0, enc.dataType) - val deserializer = deserializerFor( - enc, - upCastToExpectedType(input, enc.dataType, walkedTypePath), - walkedTypePath) - expressionWithNullSafety(deserializer, enc.nullable, walkedTypePath) + enc match { + case RowEncoder(fields) => + val children = fields.zipWithIndex.map { case (f, i) => + deserializerFor(f.enc, GetStructField(input, i), walkedTypePath) + } + CreateExternalRow(children, enc.schema) + case _ => + val deserializer = deserializerFor( + enc, + upCastToExpectedType(input, enc.dataType, walkedTypePath), + walkedTypePath) + expressionWithNullSafety(deserializer, enc.nullable, walkedTypePath) + } } /** @@ -178,19 +200,7 @@ object ScalaReflection extends ScalaReflection { walkedTypePath: WalkedTypePath): Expression = enc match { case _ if isNativeEncoder(enc) => path - case BoxedBooleanEncoder => - createDeserializerForTypesSupportValueOf(path, enc.clsTag.runtimeClass) - case BoxedByteEncoder => - createDeserializerForTypesSupportValueOf(path, enc.clsTag.runtimeClass) - case BoxedShortEncoder => - createDeserializerForTypesSupportValueOf(path, enc.clsTag.runtimeClass) - case BoxedIntEncoder => - createDeserializerForTypesSupportValueOf(path, enc.clsTag.runtimeClass) - case BoxedLongEncoder => - createDeserializerForTypesSupportValueOf(path, enc.clsTag.runtimeClass) - case BoxedFloatEncoder => - createDeserializerForTypesSupportValueOf(path, enc.clsTag.runtimeClass) - case BoxedDoubleEncoder => + case _: BoxedLeafEncoder[_, _] => createDeserializerForTypesSupportValueOf(path, enc.clsTag.runtimeClass) case JavaEnumEncoder(tag) => val toString = createDeserializerForString(path, returnNullable = false) @@ -204,9 +214,9 @@ object ScalaReflection extends ScalaReflection { returnNullable = false) case StringEncoder => createDeserializerForString(path, returnNullable = false) - case ScalaDecimalEncoder => + case _: ScalaDecimalEncoder => createDeserializerForScalaBigDecimal(path, returnNullable = false) - case JavaDecimalEncoder => + case _: JavaDecimalEncoder => createDeserializerForJavaBigDecimal(path, returnNullable = false) case ScalaBigIntEncoder => createDeserializerForScalaBigInt(path) @@ -216,13 +226,13 @@ object ScalaReflection extends ScalaReflection { createDeserializerForDuration(path) case YearMonthIntervalEncoder => createDeserializerForPeriod(path) - case DateEncoder => + case _: DateEncoder => createDeserializerForSqlDate(path) - case LocalDateEncoder => + case _: LocalDateEncoder => createDeserializerForLocalDate(path) - case TimestampEncoder => + case _: TimestampEncoder => createDeserializerForSqlTimestamp(path) - case InstantEncoder => + case _: InstantEncoder => createDeserializerForInstant(path) case LocalDateTimeEncoder => createDeserializerForLocalDateTime(path) @@ -232,39 +242,29 @@ object ScalaReflection extends ScalaReflection { case OptionEncoder(valueEnc) => val newTypePath = walkedTypePath.recordOption(valueEnc.clsTag.runtimeClass.getName) val deserializer = deserializerFor(valueEnc, path, newTypePath) - WrapOption(deserializer, dataTypeFor(valueEnc)) - - case ArrayEncoder(elementEnc: AgnosticEncoder[_]) => - val newTypePath = walkedTypePath.recordArray(elementEnc.clsTag.runtimeClass.getName) - val mapFunction: Expression => Expression = element => { - // upcast the array element to the data type the encoder expected. - deserializerForWithNullSafetyAndUpcast( - element, - elementEnc.dataType, - nullable = elementEnc.nullable, - newTypePath, - deserializerFor(elementEnc, _, newTypePath)) - } + WrapOption(deserializer, externalDataTypeFor(valueEnc)) + + case ArrayEncoder(elementEnc: AgnosticEncoder[_], containsNull) => Invoke( - UnresolvedMapObjects(mapFunction, path), + deserializeArray( + path, + elementEnc, + containsNull, + None, + walkedTypePath), toArrayMethodName(elementEnc), ObjectType(enc.clsTag.runtimeClass), returnNullable = false) - case IterableEncoder(clsTag, elementEnc) => - val newTypePath = walkedTypePath.recordArray(elementEnc.clsTag.runtimeClass.getName) - val mapFunction: Expression => Expression = element => { - // upcast the array element to the data type the encoder expected. - deserializerForWithNullSafetyAndUpcast( - element, - elementEnc.dataType, - nullable = elementEnc.nullable, - newTypePath, - deserializerFor(elementEnc, _, newTypePath)) - } - UnresolvedMapObjects(mapFunction, path, Some(clsTag.runtimeClass)) + case IterableEncoder(clsTag, elementEnc, containsNull, _) => + deserializeArray( + path, + elementEnc, + containsNull, + Option(clsTag.runtimeClass), + walkedTypePath) - case MapEncoder(tag, keyEncoder, valueEncoder) => + case MapEncoder(tag, keyEncoder, valueEncoder, _) => val newTypePath = walkedTypePath.recordMap( keyEncoder.clsTag.runtimeClass.getName, valueEncoder.clsTag.runtimeClass.getName) @@ -298,6 +298,39 @@ object ScalaReflection extends ScalaReflection { IsNull(path), expressions.Literal.create(null, dt), NewInstance(cls, arguments, dt, propagateNull = false)) + + case RowEncoder(fields) => + val convertedFields = fields.zipWithIndex.map { case (f, i) => + val newTypePath = walkedTypePath.recordField( + f.enc.clsTag.runtimeClass.getName, + f.name) + exprs.If( + Invoke(path, "isNullAt", BooleanType, exprs.Literal(i) :: Nil), + exprs.Literal.create(null, externalDataTypeFor(f.enc)), + deserializerFor(f.enc, GetStructField(path, i), newTypePath)) + } + exprs.If(IsNull(path), + exprs.Literal.create(null, externalDataTypeFor(enc)), + CreateExternalRow(convertedFields, enc.schema)) + } + + private def deserializeArray( + path: Expression, + elementEnc: AgnosticEncoder[_], + containsNull: Boolean, + cls: Option[Class[_]], + walkedTypePath: WalkedTypePath): Expression = { + val newTypePath = walkedTypePath.recordArray(elementEnc.clsTag.runtimeClass.getName) + val mapFunction: Expression => Expression = element => { + // upcast the array element to the data type the encoder expects. + deserializerForWithNullSafetyAndUpcast( + element, + elementEnc.dataType, + nullable = containsNull, + newTypePath, + deserializerFor(elementEnc, _, newTypePath)) + } + UnresolvedMapObjects(mapFunction, path, cls) } /** @@ -306,7 +339,7 @@ object ScalaReflection extends ScalaReflection { * input object is located at ordinal 0 of a row, i.e., `BoundReference(0, _)`. */ def serializerFor(enc: AgnosticEncoder[_]): Expression = { - val input = BoundReference(0, dataTypeFor(enc), nullable = enc.nullable) + val input = BoundReference(0, lenientExternalDataTypeFor(enc), nullable = enc.nullable) serializerFor(enc, input) } @@ -327,45 +360,52 @@ object ScalaReflection extends ScalaReflection { case JavaEnumEncoder(_) => createSerializerForJavaEnum(input) case ScalaEnumEncoder(_, _) => createSerializerForScalaEnum(input) case StringEncoder => createSerializerForString(input) - case ScalaDecimalEncoder => createSerializerForScalaBigDecimal(input) - case JavaDecimalEncoder => createSerializerForJavaBigDecimal(input) - case ScalaBigIntEncoder => createSerializerForScalaBigInt(input) - case JavaBigIntEncoder => createSerializerForJavaBigInteger(input) + case ScalaDecimalEncoder(dt) => createSerializerForBigDecimal(input, dt) + case JavaDecimalEncoder(dt, false) => createSerializerForBigDecimal(input, dt) + case JavaDecimalEncoder(dt, true) => createSerializerForAnyDecimal(input, dt) + case ScalaBigIntEncoder => createSerializerForBigInteger(input) + case JavaBigIntEncoder => createSerializerForBigInteger(input) case DayTimeIntervalEncoder => createSerializerForJavaDuration(input) case YearMonthIntervalEncoder => createSerializerForJavaPeriod(input) - case DateEncoder => createSerializerForSqlDate(input) - case LocalDateEncoder => createSerializerForJavaLocalDate(input) - case TimestampEncoder => createSerializerForSqlTimestamp(input) - case InstantEncoder => createSerializerForJavaInstant(input) + case DateEncoder(true) | LocalDateEncoder(true) => createSerializerForAnyDate(input) + case DateEncoder(false) => createSerializerForSqlDate(input) + case LocalDateEncoder(false) => createSerializerForJavaLocalDate(input) + case TimestampEncoder(true) | InstantEncoder(true) => createSerializerForAnyTimestamp(input) + case TimestampEncoder(false) => createSerializerForSqlTimestamp(input) + case InstantEncoder(false) => createSerializerForJavaInstant(input) case LocalDateTimeEncoder => createSerializerForLocalDateTime(input) case UDTEncoder(udt, udtClass) => createSerializerForUserDefinedType(input, udt, udtClass) case OptionEncoder(valueEnc) => - serializerFor(valueEnc, UnwrapOption(dataTypeFor(valueEnc), input)) + serializerFor(valueEnc, UnwrapOption(externalDataTypeFor(valueEnc), input)) - case ArrayEncoder(elementEncoder) => - serializerForArray(isArray = true, elementEncoder, input) + case ArrayEncoder(elementEncoder, containsNull) => + if (elementEncoder.isPrimitive) { + createSerializerForPrimitiveArray(input, elementEncoder.dataType) + } else { + serializerForArray(elementEncoder, containsNull, input, lenientSerialization = false) + } - case IterableEncoder(ctag, elementEncoder) => + case IterableEncoder(ctag, elementEncoder, containsNull, lenientSerialization) => val getter = if (classOf[scala.collection.Set[_]].isAssignableFrom(ctag.runtimeClass)) { // There's no corresponding Catalyst type for `Set`, we serialize a `Set` to Catalyst array. // Note that the property of `Set` is only kept when manipulating the data as domain object. - Invoke(input, "toSeq", ObjectType(classOf[Seq[_]])) + Invoke(input, "toSeq", ObjectType(classOf[scala.collection.Seq[_]])) } else { input } - serializerForArray(isArray = false, elementEncoder, getter) + serializerForArray(elementEncoder, containsNull, getter, lenientSerialization) - case MapEncoder(_, keyEncoder, valueEncoder) => + case MapEncoder(_, keyEncoder, valueEncoder, valueContainsNull) => createSerializerForMap( input, MapElementInformation( - dataTypeFor(keyEncoder), - nullable = !keyEncoder.isPrimitive, - serializerFor(keyEncoder, _)), + ObjectType(classOf[AnyRef]), + nullable = keyEncoder.nullable, + validateAndSerializeElement(keyEncoder, keyEncoder.nullable)), MapElementInformation( - dataTypeFor(valueEncoder), - nullable = !valueEncoder.isPrimitive, - serializerFor(valueEncoder, _)) + ObjectType(classOf[AnyRef]), + nullable = valueContainsNull, + validateAndSerializeElement(valueEncoder, valueContainsNull)) ) case ProductEncoder(_, fields) => @@ -377,25 +417,94 @@ object ScalaReflection extends ScalaReflection { val getter = Invoke( KnownNotNull(input), field.name, - dataTypeFor(field.enc), - returnNullable = field.enc.nullable) + externalDataTypeFor(field.enc), + returnNullable = field.nullable) field.name -> serializerFor(field.enc, getter) } createSerializerForObject(input, serializedFields) + + case RowEncoder(fields) => + val serializedFields = fields.zipWithIndex.map { case (field, index) => + val fieldValue = serializerFor( + field.enc, + ValidateExternalType( + GetExternalRowField(input, index, field.name), + field.enc.dataType, + lenientExternalDataTypeFor(field.enc))) + + val convertedField = if (field.nullable) { + exprs.If( + Invoke(input, "isNullAt", BooleanType, exprs.Literal(index) :: Nil), + // Because we strip UDTs, `field.dataType` can be different from `fieldValue.dataType`. + // We should use `fieldValue.dataType` here. + exprs.Literal.create(null, fieldValue.dataType), + fieldValue + ) + } else { + AssertNotNull(fieldValue) + } + field.name -> convertedField + } + createSerializerForObject(input, serializedFields) } private def serializerForArray( - isArray: Boolean, elementEnc: AgnosticEncoder[_], - input: Expression): Expression = { - dataTypeFor(elementEnc) match { - case dt: ObjectType => - createSerializerForMapObjects(input, dt, serializerFor(elementEnc, _)) - case dt if isArray && elementEnc.isPrimitive => - createSerializerForPrimitiveArray(input, dt) - case dt => - createSerializerForGenericArray(input, dt, elementEnc.nullable) + elementNullable: Boolean, + input: Expression, + lenientSerialization: Boolean): Expression = { + // Default serializer for Seq and generic Arrays. This does not work for primitive arrays. + val genericSerializer = createSerializerForMapObjects( + input, + ObjectType(classOf[AnyRef]), + validateAndSerializeElement(elementEnc, elementNullable)) + + // Check if it is possible the user can pass a primitive array. This is the only case when it + // is safe to directly convert to an array (for generic arrays and Seqs the type and the + // nullability can be violated). If the user has passed a primitive array we create a special + // code path to deal with these. + val primitiveEncoderOption = elementEnc match { + case _ if !lenientSerialization => None + case enc: PrimitiveLeafEncoder[_] => Option(enc) + case enc: BoxedLeafEncoder[_, _] => Option(enc.primitive) + case _ => None } + primitiveEncoderOption match { + case Some(primitiveEncoder) => + val primitiveArrayClass = primitiveEncoder.clsTag.wrap.runtimeClass + val check = Invoke( + targetObject = exprs.Literal.fromObject(primitiveArrayClass), + functionName = "isInstance", + BooleanType, + arguments = input :: Nil, + propagateNull = false, + returnNullable = false) + exprs.If( + check, + // TODO replace this with `createSerializerForPrimitiveArray` as + // soon as Cast support ObjectType casts. + StaticInvoke( + classOf[ArrayData], + ArrayType(elementEnc.dataType, containsNull = false), + "toArrayData", + input :: Nil, + propagateNull = false, + returnNullable = false), + genericSerializer) + case None => + genericSerializer + } + } + + private def validateAndSerializeElement( + enc: AgnosticEncoder[_], + nullable: Boolean): Expression => Expression = { input => + expressionWithNullSafety( + serializerFor( + enc, + ValidateExternalType(input, enc.dataType, lenientExternalDataTypeFor(enc))), + nullable, + WalkedTypePath()) } /** @@ -598,8 +707,8 @@ object ScalaReflection extends ScalaReflection { case StringType => classOf[UTF8String] case CalendarIntervalType => classOf[CalendarInterval] case _: StructType => classOf[InternalRow] - case _: ArrayType => classOf[ArrayType] - case _: MapType => classOf[MapType] + case _: ArrayType => classOf[ArrayData] + case _: MapType => classOf[MapData] case udt: UserDefinedType[_] => javaBoxedType(udt.sqlType) case ObjectType(cls) => cls case _ => ScalaReflection.typeBoxedJavaMapping.getOrElse(dt, classOf[java.lang.Object]) @@ -657,7 +766,11 @@ object ScalaReflection extends ScalaReflection { case NoSymbol => fallbackClass case _ => mirror.runtimeClass(t.typeSymbol.asClass) } - IterableEncoder(ClassTag(targetClass), encoder) + IterableEncoder( + ClassTag(targetClass), + encoder, + encoder.nullable, + lenientSerialization = false) } baseType(tpe) match { @@ -698,18 +811,18 @@ object ScalaReflection extends ScalaReflection { // Leaf encoders case t if isSubtype(t, localTypeOf[String]) => StringEncoder - case t if isSubtype(t, localTypeOf[Decimal]) => SparkDecimalEncoder - case t if isSubtype(t, localTypeOf[BigDecimal]) => ScalaDecimalEncoder - case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) => JavaDecimalEncoder + case t if isSubtype(t, localTypeOf[Decimal]) => DEFAULT_SPARK_DECIMAL_ENCODER + case t if isSubtype(t, localTypeOf[BigDecimal]) => DEFAULT_SCALA_DECIMAL_ENCODER + case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) => DEFAULT_JAVA_DECIMAL_ENCODER case t if isSubtype(t, localTypeOf[BigInt]) => ScalaBigIntEncoder case t if isSubtype(t, localTypeOf[java.math.BigInteger]) => JavaBigIntEncoder case t if isSubtype(t, localTypeOf[CalendarInterval]) => CalendarIntervalEncoder case t if isSubtype(t, localTypeOf[java.time.Duration]) => DayTimeIntervalEncoder case t if isSubtype(t, localTypeOf[java.time.Period]) => YearMonthIntervalEncoder - case t if isSubtype(t, localTypeOf[java.sql.Date]) => DateEncoder - case t if isSubtype(t, localTypeOf[java.time.LocalDate]) => LocalDateEncoder - case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) => TimestampEncoder - case t if isSubtype(t, localTypeOf[java.time.Instant]) => InstantEncoder + case t if isSubtype(t, localTypeOf[java.sql.Date]) => STRICT_DATE_ENCODER + case t if isSubtype(t, localTypeOf[java.time.LocalDate]) => STRICT_LOCAL_DATE_ENCODER + case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) => STRICT_TIMESTAMP_ENCODER + case t if isSubtype(t, localTypeOf[java.time.Instant]) => STRICT_INSTANT_ENCODER case t if isSubtype(t, localTypeOf[java.time.LocalDateTime]) => LocalDateTimeEncoder // UDT encoders @@ -739,7 +852,7 @@ object ScalaReflection extends ScalaReflection { elementType, seenTypeSet, path.recordArray(getClassNameFromType(elementType))) - ArrayEncoder(encoder) + ArrayEncoder(encoder, encoder.nullable) case t if isSubtype(t, localTypeOf[scala.collection.Seq[_]]) => createIterableEncoder(t, classOf[scala.collection.Seq[_]]) @@ -757,7 +870,7 @@ object ScalaReflection extends ScalaReflection { valueType, seenTypeSet, path.recordValueForMap(getClassNameFromType(valueType))) - MapEncoder(ClassTag(getClassFromType(t)), keyEncoder, valueEncoder) + MapEncoder(ClassTag(getClassFromType(t)), keyEncoder, valueEncoder, valueEncoder.nullable) case t if definedByConstructorParams(t) => if (seenTypeSet.contains(t)) { @@ -775,7 +888,7 @@ object ScalaReflection extends ScalaReflection { fieldType, seenTypeSet + t, path.recordField(getClassNameFromType(fieldType), fieldName)) - EncoderField(fieldName, encoder) + EncoderField(fieldName, encoder, encoder.nullable, Metadata.empty) } ProductEncoder(ClassTag(getClassFromType(t)), params) case _ => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala index 25f6ce520d9..33b0edb0c44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala @@ -158,20 +158,29 @@ object SerializerBuildHelper { returnNullable = false) } - def createSerializerForJavaBigDecimal(inputObject: Expression): Expression = { + def createSerializerForBigDecimal(inputObject: Expression): Expression = { + createSerializerForBigDecimal(inputObject, DecimalType.SYSTEM_DEFAULT) + } + + def createSerializerForBigDecimal(inputObject: Expression, dt: DecimalType): Expression = { CheckOverflow(StaticInvoke( Decimal.getClass, - DecimalType.SYSTEM_DEFAULT, + dt, "apply", inputObject :: Nil, - returnNullable = false), DecimalType.SYSTEM_DEFAULT, nullOnOverflow) + returnNullable = false), dt, nullOnOverflow) } - def createSerializerForScalaBigDecimal(inputObject: Expression): Expression = { - createSerializerForJavaBigDecimal(inputObject) + def createSerializerForAnyDecimal(inputObject: Expression, dt: DecimalType): Expression = { + CheckOverflow(StaticInvoke( + Decimal.getClass, + dt, + "fromDecimal", + inputObject :: Nil, + returnNullable = false), dt, nullOnOverflow) } - def createSerializerForJavaBigInteger(inputObject: Expression): Expression = { + def createSerializerForBigInteger(inputObject: Expression): Expression = { CheckOverflow(StaticInvoke( Decimal.getClass, DecimalType.BigIntDecimal, @@ -180,10 +189,6 @@ object SerializerBuildHelper { returnNullable = false), DecimalType.BigIntDecimal, nullOnOverflow) } - def createSerializerForScalaBigInt(inputObject: Expression): Expression = { - createSerializerForJavaBigInteger(inputObject) - } - def createSerializerForPrimitiveArray( inputObject: Expression, dataType: DataType): Expression = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala index 6081ac8dc28..cdc64f2ddb5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala @@ -16,28 +16,33 @@ */ package org.apache.spark.sql.catalyst.encoders +import java.{sql => jsql} import java.math.{BigDecimal => JBigDecimal, BigInteger => JBigInt} import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period} import scala.reflect.{classTag, ClassTag} -import org.apache.spark.sql.Encoder +import org.apache.spark.sql.{Encoder, Row} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval /** * A non implementation specific encoder. This encoder containers all the information needed * to generate an implementation specific encoder (e.g. InternalRow <=> Custom Object). + * + * The input of the serialization does not need to match the external type of the encoder. This is + * called lenient serialization. An example of this is lenient date serialization, in this case both + * [[java.sql.Date]] and [[java.time.LocalDate]] are allowed. Deserialization is never lenient; it + * will always produce instance of the external type. */ trait AgnosticEncoder[T] extends Encoder[T] { def isPrimitive: Boolean def nullable: Boolean = !isPrimitive def dataType: DataType override def schema: StructType = StructType(StructField("value", dataType, nullable) :: Nil) + def lenientSerialization: Boolean = false } -// TODO check RowEncoder -// TODO check BeanEncoder object AgnosticEncoders { case class OptionEncoder[E](elementEncoder: AgnosticEncoder[E]) extends AgnosticEncoder[Option[E]] { @@ -46,35 +51,48 @@ object AgnosticEncoders { override val clsTag: ClassTag[Option[E]] = ClassTag(classOf[Option[E]]) } - case class ArrayEncoder[E](element: AgnosticEncoder[E]) + case class ArrayEncoder[E](element: AgnosticEncoder[E], containsNull: Boolean) extends AgnosticEncoder[Array[E]] { override def isPrimitive: Boolean = false - override def dataType: DataType = ArrayType(element.dataType, element.nullable) + override def dataType: DataType = ArrayType(element.dataType, containsNull) override val clsTag: ClassTag[Array[E]] = element.clsTag.wrap } - case class IterableEncoder[C <: Iterable[E], E]( + /** + * Encoder for collections. + * + * This encoder can be lenient for [[Row]] encoders. In that case we allow [[Seq]], primitive + * array (if any), and generic arrays as input. + */ + case class IterableEncoder[C, E]( override val clsTag: ClassTag[C], - element: AgnosticEncoder[E]) + element: AgnosticEncoder[E], + containsNull: Boolean, + override val lenientSerialization: Boolean) extends AgnosticEncoder[C] { override def isPrimitive: Boolean = false - override val dataType: DataType = ArrayType(element.dataType, element.nullable) + override val dataType: DataType = ArrayType(element.dataType, containsNull) } case class MapEncoder[C, K, V]( override val clsTag: ClassTag[C], keyEncoder: AgnosticEncoder[K], - valueEncoder: AgnosticEncoder[V]) + valueEncoder: AgnosticEncoder[V], + valueContainsNull: Boolean) extends AgnosticEncoder[C] { override def isPrimitive: Boolean = false override val dataType: DataType = MapType( keyEncoder.dataType, valueEncoder.dataType, - valueEncoder.nullable) + valueContainsNull) } - case class EncoderField(name: String, enc: AgnosticEncoder[_]) { - def structField: StructField = StructField(name, enc.dataType, enc.nullable) + case class EncoderField( + name: String, + enc: AgnosticEncoder[_], + nullable: Boolean, + metadata: Metadata) { + def structField: StructField = StructField(name, enc.dataType, nullable, metadata) } // This supports both Product and DefinedByConstructorParams @@ -87,6 +105,13 @@ object AgnosticEncoders { override def dataType: DataType = schema } + case class RowEncoder(fields: Seq[EncoderField]) extends AgnosticEncoder[Row] { + override def isPrimitive: Boolean = false + override val schema: StructType = StructType(fields.map(_.structField)) + override def dataType: DataType = schema + override def clsTag: ClassTag[Row] = classTag[Row] + } + // This will only work for encoding from/to Sparks' InternalRow format. // It is here for compatibility. case class UDTEncoder[E >: Null]( @@ -116,39 +141,74 @@ object AgnosticEncoders { } // Primitive encoders - case object PrimitiveBooleanEncoder extends LeafEncoder[Boolean](BooleanType) - case object PrimitiveByteEncoder extends LeafEncoder[Byte](ByteType) - case object PrimitiveShortEncoder extends LeafEncoder[Short](ShortType) - case object PrimitiveIntEncoder extends LeafEncoder[Int](IntegerType) - case object PrimitiveLongEncoder extends LeafEncoder[Long](LongType) - case object PrimitiveFloatEncoder extends LeafEncoder[Float](FloatType) - case object PrimitiveDoubleEncoder extends LeafEncoder[Double](DoubleType) + abstract class PrimitiveLeafEncoder[E : ClassTag](dataType: DataType) + extends LeafEncoder[E](dataType) + case object PrimitiveBooleanEncoder extends PrimitiveLeafEncoder[Boolean](BooleanType) + case object PrimitiveByteEncoder extends PrimitiveLeafEncoder[Byte](ByteType) + case object PrimitiveShortEncoder extends PrimitiveLeafEncoder[Short](ShortType) + case object PrimitiveIntEncoder extends PrimitiveLeafEncoder[Int](IntegerType) + case object PrimitiveLongEncoder extends PrimitiveLeafEncoder[Long](LongType) + case object PrimitiveFloatEncoder extends PrimitiveLeafEncoder[Float](FloatType) + case object PrimitiveDoubleEncoder extends PrimitiveLeafEncoder[Double](DoubleType) // Primitive wrapper encoders. - case object NullEncoder extends LeafEncoder[java.lang.Void](NullType) - case object BoxedBooleanEncoder extends LeafEncoder[java.lang.Boolean](BooleanType) - case object BoxedByteEncoder extends LeafEncoder[java.lang.Byte](ByteType) - case object BoxedShortEncoder extends LeafEncoder[java.lang.Short](ShortType) - case object BoxedIntEncoder extends LeafEncoder[java.lang.Integer](IntegerType) - case object BoxedLongEncoder extends LeafEncoder[java.lang.Long](LongType) - case object BoxedFloatEncoder extends LeafEncoder[java.lang.Float](FloatType) - case object BoxedDoubleEncoder extends LeafEncoder[java.lang.Double](DoubleType) + abstract class BoxedLeafEncoder[E : ClassTag, P]( + dataType: DataType, + val primitive: PrimitiveLeafEncoder[P]) + extends LeafEncoder[E](dataType) + case object BoxedBooleanEncoder + extends BoxedLeafEncoder[java.lang.Boolean, Boolean](BooleanType, PrimitiveBooleanEncoder) + case object BoxedByteEncoder + extends BoxedLeafEncoder[java.lang.Byte, Byte](ByteType, PrimitiveByteEncoder) + case object BoxedShortEncoder + extends BoxedLeafEncoder[java.lang.Short, Short](ShortType, PrimitiveShortEncoder) + case object BoxedIntEncoder + extends BoxedLeafEncoder[java.lang.Integer, Int](IntegerType, PrimitiveIntEncoder) + case object BoxedLongEncoder + extends BoxedLeafEncoder[java.lang.Long, Long](LongType, PrimitiveLongEncoder) + case object BoxedFloatEncoder + extends BoxedLeafEncoder[java.lang.Float, Float](FloatType, PrimitiveFloatEncoder) + case object BoxedDoubleEncoder + extends BoxedLeafEncoder[java.lang.Double, Double](DoubleType, PrimitiveDoubleEncoder) // Nullable leaf encoders + case object NullEncoder extends LeafEncoder[java.lang.Void](NullType) case object StringEncoder extends LeafEncoder[String](StringType) case object BinaryEncoder extends LeafEncoder[Array[Byte]](BinaryType) - case object SparkDecimalEncoder extends LeafEncoder[Decimal](DecimalType.SYSTEM_DEFAULT) - case object ScalaDecimalEncoder extends LeafEncoder[BigDecimal](DecimalType.SYSTEM_DEFAULT) - case object JavaDecimalEncoder extends LeafEncoder[JBigDecimal](DecimalType.SYSTEM_DEFAULT) case object ScalaBigIntEncoder extends LeafEncoder[BigInt](DecimalType.BigIntDecimal) case object JavaBigIntEncoder extends LeafEncoder[JBigInt](DecimalType.BigIntDecimal) case object CalendarIntervalEncoder extends LeafEncoder[CalendarInterval](CalendarIntervalType) case object DayTimeIntervalEncoder extends LeafEncoder[Duration](DayTimeIntervalType()) case object YearMonthIntervalEncoder extends LeafEncoder[Period](YearMonthIntervalType()) - case object DateEncoder extends LeafEncoder[java.sql.Date](DateType) - case object LocalDateEncoder extends LeafEncoder[LocalDate](DateType) - case object TimestampEncoder extends LeafEncoder[java.sql.Timestamp](TimestampType) - case object InstantEncoder extends LeafEncoder[Instant](TimestampType) + case class DateEncoder(override val lenientSerialization: Boolean) + extends LeafEncoder[jsql.Date](DateType) + case class LocalDateEncoder(override val lenientSerialization: Boolean) + extends LeafEncoder[LocalDate](DateType) + case class TimestampEncoder(override val lenientSerialization: Boolean) + extends LeafEncoder[jsql.Timestamp](TimestampType) + case class InstantEncoder(override val lenientSerialization: Boolean) + extends LeafEncoder[Instant](TimestampType) case object LocalDateTimeEncoder extends LeafEncoder[LocalDateTime](TimestampNTZType) + + case class SparkDecimalEncoder(dt: DecimalType) extends LeafEncoder[Decimal](dt) + case class ScalaDecimalEncoder(dt: DecimalType) extends LeafEncoder[BigDecimal](dt) + case class JavaDecimalEncoder(dt: DecimalType, override val lenientSerialization: Boolean) + extends LeafEncoder[JBigDecimal](dt) + + val STRICT_DATE_ENCODER: DateEncoder = DateEncoder(lenientSerialization = false) + val STRICT_LOCAL_DATE_ENCODER: LocalDateEncoder = LocalDateEncoder(lenientSerialization = false) + val STRICT_TIMESTAMP_ENCODER: TimestampEncoder = TimestampEncoder(lenientSerialization = false) + val STRICT_INSTANT_ENCODER: InstantEncoder = InstantEncoder(lenientSerialization = false) + val LENIENT_DATE_ENCODER: DateEncoder = DateEncoder(lenientSerialization = true) + val LENIENT_LOCAL_DATE_ENCODER: LocalDateEncoder = LocalDateEncoder(lenientSerialization = true) + val LENIENT_TIMESTAMP_ENCODER: TimestampEncoder = TimestampEncoder(lenientSerialization = true) + val LENIENT_INSTANT_ENCODER: InstantEncoder = InstantEncoder(lenientSerialization = true) + + val DEFAULT_SPARK_DECIMAL_ENCODER: SparkDecimalEncoder = + SparkDecimalEncoder(DecimalType.SYSTEM_DEFAULT) + val DEFAULT_SCALA_DECIMAL_ENCODER: ScalaDecimalEncoder = + ScalaDecimalEncoder(DecimalType.SYSTEM_DEFAULT) + val DEFAULT_JAVA_DECIMAL_ENCODER: JavaDecimalEncoder = + JavaDecimalEncoder(DecimalType.SYSTEM_DEFAULT, lenientSerialization = false) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 82a6863b5ff..9ca2fc72ad9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -47,7 +47,10 @@ import org.apache.spark.util.Utils object ExpressionEncoder { def apply[T : TypeTag](): ExpressionEncoder[T] = { - val enc = ScalaReflection.encoderFor[T] + apply(ScalaReflection.encoderFor[T]) + } + + def apply[T](enc: AgnosticEncoder[T]): ExpressionEncoder[T] = { new ExpressionEncoder[T]( ScalaReflection.serializerFor(enc), ScalaReflection.deserializerFor(enc), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 8eb3475acaa..78243894544 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -17,19 +17,11 @@ package org.apache.spark.sql.catalyst.encoders -import scala.annotation.tailrec -import scala.collection.Map -import scala.reflect.ClassTag +import scala.collection.mutable +import scala.reflect.classTag import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.{ScalaReflection, WalkedTypePath} -import org.apache.spark.sql.catalyst.DeserializerBuildHelper._ -import org.apache.spark.sql.catalyst.SerializerBuildHelper._ -import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder, DateEncoder, DayTimeIntervalEncoder, EncoderField, InstantEncoder, IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, NullEncoder, RowEncoder => AgnosticRowEncoder, StringEncoder, TimestampEncoder, UDTEncoder, YearMont [...] import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -68,224 +60,46 @@ import org.apache.spark.sql.types._ */ object RowEncoder { def apply(schema: StructType, lenient: Boolean): ExpressionEncoder[Row] = { - val cls = classOf[Row] - val inputObject = BoundReference(0, ObjectType(cls), nullable = true) - val serializer = serializerFor(inputObject, schema, lenient) - val deserializer = deserializerFor(GetColumnByOrdinal(0, serializer.dataType), schema) - new ExpressionEncoder[Row]( - serializer, - deserializer, - ClassTag(cls)) + ExpressionEncoder(encoderFor(schema, lenient)) } + def apply(schema: StructType): ExpressionEncoder[Row] = { apply(schema, lenient = false) } - private def serializerFor( - inputObject: Expression, - inputType: DataType, - lenient: Boolean): Expression = inputType match { - case dt if ScalaReflection.isNativeType(dt) => inputObject - - case p: PythonUserDefinedType => serializerFor(inputObject, p.sqlType, lenient) - - case udt: UserDefinedType[_] => - val annotation = udt.userClass.getAnnotation(classOf[SQLUserDefinedType]) - val udtClass: Class[_] = if (annotation != null) { - annotation.udt() - } else { - UDTRegistration.getUDTFor(udt.userClass.getName).getOrElse { - throw QueryExecutionErrors.userDefinedTypeNotAnnotatedAndRegisteredError(udt) - } - } - val obj = NewInstance( - udtClass, - Nil, - dataType = ObjectType(udtClass), false) - Invoke(obj, "serialize", udt, inputObject :: Nil, returnNullable = false) - - case TimestampType => - if (lenient) { - createSerializerForAnyTimestamp(inputObject) - } else if (SQLConf.get.datetimeJava8ApiEnabled) { - createSerializerForJavaInstant(inputObject) - } else { - createSerializerForSqlTimestamp(inputObject) - } - - case TimestampNTZType => createSerializerForLocalDateTime(inputObject) - - case DateType => - if (lenient) { - createSerializerForAnyDate(inputObject) - } else if (SQLConf.get.datetimeJava8ApiEnabled) { - createSerializerForJavaLocalDate(inputObject) - } else { - createSerializerForSqlDate(inputObject) - } - - case _: DayTimeIntervalType => createSerializerForJavaDuration(inputObject) - - case _: YearMonthIntervalType => createSerializerForJavaPeriod(inputObject) - - case d: DecimalType => - CheckOverflow(StaticInvoke( - Decimal.getClass, - d, - "fromDecimal", - inputObject :: Nil, - returnNullable = false), d, !SQLConf.get.ansiEnabled) - - case StringType => createSerializerForString(inputObject) - - case t @ ArrayType(et, containsNull) => - et match { - case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => - StaticInvoke( - classOf[ArrayData], - t, - "toArrayData", - inputObject :: Nil, - returnNullable = false) - - case _ => - createSerializerForMapObjects( - inputObject, - ObjectType(classOf[Object]), - element => { - val value = serializerFor(ValidateExternalType(element, et, lenient), et, lenient) - expressionWithNullSafety(value, containsNull, WalkedTypePath()) - }) - } - - case t @ MapType(kt, vt, valueNullable) => - val keys = - Invoke( - Invoke(inputObject, "keysIterator", ObjectType(classOf[scala.collection.Iterator[_]]), - returnNullable = false), - "toSeq", - ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false) - val convertedKeys = serializerFor(keys, ArrayType(kt, false), lenient) - - val values = - Invoke( - Invoke(inputObject, "valuesIterator", ObjectType(classOf[scala.collection.Iterator[_]]), - returnNullable = false), - "toSeq", - ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false) - val convertedValues = serializerFor(values, ArrayType(vt, valueNullable), lenient) - - val nonNullOutput = NewInstance( - classOf[ArrayBasedMapData], - convertedKeys :: convertedValues :: Nil, - dataType = t, - propagateNull = false) - - if (inputObject.nullable) { - expressionForNullableExpr(inputObject, nonNullOutput) - } else { - nonNullOutput - } - - case StructType(fields) => - val nonNullOutput = CreateNamedStruct(fields.zipWithIndex.flatMap { case (field, index) => - val fieldValue = serializerFor( - ValidateExternalType( - GetExternalRowField(inputObject, index, field.name), - field.dataType, - lenient), - field.dataType, - lenient) - val convertedField = if (field.nullable) { - If( - Invoke(inputObject, "isNullAt", BooleanType, Literal(index) :: Nil), - // Because we strip UDTs, `field.dataType` can be different from `fieldValue.dataType`. - // We should use `fieldValue.dataType` here. - Literal.create(null, fieldValue.dataType), - fieldValue - ) - } else { - fieldValue - } - Literal(field.name) :: convertedField :: Nil - }) - - if (inputObject.nullable) { - expressionForNullableExpr(inputObject, nonNullOutput) - } else { - nonNullOutput - } - // For other data types, return the internal catalyst value as it is. - case _ => inputObject - } - - /** - * Returns the `DataType` that can be used when generating code that converts input data - * into the Spark SQL internal format. Unlike `externalDataTypeFor`, the `DataType` returned - * by this function can be more permissive since multiple external types may map to a single - * internal type. For example, for an input with DecimalType in external row, its external types - * can be `scala.math.BigDecimal`, `java.math.BigDecimal`, or - * `org.apache.spark.sql.types.Decimal`. - */ - def externalDataTypeForInput(dt: DataType, lenient: Boolean): DataType = dt match { - // In order to support both Decimal and java/scala BigDecimal in external row, we make this - // as java.lang.Object. - case _: DecimalType => ObjectType(classOf[java.lang.Object]) - // In order to support both Array and Seq in external row, we make this as java.lang.Object. - case _: ArrayType => ObjectType(classOf[java.lang.Object]) - case _: DateType | _: TimestampType if lenient => ObjectType(classOf[java.lang.Object]) - case _ => externalDataTypeFor(dt) - } - - @tailrec - def externalDataTypeFor(dt: DataType): DataType = dt match { - case _ if ScalaReflection.isNativeType(dt) => dt - case TimestampType => - if (SQLConf.get.datetimeJava8ApiEnabled) { - ObjectType(classOf[java.time.Instant]) - } else { - ObjectType(classOf[java.sql.Timestamp]) - } - case TimestampNTZType => - ObjectType(classOf[java.time.LocalDateTime]) - case DateType => - if (SQLConf.get.datetimeJava8ApiEnabled) { - ObjectType(classOf[java.time.LocalDate]) - } else { - ObjectType(classOf[java.sql.Date]) - } - case _: DayTimeIntervalType => ObjectType(classOf[java.time.Duration]) - case _: YearMonthIntervalType => ObjectType(classOf[java.time.Period]) - case p: PythonUserDefinedType => externalDataTypeFor(p.sqlType) - case udt: UserDefinedType[_] => ObjectType(udt.userClass) - case _ => dt.physicalDataType match { - case _: PhysicalArrayType => ObjectType(classOf[scala.collection.Seq[_]]) - case _: PhysicalDecimalType => ObjectType(classOf[java.math.BigDecimal]) - case _: PhysicalMapType => ObjectType(classOf[scala.collection.Map[_, _]]) - case PhysicalStringType => ObjectType(classOf[java.lang.String]) - case _: PhysicalStructType => ObjectType(classOf[Row]) - // For other data types, return the data type as it is. - case _ => dt - } - } - - private def deserializerFor(input: Expression, schema: StructType): Expression = { - val fields = schema.zipWithIndex.map { case (f, i) => - deserializerFor(GetStructField(input, i)) - } - CreateExternalRow(fields, schema) + def encoderFor(schema: StructType): AgnosticEncoder[Row] = { + encoderFor(schema, lenient = false) } - private def deserializerFor(input: Expression): Expression = { - deserializerFor(input, input.dataType) + def encoderFor(schema: StructType, lenient: Boolean): AgnosticEncoder[Row] = { + encoderForDataType(schema, lenient).asInstanceOf[AgnosticEncoder[Row]] } - @tailrec - private def deserializerFor(input: Expression, dataType: DataType): Expression = dataType match { - case dt if ScalaReflection.isNativeType(dt) => input - - case p: PythonUserDefinedType => deserializerFor(input, p.sqlType) - + private[catalyst] def encoderForDataType( + dataType: DataType, + lenient: Boolean): AgnosticEncoder[_] = dataType match { + case NullType => NullEncoder + case BooleanType => BoxedBooleanEncoder + case ByteType => BoxedByteEncoder + case ShortType => BoxedShortEncoder + case IntegerType => BoxedIntEncoder + case LongType => BoxedLongEncoder + case FloatType => BoxedFloatEncoder + case DoubleType => BoxedDoubleEncoder + case dt: DecimalType => JavaDecimalEncoder(dt, lenientSerialization = true) + case BinaryType => BinaryEncoder + case StringType => StringEncoder + case TimestampType if SQLConf.get.datetimeJava8ApiEnabled => InstantEncoder(lenient) + case TimestampType => TimestampEncoder(lenient) + case TimestampNTZType => LocalDateTimeEncoder + case DateType if SQLConf.get.datetimeJava8ApiEnabled => LocalDateEncoder(lenient) + case DateType => DateEncoder(lenient) + case CalendarIntervalType => CalendarIntervalEncoder + case _: DayTimeIntervalType => DayTimeIntervalEncoder + case _: YearMonthIntervalType => YearMonthIntervalEncoder + case p: PythonUserDefinedType => + // TODO check if this works. + encoderForDataType(p.sqlType, lenient) case udt: UserDefinedType[_] => val annotation = udt.userClass.getAnnotation(classOf[SQLUserDefinedType]) val udtClass: Class[_] = if (annotation != null) { @@ -295,84 +109,26 @@ object RowEncoder { throw QueryExecutionErrors.userDefinedTypeNotAnnotatedAndRegisteredError(udt) } } - val obj = NewInstance( - udtClass, - Nil, - dataType = ObjectType(udtClass)) - Invoke(obj, "deserialize", ObjectType(udt.userClass), input :: Nil) - - case TimestampType => - if (SQLConf.get.datetimeJava8ApiEnabled) { - createDeserializerForInstant(input) - } else { - createDeserializerForSqlTimestamp(input) - } - - case TimestampNTZType => - createDeserializerForLocalDateTime(input) - - case DateType => - if (SQLConf.get.datetimeJava8ApiEnabled) { - createDeserializerForLocalDate(input) - } else { - createDeserializerForSqlDate(input) - } - - case _: DayTimeIntervalType => createDeserializerForDuration(input) - - case _: YearMonthIntervalType => createDeserializerForPeriod(input) - - case _: DecimalType => createDeserializerForJavaBigDecimal(input, returnNullable = false) - - case StringType => createDeserializerForString(input, returnNullable = false) - - case ArrayType(et, nullable) => - val arrayData = - Invoke( - MapObjects(deserializerFor(_), input, et), - "array", - ObjectType(classOf[Array[_]]), returnNullable = false) - // TODO should use `scala.collection.immutable.ArrayDeq.unsafeMake` method to create - // `immutable.Seq` in Scala 2.13 when Scala version compatibility is no longer required. - StaticInvoke( - scala.collection.mutable.WrappedArray.getClass, - ObjectType(classOf[scala.collection.Seq[_]]), - "make", - arrayData :: Nil, - returnNullable = false) - - case MapType(kt, vt, valueNullable) => - val keyArrayType = ArrayType(kt, false) - val keyData = deserializerFor(Invoke(input, "keyArray", keyArrayType)) - - val valueArrayType = ArrayType(vt, valueNullable) - val valueData = deserializerFor(Invoke(input, "valueArray", valueArrayType)) - - StaticInvoke( - ArrayBasedMapData.getClass, - ObjectType(classOf[Map[_, _]]), - "toScalaMap", - keyData :: valueData :: Nil, - returnNullable = false) - - case schema @ StructType(fields) => - val convertedFields = fields.zipWithIndex.map { case (f, i) => - If( - Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil), - Literal.create(null, externalDataTypeFor(f.dataType)), - deserializerFor(GetStructField(input, i))) - } - If(IsNull(input), - Literal.create(null, externalDataTypeFor(input.dataType)), - CreateExternalRow(convertedFields, schema)) - - // For other data types, return the internal catalyst value as it is. - case _ => input - } - - private def expressionForNullableExpr( - expr: Expression, - newExprWhenNotNull: Expression): Expression = { - If(IsNull(expr), Literal.create(null, newExprWhenNotNull.dataType), newExprWhenNotNull) + UDTEncoder(udt, udtClass.asInstanceOf[Class[_ <: UserDefinedType[_]]]) + case ArrayType(elementType, containsNull) => + IterableEncoder( + classTag[mutable.WrappedArray[_]], + encoderForDataType(elementType, lenient), + containsNull, + lenientSerialization = true) + case MapType(keyType, valueType, valueContainsNull) => + MapEncoder( + classTag[scala.collection.Map[_, _]], + encoderForDataType(keyType, lenient), + encoderForDataType(valueType, lenient), + valueContainsNull) + case StructType(fields) => + AgnosticRowEncoder(fields.map { field => + EncoderField( + field.name, + encoderForDataType(field.dataType, lenient), + field.nullable, + field.metadata) + }) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index a644b90a96f..56facda2af6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -20,9 +20,10 @@ package org.apache.spark.sql.catalyst.expressions.objects import java.lang.reflect.{Method, Modifier} import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.collection.mutable.{Builder, WrappedArray} import scala.reflect.ClassTag -import scala.util.{Properties, Try} +import scala.util.Try import org.apache.commons.lang3.reflect.MethodUtils @@ -30,7 +31,6 @@ import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.serializer._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} -import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ @@ -859,7 +859,7 @@ case class MapObjects private( case _ => inputData.dataType } - private def executeFuncOnCollection(inputCollection: Seq[_]): Iterator[_] = { + private def executeFuncOnCollection(inputCollection: Iterable[_]): Iterator[_] = { val row = new GenericInternalRow(1) inputCollection.iterator.map { element => row.update(0, element) @@ -867,7 +867,7 @@ case class MapObjects private( } } - private lazy val convertToSeq: Any => Seq[_] = inputDataType match { + private lazy val convertToSeq: Any => scala.collection.Seq[_] = inputDataType match { case ObjectType(cls) if classOf[scala.collection.Seq[_]].isAssignableFrom(cls) => _.asInstanceOf[scala.collection.Seq[_]].toSeq case ObjectType(cls) if cls.isArray => @@ -879,17 +879,33 @@ case class MapObjects private( if (inputCollection.getClass.isArray) { inputCollection.asInstanceOf[Array[_]].toSeq } else { - inputCollection.asInstanceOf[Seq[_]] + inputCollection.asInstanceOf[scala.collection.Seq[_]] } } case ArrayType(et, _) => _.asInstanceOf[ArrayData].toSeq[Any](et) } - private lazy val mapElements: Seq[_] => Any = customCollectionCls match { + private def elementClassTag(): ClassTag[Any] = { + val clazz = lambdaFunction.dataType match { + case ObjectType(cls) => cls + case dt if lambdaFunction.nullable => ScalaReflection.javaBoxedType(dt) + case dt => ScalaReflection.dataTypeJavaClass(dt) + } + ClassTag(clazz).asInstanceOf[ClassTag[Any]] + } + + private lazy val mapElements: scala.collection.Seq[_] => Any = customCollectionCls match { case Some(cls) if classOf[WrappedArray[_]].isAssignableFrom(cls) => - // Scala WrappedArray - inputCollection => WrappedArray.make(executeFuncOnCollection(inputCollection).toArray) + // The implicit tag is a workaround to deal with a small change in the + // (scala) signature of ArrayBuilder.make between Scala 2.12 and 2.13. + implicit val tag: ClassTag[Any] = elementClassTag() + input => { + val builder = mutable.ArrayBuilder.make[Any] + builder.sizeHint(input.size) + executeFuncOnCollection(input).foreach(builder += _) + mutable.WrappedArray.make(builder.result()) + } case Some(cls) if classOf[scala.collection.Seq[_]].isAssignableFrom(cls) => // Scala sequence executeFuncOnCollection(_).toSeq @@ -1047,44 +1063,20 @@ case class MapObjects private( val (initCollection, addElement, getResult): (String, String => String, String) = customCollectionCls match { case Some(cls) if classOf[WrappedArray[_]].isAssignableFrom(cls) => - def doCodeGenForScala212 = { - // WrappedArray in Scala 2.12 - val getBuilder = s"${cls.getName}$$.MODULE$$.newBuilder()" - val builder = ctx.freshName("collectionBuilder") - ( - s""" - ${classOf[Builder[_, _]].getName} $builder = $getBuilder; - $builder.sizeHint($dataLength); - """, - (genValue: String) => s"$builder.$$plus$$eq($genValue);", - s"(${cls.getName}) ${classOf[WrappedArray[_]].getName}$$." + - s"MODULE$$.make(((${classOf[IndexedSeq[_]].getName})$builder" + - s".result()).toArray(scala.reflect.ClassTag$$.MODULE$$.Object()));" - ) - } - - def doCodeGenForScala213 = { - // In Scala 2.13, WrappedArray is mutable.ArraySeq and newBuilder method need - // a ClassTag type construction parameter - val getBuilder = s"${cls.getName}$$.MODULE$$.newBuilder(" + - s"scala.reflect.ClassTag$$.MODULE$$.Object())" - val builder = ctx.freshName("collectionBuilder") - ( - s""" + val tag = ctx.addReferenceObj("tag", elementClassTag()) + val builderClassName = classOf[mutable.ArrayBuilder[_]].getName + val getBuilder = s"$builderClassName$$.MODULE$$.make($tag)" + val builder = ctx.freshName("collectionBuilder") + ( + s""" ${classOf[Builder[_, _]].getName} $builder = $getBuilder; $builder.sizeHint($dataLength); """, - (genValue: String) => s"$builder.$$plus$$eq($genValue);", - s"(${cls.getName})$builder.result();" - ) - } + (genValue: String) => s"$builder.$$plus$$eq($genValue);", + s"(${cls.getName}) ${classOf[WrappedArray[_]].getName}$$." + + s"MODULE$$.make($builder.result());" + ) - val scalaVersion = Properties.versionNumberString - if (scalaVersion.startsWith("2.12")) { - doCodeGenForScala212 - } else { - doCodeGenForScala213 - } case Some(cls) if classOf[scala.collection.Seq[_]].isAssignableFrom(cls) || classOf[scala.collection.Set[_]].isAssignableFrom(cls) => // Scala sequence or set @@ -1908,14 +1900,14 @@ case class GetExternalRowField( * Validates the actual data type of input expression at runtime. If it doesn't match the * expectation, throw an exception. */ -case class ValidateExternalType(child: Expression, expected: DataType, lenient: Boolean) +case class ValidateExternalType(child: Expression, expected: DataType, externalDataType: DataType) extends UnaryExpression with NonSQLExpression with ExpectsInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(ObjectType(classOf[Object])) override def nullable: Boolean = child.nullable - override val dataType: DataType = RowEncoder.externalDataTypeForInput(expected, lenient) + override val dataType: DataType = externalDataType private lazy val errMsg = s" is not a valid external type for schema of ${expected.simpleString}" @@ -1927,7 +1919,9 @@ case class ValidateExternalType(child: Expression, expected: DataType, lenient: } case _: ArrayType => (value: Any) => { - value.getClass.isArray || value.isInstanceOf[Seq[_]] + value.getClass.isArray || + value.isInstanceOf[scala.collection.Seq[_]] || + value.isInstanceOf[Set[_]] } case _: DateType => (value: Any) => { @@ -1968,7 +1962,8 @@ case class ValidateExternalType(child: Expression, expected: DataType, lenient: classOf[scala.math.BigDecimal], classOf[Decimal])) case _: ArrayType => - s"$obj.getClass().isArray() || $obj instanceof ${classOf[scala.collection.Seq[_]].getName}" + val check = genCheckTypes(Seq(classOf[scala.collection.Seq[_]], classOf[Set[_]])) + s"$obj.getClass().isArray() || $check" case _: DateType => genCheckTypes(Seq(classOf[java.sql.Date], classOf[java.time.LocalDate])) case _: TimestampType => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 7e7ce29972b..f8ebdfe7676 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.FooEnum.FooEnum import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, Expression, If, SpecificInternalRow, UpCast} -import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, NewInstance} +import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, MapObjects, NewInstance} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -388,11 +388,10 @@ class ScalaReflectionSuite extends SparkFunSuite { } test("SPARK-15062: Get correct serializer for List[_]") { - val list = List(1, 2, 3) val serializer = serializerFor[List[Int]] - assert(serializer.isInstanceOf[NewInstance]) - assert(serializer.asInstanceOf[NewInstance] - .cls.isAssignableFrom(classOf[org.apache.spark.sql.catalyst.util.GenericArrayData])) + assert(serializer.isInstanceOf[MapObjects]) + val mapObjects = serializer.asInstanceOf[MapObjects] + assert(mapObjects.customCollectionCls.isEmpty) } test("SPARK 16792: Get correct deserializer for List[_]") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 3a0db1ca121..c6546105231 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -480,6 +480,8 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes encodeDecodeTest(ScroogeLikeExample(1), "SPARK-40385 class with only a companion object constructor") + encodeDecodeTest(Array(Set(1, 2), Set(2, 3)), "array of sets") + productTest(("UDT", new ExamplePoint(0.1, 0.2))) test("AnyVal class with Any fields") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index c6bddfa5eee..b133b38a559 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.encoders +import scala.collection.mutable import scala.util.Random import org.apache.spark.sql.{RandomDataGenerator, Row} @@ -310,6 +311,19 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { assert(e4.getMessage.contains("java.lang.String is not a valid external type")) } + private def roundTripArray[T](dt: DataType, nullable: Boolean, data: Array[T]): Unit = { + val schema = new StructType().add("a", ArrayType(dt, nullable)) + test(s"RowEncoder should return WrappedArray with properly typed array for $schema") { + val encoder = RowEncoder(schema).resolveAndBind() + val result = fromRow(encoder, toRow(encoder, Row(data))).getAs[mutable.WrappedArray[_]](0) + assert(result.array.getClass === data.getClass) + assert(result === data) + } + } + + roundTripArray(IntegerType, nullable = false, Array(1, 2, 3).map(Int.box)) + roundTripArray(StringType, nullable = true, Array("hello", "world", "!", null)) + test("SPARK-25791: Datatype of serializers should be accessible") { val udtSQLType = new StructType().add("a", IntegerType) val pythonUDT = new PythonUserDefinedType(udtSQLType, "pyUDT", "serializedPyClass") @@ -458,4 +472,14 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { } } } + + test("Encoding an ArraySeq/WrappedArray in scala-2.13") { + val schema = new StructType() + .add("headers", ArrayType(new StructType() + .add("key", StringType) + .add("value", BinaryType))) + val encoder = RowEncoder(schema, lenient = true).resolveAndBind() + val data = Row(mutable.WrappedArray.make(Array(Row("key", "value".getBytes)))) + val row = encoder.createSerializer()(data) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 737fcb1bada..265b0eeb8bd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -332,7 +332,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { ValidateExternalType( GetExternalRowField(inputObject, index = 0, fieldName = "\"quote"), IntegerType, - lenient = false) :: Nil) + IntegerType) :: Nil) } test("SPARK-17160: field names are properly escaped by AssertTrue") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 2286b734477..05ab7a65a32 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -496,10 +496,11 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { (java.math.BigDecimal.valueOf(10), DecimalType.BigIntDecimal), (Array(3, 2, 1), ArrayType(IntegerType)) ).foreach { case (input, dt) => + val enc = RowEncoder.encoderForDataType(dt, lenient = false) val validateType = ValidateExternalType( GetExternalRowField(inputObject, index = 0, fieldName = "c0"), dt, - lenient = false) + ScalaReflection.lenientExternalDataTypeFor(enc)) checkObjectExprEvaluation(validateType, input, InternalRow.fromSeq(Seq(Row(input)))) } @@ -507,7 +508,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { ValidateExternalType( GetExternalRowField(inputObject, index = 0, fieldName = "c0"), DoubleType, - lenient = false), + DoubleType), InternalRow.fromSeq(Seq(Row(1))), "java.lang.Integer is not a valid external type for schema of double") } @@ -559,10 +560,10 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { ExternalMapToCatalyst( inputObject, - ScalaReflection.dataTypeFor(keyEnc), + ScalaReflection.externalDataTypeFor(keyEnc), kvSerializerFor(keyEnc), keyNullable = keyEnc.nullable, - ScalaReflection.dataTypeFor(valueEnc), + ScalaReflection.externalDataTypeFor(valueEnc), kvSerializerFor(valueEnc), valueNullable = valueEnc.nullable ) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org