This is an automated email from the ASF dual-hosted git repository.

sunchao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 127ccc208aa [SPARK-40295][SQL] Allow v2 functions with literal args in 
write distribution/ordering
127ccc208aa is described below

commit 127ccc208aa8fd03f53dcb926087f1e72531bdbf
Author: aokolnychyi <aokolnyc...@apple.com>
AuthorDate: Wed Sep 7 09:15:56 2022 -0700

    [SPARK-40295][SQL] Allow v2 functions with literal args in write 
distribution/ordering
    
    ### What changes were proposed in this pull request?
    
    This PR adapts `V2ExpressionUtils` to support arbitrary transforms with 
multiple args that are either references or literals.
    
    ### Why are the changes needed?
    
    After PR #36995, data sources can request distribution and ordering that 
reference v2 functions. If a data source needs a transform with multiple input 
args or a transform where not all args are references, Spark will throw an 
exception.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    This PR adapts the test added recently in PR #36995.
    
    Closes #37749 from aokolnychyi/spark-40295.
    
    Lead-authored-by: aokolnychyi <aokolnyc...@apple.com>
    Co-authored-by: Anton Okolnychyi <aokolnyc...@apple.com>
    Signed-off-by: Chao Sun <sunc...@apple.com>
---
 .../catalyst/expressions/V2ExpressionUtils.scala   | 17 +++++-------
 .../sql/catalyst/plans/physical/partitioning.scala | 20 ++++++++++++++
 .../sql/connector/catalog/InMemoryBaseTable.scala  |  8 ++++++
 .../datasources/v2/DataSourceV2ScanExecBase.scala  | 17 ++++++++----
 .../connector/KeyGroupedPartitioningSuite.scala    | 29 ++++++++++++++++++--
 .../WriteDistributionAndOrderingSuite.scala        | 32 ++++++++++++++++------
 .../catalog/functions/transformFunctions.scala     | 19 +++++++++++++
 7 files changed, 117 insertions(+), 25 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala
index 64eb307bb9f..06ecf79c58c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.connector.catalog.{FunctionCatalog, Identifier}
 import org.apache.spark.sql.connector.catalog.functions._
 import 
org.apache.spark.sql.connector.catalog.functions.ScalarFunction.MAGIC_METHOD_NAME
-import org.apache.spark.sql.connector.expressions.{BucketTransform, Expression 
=> V2Expression, FieldReference, IdentityTransform, NamedReference, 
NamedTransform, NullOrdering => V2NullOrdering, SortDirection => 
V2SortDirection, SortOrder => V2SortOrder, SortValue, Transform}
+import org.apache.spark.sql.connector.expressions.{BucketTransform, Expression 
=> V2Expression, FieldReference, IdentityTransform, Literal => V2Literal, 
NamedReference, NamedTransform, NullOrdering => V2NullOrdering, SortDirection 
=> V2SortDirection, SortOrder => V2SortOrder, SortValue, Transform}
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.types._
 
@@ -75,6 +75,8 @@ object V2ExpressionUtils extends SQLConfHelper with Logging {
       query: LogicalPlan,
       funCatalogOpt: Option[FunctionCatalog] = None): Option[Expression] = {
     expr match {
+      case l: V2Literal[_] =>
+        Some(Literal.create(l.value, l.dataType))
       case t: Transform =>
         toCatalystTransformOpt(t, query, funCatalogOpt)
       case SortValue(child, direction, nullOrdering) =>
@@ -105,18 +107,13 @@ object V2ExpressionUtils extends SQLConfHelper with 
Logging {
           TransformExpression(bound, resolvedRefs, Some(numBuckets))
         }
       }
-    case NamedTransform(name, refs)
-        if refs.length == 1 && refs.forall(_.isInstanceOf[NamedReference]) =>
-      val resolvedRefs = refs.map(_.asInstanceOf[NamedReference]).map { r =>
-        resolveRef[NamedExpression](r, query)
-      }
+    case NamedTransform(name, args) =>
+      val catalystArgs = args.map(toCatalyst(_, query, funCatalogOpt))
       funCatalogOpt.flatMap { catalog =>
-        loadV2FunctionOpt(catalog, name, resolvedRefs).map { bound =>
-          TransformExpression(bound, resolvedRefs)
+        loadV2FunctionOpt(catalog, name, catalystArgs).map { bound =>
+          TransformExpression(bound, catalystArgs)
         }
       }
-    case _ =>
-      throw new AnalysisException(s"Transform $trans is not currently 
supported")
   }
 
   private def loadV2FunctionOpt(
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index 69eeab426ed..41de502e021 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.catalyst.plans.physical
 
+import scala.annotation.tailrec
 import scala.collection.mutable
 
 import org.apache.spark.sql.catalyst.InternalRow
@@ -361,6 +362,25 @@ object KeyGroupedPartitioning {
       partitionValues: Seq[InternalRow]): KeyGroupedPartitioning = {
     KeyGroupedPartitioning(expressions, partitionValues.size, 
Some(partitionValues))
   }
+
+  def supportsExpressions(expressions: Seq[Expression]): Boolean = {
+    def isSupportedTransform(transform: TransformExpression): Boolean = {
+      transform.children.size == 1 && isReference(transform.children.head)
+    }
+
+    @tailrec
+    def isReference(e: Expression): Boolean = e match {
+      case _: Attribute => true
+      case g: GetStructField => isReference(g.child)
+      case _ => false
+    }
+
+    expressions.forall {
+      case t: TransformExpression if isSupportedTransform(t) => true
+      case e: Expression if isReference(e) => true
+      case _ => false
+    }
+  }
 }
 
 /**
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
index f139399ed76..7da6c1480e0 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
@@ -83,6 +83,7 @@ abstract class InMemoryBaseTable(
     case _: HoursTransform =>
     case _: BucketTransform =>
     case _: SortedBucketTransform =>
+    case NamedTransform("truncate", Seq(_: NamedReference, _: Literal[_])) =>
     case t if !allowUnsupportedTransforms =>
       throw new IllegalArgumentException(s"Transform $t is not a supported 
transform")
   }
@@ -177,6 +178,13 @@ abstract class InMemoryBaseTable(
         var dataTypeHashCode = 0
         valueTypePairs.foreach(dataTypeHashCode += _._2.hashCode())
         ((valueHashCode + 31 * dataTypeHashCode) & Integer.MAX_VALUE) % 
numBuckets
+      case NamedTransform("truncate", Seq(ref: NamedReference, length: 
Literal[_])) =>
+        extractor(ref.fieldNames, cleanedSchema, row) match {
+          case (str: UTF8String, StringType) =>
+            str.substring(0, length.value.asInstanceOf[Int])
+          case (v, t) =>
+            throw new IllegalArgumentException(s"Match: unsupported 
argument(s) type - ($v, $t)")
+        }
     }
   }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala
index e6d7cddc71b..fa4ae171df5 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala
@@ -91,11 +91,18 @@ trait DataSourceV2ScanExecBase extends LeafExecNode {
   }
 
   override def outputPartitioning: physical.Partitioning = {
-    if (partitions.length == 1) SinglePartition
-    else groupedPartitions.map { partitionValues =>
-      KeyGroupedPartitioning(keyGroupedPartitioning.get,
-        partitionValues.size, Some(partitionValues.map(_._1)))
-    }.getOrElse(super.outputPartitioning)
+    if (partitions.length == 1) {
+      SinglePartition
+    } else {
+      keyGroupedPartitioning match {
+        case Some(exprs) if KeyGroupedPartitioning.supportsExpressions(exprs) 
=>
+          groupedPartitions.map { partitionValues =>
+            KeyGroupedPartitioning(exprs, partitionValues.size, 
Some(partitionValues.map(_._1)))
+          }.getOrElse(super.outputPartitioning)
+        case _ =>
+          super.outputPartitioning
+      }
+    }
   }
 
   @transient lazy val groupedPartitions: Option[Seq[(InternalRow, 
Seq[InputPartition])]] =
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
index bdbf309214f..c0dc3263616 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
@@ -20,7 +20,7 @@ import java.util.Collections
 
 import org.apache.spark.sql.{DataFrame, Row}
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.TransformExpression
+import org.apache.spark.sql.catalyst.expressions.{Literal, TransformExpression}
 import org.apache.spark.sql.catalyst.plans.physical
 import org.apache.spark.sql.connector.catalog.Identifier
 import org.apache.spark.sql.connector.catalog.InMemoryTableCatalog
@@ -38,6 +38,12 @@ import org.apache.spark.sql.internal.SQLConf._
 import org.apache.spark.sql.types._
 
 class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
+  private val functions = Seq(
+    UnboundYearsFunction,
+    UnboundDaysFunction,
+    UnboundBucketFunction,
+    UnboundTruncateFunction)
+
   private var originalV2BucketingEnabled: Boolean = false
   private var originalAutoBroadcastJoinThreshold: Long = -1
 
@@ -59,7 +65,7 @@ class KeyGroupedPartitioningSuite extends 
DistributionAndOrderingSuiteBase {
   }
 
   before {
-    Seq(UnboundYearsFunction, UnboundDaysFunction, 
UnboundBucketFunction).foreach { f =>
+    functions.foreach { f =>
       catalog.createFunction(Identifier.of(Array.empty, f.name()), f)
     }
   }
@@ -179,6 +185,25 @@ class KeyGroupedPartitioningSuite extends 
DistributionAndOrderingSuiteBase {
     }
   }
 
+  test("non-clustered distribution: V2 function with multiple args") {
+    val partitions: Array[Transform] = Array(
+      Expressions.apply("truncate", Expressions.column("data"), 
Expressions.literal(2))
+    )
+
+    // create a table with 3 partitions, partitioned by `truncate` transform
+    createTable(table, schema, partitions)
+    sql(s"INSERT INTO testcat.ns.$table VALUES " +
+      s"(0, 'aaa', CAST('2022-01-01' AS timestamp)), " +
+      s"(1, 'bbb', CAST('2021-01-01' AS timestamp)), " +
+      s"(2, 'ccc', CAST('2020-01-01' AS timestamp))")
+
+    val df = sql(s"SELECT * FROM testcat.ns.$table")
+    val distribution = physical.ClusteredDistribution(
+      Seq(TransformExpression(TruncateFunction, Seq(attr("data"), 
Literal(2)))))
+
+    checkQueryPlan(df, distribution, physical.UnknownPartitioning(0))
+  }
+
   /**
    * Check whether the query plan from `df` has the expected `distribution`, 
`ordering` and
    * `partitioning`.
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala
index 7966add7738..b262e405d4e 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala
@@ -24,7 +24,7 @@ import 
org.apache.spark.sql.catalyst.expressions.{ApplyFunctionExpression, Cast,
 import org.apache.spark.sql.catalyst.plans.physical
 import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, 
RangePartitioning, UnknownPartitioning}
 import org.apache.spark.sql.connector.catalog.Identifier
-import org.apache.spark.sql.connector.catalog.functions.{BucketFunction, 
StringSelfFunction, UnboundBucketFunction, UnboundStringSelfFunction}
+import org.apache.spark.sql.connector.catalog.functions.{BucketFunction, 
StringSelfFunction, TruncateFunction, UnboundBucketFunction, 
UnboundStringSelfFunction, UnboundTruncateFunction}
 import org.apache.spark.sql.connector.distributions.{Distribution, 
Distributions}
 import org.apache.spark.sql.connector.expressions._
 import org.apache.spark.sql.connector.expressions.LogicalExpressions._
@@ -45,7 +45,7 @@ class WriteDistributionAndOrderingSuite extends 
DistributionAndOrderingSuiteBase
   import testImplicits._
 
   before {
-    Seq(UnboundBucketFunction, UnboundStringSelfFunction).foreach { f =>
+    Seq(UnboundBucketFunction, UnboundStringSelfFunction, 
UnboundTruncateFunction).foreach { f =>
       catalog.createFunction(Identifier.of(Array.empty, f.name()), f)
     }
   }
@@ -1041,19 +1041,36 @@ class WriteDistributionAndOrderingSuite extends 
DistributionAndOrderingSuiteBase
       distributionStrictlyRequired: Boolean = true,
       dataSkewed: Boolean = false,
       coalesce: Boolean = false): Unit = {
+
+    val stringSelfTransform = ApplyTransform(
+      "string_self",
+      Seq(FieldReference("data")))
+    val truncateTransform = ApplyTransform(
+      "truncate",
+      Seq(stringSelfTransform, LiteralValue(2, IntegerType)))
+
     val tableOrdering = Array[SortOrder](
-      sort(FieldReference("data"), SortDirection.DESCENDING, 
NullOrdering.NULLS_FIRST),
+      sort(
+        stringSelfTransform,
+        SortDirection.DESCENDING,
+        NullOrdering.NULLS_FIRST),
       sort(
         BucketTransform(LiteralValue(10, IntegerType), 
Seq(FieldReference("id"))),
         SortDirection.DESCENDING,
         NullOrdering.NULLS_FIRST)
     )
-    val tableDistribution = Distributions.clustered(Array(
-      ApplyTransform("string_self", Seq(FieldReference("data")))))
+    val tableDistribution = Distributions.clustered(Array(truncateTransform))
+
+    val stringSelfExpr = ApplyFunctionExpression(
+      StringSelfFunction,
+      Seq(attr("data")))
+    val truncateExpr = ApplyFunctionExpression(
+      TruncateFunction,
+      Seq(stringSelfExpr, Literal(2)))
 
     val writeOrdering = Seq(
       catalyst.expressions.SortOrder(
-        attr("data"),
+        stringSelfExpr,
         catalyst.expressions.Descending,
         catalyst.expressions.NullsFirst,
         Seq.empty
@@ -1066,8 +1083,7 @@ class WriteDistributionAndOrderingSuite extends 
DistributionAndOrderingSuiteBase
       )
     )
 
-    val writePartitioningExprs = Seq(
-      ApplyFunctionExpression(StringSelfFunction, Seq(attr("data"))))
+    val writePartitioningExprs = Seq(truncateExpr)
     val writePartitioning = if (!coalesce) {
       clusteredWritePartitioning(writePartitioningExprs, targetNumPartitions)
     } else {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala
index 9277e8d059f..6ea48aff2a2 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala
@@ -99,3 +99,22 @@ object StringSelfFunction extends ScalarFunction[UTF8String] 
{
     input.getUTF8String(0)
   }
 }
+
+object UnboundTruncateFunction extends UnboundFunction {
+  override def bind(inputType: StructType): BoundFunction = TruncateFunction
+  override def description(): String = name()
+  override def name(): String = "truncate"
+}
+
+object TruncateFunction extends ScalarFunction[UTF8String] {
+  override def inputTypes(): Array[DataType] = Array(StringType, IntegerType)
+  override def resultType(): DataType = StringType
+  override def name(): String = "truncate"
+  override def canonicalName(): String = name()
+  override def toString: String = name()
+  override def produceResult(input: InternalRow): UTF8String = {
+    val str = input.getUTF8String(0)
+    val length = input.getInt(1)
+    str.substring(0, length)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to