This is an automated email from the ASF dual-hosted git repository.

sunchao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 24bce72c9065 [SPARK-48012][SQL] SPJ: Support Transfrom Expressions for 
One Side Shuffle
24bce72c9065 is described below

commit 24bce72c9065336a962fe76feeb14fa2119ef961
Author: Szehon Ho <szehon.apa...@gmail.com>
AuthorDate: Sun Jun 9 10:22:21 2024 -0400

    [SPARK-48012][SQL] SPJ: Support Transfrom Expressions for One Side Shuffle
    
    ### Why are the changes needed?
    
    Support SPJ one-side shuffle if other side has partition transform 
expression
    
      ### How was this patch tested?
    
    New unit test in KeyGroupedPartitioningSuite
    
      ### Was this patch authored or co-authored using generative AI tooling?
    
     No.
    
    Closes #46255 from szehon-ho/spj_auto_bucket.
    
    Authored-by: Szehon Ho <szehon.apa...@gmail.com>
    Signed-off-by: Chao Sun <c...@openai.com>
---
 .../main/scala/org/apache/spark/Partitioner.scala  |   5 +-
 .../catalyst/expressions/TransformExpression.scala |  26 +++-
 .../sql/catalyst/plans/physical/partitioning.scala |  26 +++-
 .../connector/KeyGroupedPartitioningSuite.scala    | 136 ++++++++++++++++++---
 .../catalog/functions/transformFunctions.scala     |  12 +-
 5 files changed, 179 insertions(+), 26 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala 
b/core/src/main/scala/org/apache/spark/Partitioner.scala
index ae39e2e183e4..357e71cdf445 100644
--- a/core/src/main/scala/org/apache/spark/Partitioner.scala
+++ b/core/src/main/scala/org/apache/spark/Partitioner.scala
@@ -19,6 +19,7 @@ package org.apache.spark
 
 import java.io.{IOException, ObjectInputStream, ObjectOutputStream}
 
+import scala.collection.immutable.ArraySeq
 import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 import scala.math.log10
@@ -149,7 +150,9 @@ private[spark] class KeyGroupedPartitioner(
     override val numPartitions: Int) extends Partitioner {
   override def getPartition(key: Any): Int = {
     val keys = key.asInstanceOf[Seq[Any]]
-    valueMap.getOrElseUpdate(keys, Utils.nonNegativeMod(keys.hashCode, 
numPartitions))
+    val normalizedKeys = ArraySeq.from(keys)
+    valueMap.getOrElseUpdate(normalizedKeys,
+      Utils.nonNegativeMod(normalizedKeys.hashCode, numPartitions))
   }
 }
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala
index d37c9d9f6452..9041ed15fc50 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala
@@ -17,7 +17,10 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
-import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, 
Reducer, ReducibleFunction}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, 
ExprCode}
+import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, 
Reducer, ReducibleFunction, ScalarFunction}
+import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.types.DataType
 
 /**
@@ -30,7 +33,7 @@ import org.apache.spark.sql.types.DataType
 case class TransformExpression(
     function: BoundFunction,
     children: Seq[Expression],
-    numBucketsOpt: Option[Int] = None) extends Expression with Unevaluable {
+    numBucketsOpt: Option[Int] = None) extends Expression {
 
   override def nullable: Boolean = true
 
@@ -113,4 +116,23 @@ case class TransformExpression(
 
   override protected def withNewChildrenInternal(newChildren: 
IndexedSeq[Expression]): Expression =
     copy(children = newChildren)
+
+  private lazy val resolvedFunction: Option[Expression] = this match {
+    case TransformExpression(scalarFunc: ScalarFunction[_], arguments, 
Some(numBuckets)) =>
+      Some(V2ExpressionUtils.resolveScalarFunction(scalarFunc,
+        Seq(Literal(numBuckets)) ++ arguments))
+    case TransformExpression(scalarFunc: ScalarFunction[_], arguments, None) =>
+      Some(V2ExpressionUtils.resolveScalarFunction(scalarFunc, arguments))
+    case _ => None
+  }
+
+  override def eval(input: InternalRow): Any = {
+    resolvedFunction match {
+      case Some(fn) => fn.eval(input)
+      case None => throw 
QueryExecutionErrors.cannotEvaluateExpressionError(this)
+    }
+  }
+
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode =
+    throw QueryExecutionErrors.cannotGenerateCodeForExpressionError(this)
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index 43aba478c37b..19595eef10b3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -871,12 +871,30 @@ case class KeyGroupedShuffleSpec(
     if (results.forall(p => p.isEmpty)) None else Some(results)
   }
 
-  override def canCreatePartitioning: Boolean = 
SQLConf.get.v2BucketingShuffleEnabled &&
-    // Only support partition expressions are AttributeReference for now
-    partitioning.expressions.forall(_.isInstanceOf[AttributeReference])
+  override def canCreatePartitioning: Boolean = {
+    // Allow one side shuffle for SPJ for now only if partially-clustered is 
not enabled
+    // and for join keys less than partition keys only if transforms are not 
enabled.
+    val checkExprType = if 
(SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) {
+      e: Expression => e.isInstanceOf[AttributeReference]
+    } else {
+      e: Expression => e.isInstanceOf[AttributeReference] || 
e.isInstanceOf[TransformExpression]
+    }
+    SQLConf.get.v2BucketingShuffleEnabled &&
+      !SQLConf.get.v2BucketingPartiallyClusteredDistributionEnabled &&
+      partitioning.expressions.forall(checkExprType)
+  }
+
+
 
   override def createPartitioning(clustering: Seq[Expression]): Partitioning = 
{
-    KeyGroupedPartitioning(clustering, partitioning.numPartitions, 
partitioning.partitionValues)
+    val newExpressions: Seq[Expression] = 
clustering.zip(partitioning.expressions).map {
+      case (c, e: TransformExpression) => TransformExpression(
+        e.function, Seq(c), e.numBucketsOpt)
+      case (c, _) => c
+    }
+    KeyGroupedPartitioning(newExpressions,
+      partitioning.numPartitions,
+      partitioning.partitionValues)
   }
 }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
index 10a32441b6cd..a5de5bc1913b 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
@@ -1136,7 +1136,7 @@ class KeyGroupedPartitioningSuite extends 
DistributionAndOrderingSuiteBase {
         val df = createJoinTestDF(Seq("arrive_time" -> "time"))
         val shuffles = collectShuffles(df.queryExecution.executedPlan)
         if (shuffle) {
-          assert(shuffles.size == 2, "partitioning with transform not work 
now")
+          assert(shuffles.size == 1, "partitioning with transform should 
trigger SPJ")
         } else {
           assert(shuffles.size == 2, "should add two side shuffle when 
bucketing shuffle one side" +
             " is not enabled")
@@ -1991,22 +1991,19 @@ class KeyGroupedPartitioningSuite extends 
DistributionAndOrderingSuiteBase {
       "(6, 50.0, cast('2023-02-01' as timestamp))")
 
     Seq(true, false).foreach { pushdownValues =>
-      Seq(true, false).foreach { partiallyClustered =>
-        withSQLConf(
-          SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true",
-          SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> 
pushdownValues.toString,
-          SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key
-            -> partiallyClustered.toString,
-          SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> 
"true") {
-          val df = createJoinTestDF(Seq("id" -> "item_id"))
-          val shuffles = collectShuffles(df.queryExecution.executedPlan)
-          assert(shuffles.size == 1, "SPJ should be triggered")
-          checkAnswer(df, Seq(Row(1, "aa", 30.0, 42.0),
-            Row(1, "aa", 30.0, 89.0),
-            Row(1, "aa", 40.0, 42.0),
-            Row(1, "aa", 40.0, 89.0),
-            Row(3, "bb", 10.0, 19.5)))
-        }
+      withSQLConf(
+        SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true",
+        SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> 
pushdownValues.toString,
+        SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> 
"false",
+        SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> 
"true") {
+        val df = createJoinTestDF(Seq("id" -> "item_id"))
+        val shuffles = collectShuffles(df.queryExecution.executedPlan)
+        assert(shuffles.size == 1, "SPJ should be triggered")
+        checkAnswer(df, Seq(Row(1, "aa", 30.0, 42.0),
+          Row(1, "aa", 30.0, 89.0),
+          Row(1, "aa", 40.0, 42.0),
+          Row(1, "aa", 40.0, 89.0),
+          Row(3, "bb", 10.0, 19.5)))
       }
     }
   }
@@ -2052,4 +2049,109 @@ class KeyGroupedPartitioningSuite extends 
DistributionAndOrderingSuiteBase {
       }
     }
   }
+
+  test("SPARK-48012: one-side shuffle with partition transforms") {
+    val items_partitions = Array(bucket(2, "id"), identity("arrive_time"))
+    val items_partitions2 = Array(identity("arrive_time"), bucket(2, "id"))
+
+    Seq(items_partitions, items_partitions2).foreach { partition =>
+      catalog.clearTables()
+
+      createTable(items, itemsColumns, partition)
+      sql(s"INSERT INTO testcat.ns.$items VALUES " +
+        "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+        "(1, 'bb', 30.0, cast('2020-01-01' as timestamp)), " +
+        "(1, 'cc', 30.0, cast('2020-01-02' as timestamp)), " +
+        "(3, 'dd', 10.0, cast('2020-01-01' as timestamp)), " +
+        "(4, 'ee', 15.5, cast('2020-02-01' as timestamp)), " +
+        "(5, 'ff', 32.1, cast('2020-03-01' as timestamp))")
+
+      createTable(purchases, purchasesColumns, Array.empty)
+      sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
+        "(1, 42.0, cast('2020-01-01' as timestamp)), " +
+        "(2, 10.7, cast('2020-01-01' as timestamp))," +
+        "(3, 19.5, cast('2020-02-01' as timestamp))," +
+        "(4, 56.5, cast('2020-02-01' as timestamp))")
+
+      withSQLConf(
+        SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true") {
+        val df = createJoinTestDF(Seq("id" -> "item_id", "arrive_time" -> 
"time"))
+        val shuffles = collectShuffles(df.queryExecution.executedPlan)
+        assert(shuffles.size == 1, "only shuffle side that does not report 
partitioning")
+
+        checkAnswer(df, Seq(
+          Row(1, "bb", 30.0, 42.0),
+          Row(1, "aa", 40.0, 42.0),
+          Row(4, "ee", 15.5, 56.5)))
+      }
+    }
+  }
+
+  test("SPARK-48012: one-side shuffle with partition transforms and pushdown 
values") {
+    val items_partitions = Array(bucket(2, "id"), identity("arrive_time"))
+    createTable(items, itemsColumns, items_partitions)
+
+    sql(s"INSERT INTO testcat.ns.$items VALUES " +
+      "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+      "(1, 'bb', 30.0, cast('2020-01-01' as timestamp)), " +
+      "(1, 'cc', 30.0, cast('2020-01-02' as timestamp))")
+
+    createTable(purchases, purchasesColumns, Array.empty)
+    sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
+      "(1, 42.0, cast('2020-01-01' as timestamp)), " +
+      "(2, 10.7, cast('2020-01-01' as timestamp))")
+
+    Seq(true, false).foreach { pushDown => {
+        withSQLConf(
+          SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true",
+          SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key ->
+            pushDown.toString) {
+          val df = createJoinTestDF(Seq("id" -> "item_id", "arrive_time" -> 
"time"))
+          val shuffles = collectShuffles(df.queryExecution.executedPlan)
+          assert(shuffles.size == 1, "only shuffle side that does not report 
partitioning")
+
+          checkAnswer(df, Seq(
+            Row(1, "bb", 30.0, 42.0),
+            Row(1, "aa", 40.0, 42.0)))
+        }
+      }
+    }
+  }
+
+  test("SPARK-48012: one-side shuffle with partition transforms " +
+    "with fewer join keys than partition kes") {
+    val items_partitions = Array(bucket(2, "id"), identity("name"))
+    createTable(items, itemsColumns, items_partitions)
+
+    sql(s"INSERT INTO testcat.ns.$items VALUES " +
+      "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+      "(1, 'aa', 30.0, cast('2020-01-02' as timestamp)), " +
+      "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
+      "(4, 'cc', 15.5, cast('2020-02-01' as timestamp))")
+
+    createTable(purchases, purchasesColumns, Array.empty)
+    sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
+      "(1, 42.0, cast('2020-01-01' as timestamp)), " +
+      "(1, 89.0, cast('2020-01-03' as timestamp)), " +
+      "(3, 19.5, cast('2020-02-01' as timestamp)), " +
+      "(5, 26.0, cast('2023-01-01' as timestamp)), " +
+      "(6, 50.0, cast('2023-02-01' as timestamp))")
+
+   withSQLConf(
+     SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false",
+     SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true",
+     SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true",
+     SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> 
"false",
+     SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> 
"true") {
+     val df = createJoinTestDF(Seq("id" -> "item_id"))
+     val shuffles = collectShuffles(df.queryExecution.executedPlan)
+     assert(shuffles.size == 2, "SPJ should not be triggered for transform 
expression with" +
+       "less join keys than partition keys for now.")
+     checkAnswer(df, Seq(Row(1, "aa", 30.0, 42.0),
+       Row(1, "aa", 30.0, 89.0),
+       Row(1, "aa", 40.0, 42.0),
+       Row(1, "aa", 40.0, 89.0),
+       Row(3, "bb", 10.0, 19.5)))
+   }
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala
index 5cdb90090105..5364fc5d6242 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala
@@ -16,9 +16,11 @@
  */
 package org.apache.spark.sql.connector.catalog.functions
 
-import java.sql.Timestamp
+import java.time.{Instant, LocalDate, ZoneId}
+import java.time.temporal.ChronoUnit
 
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 
@@ -44,7 +46,13 @@ object YearsFunction extends ScalarFunction[Long] {
   override def name(): String = "years"
   override def canonicalName(): String = name()
 
-  def invoke(ts: Long): Long = new Timestamp(ts).getYear + 1900
+  val UTC: ZoneId = ZoneId.of("UTC")
+  val EPOCH_LOCAL_DATE: LocalDate = Instant.EPOCH.atZone(UTC).toLocalDate
+
+  def invoke(ts: Long): Long = {
+    val localDate = DateTimeUtils.microsToInstant(ts).atZone(UTC).toLocalDate
+    ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, localDate)
+  }
 }
 
 object DaysFunction extends BoundFunction {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to