This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 5db87787d5c [SPARK-43781][SQL] Fix IllegalStateException when 
cogrouping two datasets derived from the same source
5db87787d5c is described below

commit 5db87787d5cc1cefb51ec77e49bac7afaa46d300
Author: Jia Fan <fanjiaemi...@qq.com>
AuthorDate: Thu Aug 10 23:19:05 2023 +0800

    [SPARK-43781][SQL] Fix IllegalStateException when cogrouping two datasets 
derived from the same source
    
    ### What changes were proposed in this pull request?
    When cogroup two datasets derived from same source, eg:
    ```scala
    val inputType = StructType(Array(StructField("id", LongType, false),
       StructField("type", StringType, false)))
    val keyType = StructType(Array(StructField("id", LongType, false)))
    
    val inputRows = new java.util.ArrayList[Row]()
    inputRows.add(Row(1L, "foo"))
    inputRows.add(Row(1L, "bar"))
    inputRows.add(Row(2L, "foo"))
    val input = spark.createDataFrame(inputRows, inputType)
    val fooGroups = input.filter("type = 
'foo'").groupBy("id").as(RowEncoder(keyType),
       RowEncoder(inputType))
    val barGroups = input.filter("type = 
'bar'").groupBy("id").as(RowEncoder(keyType),
       RowEncoder(inputType))
    
    val result = fooGroups.cogroup(barGroups) { case (row, iterator, iterator1) 
=>
       iterator.toSeq ++ iterator1.toSeq
    }(RowEncoder(inputType)).collect()
    ```
    The error will be reported:
    ```
    21:03:27.651 ERROR org.apache.spark.executor.Executor: Exception in task 
1.0 in stage 0.0 (TID 1)
    java.lang.IllegalStateException: Couldn't find id#19L in [id#0L,type#1]
            at 
org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:80)
            at 
org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:73)
            ...
    ```
    The reason are `DeduplicateRelations` rewrite `LocalRelation` but can't 
rewrite `left(right)Group` and `left(right)Attr` in `CoGroup`. In fact, the 
`Join` will face same situation. But `Join` regenerate plan when invoke itself 
to avoid this situation. Please refer 
https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala#L1089
    
    This PR let `DeduplicateRelations` handle with `CoGroup` case
    
    ### Why are the changes needed?
    Fix IllegalStateException when cogrouping two datasets derived from the 
same source
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Add new test
    
    Closes #41554 from Hisoka-X/SPARK-43781_cogrouping_two_datasets.
    
    Authored-by: Jia Fan <fanjiaemi...@qq.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../catalyst/analysis/DeduplicateRelations.scala   | 39 ++++++++++++++++++++--
 .../scala/org/apache/spark/sql/DatasetSuite.scala  | 26 +++++++++++++++
 2 files changed, 63 insertions(+), 2 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
index a8f2765b1c4..56ce3765836 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis
 
 import scala.collection.mutable
 
-import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, 
AttributeSet, NamedExpression, OuterReference, SubqueryExpression}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, 
AttributeMap, AttributeReference, AttributeSet, Expression, NamedExpression, 
OuterReference, SubqueryExpression}
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.trees.TreePattern._
@@ -228,7 +228,42 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
           if (attrMap.isEmpty) {
             planWithNewChildren
           } else {
-            planWithNewChildren.rewriteAttrs(attrMap)
+            def rewriteAttrs[T <: Expression](
+                exprs: Seq[T],
+                attrMap: Map[Attribute, Attribute]): Seq[T] = {
+              exprs.map { expr =>
+                
expr.transformWithPruning(_.containsPattern(ATTRIBUTE_REFERENCE)) {
+                  case a: AttributeReference => attrMap.getOrElse(a, a)
+                }.asInstanceOf[T]
+              }
+            }
+
+            planWithNewChildren match {
+              // TODO (SPARK-44754): we should handle all special cases here.
+              case c: CoGroup =>
+                // SPARK-43781: CoGroup is a special case, `rewriteAttrs` will 
incorrectly update
+                // some fields that do not need to be updated. We need to 
update the output
+                // attributes of CoGroup manually.
+                val leftAttrMap = attrMap.filter(a => 
c.left.output.contains(a._2))
+                val rightAttrMap = attrMap.filter(a => 
c.right.output.contains(a._2))
+                val newLeftAttr = rewriteAttrs(c.leftAttr, leftAttrMap)
+                val newRightAttr = rewriteAttrs(c.rightAttr, rightAttrMap)
+                val newLeftGroup = rewriteAttrs(c.leftGroup, leftAttrMap)
+                val newRightGroup = rewriteAttrs(c.rightGroup, rightAttrMap)
+                val newLeftOrder = rewriteAttrs(c.leftOrder, leftAttrMap)
+                val newRightOrder = rewriteAttrs(c.rightOrder, rightAttrMap)
+                val newKeyDes = 
c.keyDeserializer.asInstanceOf[UnresolvedDeserializer]
+                  .copy(inputAttributes = newLeftGroup)
+                val newLeftDes = 
c.leftDeserializer.asInstanceOf[UnresolvedDeserializer]
+                  .copy(inputAttributes = newLeftAttr)
+                val newRightDes = 
c.rightDeserializer.asInstanceOf[UnresolvedDeserializer]
+                  .copy(inputAttributes = newRightAttr)
+                c.copy(keyDeserializer = newKeyDes, leftDeserializer = 
newLeftDes,
+                  rightDeserializer = newRightDes, leftGroup = newLeftGroup,
+                  rightGroup = newRightGroup, leftAttr = newLeftAttr, 
rightAttr = newRightAttr,
+                  leftOrder = newLeftOrder, rightOrder = newRightOrder)
+              case _ => planWithNewChildren.rewriteAttrs(attrMap)
+            }
           }
         } else {
           planWithNewSubquery.withNewChildren(newChildren.toSeq)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index c967540541a..e05b545f235 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -916,6 +916,32 @@ class DatasetSuite extends QueryTest
     }
   }
 
+  test("SPARK-43781: cogroup two datasets derived from the same source") {
+    val inputType = StructType(Array(StructField("id", LongType, false),
+      StructField("type", StringType, false)))
+    val keyType = StructType(Array(StructField("id", LongType, false)))
+
+    val inputRows = new java.util.ArrayList[Row]()
+    inputRows.add(Row(1L, "foo"))
+    inputRows.add(Row(1L, "bar"))
+    inputRows.add(Row(2L, "foo"))
+    val input = spark.createDataFrame(inputRows, inputType)
+    val fooGroups = input.filter("type = 
'foo'").groupBy("id").as(ExpressionEncoder(keyType),
+      ExpressionEncoder(inputType))
+    val barGroups = input.filter("type = 
'bar'").groupBy("id").as(ExpressionEncoder(keyType),
+      ExpressionEncoder(inputType))
+
+    val result = fooGroups.cogroup(barGroups) { case (row, iterator, 
iterator1) =>
+      iterator.toSeq ++ iterator1.toSeq
+    }(ExpressionEncoder(inputType)).collect()
+    assert(result.length == 3)
+
+    val result2 = fooGroups.cogroupSorted(barGroups)($"id")($"id") {
+      case (row, iterator, iterator1) => iterator.toSeq ++ iterator1.toSeq
+    }(ExpressionEncoder(inputType)).collect()
+    assert(result2.length == 3)
+  }
+
   test("SPARK-34806: observation on datasets") {
     val namedObservation = Observation("named")
     val unnamedObservation = Observation()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to