This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new d1d97604fbe [SPARK-44180][SQL] DistributionAndOrderingUtils should apply ResolveTimeZone d1d97604fbe is described below commit d1d97604fbec2fccedbfa52b02eb1f3428b00d9f Author: Cheng Pan <cheng...@apache.org> AuthorDate: Fri Jul 14 09:25:07 2023 +0800 [SPARK-44180][SQL] DistributionAndOrderingUtils should apply ResolveTimeZone ### What changes were proposed in this pull request? Apply `ResolveTimeZone` for the plan generated by `DistributionAndOrderingUtils#prepareQuery`. ### Why are the changes needed? In SPARK-39607, we only applied `typeCoercionRules` for the plan generated by `DistributionAndOrderingUtils#prepareQuery`, this is not enough, the following exception will be thrown if `TimeZoneAwareExpression` participates in the implicit cast. ``` 23/06/25 07:30:58 WARN UnsafeProjection: Expr codegen error and falling back to interpreter mode java.util.NoSuchElementException: None.get at scala.None$.get(Option.scala:529) at scala.None$.get(Option.scala:527) at org.apache.spark.sql.catalyst.expressions.TimeZoneAwareExpression.zoneId(datetimeExpressions.scala:63) at org.apache.spark.sql.catalyst.expressions.TimeZoneAwareExpression.zoneId$(datetimeExpressions.scala:63) at org.apache.spark.sql.catalyst.expressions.Cast.zoneId$lzycompute(Cast.scala:491) at org.apache.spark.sql.catalyst.expressions.Cast.zoneId(Cast.scala:491) at org.apache.spark.sql.catalyst.expressions.Cast.castToDateCode(Cast.scala:1655) at org.apache.spark.sql.catalyst.expressions.Cast.nullSafeCastFunction(Cast.scala:1335) at org.apache.spark.sql.catalyst.expressions.Cast.doGenCode(Cast.scala:1316) at org.apache.spark.sql.catalyst.expressions.Expression.$anonfun$genCode$3(Expression.scala:200) at scala.Option.getOrElse(Option.scala:189) at org.apache.spark.sql.catalyst.expressions.Expression.genCode(Expression.scala:195) at org.apache.spark.sql.catalyst.expressions.Cast.genCode(Cast.scala:1310) at org.apache.spark.sql.catalyst.expressions.objects.InvokeLike.$anonfun$prepareArguments$3(objects.scala:124) at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:286) at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62) at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55) at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49) at scala.collection.TraversableLike.map(TraversableLike.scala:286) at scala.collection.TraversableLike.map$(TraversableLike.scala:279) at scala.collection.AbstractTraversable.map(Traversable.scala:108) at org.apache.spark.sql.catalyst.expressions.objects.InvokeLike.prepareArguments(objects.scala:123) at org.apache.spark.sql.catalyst.expressions.objects.InvokeLike.prepareArguments$(objects.scala:91) at org.apache.spark.sql.catalyst.expressions.objects.Invoke.prepareArguments(objects.scala:363) at org.apache.spark.sql.catalyst.expressions.objects.Invoke.doGenCode(objects.scala:414) at org.apache.spark.sql.catalyst.expressions.Expression.$anonfun$genCode$3(Expression.scala:200) at scala.Option.getOrElse(Option.scala:189) at org.apache.spark.sql.catalyst.expressions.Expression.genCode(Expression.scala:195) at org.apache.spark.sql.catalyst.expressions.objects.InvokeLike.$anonfun$prepareArguments$3(objects.scala:124) at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:286) at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62) at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55) at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49) at scala.collection.TraversableLike.map(TraversableLike.scala:286) at scala.collection.TraversableLike.map$(TraversableLike.scala:279) at scala.collection.AbstractTraversable.map(Traversable.scala:108) at org.apache.spark.sql.catalyst.expressions.objects.InvokeLike.prepareArguments(objects.scala:123) at org.apache.spark.sql.catalyst.expressions.objects.InvokeLike.prepareArguments$(objects.scala:91) at org.apache.spark.sql.catalyst.expressions.objects.Invoke.prepareArguments(objects.scala:363) at org.apache.spark.sql.catalyst.expressions.objects.Invoke.doGenCode(objects.scala:414) at org.apache.spark.sql.catalyst.expressions.Expression.$anonfun$genCode$3(Expression.scala:200) at scala.Option.getOrElse(Option.scala:189) at org.apache.spark.sql.catalyst.expressions.Expression.genCode(Expression.scala:195) at org.apache.spark.sql.catalyst.expressions.HashExpression.$anonfun$doGenCode$5(hash.scala:304) at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:286) at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62) at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55) at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49) at scala.collection.TraversableLike.map(TraversableLike.scala:286) at scala.collection.TraversableLike.map$(TraversableLike.scala:279) at scala.collection.AbstractTraversable.map(Traversable.scala:108) at org.apache.spark.sql.catalyst.expressions.HashExpression.doGenCode(hash.scala:303) at org.apache.spark.sql.catalyst.expressions.Expression.$anonfun$genCode$3(Expression.scala:200) at scala.Option.getOrElse(Option.scala:189) at org.apache.spark.sql.catalyst.expressions.Expression.genCode(Expression.scala:195) at org.apache.spark.sql.catalyst.expressions.Pmod.doGenCode(arithmetic.scala:1068) at org.apache.spark.sql.catalyst.expressions.Expression.$anonfun$genCode$3(Expression.scala:200) at scala.Option.getOrElse(Option.scala:189) at org.apache.spark.sql.catalyst.expressions.Expression.genCode(Expression.scala:195) at org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext.$anonfun$generateExpressions$2(CodeGenerator.scala:1278) at scala.collection.immutable.List.map(List.scala:293) at org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext.generateExpressions(CodeGenerator.scala:1278) at org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection$.createCode(GenerateUnsafeProjection.scala:290) at org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection$.create(GenerateUnsafeProjection.scala:338) at org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection$.generate(GenerateUnsafeProjection.scala:327) at org.apache.spark.sql.catalyst.expressions.UnsafeProjection$.createCodeGeneratedObject(Projection.scala:124) at org.apache.spark.sql.catalyst.expressions.UnsafeProjection$.createCodeGeneratedObject(Projection.scala:120) at org.apache.spark.sql.catalyst.expressions.CodeGeneratorWithInterpretedFallback.createObject(CodeGeneratorWithInterpretedFallback.scala:51) at org.apache.spark.sql.catalyst.expressions.UnsafeProjection$.create(Projection.scala:151) at org.apache.spark.sql.catalyst.expressions.UnsafeProjection$.create(Projection.scala:161) at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$.getPartitionKeyExtractor$1(ShuffleExchangeExec.scala:316) at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$.$anonfun$prepareShuffleDependency$13(ShuffleExchangeExec.scala:384) at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$.$anonfun$prepareShuffleDependency$13$adapted(ShuffleExchangeExec.scala:383) at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsWithIndexInternal$2(RDD.scala:875) at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsWithIndexInternal$2$adapted(RDD.scala:875) at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364) at org.apache.spark.rdd.RDD.iterator(RDD.scala:328) at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59) at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:101) at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53) at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:161) at org.apache.spark.scheduler.Task.run(Task.scala:139) at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:554) at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1529) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:557) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) at java.lang.Thread.run(Thread.java:750) ``` ### Does this PR introduce _any_ user-facing change? Yes, it's a bug fix. ### How was this patch tested? New tests are added. Closes #41725 from pan3793/SPARK-44180. Authored-by: Cheng Pan <cheng...@apache.org> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../v2/DistributionAndOrderingUtils.scala | 14 ++-- .../WriteDistributionAndOrderingSuite.scala | 90 +++++++++++++++++----- .../catalog/functions/transformFunctions.scala | 6 +- 3 files changed, 83 insertions(+), 27 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala index 9b1155ef698..36ee01e1c1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.sql.catalyst.analysis.{AnsiTypeCoercion, TypeCoercion} +import org.apache.spark.sql.catalyst.analysis.{AnsiTypeCoercion, ResolveTimeZone, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, SortOrder, TransformExpression, V2ExpressionUtils} import org.apache.spark.sql.catalyst.expressions.V2ExpressionUtils._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, RebalancePartitions, RepartitionByExpression, Sort} -import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} import org.apache.spark.sql.connector.catalog.FunctionCatalog import org.apache.spark.sql.connector.catalog.functions.ScalarFunction import org.apache.spark.sql.connector.distributions._ @@ -83,13 +83,17 @@ object DistributionAndOrderingUtils { queryWithDistribution } - // Apply typeCoercionRules since the converted expression from TransformExpression - // implemented ImplicitCastInputTypes - typeCoercionRules.foldLeft(queryWithDistributionAndOrdering)((plan, rule) => rule(plan)) + TypeCoercionExecutor.execute(queryWithDistributionAndOrdering) case _ => query } + private object TypeCoercionExecutor extends RuleExecutor[LogicalPlan] { + override val batches = + Batch("Resolve TypeCoercion", FixedPoint(1), typeCoercionRules: _*) :: + Batch("Resolve TimeZone", FixedPoint(1), ResolveTimeZone) :: Nil + } + private def resolveTransformExpression(expr: Expression): Expression = expr.transform { case TransformExpression(scalarFunc: ScalarFunction[_], arguments, Some(numBuckets)) => V2ExpressionUtils.resolveScalarFunction(scalarFunc, Seq(Literal(numBuckets)) ++ arguments) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala index 881e077514f..6cab0e0239d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala @@ -17,14 +17,17 @@ package org.apache.spark.sql.connector + +import java.sql.Date import java.util.Collections import org.apache.spark.sql.{catalyst, AnalysisException, DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.{ApplyFunctionExpression, Cast, Literal} +import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, RangePartitioning, UnknownPartitioning} import org.apache.spark.sql.connector.catalog.Identifier -import org.apache.spark.sql.connector.catalog.functions.{BucketFunction, StringSelfFunction, TruncateFunction, UnboundBucketFunction, UnboundStringSelfFunction, UnboundTruncateFunction} +import org.apache.spark.sql.connector.catalog.functions._ import org.apache.spark.sql.connector.distributions.{Distribution, Distributions} import org.apache.spark.sql.connector.expressions._ import org.apache.spark.sql.connector.expressions.LogicalExpressions._ @@ -37,8 +40,7 @@ import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream import org.apache.spark.sql.functions.lit import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.{StreamingQueryException, Trigger} -import org.apache.spark.sql.test.SQLTestData.TestData -import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} +import org.apache.spark.sql.types.{DateType, IntegerType, LongType, ObjectType, StringType, StructType, TimestampType} import org.apache.spark.sql.util.QueryExecutionListener import org.apache.spark.tags.SlowSQLTest @@ -47,7 +49,11 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase import testImplicits._ before { - Seq(UnboundBucketFunction, UnboundStringSelfFunction, UnboundTruncateFunction).foreach { f => + Seq( + UnboundYearsFunction, + UnboundBucketFunction, + UnboundStringSelfFunction, + UnboundTruncateFunction).foreach { f => catalog.createFunction(Identifier.of(Array.empty, f.name()), f) } } @@ -66,6 +72,7 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase private val schema = new StructType() .add("id", IntegerType) .add("data", StringType) + .add("day", DateType) test("ordered distribution and sort with same exprs: append") { checkOrderedDistributionAndSortWithSameExprsInVariousCases("append") @@ -985,8 +992,8 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase catalog.createTable(ident, schema, Array.empty, emptyProps, distribution, ordering, None, None) withTempDir { checkpointDir => - val inputData = ContinuousMemoryStream[(Long, String)] - val inputDF = inputData.toDF().toDF("id", "data") + val inputData = ContinuousMemoryStream[(Long, String, Date)] + val inputDF = inputData.toDF().toDF("id", "data", "day") val writer = inputDF .writeStream @@ -997,7 +1004,9 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase val analysisException = intercept[AnalysisException] { val query = writer.toTable(tableNameAsString) - inputData.addData((1, "a"), (2, "b")) + inputData.addData( + (1, "a", Date.valueOf("2021-01-01")), + (2, "b", Date.valueOf("2022-02-02"))) query.processAllAvailable() query.stop() @@ -1011,8 +1020,8 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase catalog.createTable(ident, schema, Array.empty[Transform], emptyProps) withTempDir { checkpointDir => - val inputData = ContinuousMemoryStream[(Long, String)] - val inputDF = inputData.toDF().toDF("id", "data") + val inputData = ContinuousMemoryStream[(Long, String, Date)] + val inputDF = inputData.toDF().toDF("id", "data", "day") val writer = inputDF .writeStream @@ -1022,12 +1031,17 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase val query = writer.toTable(tableNameAsString) - inputData.addData((1, "a"), (2, "b")) + inputData.addData( + (1, "a", Date.valueOf("2021-01-01")), + (2, "b", Date.valueOf("2022-02-02"))) query.processAllAvailable() query.stop() - checkAnswer(spark.table(tableNameAsString), Row(1, "a") :: Row(2, "b") :: Nil) + checkAnswer( + spark.table(tableNameAsString), + Row(1, "a", Date.valueOf("2021-01-01")) :: + Row(2, "b", Date.valueOf("2022-02-02")) :: Nil) } } @@ -1085,6 +1099,9 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase val truncateTransform = ApplyTransform( "truncate", Seq(stringSelfTransform, LiteralValue(2, IntegerType))) + val yearsTransform = ApplyTransform( + "years", + Seq(FieldReference("day"))) val tableOrdering = Array[SortOrder]( sort( @@ -1094,6 +1111,10 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase sort( BucketTransform(LiteralValue(10, IntegerType), Seq(FieldReference("id"))), SortDirection.DESCENDING, + NullOrdering.NULLS_FIRST), + sort( + yearsTransform, + SortDirection.DESCENDING, NullOrdering.NULLS_FIRST) ) val tableDistribution = Distributions.clustered(Array(truncateTransform)) @@ -1117,6 +1138,18 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase catalyst.expressions.Descending, catalyst.expressions.NullsFirst, Seq.empty + ), + catalyst.expressions.SortOrder( + Invoke( + Literal.create(YearsFunction, ObjectType(YearsFunction.getClass)), + "invoke", + LongType, + Seq(Cast(attr("day"), TimestampType, Some("America/Los_Angeles"))), + Seq(TimestampType), + propagateNull = false), + catalyst.expressions.Descending, + catalyst.expressions.NullsFirst, + Seq.empty ) ) @@ -1204,11 +1237,17 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase distributionStrictlyRequired) val df = if (!dataSkewed) { - spark.createDataFrame(Seq((1, "a"), (2, "b"), (3, "c"))).toDF("id", "data") + spark.createDataFrame(Seq( + (1, "a", Date.valueOf("2021-01-01")), + (2, "b", Date.valueOf("2022-02-02")), + (3, "c", Date.valueOf("2023-03-03"))) + ).toDF("id", "data", "day") } else { spark.sparkContext.parallelize( - (1 to 10).map(i => TestData(if (i > 4) 5 else i, i.toString)), 3) - .toDF("id", "data") + (1 to 10).map { + i => (if (i > 4) 5 else i, i.toString, Date.valueOf(s"${2020 + i}-$i-$i")) + }, 3) + .toDF("id", "data", "day") } val writer = writeTransform(df).writeTo(tableNameAsString) @@ -1300,8 +1339,8 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase tableOrdering, tableNumPartitions, tablePartitionSize) withTempDir { checkpointDir => - val inputData = MemoryStream[(Long, String)] - val inputDF = inputData.toDF().toDF("id", "data") + val inputData = MemoryStream[(Long, String, Date)] + val inputDF = inputData.toDF().toDF("id", "data", "day") val queryDF = outputMode match { case "append" | "update" => @@ -1310,8 +1349,11 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase // add an aggregate for complete mode inputDF .groupBy("id") - .agg(Map("data" -> "count")) - .select($"id", $"count(data)".cast("string").as("data")) + .agg(Map("data" -> "count", "day" -> "max")) + .select( + $"id", + $"count(data)".cast("string").as("data"), + $"max(day)".cast("date").as("day")) } val writer = writeTransform(queryDF) @@ -1322,7 +1364,9 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase def executeCommand(): SparkPlan = execute { val query = writer.toTable(tableNameAsString) - inputData.addData((1, "a"), (2, "b")) + inputData.addData( + (1, "a", Date.valueOf("2021-01-01")), + (2, "b", Date.valueOf("2022-02-02"))) query.processAllAvailable() query.stop() @@ -1346,8 +1390,12 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase maxNumShuffles = if (outputMode != "complete") 1 else 2) val expectedRows = outputMode match { - case "append" | "update" => Row(1, "a") :: Row(2, "b") :: Nil - case "complete" => Row(1, "1") :: Row(2, "1") :: Nil + case "append" | "update" => + Row(1, "a", Date.valueOf("2021-01-01")) :: + Row(2, "b", Date.valueOf("2022-02-02")) :: Nil + case "complete" => + Row(1, "1", Date.valueOf("2021-01-01")) :: + Row(2, "1", Date.valueOf("2022-02-02")) :: Nil } checkAnswer(spark.table(tableNameAsString), expectedRows) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala index 6ea48aff2a2..61895d49c4a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.connector.catalog.functions +import java.sql.Timestamp + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -36,11 +38,13 @@ object UnboundYearsFunction extends UnboundFunction { override def name(): String = "years" } -object YearsFunction extends BoundFunction { +object YearsFunction extends ScalarFunction[Long] { override def inputTypes(): Array[DataType] = Array(TimestampType) override def resultType(): DataType = LongType override def name(): String = "years" override def canonicalName(): String = name() + + def invoke(ts: Long): Long = new Timestamp(ts).getYear + 1900 } object DaysFunction extends BoundFunction { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org