This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 5db87787d5c [SPARK-43781][SQL] Fix IllegalStateException when cogrouping two datasets derived from the same source 5db87787d5c is described below commit 5db87787d5cc1cefb51ec77e49bac7afaa46d300 Author: Jia Fan <fanjiaemi...@qq.com> AuthorDate: Thu Aug 10 23:19:05 2023 +0800 [SPARK-43781][SQL] Fix IllegalStateException when cogrouping two datasets derived from the same source ### What changes were proposed in this pull request? When cogroup two datasets derived from same source, eg: ```scala val inputType = StructType(Array(StructField("id", LongType, false), StructField("type", StringType, false))) val keyType = StructType(Array(StructField("id", LongType, false))) val inputRows = new java.util.ArrayList[Row]() inputRows.add(Row(1L, "foo")) inputRows.add(Row(1L, "bar")) inputRows.add(Row(2L, "foo")) val input = spark.createDataFrame(inputRows, inputType) val fooGroups = input.filter("type = 'foo'").groupBy("id").as(RowEncoder(keyType), RowEncoder(inputType)) val barGroups = input.filter("type = 'bar'").groupBy("id").as(RowEncoder(keyType), RowEncoder(inputType)) val result = fooGroups.cogroup(barGroups) { case (row, iterator, iterator1) => iterator.toSeq ++ iterator1.toSeq }(RowEncoder(inputType)).collect() ``` The error will be reported: ``` 21:03:27.651 ERROR org.apache.spark.executor.Executor: Exception in task 1.0 in stage 0.0 (TID 1) java.lang.IllegalStateException: Couldn't find id#19L in [id#0L,type#1] at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:80) at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:73) ... ``` The reason are `DeduplicateRelations` rewrite `LocalRelation` but can't rewrite `left(right)Group` and `left(right)Attr` in `CoGroup`. In fact, the `Join` will face same situation. But `Join` regenerate plan when invoke itself to avoid this situation. Please refer https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala#L1089 This PR let `DeduplicateRelations` handle with `CoGroup` case ### Why are the changes needed? Fix IllegalStateException when cogrouping two datasets derived from the same source ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Add new test Closes #41554 from Hisoka-X/SPARK-43781_cogrouping_two_datasets. Authored-by: Jia Fan <fanjiaemi...@qq.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../catalyst/analysis/DeduplicateRelations.scala | 39 ++++++++++++++++++++-- .../scala/org/apache/spark/sql/DatasetSuite.scala | 26 +++++++++++++++ 2 files changed, 63 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala index a8f2765b1c4..56ce3765836 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, AttributeSet, NamedExpression, OuterReference, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, AttributeReference, AttributeSet, Expression, NamedExpression, OuterReference, SubqueryExpression} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern._ @@ -228,7 +228,42 @@ object DeduplicateRelations extends Rule[LogicalPlan] { if (attrMap.isEmpty) { planWithNewChildren } else { - planWithNewChildren.rewriteAttrs(attrMap) + def rewriteAttrs[T <: Expression]( + exprs: Seq[T], + attrMap: Map[Attribute, Attribute]): Seq[T] = { + exprs.map { expr => + expr.transformWithPruning(_.containsPattern(ATTRIBUTE_REFERENCE)) { + case a: AttributeReference => attrMap.getOrElse(a, a) + }.asInstanceOf[T] + } + } + + planWithNewChildren match { + // TODO (SPARK-44754): we should handle all special cases here. + case c: CoGroup => + // SPARK-43781: CoGroup is a special case, `rewriteAttrs` will incorrectly update + // some fields that do not need to be updated. We need to update the output + // attributes of CoGroup manually. + val leftAttrMap = attrMap.filter(a => c.left.output.contains(a._2)) + val rightAttrMap = attrMap.filter(a => c.right.output.contains(a._2)) + val newLeftAttr = rewriteAttrs(c.leftAttr, leftAttrMap) + val newRightAttr = rewriteAttrs(c.rightAttr, rightAttrMap) + val newLeftGroup = rewriteAttrs(c.leftGroup, leftAttrMap) + val newRightGroup = rewriteAttrs(c.rightGroup, rightAttrMap) + val newLeftOrder = rewriteAttrs(c.leftOrder, leftAttrMap) + val newRightOrder = rewriteAttrs(c.rightOrder, rightAttrMap) + val newKeyDes = c.keyDeserializer.asInstanceOf[UnresolvedDeserializer] + .copy(inputAttributes = newLeftGroup) + val newLeftDes = c.leftDeserializer.asInstanceOf[UnresolvedDeserializer] + .copy(inputAttributes = newLeftAttr) + val newRightDes = c.rightDeserializer.asInstanceOf[UnresolvedDeserializer] + .copy(inputAttributes = newRightAttr) + c.copy(keyDeserializer = newKeyDes, leftDeserializer = newLeftDes, + rightDeserializer = newRightDes, leftGroup = newLeftGroup, + rightGroup = newRightGroup, leftAttr = newLeftAttr, rightAttr = newRightAttr, + leftOrder = newLeftOrder, rightOrder = newRightOrder) + case _ => planWithNewChildren.rewriteAttrs(attrMap) + } } } else { planWithNewSubquery.withNewChildren(newChildren.toSeq) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index c967540541a..e05b545f235 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -916,6 +916,32 @@ class DatasetSuite extends QueryTest } } + test("SPARK-43781: cogroup two datasets derived from the same source") { + val inputType = StructType(Array(StructField("id", LongType, false), + StructField("type", StringType, false))) + val keyType = StructType(Array(StructField("id", LongType, false))) + + val inputRows = new java.util.ArrayList[Row]() + inputRows.add(Row(1L, "foo")) + inputRows.add(Row(1L, "bar")) + inputRows.add(Row(2L, "foo")) + val input = spark.createDataFrame(inputRows, inputType) + val fooGroups = input.filter("type = 'foo'").groupBy("id").as(ExpressionEncoder(keyType), + ExpressionEncoder(inputType)) + val barGroups = input.filter("type = 'bar'").groupBy("id").as(ExpressionEncoder(keyType), + ExpressionEncoder(inputType)) + + val result = fooGroups.cogroup(barGroups) { case (row, iterator, iterator1) => + iterator.toSeq ++ iterator1.toSeq + }(ExpressionEncoder(inputType)).collect() + assert(result.length == 3) + + val result2 = fooGroups.cogroupSorted(barGroups)($"id")($"id") { + case (row, iterator, iterator1) => iterator.toSeq ++ iterator1.toSeq + }(ExpressionEncoder(inputType)).collect() + assert(result2.length == 3) + } + test("SPARK-34806: observation on datasets") { val namedObservation = Observation("named") val unnamedObservation = Observation() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org