This is an automated email from the ASF dual-hosted git repository. zsxwing pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new ef402edff91 [SPARK-41045][SQL] Pre-compute to eliminate ScalaReflection calls after deserializer is created ef402edff91 is described below commit ef402edff91377d37c0c1b8d40921ed7bd9f7160 Author: Shixiong Zhu <zsxw...@gmail.com> AuthorDate: Tue Nov 8 08:18:50 2022 -0800 [SPARK-41045][SQL] Pre-compute to eliminate ScalaReflection calls after deserializer is created ### What changes were proposed in this pull request? Currently when `ScalaReflection` returns a deserializer, for a few complex types, such as array, map, udt, etc, it creates functions that may still touch `ScalaReflection` after the deserializer is created. `ScalaReflection` is a performance bottleneck for multiple threads as it holds multiple global locks. We can refactor `ScalaReflection.deserializerFor` to pre-compute everything that needs to touch `ScalaReflection` before creating the deserializer. After this, once the deserializer is created, it can be reused by multiple threads without touching `ScalaReflection.deserializerFor` any more and it will be much faster. ### Why are the changes needed? Optimize `ScalaReflection.deserializerFor` to make deserializers faster under multiple threads. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? This is refactoring `deserializerFor` to optimize the code. Existing tests should already cover the correctness. Closes #38556 from zsxwing/scala-ref. Authored-by: Shixiong Zhu <zsxw...@gmail.com> Signed-off-by: Shixiong Zhu <zsxw...@gmail.com> --- .../sql/catalyst/DeserializerBuildHelper.scala | 5 +- .../spark/sql/catalyst/JavaTypeInference.scala | 8 +- .../spark/sql/catalyst/ScalaReflection.scala | 157 +++++++++++---------- 3 files changed, 85 insertions(+), 85 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala index 0d3b9977e4f..7051c2d2264 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala @@ -49,10 +49,9 @@ object DeserializerBuildHelper { dataType: DataType, nullable: Boolean, walkedTypePath: WalkedTypePath, - funcForCreatingDeserializer: (Expression, WalkedTypePath) => Expression): Expression = { + funcForCreatingDeserializer: Expression => Expression): Expression = { val casted = upCastToExpectedType(expr, dataType, walkedTypePath) - expressionWithNullSafety(funcForCreatingDeserializer(casted, walkedTypePath), - nullable, walkedTypePath) + expressionWithNullSafety(funcForCreatingDeserializer(casted), nullable, walkedTypePath) } def expressionWithNullSafety( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index dccaf1c4835..827807055ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -218,9 +218,7 @@ object JavaTypeInference { // Assumes we are deserializing the first column of a row. deserializerForWithNullSafetyAndUpcast(GetColumnByOrdinal(0, dataType), dataType, - nullable = nullable, walkedTypePath, (casted, walkedTypePath) => { - deserializerFor(typeToken, casted, walkedTypePath) - }) + nullable = nullable, walkedTypePath, deserializerFor(typeToken, _, walkedTypePath)) } private def deserializerFor( @@ -280,7 +278,7 @@ object JavaTypeInference { dataType, nullable = elementNullable, newTypePath, - (casted, typePath) => deserializerFor(typeToken.getComponentType, casted, typePath)) + deserializerFor(typeToken.getComponentType, _, newTypePath)) } val arrayData = UnresolvedMapObjects(mapFunction, path) @@ -309,7 +307,7 @@ object JavaTypeInference { dataType, nullable = elementNullable, newTypePath, - (casted, typePath) => deserializerFor(et, casted, typePath)) + deserializerFor(et, _, newTypePath)) } UnresolvedMapObjects(mapFunction, path, customCollectionCls = Some(c)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 12093b9f4b2..d895a0fbe19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -172,109 +172,103 @@ object ScalaReflection extends ScalaReflection { val clsName = getClassNameFromType(tpe) val walkedTypePath = new WalkedTypePath().recordRoot(clsName) val Schema(dataType, nullable) = schemaFor(tpe) - + val deserializerFunc = deserializerFor(tpe, walkedTypePath) // Assumes we are deserializing the first column of a row. deserializerForWithNullSafetyAndUpcast(GetColumnByOrdinal(0, dataType), dataType, - nullable = nullable, walkedTypePath, - (casted, typePath) => deserializerFor(tpe, casted, typePath)) + nullable = nullable, walkedTypePath, deserializerFunc) } /** - * Returns an expression that can be used to deserialize an input expression to an object of type - * `T` with a compatible schema. + * Returns a function that receives an input expression and turns it to an expression that can be + * used to deserialize the input expression to an object of type `T` with a compatible schema. * * @param tpe The `Type` of deserialized object. - * @param path The expression which can be used to extract serialized value. * @param walkedTypePath The paths from top to bottom to access current field when deserializing. */ private def deserializerFor( tpe: `Type`, - path: Expression, - walkedTypePath: WalkedTypePath): Expression = cleanUpReflectionObjects { + walkedTypePath: WalkedTypePath): Expression => Expression = cleanUpReflectionObjects { baseType(tpe) match { - case t if !dataTypeFor(t).isInstanceOf[ObjectType] => path + case t if !dataTypeFor(t).isInstanceOf[ObjectType] => identity case t if isSubtype(t, localTypeOf[Option[_]]) => val TypeRef(_, _, Seq(optType)) = t val className = getClassNameFromType(optType) val newTypePath = walkedTypePath.recordOption(className) - WrapOption(deserializerFor(optType, path, newTypePath), dataTypeFor(optType)) + val dataType = dataTypeFor(optType) + val deserializerFunc = deserializerFor(optType, newTypePath) + path => WrapOption(deserializerFunc(path), dataType) case t if isSubtype(t, localTypeOf[java.lang.Integer]) => - createDeserializerForTypesSupportValueOf(path, - classOf[java.lang.Integer]) + createDeserializerForTypesSupportValueOf(_, classOf[java.lang.Integer]) case t if isSubtype(t, localTypeOf[java.lang.Long]) => - createDeserializerForTypesSupportValueOf(path, - classOf[java.lang.Long]) + createDeserializerForTypesSupportValueOf(_, classOf[java.lang.Long]) case t if isSubtype(t, localTypeOf[java.lang.Double]) => - createDeserializerForTypesSupportValueOf(path, - classOf[java.lang.Double]) + createDeserializerForTypesSupportValueOf(_, classOf[java.lang.Double]) case t if isSubtype(t, localTypeOf[java.lang.Float]) => - createDeserializerForTypesSupportValueOf(path, - classOf[java.lang.Float]) + createDeserializerForTypesSupportValueOf(_, classOf[java.lang.Float]) case t if isSubtype(t, localTypeOf[java.lang.Short]) => - createDeserializerForTypesSupportValueOf(path, - classOf[java.lang.Short]) + createDeserializerForTypesSupportValueOf(_, classOf[java.lang.Short]) case t if isSubtype(t, localTypeOf[java.lang.Byte]) => - createDeserializerForTypesSupportValueOf(path, - classOf[java.lang.Byte]) + createDeserializerForTypesSupportValueOf(_, classOf[java.lang.Byte]) case t if isSubtype(t, localTypeOf[java.lang.Boolean]) => - createDeserializerForTypesSupportValueOf(path, - classOf[java.lang.Boolean]) + createDeserializerForTypesSupportValueOf(_, classOf[java.lang.Boolean]) case t if isSubtype(t, localTypeOf[java.time.LocalDate]) => - createDeserializerForLocalDate(path) + createDeserializerForLocalDate case t if isSubtype(t, localTypeOf[java.sql.Date]) => - createDeserializerForSqlDate(path) + createDeserializerForSqlDate case t if isSubtype(t, localTypeOf[java.time.Instant]) => - createDeserializerForInstant(path) + createDeserializerForInstant case t if isSubtype(t, localTypeOf[java.lang.Enum[_]]) => - createDeserializerForTypesSupportValueOf( - Invoke(path, "toString", ObjectType(classOf[String]), returnNullable = false), - getClassFromType(t)) + // Code touching Scala Reflection should be called outside the returned function to allow + // caching the Scala Reflection result + val cls = getClassFromType(t) + path => createDeserializerForTypesSupportValueOf( + Invoke(path, "toString", ObjectType(classOf[String]), returnNullable = false), cls) case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) => - createDeserializerForSqlTimestamp(path) + createDeserializerForSqlTimestamp case t if isSubtype(t, localTypeOf[java.time.LocalDateTime]) => - createDeserializerForLocalDateTime(path) + createDeserializerForLocalDateTime case t if isSubtype(t, localTypeOf[java.time.Duration]) => - createDeserializerForDuration(path) + createDeserializerForDuration case t if isSubtype(t, localTypeOf[java.time.Period]) => - createDeserializerForPeriod(path) + createDeserializerForPeriod case t if isSubtype(t, localTypeOf[java.lang.String]) => - createDeserializerForString(path, returnNullable = false) + createDeserializerForString(_, returnNullable = false) case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) => - createDeserializerForJavaBigDecimal(path, returnNullable = false) + createDeserializerForJavaBigDecimal(_, returnNullable = false) case t if isSubtype(t, localTypeOf[BigDecimal]) => - createDeserializerForScalaBigDecimal(path, returnNullable = false) + createDeserializerForScalaBigDecimal(_, returnNullable = false) case t if isSubtype(t, localTypeOf[java.math.BigInteger]) => - createDeserializerForJavaBigInteger(path, returnNullable = false) + createDeserializerForJavaBigInteger(_, returnNullable = false) case t if isSubtype(t, localTypeOf[scala.math.BigInt]) => - createDeserializerForScalaBigInt(path) + createDeserializerForScalaBigInt case t if isSubtype(t, localTypeOf[Array[_]]) => val TypeRef(_, _, Seq(elementType)) = t val Schema(dataType, elementNullable) = schemaFor(elementType) val className = getClassNameFromType(elementType) val newTypePath = walkedTypePath.recordArray(className) - + val deserializerFunc = deserializerFor(elementType, newTypePath) val mapFunction: Expression => Expression = element => { // upcast the array element to the data type the encoder expected. deserializerForWithNullSafetyAndUpcast( @@ -282,10 +276,9 @@ object ScalaReflection extends ScalaReflection { dataType, nullable = elementNullable, newTypePath, - (casted, typePath) => deserializerFor(elementType, casted, typePath)) + deserializerFunc) } - val arrayData = UnresolvedMapObjects(mapFunction, path) val arrayCls = arrayClassFor(elementType) val methodName = elementType match { @@ -299,7 +292,10 @@ object ScalaReflection extends ScalaReflection { // non-primitive case _ => "array" } - Invoke(arrayData, methodName, arrayCls, returnNullable = false) + path => { + val arrayData = UnresolvedMapObjects(mapFunction, path) + Invoke(arrayData, methodName, arrayCls, returnNullable = false) + } // We serialize a `Set` to Catalyst array. When we deserialize a Catalyst array // to a `Set`, if there are duplicated elements, the elements will be de-duplicated. @@ -309,14 +305,14 @@ object ScalaReflection extends ScalaReflection { val Schema(dataType, elementNullable) = schemaFor(elementType) val className = getClassNameFromType(elementType) val newTypePath = walkedTypePath.recordArray(className) - + val deserializerFunc = deserializerFor(elementType, newTypePath) val mapFunction: Expression => Expression = element => { deserializerForWithNullSafetyAndUpcast( element, dataType, nullable = elementNullable, newTypePath, - (casted, typePath) => deserializerFor(elementType, casted, typePath)) + deserializerFunc) } val companion = t.dealias.typeSymbol.companion.typeSignature @@ -326,7 +322,7 @@ object ScalaReflection extends ScalaReflection { classOf[scala.collection.Set[_]] case _ => mirror.runtimeClass(t.typeSymbol.asClass) } - UnresolvedMapObjects(mapFunction, path, Some(cls)) + UnresolvedMapObjects(mapFunction, _, Some(cls)) case t if isSubtype(t, localTypeOf[Map[_, _]]) => val TypeRef(_, _, Seq(keyType, valueType)) = t @@ -336,12 +332,12 @@ object ScalaReflection extends ScalaReflection { val newTypePath = walkedTypePath.recordMap(classNameForKey, classNameForValue) - UnresolvedCatalystToExternalMap( - path, - p => deserializerFor(keyType, p, newTypePath), - p => deserializerFor(valueType, p, newTypePath), - mirror.runtimeClass(t.typeSymbol.asClass) - ) + // Code touching Scala Reflection should be called outside the returned function to allow + // caching the Scala Reflection result + val keyDeserializerFunc = deserializerFor(keyType, newTypePath) + val valueDeserializerFunc = deserializerFor(valueType, newTypePath) + val cls = mirror.runtimeClass(t.typeSymbol.asClass) + UnresolvedCatalystToExternalMap(_, keyDeserializerFunc, valueDeserializerFunc, cls) case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) => val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt(). @@ -350,7 +346,10 @@ object ScalaReflection extends ScalaReflection { udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), Nil, dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) - Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil) + // Code touching Scala Reflection should be called outside the returned function to allow + // caching the Scala Reflection result + val cls = udt.userClass + path => Invoke(obj, "deserialize", ObjectType(cls), Seq(path)) case t if UDTRegistration.exists(getClassNameFromType(t)) => val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.getConstructor(). @@ -359,43 +358,44 @@ object ScalaReflection extends ScalaReflection { udt.getClass, Nil, dataType = ObjectType(udt.getClass)) - Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil) + // Code touching Scala Reflection should be called outside the returned function to allow + // caching the Scala Reflection result + val cls = udt.userClass + path => Invoke(obj, "deserialize", ObjectType(cls), Seq(path)) case t if definedByConstructorParams(t) => val params = getConstructorParameters(t) val cls = getClassFromType(tpe) - val arguments = params.zipWithIndex.map { case ((fieldName, fieldType), i) => + val arguDeserializerFuncs = params.zipWithIndex.map { case ((fieldName, fieldType), i) => val Schema(dataType, nullable) = schemaFor(fieldType) val clsName = getClassNameFromType(fieldType) val newTypePath = walkedTypePath.recordField(clsName, fieldName) // For tuples, we based grab the inner fields by ordinal instead of name. - val newPath = if (cls.getName startsWith "scala.Tuple") { - deserializerFor( - fieldType, - addToPathOrdinal(path, i, dataType, newTypePath), - newTypePath) + val newPathFunc = if (cls.getName startsWith "scala.Tuple") { + addToPathOrdinal(_, i, dataType, newTypePath) } else { - deserializerFor( - fieldType, - addToPath(path, fieldName, dataType, newTypePath), - newTypePath) + addToPath(_, fieldName, dataType, newTypePath) } - expressionWithNullSafety( - newPath, + val deserializerFunc = deserializerFor(fieldType, newTypePath) + (path: Expression) => expressionWithNullSafety( + deserializerFunc(newPathFunc(path)), nullable = nullable, newTypePath) } - val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false) - - expressions.If( - IsNull(path), - expressions.Literal.create(null, ObjectType(cls)), - newInstance - ) + val nullLit = expressions.Literal.create(null, ObjectType(cls)) + path => { + val arguments = arguDeserializerFuncs.map(_(path)) + val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false) + expressions.If( + IsNull(path), + nullLit, + newInstance + ) + } case t if isSubtype(t, localTypeOf[Enumeration#Value]) => // package example @@ -406,10 +406,13 @@ object ScalaReflection extends ScalaReflection { // the fullName of tpe is example.Foo.Foo, but we need example.Foo so that // we can call example.Foo.withName to deserialize string to enumeration. val parent = t.asInstanceOf[TypeRef].pre.typeSymbol.asClass - val cls = mirror.runtimeClass(parent) - StaticInvoke( - cls, - ObjectType(getClassFromType(t)), + // Code touching Scala Reflection should be called outside the returned function to allow + // caching the Scala Reflection result + val parentCls = mirror.runtimeClass(parent) + val cls = getClassFromType(t) + path => StaticInvoke( + parentCls, + ObjectType(cls), "withName", createDeserializerForString(path, false) :: Nil, returnNullable = false) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org