This is an automated email from the ASF dual-hosted git repository. sunchao pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 127ccc208aa [SPARK-40295][SQL] Allow v2 functions with literal args in write distribution/ordering 127ccc208aa is described below commit 127ccc208aa8fd03f53dcb926087f1e72531bdbf Author: aokolnychyi <aokolnyc...@apple.com> AuthorDate: Wed Sep 7 09:15:56 2022 -0700 [SPARK-40295][SQL] Allow v2 functions with literal args in write distribution/ordering ### What changes were proposed in this pull request? This PR adapts `V2ExpressionUtils` to support arbitrary transforms with multiple args that are either references or literals. ### Why are the changes needed? After PR #36995, data sources can request distribution and ordering that reference v2 functions. If a data source needs a transform with multiple input args or a transform where not all args are references, Spark will throw an exception. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? This PR adapts the test added recently in PR #36995. Closes #37749 from aokolnychyi/spark-40295. Lead-authored-by: aokolnychyi <aokolnyc...@apple.com> Co-authored-by: Anton Okolnychyi <aokolnyc...@apple.com> Signed-off-by: Chao Sun <sunc...@apple.com> --- .../catalyst/expressions/V2ExpressionUtils.scala | 17 +++++------- .../sql/catalyst/plans/physical/partitioning.scala | 20 ++++++++++++++ .../sql/connector/catalog/InMemoryBaseTable.scala | 8 ++++++ .../datasources/v2/DataSourceV2ScanExecBase.scala | 17 ++++++++---- .../connector/KeyGroupedPartitioningSuite.scala | 29 ++++++++++++++++++-- .../WriteDistributionAndOrderingSuite.scala | 32 ++++++++++++++++------ .../catalog/functions/transformFunctions.scala | 19 +++++++++++++ 7 files changed, 117 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala index 64eb307bb9f..06ecf79c58c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.connector.catalog.{FunctionCatalog, Identifier} import org.apache.spark.sql.connector.catalog.functions._ import org.apache.spark.sql.connector.catalog.functions.ScalarFunction.MAGIC_METHOD_NAME -import org.apache.spark.sql.connector.expressions.{BucketTransform, Expression => V2Expression, FieldReference, IdentityTransform, NamedReference, NamedTransform, NullOrdering => V2NullOrdering, SortDirection => V2SortDirection, SortOrder => V2SortOrder, SortValue, Transform} +import org.apache.spark.sql.connector.expressions.{BucketTransform, Expression => V2Expression, FieldReference, IdentityTransform, Literal => V2Literal, NamedReference, NamedTransform, NullOrdering => V2NullOrdering, SortDirection => V2SortDirection, SortOrder => V2SortOrder, SortValue, Transform} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types._ @@ -75,6 +75,8 @@ object V2ExpressionUtils extends SQLConfHelper with Logging { query: LogicalPlan, funCatalogOpt: Option[FunctionCatalog] = None): Option[Expression] = { expr match { + case l: V2Literal[_] => + Some(Literal.create(l.value, l.dataType)) case t: Transform => toCatalystTransformOpt(t, query, funCatalogOpt) case SortValue(child, direction, nullOrdering) => @@ -105,18 +107,13 @@ object V2ExpressionUtils extends SQLConfHelper with Logging { TransformExpression(bound, resolvedRefs, Some(numBuckets)) } } - case NamedTransform(name, refs) - if refs.length == 1 && refs.forall(_.isInstanceOf[NamedReference]) => - val resolvedRefs = refs.map(_.asInstanceOf[NamedReference]).map { r => - resolveRef[NamedExpression](r, query) - } + case NamedTransform(name, args) => + val catalystArgs = args.map(toCatalyst(_, query, funCatalogOpt)) funCatalogOpt.flatMap { catalog => - loadV2FunctionOpt(catalog, name, resolvedRefs).map { bound => - TransformExpression(bound, resolvedRefs) + loadV2FunctionOpt(catalog, name, catalystArgs).map { bound => + TransformExpression(bound, catalystArgs) } } - case _ => - throw new AnalysisException(s"Transform $trans is not currently supported") } private def loadV2FunctionOpt( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 69eeab426ed..41de502e021 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.plans.physical +import scala.annotation.tailrec import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow @@ -361,6 +362,25 @@ object KeyGroupedPartitioning { partitionValues: Seq[InternalRow]): KeyGroupedPartitioning = { KeyGroupedPartitioning(expressions, partitionValues.size, Some(partitionValues)) } + + def supportsExpressions(expressions: Seq[Expression]): Boolean = { + def isSupportedTransform(transform: TransformExpression): Boolean = { + transform.children.size == 1 && isReference(transform.children.head) + } + + @tailrec + def isReference(e: Expression): Boolean = e match { + case _: Attribute => true + case g: GetStructField => isReference(g.child) + case _ => false + } + + expressions.forall { + case t: TransformExpression if isSupportedTransform(t) => true + case e: Expression if isReference(e) => true + case _ => false + } + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index f139399ed76..7da6c1480e0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -83,6 +83,7 @@ abstract class InMemoryBaseTable( case _: HoursTransform => case _: BucketTransform => case _: SortedBucketTransform => + case NamedTransform("truncate", Seq(_: NamedReference, _: Literal[_])) => case t if !allowUnsupportedTransforms => throw new IllegalArgumentException(s"Transform $t is not a supported transform") } @@ -177,6 +178,13 @@ abstract class InMemoryBaseTable( var dataTypeHashCode = 0 valueTypePairs.foreach(dataTypeHashCode += _._2.hashCode()) ((valueHashCode + 31 * dataTypeHashCode) & Integer.MAX_VALUE) % numBuckets + case NamedTransform("truncate", Seq(ref: NamedReference, length: Literal[_])) => + extractor(ref.fieldNames, cleanedSchema, row) match { + case (str: UTF8String, StringType) => + str.substring(0, length.value.asInstanceOf[Int]) + case (v, t) => + throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala index e6d7cddc71b..fa4ae171df5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala @@ -91,11 +91,18 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { } override def outputPartitioning: physical.Partitioning = { - if (partitions.length == 1) SinglePartition - else groupedPartitions.map { partitionValues => - KeyGroupedPartitioning(keyGroupedPartitioning.get, - partitionValues.size, Some(partitionValues.map(_._1))) - }.getOrElse(super.outputPartitioning) + if (partitions.length == 1) { + SinglePartition + } else { + keyGroupedPartitioning match { + case Some(exprs) if KeyGroupedPartitioning.supportsExpressions(exprs) => + groupedPartitions.map { partitionValues => + KeyGroupedPartitioning(exprs, partitionValues.size, Some(partitionValues.map(_._1))) + }.getOrElse(super.outputPartitioning) + case _ => + super.outputPartitioning + } + } } @transient lazy val groupedPartitions: Option[Seq[(InternalRow, Seq[InputPartition])]] = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index bdbf309214f..c0dc3263616 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -20,7 +20,7 @@ import java.util.Collections import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.TransformExpression +import org.apache.spark.sql.catalyst.expressions.{Literal, TransformExpression} import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.connector.catalog.InMemoryTableCatalog @@ -38,6 +38,12 @@ import org.apache.spark.sql.internal.SQLConf._ import org.apache.spark.sql.types._ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { + private val functions = Seq( + UnboundYearsFunction, + UnboundDaysFunction, + UnboundBucketFunction, + UnboundTruncateFunction) + private var originalV2BucketingEnabled: Boolean = false private var originalAutoBroadcastJoinThreshold: Long = -1 @@ -59,7 +65,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } before { - Seq(UnboundYearsFunction, UnboundDaysFunction, UnboundBucketFunction).foreach { f => + functions.foreach { f => catalog.createFunction(Identifier.of(Array.empty, f.name()), f) } } @@ -179,6 +185,25 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } + test("non-clustered distribution: V2 function with multiple args") { + val partitions: Array[Transform] = Array( + Expressions.apply("truncate", Expressions.column("data"), Expressions.literal(2)) + ) + + // create a table with 3 partitions, partitioned by `truncate` transform + createTable(table, schema, partitions) + sql(s"INSERT INTO testcat.ns.$table VALUES " + + s"(0, 'aaa', CAST('2022-01-01' AS timestamp)), " + + s"(1, 'bbb', CAST('2021-01-01' AS timestamp)), " + + s"(2, 'ccc', CAST('2020-01-01' AS timestamp))") + + val df = sql(s"SELECT * FROM testcat.ns.$table") + val distribution = physical.ClusteredDistribution( + Seq(TransformExpression(TruncateFunction, Seq(attr("data"), Literal(2))))) + + checkQueryPlan(df, distribution, physical.UnknownPartitioning(0)) + } + /** * Check whether the query plan from `df` has the expected `distribution`, `ordering` and * `partitioning`. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala index 7966add7738..b262e405d4e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{ApplyFunctionExpression, Cast, import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, RangePartitioning, UnknownPartitioning} import org.apache.spark.sql.connector.catalog.Identifier -import org.apache.spark.sql.connector.catalog.functions.{BucketFunction, StringSelfFunction, UnboundBucketFunction, UnboundStringSelfFunction} +import org.apache.spark.sql.connector.catalog.functions.{BucketFunction, StringSelfFunction, TruncateFunction, UnboundBucketFunction, UnboundStringSelfFunction, UnboundTruncateFunction} import org.apache.spark.sql.connector.distributions.{Distribution, Distributions} import org.apache.spark.sql.connector.expressions._ import org.apache.spark.sql.connector.expressions.LogicalExpressions._ @@ -45,7 +45,7 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase import testImplicits._ before { - Seq(UnboundBucketFunction, UnboundStringSelfFunction).foreach { f => + Seq(UnboundBucketFunction, UnboundStringSelfFunction, UnboundTruncateFunction).foreach { f => catalog.createFunction(Identifier.of(Array.empty, f.name()), f) } } @@ -1041,19 +1041,36 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase distributionStrictlyRequired: Boolean = true, dataSkewed: Boolean = false, coalesce: Boolean = false): Unit = { + + val stringSelfTransform = ApplyTransform( + "string_self", + Seq(FieldReference("data"))) + val truncateTransform = ApplyTransform( + "truncate", + Seq(stringSelfTransform, LiteralValue(2, IntegerType))) + val tableOrdering = Array[SortOrder]( - sort(FieldReference("data"), SortDirection.DESCENDING, NullOrdering.NULLS_FIRST), + sort( + stringSelfTransform, + SortDirection.DESCENDING, + NullOrdering.NULLS_FIRST), sort( BucketTransform(LiteralValue(10, IntegerType), Seq(FieldReference("id"))), SortDirection.DESCENDING, NullOrdering.NULLS_FIRST) ) - val tableDistribution = Distributions.clustered(Array( - ApplyTransform("string_self", Seq(FieldReference("data"))))) + val tableDistribution = Distributions.clustered(Array(truncateTransform)) + + val stringSelfExpr = ApplyFunctionExpression( + StringSelfFunction, + Seq(attr("data"))) + val truncateExpr = ApplyFunctionExpression( + TruncateFunction, + Seq(stringSelfExpr, Literal(2))) val writeOrdering = Seq( catalyst.expressions.SortOrder( - attr("data"), + stringSelfExpr, catalyst.expressions.Descending, catalyst.expressions.NullsFirst, Seq.empty @@ -1066,8 +1083,7 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase ) ) - val writePartitioningExprs = Seq( - ApplyFunctionExpression(StringSelfFunction, Seq(attr("data")))) + val writePartitioningExprs = Seq(truncateExpr) val writePartitioning = if (!coalesce) { clusteredWritePartitioning(writePartitioningExprs, targetNumPartitions) } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala index 9277e8d059f..6ea48aff2a2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala @@ -99,3 +99,22 @@ object StringSelfFunction extends ScalarFunction[UTF8String] { input.getUTF8String(0) } } + +object UnboundTruncateFunction extends UnboundFunction { + override def bind(inputType: StructType): BoundFunction = TruncateFunction + override def description(): String = name() + override def name(): String = "truncate" +} + +object TruncateFunction extends ScalarFunction[UTF8String] { + override def inputTypes(): Array[DataType] = Array(StringType, IntegerType) + override def resultType(): DataType = StringType + override def name(): String = "truncate" + override def canonicalName(): String = name() + override def toString: String = name() + override def produceResult(input: InternalRow): UTF8String = { + val str = input.getUTF8String(0) + val length = input.getInt(1) + str.substring(0, length) + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org