This is an automated email from the ASF dual-hosted git repository. sunchao pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 24bce72c9065 [SPARK-48012][SQL] SPJ: Support Transfrom Expressions for One Side Shuffle 24bce72c9065 is described below commit 24bce72c9065336a962fe76feeb14fa2119ef961 Author: Szehon Ho <szehon.apa...@gmail.com> AuthorDate: Sun Jun 9 10:22:21 2024 -0400 [SPARK-48012][SQL] SPJ: Support Transfrom Expressions for One Side Shuffle ### Why are the changes needed? Support SPJ one-side shuffle if other side has partition transform expression ### How was this patch tested? New unit test in KeyGroupedPartitioningSuite ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46255 from szehon-ho/spj_auto_bucket. Authored-by: Szehon Ho <szehon.apa...@gmail.com> Signed-off-by: Chao Sun <c...@openai.com> --- .../main/scala/org/apache/spark/Partitioner.scala | 5 +- .../catalyst/expressions/TransformExpression.scala | 26 +++- .../sql/catalyst/plans/physical/partitioning.scala | 26 +++- .../connector/KeyGroupedPartitioningSuite.scala | 136 ++++++++++++++++++--- .../catalog/functions/transformFunctions.scala | 12 +- 5 files changed, 179 insertions(+), 26 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index ae39e2e183e4..357e71cdf445 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -19,6 +19,7 @@ package org.apache.spark import java.io.{IOException, ObjectInputStream, ObjectOutputStream} +import scala.collection.immutable.ArraySeq import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.math.log10 @@ -149,7 +150,9 @@ private[spark] class KeyGroupedPartitioner( override val numPartitions: Int) extends Partitioner { override def getPartition(key: Any): Int = { val keys = key.asInstanceOf[Seq[Any]] - valueMap.getOrElseUpdate(keys, Utils.nonNegativeMod(keys.hashCode, numPartitions)) + val normalizedKeys = ArraySeq.from(keys) + valueMap.getOrElseUpdate(normalizedKeys, + Utils.nonNegativeMod(normalizedKeys.hashCode, numPartitions)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala index d37c9d9f6452..9041ed15fc50 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, Reducer, ReducibleFunction} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, Reducer, ReducibleFunction, ScalarFunction} +import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types.DataType /** @@ -30,7 +33,7 @@ import org.apache.spark.sql.types.DataType case class TransformExpression( function: BoundFunction, children: Seq[Expression], - numBucketsOpt: Option[Int] = None) extends Expression with Unevaluable { + numBucketsOpt: Option[Int] = None) extends Expression { override def nullable: Boolean = true @@ -113,4 +116,23 @@ case class TransformExpression( override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(children = newChildren) + + private lazy val resolvedFunction: Option[Expression] = this match { + case TransformExpression(scalarFunc: ScalarFunction[_], arguments, Some(numBuckets)) => + Some(V2ExpressionUtils.resolveScalarFunction(scalarFunc, + Seq(Literal(numBuckets)) ++ arguments)) + case TransformExpression(scalarFunc: ScalarFunction[_], arguments, None) => + Some(V2ExpressionUtils.resolveScalarFunction(scalarFunc, arguments)) + case _ => None + } + + override def eval(input: InternalRow): Any = { + resolvedFunction match { + case Some(fn) => fn.eval(input) + case None => throw QueryExecutionErrors.cannotEvaluateExpressionError(this) + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + throw QueryExecutionErrors.cannotGenerateCodeForExpressionError(this) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 43aba478c37b..19595eef10b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -871,12 +871,30 @@ case class KeyGroupedShuffleSpec( if (results.forall(p => p.isEmpty)) None else Some(results) } - override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleEnabled && - // Only support partition expressions are AttributeReference for now - partitioning.expressions.forall(_.isInstanceOf[AttributeReference]) + override def canCreatePartitioning: Boolean = { + // Allow one side shuffle for SPJ for now only if partially-clustered is not enabled + // and for join keys less than partition keys only if transforms are not enabled. + val checkExprType = if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { + e: Expression => e.isInstanceOf[AttributeReference] + } else { + e: Expression => e.isInstanceOf[AttributeReference] || e.isInstanceOf[TransformExpression] + } + SQLConf.get.v2BucketingShuffleEnabled && + !SQLConf.get.v2BucketingPartiallyClusteredDistributionEnabled && + partitioning.expressions.forall(checkExprType) + } + + override def createPartitioning(clustering: Seq[Expression]): Partitioning = { - KeyGroupedPartitioning(clustering, partitioning.numPartitions, partitioning.partitionValues) + val newExpressions: Seq[Expression] = clustering.zip(partitioning.expressions).map { + case (c, e: TransformExpression) => TransformExpression( + e.function, Seq(c), e.numBucketsOpt) + case (c, _) => c + } + KeyGroupedPartitioning(newExpressions, + partitioning.numPartitions, + partitioning.partitionValues) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 10a32441b6cd..a5de5bc1913b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -1136,7 +1136,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val df = createJoinTestDF(Seq("arrive_time" -> "time")) val shuffles = collectShuffles(df.queryExecution.executedPlan) if (shuffle) { - assert(shuffles.size == 2, "partitioning with transform not work now") + assert(shuffles.size == 1, "partitioning with transform should trigger SPJ") } else { assert(shuffles.size == 2, "should add two side shuffle when bucketing shuffle one side" + " is not enabled") @@ -1991,22 +1991,19 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { "(6, 50.0, cast('2023-02-01' as timestamp))") Seq(true, false).foreach { pushdownValues => - Seq(true, false).foreach { partiallyClustered => - withSQLConf( - SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true", - SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushdownValues.toString, - SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key - -> partiallyClustered.toString, - SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true") { - val df = createJoinTestDF(Seq("id" -> "item_id")) - val shuffles = collectShuffles(df.queryExecution.executedPlan) - assert(shuffles.size == 1, "SPJ should be triggered") - checkAnswer(df, Seq(Row(1, "aa", 30.0, 42.0), - Row(1, "aa", 30.0, 89.0), - Row(1, "aa", 40.0, 42.0), - Row(1, "aa", 40.0, 89.0), - Row(3, "bb", 10.0, 19.5))) - } + withSQLConf( + SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushdownValues.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "false", + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true") { + val df = createJoinTestDF(Seq("id" -> "item_id")) + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.size == 1, "SPJ should be triggered") + checkAnswer(df, Seq(Row(1, "aa", 30.0, 42.0), + Row(1, "aa", 30.0, 89.0), + Row(1, "aa", 40.0, 42.0), + Row(1, "aa", 40.0, 89.0), + Row(3, "bb", 10.0, 19.5))) } } } @@ -2052,4 +2049,109 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } } + + test("SPARK-48012: one-side shuffle with partition transforms") { + val items_partitions = Array(bucket(2, "id"), identity("arrive_time")) + val items_partitions2 = Array(identity("arrive_time"), bucket(2, "id")) + + Seq(items_partitions, items_partitions2).foreach { partition => + catalog.clearTables() + + createTable(items, itemsColumns, partition) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(1, 'bb', 30.0, cast('2020-01-01' as timestamp)), " + + "(1, 'cc', 30.0, cast('2020-01-02' as timestamp)), " + + "(3, 'dd', 10.0, cast('2020-01-01' as timestamp)), " + + "(4, 'ee', 15.5, cast('2020-02-01' as timestamp)), " + + "(5, 'ff', 32.1, cast('2020-03-01' as timestamp))") + + createTable(purchases, purchasesColumns, Array.empty) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(2, 10.7, cast('2020-01-01' as timestamp))," + + "(3, 19.5, cast('2020-02-01' as timestamp))," + + "(4, 56.5, cast('2020-02-01' as timestamp))") + + withSQLConf( + SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true") { + val df = createJoinTestDF(Seq("id" -> "item_id", "arrive_time" -> "time")) + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.size == 1, "only shuffle side that does not report partitioning") + + checkAnswer(df, Seq( + Row(1, "bb", 30.0, 42.0), + Row(1, "aa", 40.0, 42.0), + Row(4, "ee", 15.5, 56.5))) + } + } + } + + test("SPARK-48012: one-side shuffle with partition transforms and pushdown values") { + val items_partitions = Array(bucket(2, "id"), identity("arrive_time")) + createTable(items, itemsColumns, items_partitions) + + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(1, 'bb', 30.0, cast('2020-01-01' as timestamp)), " + + "(1, 'cc', 30.0, cast('2020-01-02' as timestamp))") + + createTable(purchases, purchasesColumns, Array.empty) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(2, 10.7, cast('2020-01-01' as timestamp))") + + Seq(true, false).foreach { pushDown => { + withSQLConf( + SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> + pushDown.toString) { + val df = createJoinTestDF(Seq("id" -> "item_id", "arrive_time" -> "time")) + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.size == 1, "only shuffle side that does not report partitioning") + + checkAnswer(df, Seq( + Row(1, "bb", 30.0, 42.0), + Row(1, "aa", 40.0, 42.0))) + } + } + } + } + + test("SPARK-48012: one-side shuffle with partition transforms " + + "with fewer join keys than partition kes") { + val items_partitions = Array(bucket(2, "id"), identity("name")) + createTable(items, itemsColumns, items_partitions) + + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(1, 'aa', 30.0, cast('2020-01-02' as timestamp)), " + + "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(4, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + createTable(purchases, purchasesColumns, Array.empty) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(1, 89.0, cast('2020-01-03' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp)), " + + "(5, 26.0, cast('2023-01-01' as timestamp)), " + + "(6, 50.0, cast('2023-02-01' as timestamp))") + + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "false", + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true") { + val df = createJoinTestDF(Seq("id" -> "item_id")) + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.size == 2, "SPJ should not be triggered for transform expression with" + + "less join keys than partition keys for now.") + checkAnswer(df, Seq(Row(1, "aa", 30.0, 42.0), + Row(1, "aa", 30.0, 89.0), + Row(1, "aa", 40.0, 42.0), + Row(1, "aa", 40.0, 89.0), + Row(3, "bb", 10.0, 19.5))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala index 5cdb90090105..5364fc5d6242 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala @@ -16,9 +16,11 @@ */ package org.apache.spark.sql.connector.catalog.functions -import java.sql.Timestamp +import java.time.{Instant, LocalDate, ZoneId} +import java.time.temporal.ChronoUnit import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -44,7 +46,13 @@ object YearsFunction extends ScalarFunction[Long] { override def name(): String = "years" override def canonicalName(): String = name() - def invoke(ts: Long): Long = new Timestamp(ts).getYear + 1900 + val UTC: ZoneId = ZoneId.of("UTC") + val EPOCH_LOCAL_DATE: LocalDate = Instant.EPOCH.atZone(UTC).toLocalDate + + def invoke(ts: Long): Long = { + val localDate = DateTimeUtils.microsToInstant(ts).atZone(UTC).toLocalDate + ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, localDate) + } } object DaysFunction extends BoundFunction { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org