[ https://issues.apache.org/jira/browse/SPARK-49261?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=17875657#comment-17875657 ]
Bruce Robbins commented on SPARK-49261: --------------------------------------- {quote}It seems to be a correlation between F.lit(6).alias("run_number") and F.round(F.col("total_amount") / 1000, 6). If both lit and scale in round are set to the same number i.e. 6 code fails. {quote} That's a good summary of the issue. The bug seems to be [here|https://github.com/apache/spark/blob/a885365897acefcf353206aaabd0048e088cc9a7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala#L409]. That code will replace foldable and non-foldable expressions with expressions from the group by attributes, but I think it should only replace non-foldable expressions. In the case of the round function, that code is patching the second parameter, which requires a foldable expression, with a non-foldable expression. As a result, {{RoundBase#checkInputDataTypes}} fails. > Correlation between lit and round during grouping > ------------------------------------------------- > > Key: SPARK-49261 > URL: https://issues.apache.org/jira/browse/SPARK-49261 > Project: Spark > Issue Type: Bug > Components: PySpark > Affects Versions: 3.5.0 > Environment: Databricks DBR 14.3 > Spark 3.5.0 > Scala 2.12 > Reporter: Krystian Kulig > Priority: Major > Fix For: 3.5.0 > > > Running following code: > > {code:java} > import pyspark.sql.functions as F > from decimal import Decimal > data = [ > (1, 100, Decimal("1.1"), "L", True), > (2, 200, Decimal("1.2"), "H", False), > (2, 300, Decimal("2.345"), "E", False), > ] > columns = ["group_a", "id", "amount", "selector_a", "selector_b"] > df = spark.createDataFrame(data, schema=columns) > df_final = ( > df.select( > F.lit(6).alias("run_number"), > F.lit("AA").alias("run_type"), > F.col("group_a"), > F.col("id"), > F.col("amount"), > F.col("selector_a"), > F.col("selector_b"), > ) > .withColumn( > "amount_c", > F.when( > (F.col("selector_b") == False) > & (F.col("selector_a").isin(["L", "H", "E"])), > F.col("amount"), > ).otherwise(F.lit(None)) > ) > .withColumn( > "count_of_amount_c", > F.when( > (F.col("selector_b") == False) > & (F.col("selector_a").isin(["L", "H", "E"])), > F.col("id") > ).otherwise(F.lit(None)) > ) > ) > group_by_cols = [ > "run_number", > "group_a", > "run_type" > ] > df_final = df_final.groupBy(group_by_cols).agg( > F.countDistinct("id").alias("count_of_amount"), > F.round(F.sum("amount")/ 1000, 1).alias("total_amount"), > F.sum("amount_c").alias("amount_c"), > F.countDistinct("count_of_amount_c").alias( > "count_of_amount_c" > ), > ) > df_final = ( > df_final > .withColumn( > "total_amount", > F.round(F.col("total_amount") / 1000, 6), > ) > .withColumn( > "count_of_amount", F.col("count_of_amount").cast("int") > ) > .withColumn( > "count_of_amount_c", > F.when( > F.col("amount_c").isNull(), F.lit(None).cast("int") > ).otherwise(F.col("count_of_amount_c").cast("int")), > ) > ) > df_final = df_final.select( > F.col("total_amount"), > "run_number", > "group_a", > "run_type", > "count_of_amount", > "amount_c", > "count_of_amount_c", > ) > df_final.show() {code} > Produces error: > {code:java} > [[INTERNAL_ERROR](https://docs.microsoft.com/azure/databricks/error-messages/error-classes#internal_error)] > Couldn't find total_amount#1046 in > [group_a#984L,count_of_amount#1054,amount_c#1033,count_of_amount_c#1034L] > SQLSTATE: XX000 {code} > With stack trace: > {code:java} > org.apache.spark.SparkException: [INTERNAL_ERROR] Couldn't find > total_amount#1046 in > [group_a#984L,count_of_amount#1054,amount_c#1033,count_of_amount_c#1034L] > SQLSTATE: XX000 at > org.apache.spark.SparkException$.internalError(SparkException.scala:97) at > org.apache.spark.SparkException$.internalError(SparkException.scala:101) at > org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:81) > at > org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:74) > at > org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$1(TreeNode.scala:505) > at > org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(origin.scala:83) > at > org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:505) > at > org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:481) > at > org.apache.spark.sql.catalyst.trees.TreeNode.transform(TreeNode.scala:449) at > org.apache.spark.sql.catalyst.expressions.BindReferences$.bindReference(BoundAttribute.scala:74) > at > org.apache.spark.sql.catalyst.expressions.BindReferences$.$anonfun$bindReferences$1(BoundAttribute.scala:97) > at > scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:286) at > scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62) at > scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55) at > scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49) at > scala.collection.TraversableLike.map(TraversableLike.scala:286) at > scala.collection.TraversableLike.map$(TraversableLike.scala:279) at > scala.collection.AbstractTraversable.map(Traversable.scala:108) at > org.apache.spark.sql.catalyst.expressions.BindReferences$.bindReferences(BoundAttribute.scala:97) > at > org.apache.spark.sql.execution.ProjectExec.doConsume(basicPhysicalOperators.scala:74) > at > org.apache.spark.sql.execution.CodegenSupport.consume(WholeStageCodegenExec.scala:202) > at > org.apache.spark.sql.execution.CodegenSupport.consume$(WholeStageCodegenExec.scala:155) > at > org.apache.spark.sql.execution.aggregate.HashAggregateExec.consume(HashAggregateExec.scala:51) > at > org.apache.spark.sql.execution.aggregate.HashAggregateExec.generateResultFunction(HashAggregateExec.scala:411) > at > org.apache.spark.sql.execution.aggregate.HashAggregateExec.doConsumeWithKeys(HashAggregateExec.scala:995) > at > org.apache.spark.sql.execution.aggregate.AggregateCodegenSupport.doConsume(AggregateCodegenSupport.scala:81) > at > org.apache.spark.sql.execution.aggregate.AggregateCodegenSupport.doConsume$(AggregateCodegenSupport.scala:77) > at > org.apache.spark.sql.execution.aggregate.HashAggregateExec.doConsume(HashAggregateExec.scala:51) > at > org.apache.spark.sql.execution.CodegenSupport.constructDoConsumeFunction(WholeStageCodegenExec.scala:229) > at > org.apache.spark.sql.execution.CodegenSupport.consume(WholeStageCodegenExec.scala:200) > at > org.apache.spark.sql.execution.CodegenSupport.consume$(WholeStageCodegenExec.scala:155) > at > org.apache.spark.sql.execution.InputAdapter.consume(WholeStageCodegenExec.scala:506) > at > org.apache.spark.sql.execution.InputRDDCodegen.doProduce(WholeStageCodegenExec.scala:493) > at > org.apache.spark.sql.execution.InputRDDCodegen.doProduce$(WholeStageCodegenExec.scala:466) > at > org.apache.spark.sql.execution.InputAdapter.doProduce(WholeStageCodegenExec.scala:506) > at > org.apache.spark.sql.execution.CodegenSupport.$anonfun$produce$1(WholeStageCodegenExec.scala:100) > at > org.apache.spark.sql.execution.SparkPlan$.org$apache$spark$sql$execution$SparkPlan$$withExecuteQueryLogging(SparkPlan.scala:130) > at > org.apache.spark.sql.execution.SparkPlan.$anonfun$executeQuery$1(SparkPlan.scala:385) > at > org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:165) > at > org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:381) at > org.apache.spark.sql.execution.CodegenSupport.produce(WholeStageCodegenExec.scala:95) > at > org.apache.spark.sql.execution.CodegenSupport.produce$(WholeStageCodegenExec.scala:94) > at > org.apache.spark.sql.execution.InputAdapter.produce(WholeStageCodegenExec.scala:506) > at > org.apache.spark.sql.execution.aggregate.HashAggregateExec.doProduceWithKeys(HashAggregateExec.scala:629) > at > org.apache.spark.sql.execution.aggregate.AggregateCodegenSupport.doProduce(AggregateCodegenSupport.scala:73) > at > org.apache.spark.sql.execution.aggregate.AggregateCodegenSupport.doProduce$(AggregateCodegenSupport.scala:69) > at > org.apache.spark.sql.execution.aggregate.HashAggregateExec.doProduce(HashAggregateExec.scala:51) > at > org.apache.spark.sql.execution.CodegenSupport.$anonfun$produce$1(WholeStageCodegenExec.scala:100) > at > org.apache.spark.sql.execution.SparkPlan$.org$apache$spark$sql$execution$SparkPlan$$withExecuteQueryLogging(SparkPlan.scala:130) > at > org.apache.spark.sql.execution.SparkPlan.$anonfun$executeQuery$1(SparkPlan.scala:385) > at > org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:165) > at > org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:381) at > org.apache.spark.sql.execution.CodegenSupport.produce(WholeStageCodegenExec.scala:95) > at > org.apache.spark.sql.execution.CodegenSupport.produce$(WholeStageCodegenExec.scala:94) > at > org.apache.spark.sql.execution.aggregate.HashAggregateExec.produce(HashAggregateExec.scala:51) > at > org.apache.spark.sql.execution.ProjectExec.doProduce(basicPhysicalOperators.scala:59) > at > org.apache.spark.sql.execution.CodegenSupport.$anonfun$produce$1(WholeStageCodegenExec.scala:100) > at > org.apache.spark.sql.execution.SparkPlan$.org$apache$spark$sql$execution$SparkPlan$$withExecuteQueryLogging(SparkPlan.scala:130) > at > org.apache.spark.sql.execution.SparkPlan.$anonfun$executeQuery$1(SparkPlan.scala:385) > at > org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:165) > at > org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:381) at > org.apache.spark.sql.execution.CodegenSupport.produce(WholeStageCodegenExec.scala:95) > at > org.apache.spark.sql.execution.CodegenSupport.produce$(WholeStageCodegenExec.scala:94) > at > org.apache.spark.sql.execution.ProjectExec.produce(basicPhysicalOperators.scala:46) > at > org.apache.spark.sql.execution.WholeStageCodegenExec.doCodeGen(WholeStageCodegenExec.scala:666) > at > org.apache.spark.sql.execution.WholeStageCodegenExec.doExecute(WholeStageCodegenExec.scala:729) > at > org.apache.spark.sql.execution.SparkPlan.$anonfun$execute$2(SparkPlan.scala:327) > at com.databricks.spark.util.FrameProfiler$.record(FrameProfiler.scala:94) > at > org.apache.spark.sql.execution.SparkPlan.$anonfun$execute$1(SparkPlan.scala:327) > at > org.apache.spark.sql.execution.SparkPlan$.org$apache$spark$sql$execution$SparkPlan$$withExecuteQueryLogging(SparkPlan.scala:130) > at > org.apache.spark.sql.execution.SparkPlan.$anonfun$executeQuery$1(SparkPlan.scala:385) > at > org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:165) > at > org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:381) at > org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:322) at > org.apache.spark.sql.execution.collect.Collector$.collect(Collector.scala:117) > at > org.apache.spark.sql.execution.collect.Collector$.collect(Collector.scala:131) > at > org.apache.spark.sql.execution.qrc.InternalRowFormat$.collect(cachedSparkResults.scala:94) > at > org.apache.spark.sql.execution.qrc.InternalRowFormat$.collect(cachedSparkResults.scala:90) > at > org.apache.spark.sql.execution.qrc.InternalRowFormat$.collect(cachedSparkResults.scala:78) > at > org.apache.spark.sql.execution.qrc.ResultCacheManager.$anonfun$computeResult$1(ResultCacheManager.scala:549) > at com.databricks.spark.util.FrameProfiler$.record(FrameProfiler.scala:94) > at > org.apache.spark.sql.execution.qrc.ResultCacheManager.collectResult$1(ResultCacheManager.scala:540) > at > org.apache.spark.sql.execution.qrc.ResultCacheManager.$anonfun$computeResult$2(ResultCacheManager.scala:555) > at > org.apache.spark.sql.execution.adaptive.ResultQueryStageExec.$anonfun$doMaterialize$1(QueryStageExec.scala:663) > at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:1175) at > org.apache.spark.sql.execution.SQLExecution$.$anonfun$withThreadLocalCaptured$6(SQLExecution.scala:778) > at > com.databricks.util.LexicalThreadLocal$Handle.runWith(LexicalThreadLocal.scala:63) > at > org.apache.spark.sql.execution.SQLExecution$.$anonfun$withThreadLocalCaptured$5(SQLExecution.scala:778) > at > com.databricks.util.LexicalThreadLocal$Handle.runWith(LexicalThreadLocal.scala:63) > at > org.apache.spark.sql.execution.SQLExecution$.$anonfun$withThreadLocalCaptured$4(SQLExecution.scala:778) > at scala.util.DynamicVariable.withValue(DynamicVariable.scala:62) at > org.apache.spark.sql.execution.SQLExecution$.$anonfun$withThreadLocalCaptured$3(SQLExecution.scala:777) > at scala.util.DynamicVariable.withValue(DynamicVariable.scala:62) at > org.apache.spark.sql.execution.SQLExecution$.$anonfun$withThreadLocalCaptured$2(SQLExecution.scala:776) > at > org.apache.spark.sql.execution.SQLExecution$.withOptimisticTransaction(SQLExecution.scala:798) > at > org.apache.spark.sql.execution.SQLExecution$.$anonfun$withThreadLocalCaptured$1(SQLExecution.scala:775) > at > java.util.concurrent.CompletableFuture$AsyncSupply.run(CompletableFuture.java:1604) > at > org.apache.spark.util.threads.SparkThreadLocalCapturingRunnable.$anonfun$run$1(SparkThreadLocalForwardingThreadPoolExecutor.scala:134) > at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23) at > com.databricks.spark.util.IdentityClaim$.withClaim(IdentityClaim.scala:48) at > org.apache.spark.util.threads.SparkThreadLocalCapturingHelper.$anonfun$runWithCaptured$4(SparkThreadLocalForwardingThreadPoolExecutor.scala:91) > at > com.databricks.unity.UCSEphemeralState$Handle.runWith(UCSEphemeralState.scala:45) > at > org.apache.spark.util.threads.SparkThreadLocalCapturingHelper.runWithCaptured(SparkThreadLocalForwardingThreadPoolExecutor.scala:90) > at > org.apache.spark.util.threads.SparkThreadLocalCapturingHelper.runWithCaptured$(SparkThreadLocalForwardingThreadPoolExecutor.scala:67) > at > org.apache.spark.util.threads.SparkThreadLocalCapturingRunnable.runWithCaptured(SparkThreadLocalForwardingThreadPoolExecutor.scala:131) > at > org.apache.spark.util.threads.SparkThreadLocalCapturingRunnable.run(SparkThreadLocalForwardingThreadPoolExecutor.scala:134) > at > java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) > at > java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) > at java.lang.Thread.run(Thread.java:750) > {code} > > It seems to be a correlation between *F.lit(6).alias("run_number")* and > {*}F.round(F.col("total_amount") / 1000, 6){*}. If both *lit* and *scale* in > *round* are set to the same number i.e. *6* code fails. > If numbers are different all works. > Moving *F.lit(6).alias("run_number")* to the final *select* also solves the > problem when both numbers in *lit* and *scale* in *round* are the same. > Example of the working code: > {code:java} > import pyspark.sql.functions as F > from decimal import Decimal > data = [ (1, 100, Decimal("1.1"), "L", True), > (2, 200, Decimal("1.2"), "H", False), > (2, 300, Decimal("2.345"), "E", False), > ] > columns = ["group_a", "id", "amount", "selector_a", "selector_b"] > df = spark.createDataFrame(data, schema=columns) > df_final = ( > df.select( > F.lit(7).alias("run_number"), > F.lit("AA").alias("run_type"), > F.col("group_a"), > F.col("id"), > F.col("amount"), > F.col("selector_a"), > F.col("selector_b"), > ) > .withColumn( > "amount_c", > F.when( > (F.col("selector_b") == False) > & (F.col("selector_a").isin(["L", "H", "E"])), > F.col("amount"), > ).otherwise(F.lit(None)) > ) > .withColumn( > "count_of_amount_c", > F.when( > (F.col("selector_b") == False) > & (F.col("selector_a").isin(["L", "H", "E"])), > F.col("id") > ).otherwise(F.lit(None)) > ) > ) > group_by_cols = [ > "run_number", > "group_a", > "run_type" > ] > df_final = df_final.groupBy(group_by_cols).agg( > F.countDistinct("id").alias("count_of_amount"), > F.round(F.sum("amount")/ 1000, 1).alias("total_amount"), > F.sum("amount_c").alias("amount_c"), > F.countDistinct("count_of_amount_c").alias( > "count_of_amount_c" > ), > ) > df_final = ( > df_final > .withColumn( > "total_amount", > F.round(F.col("total_amount") / 1000, 6), > ) > .withColumn( > "count_of_amount", F.col("count_of_amount").cast("int") > ) > .withColumn( > "count_of_amount_c", > F.when( > F.col("amount_c").isNull(), F.lit(None).cast("int") > ).otherwise(F.col("count_of_amount_c").cast("int")), > ) > ) > df_final = df_final.select( > F.col("total_amount"), > "run_number", > "group_a", > "run_type", > "count_of_amount", > "amount_c", > "count_of_amount_c", > ) > df_final.show() {code} > Output: > {code:java} > +------------+----------+-------+--------+---------------+--------------------+-----------------+ > |total_amount|run_number|group_a|run_type|count_of_amount| > amount_c|count_of_amount_c| > +------------+----------+-------+--------+---------------+--------------------+-----------------+ > | 0.000000| 7| 2| AA| > 2|3.545000000000000000| 2| > | 0.000000| 7| 1| AA| 1| > NULL| NULL| > +------------+----------+-------+--------+---------------+--------------------+-----------------+{code} > Expected behavior: > Values used in the *lit* function shouldn't interfere with the *scale* > parameter in the *round* function > > > -- This message was sent by Atlassian Jira (v8.20.10#820010) --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org For additional commands, e-mail: issues-h...@spark.apache.org