This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new e43b9e8 [SPARK-36733][SQL] Fix a perf issue in SchemaPruning when a struct has many fields e43b9e8 is described below commit e43b9e8520bd4ea5bc3693beb496893b17e79054 Author: Kousuke Saruta <saru...@oss.nttdata.com> AuthorDate: Wed Sep 15 10:33:58 2021 +0900 [SPARK-36733][SQL] Fix a perf issue in SchemaPruning when a struct has many fields ### What changes were proposed in this pull request? This PR fixes a perf issue in `SchemaPruning` when a struct has many fields (e.g. >10K fields). The root cause is `SchemaPruning.sortLeftFieldsByRight` does N * M order searching. ``` val filteredRightFieldNames = rightStruct.fieldNames .filter(name => leftStruct.fieldNames.exists(resolver(_, name))) ``` To fix this issue, this PR proposes to use `HashMap` to expect a constant order searching. This PR also adds `case _ if left == right => left` to the method as a short-circuit code. ### Why are the changes needed? To fix a perf issue. ### Does this PR introduce _any_ user-facing change? No. The logic should be identical. ### How was this patch tested? I confirmed that the following micro benchmark finishes within a few seconds. ``` import org.apache.spark.sql.catalyst.expressions.SchemaPruning import org.apache.spark.sql.types._ var struct1 = new StructType() (1 to 50000).foreach { i => struct1 = struct1.add(new StructField(i + "", IntegerType)) } var struct2 = new StructType() (50001 to 100000).foreach { i => struct2 = struct2.add(new StructField(i + "", IntegerType)) } SchemaPruning.sortLeftFieldsByRight(struct1, struct2) SchemaPruning.sortLeftFieldsByRight(struct2, struct2) ``` The correctness should be checked by existing tests. Closes #33981 from sarutak/improve-schemapruning-performance. Authored-by: Kousuke Saruta <saru...@oss.nttdata.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../sql/catalyst/expressions/SchemaPruning.scala | 32 +++++++++++++++------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala index 9aa2766..2a182b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql.catalyst.expressions +import java.util.Locale + +import scala.collection.immutable.HashMap + import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.types._ @@ -54,6 +58,7 @@ object SchemaPruning extends SQLConfHelper { */ private def sortLeftFieldsByRight(left: DataType, right: DataType): DataType = (left, right) match { + case _ if left == right => left case (ArrayType(leftElementType, containsNull), ArrayType(rightElementType, _)) => ArrayType( sortLeftFieldsByRight(leftElementType, rightElementType), @@ -65,16 +70,23 @@ object SchemaPruning extends SQLConfHelper { sortLeftFieldsByRight(leftValueType, rightValueType), containsNull) case (leftStruct: StructType, rightStruct: StructType) => - val resolver = conf.resolver - val filteredRightFieldNames = rightStruct.fieldNames - .filter(name => leftStruct.fieldNames.exists(resolver(_, name))) - val sortedLeftFields = filteredRightFieldNames.map { fieldName => - val resolvedLeftStruct = leftStruct.find(p => resolver(p.name, fieldName)).get - val leftFieldType = resolvedLeftStruct.dataType - val rightFieldType = rightStruct(fieldName).dataType - val sortedLeftFieldType = sortLeftFieldsByRight(leftFieldType, rightFieldType) - StructField(fieldName, sortedLeftFieldType, nullable = resolvedLeftStruct.nullable, - metadata = resolvedLeftStruct.metadata) + val formatFieldName: String => String = + if (conf.caseSensitiveAnalysis) identity else _.toLowerCase(Locale.ROOT) + + val leftStructHashMap = + HashMap(leftStruct.map(f => formatFieldName(f.name)).zip(leftStruct): _*) + val sortedLeftFields = rightStruct.fieldNames.flatMap { fieldName => + val formattedFieldName = formatFieldName(fieldName) + if (leftStructHashMap.contains(formattedFieldName)) { + val resolvedLeftStruct = leftStructHashMap(formattedFieldName) + val leftFieldType = resolvedLeftStruct.dataType + val rightFieldType = rightStruct(fieldName).dataType + val sortedLeftFieldType = sortLeftFieldsByRight(leftFieldType, rightFieldType) + Some(StructField(fieldName, sortedLeftFieldType, nullable = resolvedLeftStruct.nullable, + metadata = resolvedLeftStruct.metadata)) + } else { + None + } } StructType(sortedLeftFields) case _ => left --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org