This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-3.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push: new 18672003513 [SPARK-42093][SQL] Move JavaTypeInference to AgnosticEncoders 18672003513 is described below commit 18672003513d5a4aa610b6b94dbbc15c33185d3a Author: Herman van Hovell <her...@databricks.com> AuthorDate: Thu Feb 2 10:53:11 2023 +0800 [SPARK-42093][SQL] Move JavaTypeInference to AgnosticEncoders ### What changes were proposed in this pull request? This PR makes `JavaTypeInference` produce an `AgnosticEncoder`. The expression generation for these encoders is moved to `ScalaReflection`. ### Why are the changes needed? For the Spark Connect Scala Client we also want to be able to use Java Bean based results. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? I have added a lot of tests to `JavaTypeInferenceSuite`. Closes #39615 from hvanhovell/SPARK-42093. Authored-by: Herman van Hovell <her...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> (cherry picked from commit 0d93bb2c0a47f652727accfc36b652bdac33f894) Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../spark/sql/catalyst/JavaTypeInference.scala | 565 ++++++--------------- .../spark/sql/catalyst/ScalaReflection.scala | 64 ++- .../sql/catalyst/encoders/AgnosticEncoder.scala | 13 +- .../sql/catalyst/encoders/ExpressionEncoder.scala | 11 +- .../sql/catalyst/expressions/objects/objects.scala | 8 +- .../sql/catalyst/JavaTypeInferenceSuite.scala | 203 +++++++- 6 files changed, 418 insertions(+), 446 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 81f363dda36..105bed38704 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -14,25 +14,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.spark.sql.catalyst import java.beans.{Introspector, PropertyDescriptor} -import java.lang.{Iterable => JIterable} -import java.lang.reflect.Type -import java.util.{Iterator => JIterator, List => JList, Map => JMap} +import java.lang.reflect.{ParameterizedType, Type, TypeVariable} +import java.util.{ArrayDeque, List => JList, Map => JMap} import javax.annotation.Nonnull -import scala.language.existentials - -import com.google.common.reflect.TypeToken +import scala.annotation.tailrec +import scala.reflect.ClassTag -import org.apache.spark.sql.catalyst.DeserializerBuildHelper._ -import org.apache.spark.sql.catalyst.SerializerBuildHelper._ -import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.ArrayBasedMapData +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, DayTimeIntervalEncoder, DEFAULT_JAVA_DECIMAL_ENCODER, EncoderField, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaEnumEncoder, LocalDateTimeEncoder, MapEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, P [...] import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types._ @@ -40,123 +33,112 @@ import org.apache.spark.sql.types._ * Type-inference utilities for POJOs and Java collections. */ object JavaTypeInference { - - private val iterableType = TypeToken.of(classOf[JIterable[_]]) - private val mapType = TypeToken.of(classOf[JMap[_, _]]) - private val listType = TypeToken.of(classOf[JList[_]]) - private val iteratorReturnType = classOf[JIterable[_]].getMethod("iterator").getGenericReturnType - private val nextReturnType = classOf[JIterator[_]].getMethod("next").getGenericReturnType - private val keySetReturnType = classOf[JMap[_, _]].getMethod("keySet").getGenericReturnType - private val valuesReturnType = classOf[JMap[_, _]].getMethod("values").getGenericReturnType - - // Guava changed the name of this method; this tries to stay compatible with both - // TODO replace with isSupertypeOf when Guava 14 support no longer needed for Hadoop - private val ttIsAssignableFrom: (TypeToken[_], TypeToken[_]) => Boolean = { - val ttMethods = classOf[TypeToken[_]].getMethods. - filter(_.getParameterCount == 1). - filter(_.getParameterTypes.head == classOf[TypeToken[_]]) - val isAssignableFromMethod = ttMethods.find(_.getName == "isSupertypeOf").getOrElse( - ttMethods.find(_.getName == "isAssignableFrom").get) - (a: TypeToken[_], b: TypeToken[_]) => isAssignableFromMethod.invoke(a, b).asInstanceOf[Boolean] - } - /** - * Infers the corresponding SQL data type of a JavaBean class. - * @param beanClass Java type + * Infers the corresponding SQL data type of a Java type. + * @param beanType Java type * @return (SQL data type, nullable) */ - def inferDataType(beanClass: Class[_]): (DataType, Boolean) = { - inferDataType(TypeToken.of(beanClass)) + def inferDataType(beanType: Type): (DataType, Boolean) = { + val encoder = encoderFor(beanType) + (encoder.dataType, encoder.nullable) } /** - * Infers the corresponding SQL data type of a Java type. - * @param beanType Java type - * @return (SQL data type, nullable) + * Infer an [[AgnosticEncoder]] for the [[Class]] `cls`. */ - private[sql] def inferDataType(beanType: Type): (DataType, Boolean) = { - inferDataType(TypeToken.of(beanType)) + def encoderFor[T](cls: Class[T]): AgnosticEncoder[T] = { + encoderFor(cls.asInstanceOf[Type]) } /** - * Infers the corresponding SQL data type of a Java type. - * @param typeToken Java type - * @return (SQL data type, nullable) + * Infer an [[AgnosticEncoder]] for the `beanType`. */ - private def inferDataType(typeToken: TypeToken[_], seenTypeSet: Set[Class[_]] = Set.empty) - : (DataType, Boolean) = { - typeToken.getRawType match { - case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => - (c.getAnnotation(classOf[SQLUserDefinedType]).udt().getConstructor().newInstance(), true) - - case c: Class[_] if UDTRegistration.exists(c.getName) => - val udt = UDTRegistration.getUDTFor(c.getName).get.getConstructor().newInstance() - .asInstanceOf[UserDefinedType[_ >: Null]] - (udt, true) - - case c: Class[_] if c == classOf[java.lang.String] => (StringType, true) - case c: Class[_] if c == classOf[Array[Byte]] => (BinaryType, true) - - case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false) - case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false) - case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false) - case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false) - case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false) - case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false) - case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false) - - case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true) - case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true) - case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true) - case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true) - case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true) - case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true) - case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true) - - case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType.SYSTEM_DEFAULT, true) - case c: Class[_] if c == classOf[java.math.BigInteger] => (DecimalType.BigIntDecimal, true) - case c: Class[_] if c == classOf[java.time.LocalDate] => (DateType, true) - case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true) - case c: Class[_] if c == classOf[java.time.Instant] => (TimestampType, true) - case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true) - case c: Class[_] if c == classOf[java.time.LocalDateTime] => (TimestampNTZType, true) - case c: Class[_] if c == classOf[java.time.Duration] => (DayTimeIntervalType(), true) - case c: Class[_] if c == classOf[java.time.Period] => (YearMonthIntervalType(), true) - - case _ if typeToken.isArray => - val (dataType, nullable) = inferDataType(typeToken.getComponentType, seenTypeSet) - (ArrayType(dataType, nullable), true) - - case _ if ttIsAssignableFrom(iterableType, typeToken) => - val (dataType, nullable) = inferDataType(elementType(typeToken), seenTypeSet) - (ArrayType(dataType, nullable), true) - - case _ if ttIsAssignableFrom(mapType, typeToken) => - val (keyType, valueType) = mapKeyValueType(typeToken) - val (keyDataType, _) = inferDataType(keyType, seenTypeSet) - val (valueDataType, nullable) = inferDataType(valueType, seenTypeSet) - (MapType(keyDataType, valueDataType, nullable), true) + def encoderFor[T](beanType: Type): AgnosticEncoder[T] = { + encoderFor(beanType, Set.empty).asInstanceOf[AgnosticEncoder[T]] + } - case other if other.isEnum => - (StringType, true) + private def encoderFor(t: Type, seenTypeSet: Set[Class[_]]): AgnosticEncoder[_] = t match { + + case c: Class[_] if c == java.lang.Boolean.TYPE => PrimitiveBooleanEncoder + case c: Class[_] if c == java.lang.Byte.TYPE => PrimitiveByteEncoder + case c: Class[_] if c == java.lang.Short.TYPE => PrimitiveShortEncoder + case c: Class[_] if c == java.lang.Integer.TYPE => PrimitiveIntEncoder + case c: Class[_] if c == java.lang.Long.TYPE => PrimitiveLongEncoder + case c: Class[_] if c == java.lang.Float.TYPE => PrimitiveFloatEncoder + case c: Class[_] if c == java.lang.Double.TYPE => PrimitiveDoubleEncoder + + case c: Class[_] if c == classOf[java.lang.Boolean] => BoxedBooleanEncoder + case c: Class[_] if c == classOf[java.lang.Byte] => BoxedByteEncoder + case c: Class[_] if c == classOf[java.lang.Short] => BoxedShortEncoder + case c: Class[_] if c == classOf[java.lang.Integer] => BoxedIntEncoder + case c: Class[_] if c == classOf[java.lang.Long] => BoxedLongEncoder + case c: Class[_] if c == classOf[java.lang.Float] => BoxedFloatEncoder + case c: Class[_] if c == classOf[java.lang.Double] => BoxedDoubleEncoder + + case c: Class[_] if c == classOf[java.lang.String] => StringEncoder + case c: Class[_] if c == classOf[Array[Byte]] => BinaryEncoder + case c: Class[_] if c == classOf[java.math.BigDecimal] => DEFAULT_JAVA_DECIMAL_ENCODER + case c: Class[_] if c == classOf[java.math.BigInteger] => JavaBigIntEncoder + case c: Class[_] if c == classOf[java.time.LocalDate] => STRICT_LOCAL_DATE_ENCODER + case c: Class[_] if c == classOf[java.sql.Date] => STRICT_DATE_ENCODER + case c: Class[_] if c == classOf[java.time.Instant] => STRICT_INSTANT_ENCODER + case c: Class[_] if c == classOf[java.sql.Timestamp] => STRICT_TIMESTAMP_ENCODER + case c: Class[_] if c == classOf[java.time.LocalDateTime] => LocalDateTimeEncoder + case c: Class[_] if c == classOf[java.time.Duration] => DayTimeIntervalEncoder + case c: Class[_] if c == classOf[java.time.Period] => YearMonthIntervalEncoder + + case c: Class[_] if c.isEnum => JavaEnumEncoder(ClassTag(c)) + + case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => + val udt = c.getAnnotation(classOf[SQLUserDefinedType]).udt() + .getConstructor().newInstance().asInstanceOf[UserDefinedType[Any]] + val udtClass = udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt() + UDTEncoder(udt, udtClass) + + case c: Class[_] if UDTRegistration.exists(c.getName) => + val udt = UDTRegistration.getUDTFor(c.getName).get.getConstructor(). + newInstance().asInstanceOf[UserDefinedType[Any]] + UDTEncoder(udt, udt.getClass) + + case c: Class[_] if c.isArray => + val elementEncoder = encoderFor(c.getComponentType, seenTypeSet) + ArrayEncoder(elementEncoder, elementEncoder.nullable) + + case ImplementsList(c, Array(elementCls)) => + val element = encoderFor(elementCls, seenTypeSet) + IterableEncoder(ClassTag(c), element, element.nullable, lenientSerialization = false) + + case ImplementsMap(c, Array(keyCls, valueCls)) => + val keyEncoder = encoderFor(keyCls, seenTypeSet) + val valueEncoder = encoderFor(valueCls, seenTypeSet) + MapEncoder(ClassTag(c), keyEncoder, valueEncoder, valueEncoder.nullable) + + case c: Class[_] => + if (seenTypeSet.contains(c)) { + throw QueryExecutionErrors.cannotHaveCircularReferencesInBeanClassError(c) + } - case other => - if (seenTypeSet.contains(other)) { - throw QueryExecutionErrors.cannotHaveCircularReferencesInBeanClassError(other) - } + // TODO: we should only collect properties that have getter and setter. However, some tests + // pass in scala case class as java bean class which doesn't have getter and setter. + val properties = getJavaBeanReadableProperties(c) + // Note that the fields are ordered by name. + val fields = properties.map { property => + val readMethod = property.getReadMethod + val encoder = encoderFor(readMethod.getGenericReturnType, seenTypeSet + c) + // The existence of `javax.annotation.Nonnull`, means this field is not nullable. + val hasNonNull = readMethod.isAnnotationPresent(classOf[Nonnull]) + EncoderField( + property.getName, + encoder, + encoder.nullable && !hasNonNull, + Metadata.empty, + Option(readMethod.getName), + Option(property.getWriteMethod).map(_.getName)) + } + JavaBeanEncoder(ClassTag(c), fields) - // TODO: we should only collect properties that have getter and setter. However, some tests - // pass in scala case class as java bean class which doesn't have getter and setter. - val properties = getJavaBeanReadableProperties(other) - val fields = properties.map { property => - val returnType = typeToken.method(property.getReadMethod).getReturnType - val (dataType, nullable) = inferDataType(returnType, seenTypeSet + other) - // The existence of `javax.annotation.Nonnull`, means this field is not nullable. - val hasNonNull = property.getReadMethod.isAnnotationPresent(classOf[Nonnull]) - new StructField(property.getName, dataType, nullable && !hasNonNull) - } - (new StructType(fields), true) - } + case _ => + throw QueryExecutionErrors.cannotFindEncoderForTypeError(t.toString) } def getJavaBeanReadableProperties(beanClass: Class[_]): Array[PropertyDescriptor] = { @@ -166,317 +148,58 @@ object JavaTypeInference { .filter(_.getReadMethod != null) } - private def getJavaBeanReadableAndWritableProperties( - beanClass: Class[_]): Array[PropertyDescriptor] = { - getJavaBeanReadableProperties(beanClass).filter(_.getWriteMethod != null) - } - - private def elementType(typeToken: TypeToken[_]): TypeToken[_] = { - val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JIterable[_]]] - val iterableSuperType = typeToken2.getSupertype(classOf[JIterable[_]]) - val iteratorType = iterableSuperType.resolveType(iteratorReturnType) - iteratorType.resolveType(nextReturnType) - } - - private def mapKeyValueType(typeToken: TypeToken[_]): (TypeToken[_], TypeToken[_]) = { - val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JMap[_, _]]] - val mapSuperType = typeToken2.getSupertype(classOf[JMap[_, _]]) - val keyType = elementType(mapSuperType.resolveType(keySetReturnType)) - val valueType = elementType(mapSuperType.resolveType(valuesReturnType)) - keyType -> valueType - } - - /** - * Returns the Spark SQL DataType for a given java class. Where this is not an exact mapping - * to a native type, an ObjectType is returned. - * - * Unlike `inferDataType`, this function doesn't do any massaging of types into the Spark SQL type - * system. As a result, ObjectType will be returned for things like boxed Integers. - */ - private def inferExternalType(cls: Class[_]): DataType = cls match { - case c if c == java.lang.Boolean.TYPE => BooleanType - case c if c == java.lang.Byte.TYPE => ByteType - case c if c == java.lang.Short.TYPE => ShortType - case c if c == java.lang.Integer.TYPE => IntegerType - case c if c == java.lang.Long.TYPE => LongType - case c if c == java.lang.Float.TYPE => FloatType - case c if c == java.lang.Double.TYPE => DoubleType - case c if c == classOf[Array[Byte]] => BinaryType - case _ => ObjectType(cls) - } - - /** - * Returns an expression that can be used to deserialize a Spark SQL representation to an object - * of java bean `T` with a compatible schema. The Spark SQL representation is located at ordinal - * 0 of a row, i.e., `GetColumnByOrdinal(0, _)`. Nested classes will have their fields accessed - * using `UnresolvedExtractValue`. - */ - def deserializerFor(beanClass: Class[_]): Expression = { - val typeToken = TypeToken.of(beanClass) - val walkedTypePath = new WalkedTypePath().recordRoot(beanClass.getCanonicalName) - val (dataType, nullable) = inferDataType(typeToken) - - // Assumes we are deserializing the first column of a row. - deserializerForWithNullSafetyAndUpcast(GetColumnByOrdinal(0, dataType), dataType, - nullable = nullable, walkedTypePath, deserializerFor(typeToken, _, walkedTypePath)) - } - - private def deserializerFor( - typeToken: TypeToken[_], - path: Expression, - walkedTypePath: WalkedTypePath): Expression = { - typeToken.getRawType match { - case c if !inferExternalType(c).isInstanceOf[ObjectType] => path - - case c if c == classOf[java.lang.Short] || - c == classOf[java.lang.Integer] || - c == classOf[java.lang.Long] || - c == classOf[java.lang.Double] || - c == classOf[java.lang.Float] || - c == classOf[java.lang.Byte] || - c == classOf[java.lang.Boolean] => - createDeserializerForTypesSupportValueOf(path, c) - - case c if c == classOf[java.time.LocalDate] => - createDeserializerForLocalDate(path) - - case c if c == classOf[java.sql.Date] => - createDeserializerForSqlDate(path) - - case c if c == classOf[java.time.Instant] => - createDeserializerForInstant(path) - - case c if c == classOf[java.sql.Timestamp] => - createDeserializerForSqlTimestamp(path) + private class ImplementsGenericInterface(interface: Class[_]) { + assert(interface.isInterface) + assert(interface.getTypeParameters.nonEmpty) - case c if c == classOf[java.time.LocalDateTime] => - createDeserializerForLocalDateTime(path) - - case c if c == classOf[java.time.Duration] => - createDeserializerForDuration(path) - - case c if c == classOf[java.time.Period] => - createDeserializerForPeriod(path) - - case c if c == classOf[java.lang.String] => - createDeserializerForString(path, returnNullable = true) - - case c if c == classOf[java.math.BigDecimal] => - createDeserializerForJavaBigDecimal(path, returnNullable = true) - - case c if c == classOf[java.math.BigInteger] => - createDeserializerForJavaBigInteger(path, returnNullable = true) - - case c if c.isArray => - val elementType = c.getComponentType - val newTypePath = walkedTypePath.recordArray(elementType.getCanonicalName) - val (dataType, elementNullable) = inferDataType(elementType) - val mapFunction: Expression => Expression = element => { - // upcast the array element to the data type the encoder expected. - deserializerForWithNullSafetyAndUpcast( - element, - dataType, - nullable = elementNullable, - newTypePath, - deserializerFor(typeToken.getComponentType, _, newTypePath)) - } - - val arrayData = UnresolvedMapObjects(mapFunction, path) - - val methodName = elementType match { - case c if c == java.lang.Integer.TYPE => "toIntArray" - case c if c == java.lang.Long.TYPE => "toLongArray" - case c if c == java.lang.Double.TYPE => "toDoubleArray" - case c if c == java.lang.Float.TYPE => "toFloatArray" - case c if c == java.lang.Short.TYPE => "toShortArray" - case c if c == java.lang.Byte.TYPE => "toByteArray" - case c if c == java.lang.Boolean.TYPE => "toBooleanArray" - // non-primitive - case _ => "array" - } - Invoke(arrayData, methodName, ObjectType(c)) - - case c if ttIsAssignableFrom(listType, typeToken) => - val et = elementType(typeToken) - val newTypePath = walkedTypePath.recordArray(et.getType.getTypeName) - val (dataType, elementNullable) = inferDataType(et) - val mapFunction: Expression => Expression = element => { - // upcast the array element to the data type the encoder expected. - deserializerForWithNullSafetyAndUpcast( - element, - dataType, - nullable = elementNullable, - newTypePath, - deserializerFor(et, _, newTypePath)) - } - - UnresolvedMapObjects(mapFunction, path, customCollectionCls = Some(c)) - - case _ if ttIsAssignableFrom(mapType, typeToken) => - val (keyType, valueType) = mapKeyValueType(typeToken) - val newTypePath = walkedTypePath.recordMap(keyType.getType.getTypeName, - valueType.getType.getTypeName) - - val keyData = - Invoke( - UnresolvedMapObjects( - p => deserializerFor(keyType, p, newTypePath), - MapKeys(path)), - "array", - ObjectType(classOf[Array[Any]])) - - val valueData = - Invoke( - UnresolvedMapObjects( - p => deserializerFor(valueType, p, newTypePath), - MapValues(path)), - "array", - ObjectType(classOf[Array[Any]])) - - StaticInvoke( - ArrayBasedMapData.getClass, - ObjectType(classOf[JMap[_, _]]), - "toJavaMap", - keyData :: valueData :: Nil, - returnNullable = false) - - case other if other.isEnum => - createDeserializerForTypesSupportValueOf( - createDeserializerForString(path, returnNullable = false), - other) - - case other => - val properties = getJavaBeanReadableAndWritableProperties(other) - val setters = properties.map { p => - val fieldName = p.getName - val fieldType = typeToken.method(p.getReadMethod).getReturnType - val (dataType, nullable) = inferDataType(fieldType) - val newTypePath = walkedTypePath.recordField(fieldType.getType.getTypeName, fieldName) - // The existence of `javax.annotation.Nonnull`, means this field is not nullable. - val hasNonNull = p.getReadMethod.isAnnotationPresent(classOf[Nonnull]) - val setter = expressionWithNullSafety( - deserializerFor(fieldType, addToPath(path, fieldName, dataType, newTypePath), - newTypePath), - nullable = nullable && !hasNonNull, - newTypePath) - p.getWriteMethod.getName -> setter - }.toMap - - val newInstance = NewInstance(other, Nil, ObjectType(other), propagateNull = false) - val result = InitializeJavaBean(newInstance, setters) - - expressions.If( - IsNull(path), - expressions.Literal.create(null, ObjectType(other)), - result - ) + def unapply(t: Type): Option[(Class[_], Array[Type])] = implementsInterface(t).map { cls => + cls -> findTypeArgumentsForInterface(t) } - } - /** - * Returns an expression for serializing an object of the given type to a Spark SQL - * representation. The input object is located at ordinal 0 of a row, i.e., - * `BoundReference(0, _)`. - */ - def serializerFor(beanClass: Class[_]): Expression = { - val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true) - val nullSafeInput = AssertNotNull(inputObject, Seq("top level input bean")) - serializerFor(nullSafeInput, TypeToken.of(beanClass)) - } - - private def serializerFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = { - - def toCatalystArray(input: Expression, elementType: TypeToken[_]): Expression = { - val (dataType, nullable) = inferDataType(elementType) - if (ScalaReflection.isNativeType(dataType)) { - val cls = input.dataType.asInstanceOf[ObjectType].cls - if (cls.isArray && cls.getComponentType.isPrimitive) { - createSerializerForPrimitiveArray(input, dataType) - } else { - createSerializerForGenericArray(input, dataType, nullable = nullable) - } - } else { - createSerializerForMapObjects(input, ObjectType(elementType.getRawType), - serializerFor(_, elementType)) - } + @tailrec + private def implementsInterface(t: Type): Option[Class[_]] = t match { + case pt: ParameterizedType => implementsInterface(pt.getRawType) + case c: Class[_] if interface.isAssignableFrom(c) => Option(c) + case _ => None } - if (!inputObject.dataType.isInstanceOf[ObjectType]) { - inputObject - } else { - typeToken.getRawType match { - case c if c == classOf[String] => createSerializerForString(inputObject) - - case c if c == classOf[java.time.Instant] => createSerializerForJavaInstant(inputObject) - - case c if c == classOf[java.sql.Timestamp] => createSerializerForSqlTimestamp(inputObject) - - case c if c == classOf[java.time.LocalDateTime] => - createSerializerForLocalDateTime(inputObject) - - case c if c == classOf[java.time.LocalDate] => createSerializerForJavaLocalDate(inputObject) - - case c if c == classOf[java.sql.Date] => createSerializerForSqlDate(inputObject) - - case c if c == classOf[java.time.Duration] => createSerializerForJavaDuration(inputObject) - - case c if c == classOf[java.time.Period] => createSerializerForJavaPeriod(inputObject) - - case c if c == classOf[java.math.BigInteger] => - createSerializerForBigInteger(inputObject) - - case c if c == classOf[java.math.BigDecimal] => - createSerializerForBigDecimal(inputObject) - - case c if c == classOf[java.lang.Boolean] => createSerializerForBoolean(inputObject) - case c if c == classOf[java.lang.Byte] => createSerializerForByte(inputObject) - case c if c == classOf[java.lang.Short] => createSerializerForShort(inputObject) - case c if c == classOf[java.lang.Integer] => createSerializerForInteger(inputObject) - case c if c == classOf[java.lang.Long] => createSerializerForLong(inputObject) - case c if c == classOf[java.lang.Float] => createSerializerForFloat(inputObject) - case c if c == classOf[java.lang.Double] => createSerializerForDouble(inputObject) - - case _ if typeToken.isArray => - toCatalystArray(inputObject, typeToken.getComponentType) - - case _ if ttIsAssignableFrom(listType, typeToken) => - toCatalystArray(inputObject, elementType(typeToken)) - - case _ if ttIsAssignableFrom(mapType, typeToken) => - val (keyType, valueType) = mapKeyValueType(typeToken) - - createSerializerForMap( - inputObject, - MapElementInformation( - ObjectType(keyType.getRawType), - nullable = true, - serializerFor(_, keyType)), - MapElementInformation( - ObjectType(valueType.getRawType), - nullable = true, - serializerFor(_, valueType)) - ) - - case other if other.isEnum => - createSerializerForString( - Invoke(inputObject, "name", ObjectType(classOf[String]), returnNullable = false)) - - case other => - val properties = getJavaBeanReadableAndWritableProperties(other) - val fields = properties.map { p => - val fieldName = p.getName - val fieldType = typeToken.method(p.getReadMethod).getReturnType - val hasNonNull = p.getReadMethod.isAnnotationPresent(classOf[Nonnull]) - val fieldValue = Invoke( - inputObject, - p.getReadMethod.getName, - inferExternalType(fieldType.getRawType), - propagateNull = !hasNonNull, - returnNullable = !hasNonNull) - (fieldName, serializerFor(fieldValue, fieldType)) - } - createSerializerForObject(inputObject, fields) + private def findTypeArgumentsForInterface(t: Type): Array[Type] = { + val queue = new ArrayDeque[(Type, Map[Any, Type])] + queue.add(t -> Map.empty) + while (!queue.isEmpty) { + queue.poll() match { + case (pt: ParameterizedType, bindings) => + // translate mappings... + val mappedTypeArguments = pt.getActualTypeArguments.map { + case v: TypeVariable[_] => bindings(v.getName) + case v => v + } + if (pt.getRawType == interface) { + return mappedTypeArguments + } else { + val mappedTypeArgumentMap = mappedTypeArguments + .zipWithIndex.map(_.swap) + .toMap[Any, Type] + queue.add(pt.getRawType -> mappedTypeArgumentMap) + } + case (c: Class[_], indexedBindings) => + val namedBindings = c.getTypeParameters.zipWithIndex.map { + case (parameter, index) => + parameter.getName -> indexedBindings(index) + }.toMap[Any, Type] + val superClass = c.getGenericSuperclass + if (superClass != null) { + queue.add(superClass -> namedBindings) + } + c.getGenericInterfaces.foreach { iface => + queue.add(iface -> namedBindings) + } + } } + throw QueryExecutionErrors.unreachableError() } } + + private object ImplementsList extends ImplementsGenericInterface(classOf[JList[_]]) + private object ImplementsMap extends ImplementsGenericInterface(classOf[JMap[_, _]]) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 42208cd1098..4680a2aec2b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, MapData} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -264,6 +264,36 @@ object ScalaReflection extends ScalaReflection { Option(clsTag.runtimeClass), walkedTypePath) + case MapEncoder(tag, keyEncoder, valueEncoder, _) + if classOf[java.util.Map[_, _]].isAssignableFrom(tag.runtimeClass) => + // TODO (hvanhovell) this is can be improved. + val newTypePath = walkedTypePath.recordMap( + keyEncoder.clsTag.runtimeClass.getName, + valueEncoder.clsTag.runtimeClass.getName) + + val keyData = + Invoke( + UnresolvedMapObjects( + p => deserializerFor(keyEncoder, p, newTypePath), + MapKeys(path)), + "array", + ObjectType(classOf[Array[Any]])) + + val valueData = + Invoke( + UnresolvedMapObjects( + p => deserializerFor(valueEncoder, p, newTypePath), + MapValues(path)), + "array", + ObjectType(classOf[Array[Any]])) + + StaticInvoke( + ArrayBasedMapData.getClass, + ObjectType(classOf[java.util.Map[_, _]]), + "toJavaMap", + keyData :: valueData :: Nil, + returnNullable = false) + case MapEncoder(tag, keyEncoder, valueEncoder, _) => val newTypePath = walkedTypePath.recordMap( keyEncoder.clsTag.runtimeClass.getName, @@ -312,6 +342,26 @@ object ScalaReflection extends ScalaReflection { exprs.If(IsNull(path), exprs.Literal.create(null, externalDataTypeFor(enc)), CreateExternalRow(convertedFields, enc.schema)) + + case JavaBeanEncoder(tag, fields) => + val setters = fields.map { f => + val newTypePath = walkedTypePath.recordField( + f.enc.clsTag.runtimeClass.getName, + f.name) + val setter = expressionWithNullSafety( + deserializerFor( + f.enc, + addToPath(path, f.name, f.enc.dataType, newTypePath), + newTypePath), + nullable = f.nullable, + newTypePath) + f.writeMethod.get -> setter + } + + val cls = tag.runtimeClass + val newInstance = NewInstance(cls, Nil, ObjectType(cls), propagateNull = false) + val result = InitializeJavaBean(newInstance, setters.toMap) + exprs.If(IsNull(path), exprs.Literal.create(null, ObjectType(cls)), result) } private def deserializeArray( @@ -446,6 +496,18 @@ object ScalaReflection extends ScalaReflection { field.name -> convertedField } createSerializerForObject(input, serializedFields) + + case JavaBeanEncoder(_, fields) => + val serializedFields = fields.map { f => + val fieldValue = Invoke( + KnownNotNull(input), + f.readMethod.get, + externalDataTypeFor(f.enc), + propagateNull = f.nullable, + returnNullable = f.nullable) + f.name -> serializerFor(f.enc, fieldValue) + } + createSerializerForObject(input, serializedFields) } private def serializerForArray( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala index cdc64f2ddb5..1a3c1089649 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala @@ -91,7 +91,9 @@ object AgnosticEncoders { name: String, enc: AgnosticEncoder[_], nullable: Boolean, - metadata: Metadata) { + metadata: Metadata, + readMethod: Option[String] = None, + writeMethod: Option[String] = None) { def structField: StructField = StructField(name, enc.dataType, nullable, metadata) } @@ -112,6 +114,15 @@ object AgnosticEncoders { override def clsTag: ClassTag[Row] = classTag[Row] } + case class JavaBeanEncoder[K]( + override val clsTag: ClassTag[K], + fields: Seq[EncoderField]) + extends AgnosticEncoder[K] { + override def isPrimitive: Boolean = false + override val schema: StructType = StructType(fields.map(_.structField)) + override def dataType: DataType = schema + } + // This will only work for encoding from/to Sparks' InternalRow format. // It is here for compatibility. case class UDTEncoder[E >: Null]( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 9ca2fc72ad9..faa165c298d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -59,16 +59,7 @@ object ExpressionEncoder { // TODO: improve error message for java bean encoder. def javaBean[T](beanClass: Class[T]): ExpressionEncoder[T] = { - val schema = JavaTypeInference.inferDataType(beanClass)._1 - assert(schema.isInstanceOf[StructType]) - - val objSerializer = JavaTypeInference.serializerFor(beanClass) - val objDeserializer = JavaTypeInference.deserializerFor(beanClass) - - new ExpressionEncoder[T]( - objSerializer, - objDeserializer, - ClassTag[T](beanClass)) + apply(JavaTypeInference.encoderFor(beanClass)) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 299a928f267..929beb660ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1927,7 +1927,8 @@ case class ValidateExternalType(child: Expression, expected: DataType, externalD (value: Any) => { value.getClass.isArray || value.isInstanceOf[scala.collection.Seq[_]] || - value.isInstanceOf[Set[_]] + value.isInstanceOf[Set[_]] || + value.isInstanceOf[java.util.List[_]] } case _: DateType => (value: Any) => { @@ -1968,7 +1969,10 @@ case class ValidateExternalType(child: Expression, expected: DataType, externalD classOf[scala.math.BigDecimal], classOf[Decimal])) case _: ArrayType => - val check = genCheckTypes(Seq(classOf[scala.collection.Seq[_]], classOf[Set[_]])) + val check = genCheckTypes(Seq( + classOf[scala.collection.Seq[_]], + classOf[Set[_]], + classOf[java.util.List[_]])) s"$obj.getClass().isArray() || $check" case _: DateType => genCheckTypes(Seq(classOf[java.sql.Date], classOf[java.time.LocalDate])) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/JavaTypeInferenceSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/JavaTypeInferenceSuite.scala index 9c1d0c17777..35f5bf739bf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/JavaTypeInferenceSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/JavaTypeInferenceSuite.scala @@ -18,25 +18,206 @@ package org.apache.spark.sql.catalyst import java.math.BigInteger +import java.util.{LinkedList, List => JList, Map => JMap} -import scala.beans.BeanProperty +import scala.beans.{BeanProperty, BooleanBeanProperty} +import scala.reflect.{classTag, ClassTag} import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{CheckOverflow, Expression, Literal} -import org.apache.spark.sql.types.DecimalType +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, UDTCaseClass, UDTForCaseClass} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ +import org.apache.spark.sql.types.{DecimalType, MapType, Metadata, StringType, StructField, StructType} -class DummyBean() { - @BeanProperty var bigInteger = null: BigInteger +class DummyBean { + @BeanProperty var bigInteger: BigInteger = _ } +class GenericCollectionBean { + @BeanProperty var listOfListOfStrings: JList[JList[String]] = _ + @BeanProperty var mapOfDummyBeans: JMap[String, DummyBean] = _ + @BeanProperty var linkedListOfStrings: LinkedList[String] = _ +} + +class LeafBean { + @BooleanBeanProperty var primitiveBoolean: Boolean = false + @BeanProperty var primitiveByte: Byte = 0 + @BeanProperty var primitiveShort: Short = 0 + @BeanProperty var primitiveInt: Int = 0 + @BeanProperty var primitiveLong: Long = 0 + @BeanProperty var primitiveFloat: Float = 0 + @BeanProperty var primitiveDouble: Double = 0 + @BeanProperty var boxedBoolean: java.lang.Boolean = false + @BeanProperty var boxedByte: java.lang.Byte = 0.toByte + @BeanProperty var boxedShort: java.lang.Short = 0.toShort + @BeanProperty var boxedInt: java.lang.Integer = 0 + @BeanProperty var boxedLong: java.lang.Long = 0 + @BeanProperty var boxedFloat: java.lang.Float = 0 + @BeanProperty var boxedDouble: java.lang.Double = 0 + @BeanProperty var string: String = _ + @BeanProperty var binary: Array[Byte] = _ + @BeanProperty var bigDecimal: java.math.BigDecimal = _ + @BeanProperty var bigInteger: java.math.BigInteger = _ + @BeanProperty var localDate: java.time.LocalDate = _ + @BeanProperty var date: java.sql.Date = _ + @BeanProperty var instant: java.time.Instant = _ + @BeanProperty var timestamp: java.sql.Timestamp = _ + @BeanProperty var localDateTime: java.time.LocalDateTime = _ + @BeanProperty var duration: java.time.Duration = _ + @BeanProperty var period: java.time.Period = _ + @BeanProperty var enum: java.time.Month = _ + @BeanProperty val readOnlyString = "read-only" + + var nonNullString: String = "value" + @javax.annotation.Nonnull + def getNonNullString: String = nonNullString + def setNonNullString(v: String): Unit = nonNullString = { + java.util.Objects.nonNull(v) + v + } +} + +class ArrayBean { + @BeanProperty var dummyBeanArray: Array[DummyBean] = _ + @BeanProperty var primitiveIntArray: Array[Int] = _ + @BeanProperty var stringArray: Array[String] = _ +} + +class UDTBean { + @BeanProperty var udt: UDTCaseClass = _ +} + +/** + * Test suite for Encoders produced by [[JavaTypeInference]]. + */ class JavaTypeInferenceSuite extends SparkFunSuite { + private def encoderField( + name: String, + encoder: AgnosticEncoder[_], + overrideNullable: Option[Boolean] = None, + readOnly: Boolean = false): EncoderField = { + val readPrefix = if (encoder == PrimitiveBooleanEncoder) "is" else "get" + EncoderField( + name, + encoder, + overrideNullable.getOrElse(encoder.nullable), + Metadata.empty, + Option(readPrefix + name.capitalize), + Option("set" + name.capitalize).filterNot(_ => readOnly)) + } + + private val expectedDummyBeanEncoder = + JavaBeanEncoder[DummyBean]( + ClassTag(classOf[DummyBean]), + Seq(encoderField("bigInteger", JavaBigIntEncoder))) + + private val expectedDummyBeanSchema = + StructType(StructField("bigInteger", DecimalType(38, 0)) :: Nil) + test("SPARK-41007: JavaTypeInference returns the correct serializer for BigInteger") { - var serializer = JavaTypeInference.serializerFor(classOf[DummyBean]) - var bigIntegerFieldName: Expression = serializer.children(0) - assert(bigIntegerFieldName.asInstanceOf[Literal].value.toString == "bigInteger") - var bigIntegerFieldExpression: Expression = serializer.children(1) - assert(bigIntegerFieldExpression.asInstanceOf[CheckOverflow].dataType == - DecimalType.BigIntDecimal) + val encoder = JavaTypeInference.encoderFor(classOf[DummyBean]) + assert(encoder === expectedDummyBeanEncoder) + assert(encoder.schema === expectedDummyBeanSchema) + } + + test("resolve schema for class") { + val (schema, nullable) = JavaTypeInference.inferDataType(classOf[DummyBean]) + assert(nullable) + assert(schema === expectedDummyBeanSchema) + } + + test("resolve schema for type") { + val getter = classOf[GenericCollectionBean].getDeclaredMethods + .find(_.getName == "getMapOfDummyBeans") + .get + val (schema, nullable) = JavaTypeInference.inferDataType(getter.getGenericReturnType) + val expected = MapType(StringType, expectedDummyBeanSchema, valueContainsNull = true) + assert(nullable) + assert(schema === expected) + } + + test("resolve type parameters for map and list") { + val encoder = JavaTypeInference.encoderFor(classOf[GenericCollectionBean]) + val expected = JavaBeanEncoder(ClassTag(classOf[GenericCollectionBean]), Seq( + encoderField( + "linkedListOfStrings", + IterableEncoder( + ClassTag(classOf[LinkedList[_]]), + StringEncoder, + containsNull = true, + lenientSerialization = false)), + encoderField( + "listOfListOfStrings", + IterableEncoder( + ClassTag(classOf[JList[_]]), + IterableEncoder( + ClassTag(classOf[JList[_]]), + StringEncoder, + containsNull = true, + lenientSerialization = false), + containsNull = true, + lenientSerialization = false)), + encoderField( + "mapOfDummyBeans", + MapEncoder( + ClassTag(classOf[JMap[_, _]]), + StringEncoder, + expectedDummyBeanEncoder, + valueContainsNull = true)))) + assert(encoder === expected) + } + + test("resolve leaf encoders") { + val encoder = JavaTypeInference.encoderFor(classOf[LeafBean]) + val expected = JavaBeanEncoder(ClassTag(classOf[LeafBean]), Seq( + // The order is different from the definition because fields are ordered by name. + encoderField("bigDecimal", DEFAULT_JAVA_DECIMAL_ENCODER), + encoderField("bigInteger", JavaBigIntEncoder), + encoderField("binary", BinaryEncoder), + encoderField("boxedBoolean", BoxedBooleanEncoder), + encoderField("boxedByte", BoxedByteEncoder), + encoderField("boxedDouble", BoxedDoubleEncoder), + encoderField("boxedFloat", BoxedFloatEncoder), + encoderField("boxedInt", BoxedIntEncoder), + encoderField("boxedLong", BoxedLongEncoder), + encoderField("boxedShort", BoxedShortEncoder), + encoderField("date", STRICT_DATE_ENCODER), + encoderField("duration", DayTimeIntervalEncoder), + encoderField("enum", JavaEnumEncoder(classTag[java.time.Month])), + encoderField("instant", STRICT_INSTANT_ENCODER), + encoderField("localDate", STRICT_LOCAL_DATE_ENCODER), + encoderField("localDateTime", LocalDateTimeEncoder), + encoderField("nonNullString", StringEncoder, overrideNullable = Option(false)), + encoderField("period", YearMonthIntervalEncoder), + encoderField("primitiveBoolean", PrimitiveBooleanEncoder), + encoderField("primitiveByte", PrimitiveByteEncoder), + encoderField("primitiveDouble", PrimitiveDoubleEncoder), + encoderField("primitiveFloat", PrimitiveFloatEncoder), + encoderField("primitiveInt", PrimitiveIntEncoder), + encoderField("primitiveLong", PrimitiveLongEncoder), + encoderField("primitiveShort", PrimitiveShortEncoder), + encoderField("readOnlyString", StringEncoder, readOnly = true), + encoderField("string", StringEncoder), + encoderField("timestamp", STRICT_TIMESTAMP_ENCODER) + )) + assert(encoder === expected) + } + + test("resolve array encoders") { + val encoder = JavaTypeInference.encoderFor(classOf[ArrayBean]) + val expected = JavaBeanEncoder(ClassTag(classOf[ArrayBean]), Seq( + encoderField("dummyBeanArray", ArrayEncoder(expectedDummyBeanEncoder, containsNull = true)), + encoderField("primitiveIntArray", ArrayEncoder(PrimitiveIntEncoder, containsNull = false)), + encoderField("stringArray", ArrayEncoder(StringEncoder, containsNull = true)) + )) + assert(encoder === expected) + } + + test("resolve UDT encoders") { + val encoder = JavaTypeInference.encoderFor(classOf[UDTBean]) + val expected = JavaBeanEncoder(ClassTag(classOf[UDTBean]), Seq( + encoderField("udt", UDTEncoder(new UDTForCaseClass, classOf[UDTForCaseClass])) + )) + assert(encoder === expected) } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org