This is an automated email from the ASF dual-hosted git repository. aljoscha pushed a commit to branch release-1.8 in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/release-1.8 by this push: new ecc6639 [FLINK-12301] Fix ScalaCaseClassSerializer to support value types ecc6639 is described below commit ecc6639053cb36672ce552bb7626f75ff98b8293 Author: Igal Shilman <igal.shil...@data-artisans.com> AuthorDate: Mon May 13 11:00:37 2019 +0200 [FLINK-12301] Fix ScalaCaseClassSerializer to support value types We now use Scala reflection because it correctly deals with Scala language features. --- .../scala/typeutils/ScalaCaseClassSerializer.scala | 65 +++++++--------------- .../ScalaCaseClassSerializerReflectionTest.scala | 41 ++++++++++---- 2 files changed, 50 insertions(+), 56 deletions(-) diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/ScalaCaseClassSerializer.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/ScalaCaseClassSerializer.scala index 7ff1427..fbaa2ac 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/ScalaCaseClassSerializer.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/ScalaCaseClassSerializer.scala @@ -18,15 +18,14 @@ package org.apache.flink.api.scala.typeutils +import java.io.ObjectInputStream + import org.apache.flink.api.common.typeutils.CompositeTypeSerializerUtil.delegateCompatibilityCheckToNewSnapshot import org.apache.flink.api.common.typeutils.TypeSerializerConfigSnapshot.SelfResolvingTypeSerializer import org.apache.flink.api.common.typeutils._ import org.apache.flink.api.java.typeutils.runtime.TupleSerializerConfigSnapshot import org.apache.flink.api.scala.typeutils.ScalaCaseClassSerializer.lookupConstructor -import java.io.ObjectInputStream -import java.lang.invoke.{MethodHandle, MethodHandles} - import scala.collection.JavaConverters._ import scala.reflect.runtime.universe @@ -38,16 +37,16 @@ import scala.reflect.runtime.universe */ @SerialVersionUID(1L) class ScalaCaseClassSerializer[T <: Product]( - clazz: Class[T], - scalaFieldSerializers: Array[TypeSerializer[_]] -) extends CaseClassSerializer[T](clazz, scalaFieldSerializers) - with SelfResolvingTypeSerializer[T] { + clazz: Class[T], + scalaFieldSerializers: Array[TypeSerializer[_]] + ) extends CaseClassSerializer[T](clazz, scalaFieldSerializers) + with SelfResolvingTypeSerializer[T] { @transient private var constructor = lookupConstructor(clazz) override def createInstance(fields: Array[AnyRef]): T = { - constructor.invoke(fields).asInstanceOf[T] + constructor(fields) } override def snapshotConfiguration(): TypeSerializerSnapshot[T] = { @@ -55,8 +54,7 @@ class ScalaCaseClassSerializer[T <: Product]( } override def resolveSchemaCompatibilityViaRedirectingToNewSnapshotClass( - s: TypeSerializerConfigSnapshot[T] - ): TypeSerializerSchemaCompatibility[T] = { + s: TypeSerializerConfigSnapshot[T]): TypeSerializerSchemaCompatibility[T] = { require(s.isInstanceOf[TupleSerializerConfigSnapshot[_]]) @@ -85,22 +83,8 @@ class ScalaCaseClassSerializer[T <: Product]( object ScalaCaseClassSerializer { - def lookupConstructor[T](clazz: Class[_]): MethodHandle = { - val types = findPrimaryConstructorParameterTypes(clazz, clazz.getClassLoader) - - val constructor = clazz.getConstructor(types: _*) - - val handle = MethodHandles - .lookup() - .unreflectConstructor(constructor) - .asSpreader(classOf[Array[AnyRef]], types.length) - - handle - } - - private def findPrimaryConstructorParameterTypes(cls: Class[_], cl: ClassLoader): - List[Class[_]] = { - val rootMirror = universe.runtimeMirror(cl) + def lookupConstructor[T](cls: Class[T]): Array[AnyRef] => T = { + val rootMirror = universe.runtimeMirror(cls.getClassLoader) val classSymbol = rootMirror.classSymbol(cls) require( @@ -113,30 +97,21 @@ object ScalaCaseClassSerializer { |""".stripMargin ) - val primaryConstructorSymbol = findPrimaryConstructorMethodSymbol(classSymbol) - val scalaTypes = getArgumentsTypes(primaryConstructorSymbol) - scalaTypes.map(tpe => scalaTypeToJavaClass(rootMirror)(tpe)) - } - - private def findPrimaryConstructorMethodSymbol(classSymbol: universe.ClassSymbol): - universe.MethodSymbol = { - classSymbol.toType + val primaryConstructorSymbol = classSymbol.toType .decl(universe.termNames.CONSTRUCTOR) .alternatives + .collectFirst({ + case constructorSymbol: universe.MethodSymbol if constructorSymbol.isPrimaryConstructor => + constructorSymbol + }) .head .asMethod - } - private def getArgumentsTypes(primaryConstructorSymbol: universe.MethodSymbol): - List[universe.Type] = { - primaryConstructorSymbol.typeSignature - .paramLists - .head - .map(symbol => symbol.typeSignature) - } + val classMirror = rootMirror.reflectClass(classSymbol) + val constructorMethodMirror = classMirror.reflectConstructor(primaryConstructorSymbol) - private def scalaTypeToJavaClass(mirror: universe.Mirror)(scalaType: universe.Type): Class[_] = { - val erasure = scalaType.erasure - mirror.runtimeClass(erasure) + arr: Array[AnyRef] => { + constructorMethodMirror.apply(arr: _*).asInstanceOf[T] + } } } diff --git a/flink-scala/src/test/scala/org/apache/flink/api/scala/typeutils/ScalaCaseClassSerializerReflectionTest.scala b/flink-scala/src/test/scala/org/apache/flink/api/scala/typeutils/ScalaCaseClassSerializerReflectionTest.scala index 222dc57..47eab50 100644 --- a/flink-scala/src/test/scala/org/apache/flink/api/scala/typeutils/ScalaCaseClassSerializerReflectionTest.scala +++ b/flink-scala/src/test/scala/org/apache/flink/api/scala/typeutils/ScalaCaseClassSerializerReflectionTest.scala @@ -18,13 +18,11 @@ package org.apache.flink.api.scala.typeutils -import org.apache.flink.api.scala.typeutils.ScalaCaseClassSerializerReflectionTest.{Generic, HigherKind, SimpleCaseClass} +import org.apache.flink.api.scala.typeutils.ScalaCaseClassSerializerReflectionTest._ import org.apache.flink.util.TestLogger - import org.junit.Assert.assertEquals import org.junit.Test -import java.lang.invoke.MethodHandle /** * Test obtaining the primary constructor of a case class @@ -34,40 +32,40 @@ class ScalaCaseClassSerializerReflectionTest extends TestLogger { @Test def usageExample(): Unit = { - val constructor: MethodHandle = ScalaCaseClassSerializer + val constructor = ScalaCaseClassSerializer .lookupConstructor(classOf[SimpleCaseClass]) - val actual = constructor.invoke(Array("hi", 1.asInstanceOf[Any])) + val actual = constructor(Array("hi", 1.asInstanceOf[AnyRef])) assertEquals(SimpleCaseClass("hi", 1), actual) } @Test def genericCaseClass(): Unit = { - val constructor: MethodHandle = ScalaCaseClassSerializer + val constructor = ScalaCaseClassSerializer .lookupConstructor(classOf[Generic[_]]) - val actual = constructor.invoke(Array(1.asInstanceOf[AnyRef])) + val actual = constructor(Array(1.asInstanceOf[AnyRef])) assertEquals(Generic[Int](1), actual) } @Test def caseClassWithParameterizedList(): Unit = { - val constructor: MethodHandle = ScalaCaseClassSerializer + val constructor = ScalaCaseClassSerializer .lookupConstructor(classOf[HigherKind]) - val actual = constructor.invoke(Array(List(1, 2, 3), "hey")) + val actual = constructor(Array(List(1, 2, 3), "hey")) assertEquals(HigherKind(List(1, 2, 3), "hey"), actual) } @Test def tupleType(): Unit = { - val constructor: MethodHandle = ScalaCaseClassSerializer + val constructor = ScalaCaseClassSerializer .lookupConstructor(classOf[(String, String, Int)]) - val actual = constructor.invoke(Array("a", "b", 7)) + val actual = constructor(Array("a", "b", 7.asInstanceOf[AnyRef])) assertEquals(("a", "b", 7), actual) } @@ -80,6 +78,21 @@ class ScalaCaseClassSerializerReflectionTest extends TestLogger { ScalaCaseClassSerializer .lookupConstructor(classOf[outerInstance.InnerCaseClass]) } + + @Test + def valueClass(): Unit = { + val constructor = ScalaCaseClassSerializer + .lookupConstructor(classOf[Measurement]) + + val arguments = Array( + 1.asInstanceOf[AnyRef], + new DegreeCelsius(0.5f).asInstanceOf[AnyRef] + ) + + val actual = constructor(arguments) + + assertEquals(Measurement(1, new DegreeCelsius(0.5f)), actual) + } } object ScalaCaseClassSerializerReflectionTest { @@ -94,6 +107,12 @@ object ScalaCaseClassSerializerReflectionTest { case class Generic[T](item: T) + class DegreeCelsius(val value: Float) extends AnyVal { + override def toString: String = s"$value °C" + } + + case class Measurement(i: Int, temperature: DegreeCelsius) + } class OuterClass {