This is an automated email from the ASF dual-hosted git repository. gengliang pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 29eca8c [SPARK-38325][SQL] ANSI mode: avoid potential runtime error in HashJoin.extractKeyExprAt() 29eca8c is described below commit 29eca8c87f4e8c19c0380f7c30668fd88edee573 Author: Gengliang Wang <gengli...@apache.org> AuthorDate: Fri Feb 25 17:11:15 2022 +0800 [SPARK-38325][SQL] ANSI mode: avoid potential runtime error in HashJoin.extractKeyExprAt() ### What changes were proposed in this pull request? SubqueryBroadcastExec retrieves the partition key from the broadcast results based on the type of HashedRelation returned. If the key is packed inside a Long, we extract it through bitwise operations and cast it as Byte/Short/Int if necessary. The casting here can cause a potential runtime error. This PR is to fix it. ### Why are the changes needed? Bug fix ### Does this PR introduce _any_ user-facing change? Yes, avoid potential runtime error in dynamic pruning under ANSI mode ### How was this patch tested? UT Closes #35659 from gengliangwang/fixHashJoin. Authored-by: Gengliang Wang <gengli...@apache.org> Signed-off-by: Gengliang Wang <gengli...@apache.org> --- .../spark/sql/execution/joins/HashJoin.scala | 27 +++++++++++++++++----- .../sql/execution/joins/HashedRelationSuite.scala | 22 +++++++++++------- 2 files changed, 35 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 0e8bb84..4595ea0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -705,6 +705,13 @@ trait HashJoin extends JoinCodegenSupport { } object HashJoin extends CastSupport with SQLConfHelper { + + private def canRewriteAsLongType(keys: Seq[Expression]): Boolean = { + // TODO: support BooleanType, DateType and TimestampType + keys.forall(_.dataType.isInstanceOf[IntegralType]) && + keys.map(_.dataType.defaultSize).sum <= 8 + } + /** * Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long. * @@ -712,9 +719,7 @@ object HashJoin extends CastSupport with SQLConfHelper { */ def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = { assert(keys.nonEmpty) - // TODO: support BooleanType, DateType and TimestampType - if (keys.exists(!_.dataType.isInstanceOf[IntegralType]) - || keys.map(_.dataType.defaultSize).sum > 8) { + if (!canRewriteAsLongType(keys)) { return keys } @@ -736,18 +741,28 @@ object HashJoin extends CastSupport with SQLConfHelper { * determine the number of bits to shift */ def extractKeyExprAt(keys: Seq[Expression], index: Int): Expression = { + assert(canRewriteAsLongType(keys)) // jump over keys that have a higher index value than the required key if (keys.size == 1) { assert(index == 0) - cast(BoundReference(0, LongType, nullable = false), keys(index).dataType) + Cast( + child = BoundReference(0, LongType, nullable = false), + dataType = keys(index).dataType, + timeZoneId = Option(conf.sessionLocalTimeZone), + ansiEnabled = false) } else { val shiftedBits = keys.slice(index + 1, keys.size).map(_.dataType.defaultSize * 8).sum val mask = (1L << (keys(index).dataType.defaultSize * 8)) - 1 // build the schema for unpacking the required key - cast(BitwiseAnd( + val castChild = BitwiseAnd( ShiftRightUnsigned(BoundReference(0, LongType, nullable = false), Literal(shiftedBits)), - Literal(mask)), keys(index).dataType) + Literal(mask)) + Cast( + child = castChild, + dataType = keys(index).dataType, + timeZoneId = Option(conf.sessionLocalTimeZone), + ansiEnabled = false) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index b8ffc47..d5b7ed6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.memory.{TaskMemoryManager, UnifiedMemoryManager} import org.apache.spark.serializer.KryoSerializer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.unsafe.map.BytesToBytesMap @@ -610,14 +611,19 @@ class HashedRelationSuite extends SharedSparkSession { val keys = Seq(BoundReference(0, ByteType, false), BoundReference(1, IntegerType, false), BoundReference(2, ShortType, false)) - val packed = HashJoin.rewriteKeyExpr(keys) - val unsafeProj = UnsafeProjection.create(packed) - val packedKeys = unsafeProj(row) - - Seq((0, ByteType), (1, IntegerType), (2, ShortType)).foreach { case (i, dt) => - val key = HashJoin.extractKeyExprAt(keys, i) - val proj = UnsafeProjection.create(key) - assert(proj(packedKeys).get(0, dt) == -i - 1) + // Rewrite and exacting key expressions should not cause exception when ANSI mode is on. + Seq("false", "true").foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled) { + val packed = HashJoin.rewriteKeyExpr(keys) + val unsafeProj = UnsafeProjection.create(packed) + val packedKeys = unsafeProj(row) + + Seq((0, ByteType), (1, IntegerType), (2, ShortType)).foreach { case (i, dt) => + val key = HashJoin.extractKeyExprAt(keys, i) + val proj = UnsafeProjection.create(key) + assert(proj(packedKeys).get(0, dt) == -i - 1) + } + } } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org