viirya commented on code in PR #39615:
URL: https://github.com/apache/spark/pull/39615#discussion_r1590400148


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala:
##########
@@ -166,317 +148,58 @@ object JavaTypeInference {
       .filter(_.getReadMethod != null)
   }
 
-  private def getJavaBeanReadableAndWritableProperties(
-      beanClass: Class[_]): Array[PropertyDescriptor] = {
-    getJavaBeanReadableProperties(beanClass).filter(_.getWriteMethod != null)
-  }
-
-  private def elementType(typeToken: TypeToken[_]): TypeToken[_] = {
-    val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JIterable[_]]]
-    val iterableSuperType = typeToken2.getSupertype(classOf[JIterable[_]])
-    val iteratorType = iterableSuperType.resolveType(iteratorReturnType)
-    iteratorType.resolveType(nextReturnType)
-  }
-
-  private def mapKeyValueType(typeToken: TypeToken[_]): (TypeToken[_], 
TypeToken[_]) = {
-    val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JMap[_, _]]]
-    val mapSuperType = typeToken2.getSupertype(classOf[JMap[_, _]])
-    val keyType = elementType(mapSuperType.resolveType(keySetReturnType))
-    val valueType = elementType(mapSuperType.resolveType(valuesReturnType))
-    keyType -> valueType
-  }
-
-  /**
-   * Returns the Spark SQL DataType for a given java class.  Where this is not 
an exact mapping
-   * to a native type, an ObjectType is returned.
-   *
-   * Unlike `inferDataType`, this function doesn't do any massaging of types 
into the Spark SQL type
-   * system.  As a result, ObjectType will be returned for things like boxed 
Integers.
-   */
-  private def inferExternalType(cls: Class[_]): DataType = cls match {
-    case c if c == java.lang.Boolean.TYPE => BooleanType
-    case c if c == java.lang.Byte.TYPE => ByteType
-    case c if c == java.lang.Short.TYPE => ShortType
-    case c if c == java.lang.Integer.TYPE => IntegerType
-    case c if c == java.lang.Long.TYPE => LongType
-    case c if c == java.lang.Float.TYPE => FloatType
-    case c if c == java.lang.Double.TYPE => DoubleType
-    case c if c == classOf[Array[Byte]] => BinaryType
-    case _ => ObjectType(cls)
-  }
-
-  /**
-   * Returns an expression that can be used to deserialize a Spark SQL 
representation to an object
-   * of java bean `T` with a compatible schema.  The Spark SQL representation 
is located at ordinal
-   * 0 of a row, i.e., `GetColumnByOrdinal(0, _)`. Nested classes will have 
their fields accessed
-   * using `UnresolvedExtractValue`.
-   */
-  def deserializerFor(beanClass: Class[_]): Expression = {
-    val typeToken = TypeToken.of(beanClass)
-    val walkedTypePath = new 
WalkedTypePath().recordRoot(beanClass.getCanonicalName)
-    val (dataType, nullable) = inferDataType(typeToken)
-
-    // Assumes we are deserializing the first column of a row.
-    deserializerForWithNullSafetyAndUpcast(GetColumnByOrdinal(0, dataType), 
dataType,
-      nullable = nullable, walkedTypePath, deserializerFor(typeToken, _, 
walkedTypePath))
-  }
-
-  private def deserializerFor(
-      typeToken: TypeToken[_],
-      path: Expression,
-      walkedTypePath: WalkedTypePath): Expression = {
-    typeToken.getRawType match {
-      case c if !inferExternalType(c).isInstanceOf[ObjectType] => path
-
-      case c if c == classOf[java.lang.Short] ||
-                c == classOf[java.lang.Integer] ||
-                c == classOf[java.lang.Long] ||
-                c == classOf[java.lang.Double] ||
-                c == classOf[java.lang.Float] ||
-                c == classOf[java.lang.Byte] ||
-                c == classOf[java.lang.Boolean] =>
-        createDeserializerForTypesSupportValueOf(path, c)
-
-      case c if c == classOf[java.time.LocalDate] =>
-        createDeserializerForLocalDate(path)
-
-      case c if c == classOf[java.sql.Date] =>
-        createDeserializerForSqlDate(path)
-
-      case c if c == classOf[java.time.Instant] =>
-        createDeserializerForInstant(path)
-
-      case c if c == classOf[java.sql.Timestamp] =>
-        createDeserializerForSqlTimestamp(path)
+  private class ImplementsGenericInterface(interface: Class[_]) {
+    assert(interface.isInterface)
+    assert(interface.getTypeParameters.nonEmpty)
 
-      case c if c == classOf[java.time.LocalDateTime] =>
-        createDeserializerForLocalDateTime(path)
-
-      case c if c == classOf[java.time.Duration] =>
-        createDeserializerForDuration(path)
-
-      case c if c == classOf[java.time.Period] =>
-        createDeserializerForPeriod(path)
-
-      case c if c == classOf[java.lang.String] =>
-        createDeserializerForString(path, returnNullable = true)
-
-      case c if c == classOf[java.math.BigDecimal] =>
-        createDeserializerForJavaBigDecimal(path, returnNullable = true)
-
-      case c if c == classOf[java.math.BigInteger] =>
-        createDeserializerForJavaBigInteger(path, returnNullable = true)
-
-      case c if c.isArray =>
-        val elementType = c.getComponentType
-        val newTypePath = 
walkedTypePath.recordArray(elementType.getCanonicalName)
-        val (dataType, elementNullable) = inferDataType(elementType)
-        val mapFunction: Expression => Expression = element => {
-          // upcast the array element to the data type the encoder expected.
-          deserializerForWithNullSafetyAndUpcast(
-            element,
-            dataType,
-            nullable = elementNullable,
-            newTypePath,
-            deserializerFor(typeToken.getComponentType, _, newTypePath))
-        }
-
-        val arrayData = UnresolvedMapObjects(mapFunction, path)
-
-        val methodName = elementType match {
-          case c if c == java.lang.Integer.TYPE => "toIntArray"
-          case c if c == java.lang.Long.TYPE => "toLongArray"
-          case c if c == java.lang.Double.TYPE => "toDoubleArray"
-          case c if c == java.lang.Float.TYPE => "toFloatArray"
-          case c if c == java.lang.Short.TYPE => "toShortArray"
-          case c if c == java.lang.Byte.TYPE => "toByteArray"
-          case c if c == java.lang.Boolean.TYPE => "toBooleanArray"
-          // non-primitive
-          case _ => "array"
-        }
-        Invoke(arrayData, methodName, ObjectType(c))
-
-      case c if ttIsAssignableFrom(listType, typeToken) =>
-        val et = elementType(typeToken)
-        val newTypePath = walkedTypePath.recordArray(et.getType.getTypeName)
-        val (dataType, elementNullable) = inferDataType(et)
-        val mapFunction: Expression => Expression = element => {
-          // upcast the array element to the data type the encoder expected.
-          deserializerForWithNullSafetyAndUpcast(
-            element,
-            dataType,
-            nullable = elementNullable,
-            newTypePath,
-            deserializerFor(et, _, newTypePath))
-        }
-
-        UnresolvedMapObjects(mapFunction, path, customCollectionCls = Some(c))
-
-      case _ if ttIsAssignableFrom(mapType, typeToken) =>
-        val (keyType, valueType) = mapKeyValueType(typeToken)
-        val newTypePath = walkedTypePath.recordMap(keyType.getType.getTypeName,
-          valueType.getType.getTypeName)
-
-        val keyData =
-          Invoke(
-            UnresolvedMapObjects(
-              p => deserializerFor(keyType, p, newTypePath),
-              MapKeys(path)),
-            "array",
-            ObjectType(classOf[Array[Any]]))
-
-        val valueData =
-          Invoke(
-            UnresolvedMapObjects(
-              p => deserializerFor(valueType, p, newTypePath),
-              MapValues(path)),
-            "array",
-            ObjectType(classOf[Array[Any]]))
-
-        StaticInvoke(
-          ArrayBasedMapData.getClass,
-          ObjectType(classOf[JMap[_, _]]),
-          "toJavaMap",
-          keyData :: valueData :: Nil,
-          returnNullable = false)
-
-      case other if other.isEnum =>
-        createDeserializerForTypesSupportValueOf(
-          createDeserializerForString(path, returnNullable = false),
-          other)
-
-      case other =>
-        val properties = getJavaBeanReadableAndWritableProperties(other)
-        val setters = properties.map { p =>
-          val fieldName = p.getName
-          val fieldType = typeToken.method(p.getReadMethod).getReturnType
-          val (dataType, nullable) = inferDataType(fieldType)
-          val newTypePath = 
walkedTypePath.recordField(fieldType.getType.getTypeName, fieldName)
-          // The existence of `javax.annotation.Nonnull`, means this field is 
not nullable.
-          val hasNonNull = 
p.getReadMethod.isAnnotationPresent(classOf[Nonnull])
-          val setter = expressionWithNullSafety(
-            deserializerFor(fieldType, addToPath(path, fieldName, dataType, 
newTypePath),
-              newTypePath),
-            nullable = nullable && !hasNonNull,
-            newTypePath)
-          p.getWriteMethod.getName -> setter
-        }.toMap
-
-        val newInstance = NewInstance(other, Nil, ObjectType(other), 
propagateNull = false)
-        val result = InitializeJavaBean(newInstance, setters)
-
-        expressions.If(
-          IsNull(path),
-          expressions.Literal.create(null, ObjectType(other)),
-          result
-        )
+    def unapply(t: Type): Option[(Class[_], Array[Type])] = 
implementsInterface(t).map { cls =>
+      cls -> findTypeArgumentsForInterface(t)
     }
-  }
 
-  /**
-   * Returns an expression for serializing an object of the given type to a 
Spark SQL
-   * representation. The input object is located at ordinal 0 of a row, i.e.,
-   * `BoundReference(0, _)`.
-   */
-  def serializerFor(beanClass: Class[_]): Expression = {
-    val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true)
-    val nullSafeInput = AssertNotNull(inputObject, Seq("top level input bean"))
-    serializerFor(nullSafeInput, TypeToken.of(beanClass))
-  }
-
-  private def serializerFor(inputObject: Expression, typeToken: TypeToken[_]): 
Expression = {
-
-    def toCatalystArray(input: Expression, elementType: TypeToken[_]): 
Expression = {
-      val (dataType, nullable) = inferDataType(elementType)
-      if (ScalaReflection.isNativeType(dataType)) {
-        val cls = input.dataType.asInstanceOf[ObjectType].cls
-        if (cls.isArray && cls.getComponentType.isPrimitive) {
-          createSerializerForPrimitiveArray(input, dataType)
-        } else {
-          createSerializerForGenericArray(input, dataType, nullable = nullable)
-        }
-      } else {
-        createSerializerForMapObjects(input, 
ObjectType(elementType.getRawType),
-          serializerFor(_, elementType))
-      }
+    @tailrec
+    private def implementsInterface(t: Type): Option[Class[_]] = t match {
+      case pt: ParameterizedType => implementsInterface(pt.getRawType)
+      case c: Class[_] if interface.isAssignableFrom(c) => Option(c)
+      case _ => None
     }
 
-    if (!inputObject.dataType.isInstanceOf[ObjectType]) {
-      inputObject
-    } else {
-      typeToken.getRawType match {
-        case c if c == classOf[String] => 
createSerializerForString(inputObject)
-
-        case c if c == classOf[java.time.Instant] => 
createSerializerForJavaInstant(inputObject)
-
-        case c if c == classOf[java.sql.Timestamp] => 
createSerializerForSqlTimestamp(inputObject)
-
-        case c if c == classOf[java.time.LocalDateTime] =>
-          createSerializerForLocalDateTime(inputObject)
-
-        case c if c == classOf[java.time.LocalDate] => 
createSerializerForJavaLocalDate(inputObject)
-
-        case c if c == classOf[java.sql.Date] => 
createSerializerForSqlDate(inputObject)
-
-        case c if c == classOf[java.time.Duration] => 
createSerializerForJavaDuration(inputObject)
-
-        case c if c == classOf[java.time.Period] => 
createSerializerForJavaPeriod(inputObject)
-
-        case c if c == classOf[java.math.BigInteger] =>
-          createSerializerForBigInteger(inputObject)
-
-        case c if c == classOf[java.math.BigDecimal] =>
-          createSerializerForBigDecimal(inputObject)
-
-        case c if c == classOf[java.lang.Boolean] => 
createSerializerForBoolean(inputObject)
-        case c if c == classOf[java.lang.Byte] => 
createSerializerForByte(inputObject)
-        case c if c == classOf[java.lang.Short] => 
createSerializerForShort(inputObject)
-        case c if c == classOf[java.lang.Integer] => 
createSerializerForInteger(inputObject)
-        case c if c == classOf[java.lang.Long] => 
createSerializerForLong(inputObject)
-        case c if c == classOf[java.lang.Float] => 
createSerializerForFloat(inputObject)
-        case c if c == classOf[java.lang.Double] => 
createSerializerForDouble(inputObject)
-
-        case _ if typeToken.isArray =>
-          toCatalystArray(inputObject, typeToken.getComponentType)
-
-        case _ if ttIsAssignableFrom(listType, typeToken) =>
-          toCatalystArray(inputObject, elementType(typeToken))
-
-        case _ if ttIsAssignableFrom(mapType, typeToken) =>
-          val (keyType, valueType) = mapKeyValueType(typeToken)
-
-          createSerializerForMap(
-            inputObject,
-            MapElementInformation(
-              ObjectType(keyType.getRawType),
-              nullable = true,
-              serializerFor(_, keyType)),
-            MapElementInformation(
-              ObjectType(valueType.getRawType),
-              nullable = true,
-              serializerFor(_, valueType))
-          )
-
-        case other if other.isEnum =>
-          createSerializerForString(
-            Invoke(inputObject, "name", ObjectType(classOf[String]), 
returnNullable = false))
-
-        case other =>
-          val properties = getJavaBeanReadableAndWritableProperties(other)

Review Comment:
   No matter what, this change has been there for a while. It seems difficult 
to change it back to old behavior. As discussed with @HeartSaVioR  in the JIRA, 
I'd like to see if we can add a SQL conf to fall back to previous behavior of 
`Encoder.bean()`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to