This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 0cb0fa31397 [SPARK-38591][SQL][FOLLOW-UP] Fix ambiguous references for 
sorted cogroups
0cb0fa31397 is described below

commit 0cb0fa313979e1b82ddd711a05d8c4e78cf6c9f5
Author: Enrico Minack <git...@enrico.minack.dev>
AuthorDate: Mon Jan 30 10:06:07 2023 +0800

    [SPARK-38591][SQL][FOLLOW-UP] Fix ambiguous references for sorted cogroups
    
    Sort order for left and right cogroups must be resolved against left and 
right plan, respectively. Otherwise, ambiguous reference exception can be 
thrown.
    
    ```Scala
    leftGroupedDf.cogroup(rightGroupedDf)($"time")($"time") { ... }
    ```
    
    Grouped DataFrames `leftGroupedDf` and `rightGroupedDf` both contain column 
`"time"`. Left and right sort order `$"time"` is ambiguous when resolved 
against all children. They must be resolved against left or right child, 
exclusively.
    
    This fixes errors like
    
        [AMBIGUOUS_REFERENCE] Reference `time` is ambiguous, could be: [`time`, 
`time`].
    
    Tested in `AnalysisSuite` on `Analyzer` level, and E2E in `DatasetSuite`.
    
    Closes #39744 from EnricoMi/branch-sorted-groups-ambiguous-reference.
    
    Authored-by: Enrico Minack <git...@enrico.minack.dev>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
    (cherry picked from commit 607e753f9cf390cce293cef22a682e8a2d63e86b)
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../spark/sql/catalyst/analysis/Analyzer.scala     |  31 +++++
 .../sql/catalyst/analysis/AnalysisSuite.scala      |  46 +++++++-
 .../scala/org/apache/spark/sql/DatasetSuite.scala  | 127 ++++++++++++++++++---
 3 files changed, 186 insertions(+), 18 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index dd1743f4554..48ea0460725 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1593,6 +1593,37 @@ class Analyzer(override val catalogManager: 
CatalogManager)
           Generate(newG.asInstanceOf[Generator], join, outer, qualifier, 
output, child)
         }
 
+      case mg: MapGroups if mg.dataOrder.exists(!_.resolved) =>
+        // Resolve against `AppendColumns`'s children, instead of 
`AppendColumns`,
+        // because `AppendColumns`'s serializer might produce conflict 
attribute
+        // names leading to ambiguous references exception.
+        val planForResolve = mg.child match {
+          case appendColumns: AppendColumns => appendColumns.child
+          case plan => plan
+        }
+        val resolvedOrder = mg.dataOrder
+            .map(resolveExpressionByPlanOutput(_, 
planForResolve).asInstanceOf[SortOrder])
+        mg.copy(dataOrder = resolvedOrder)
+
+      // Left and right sort expression have to be resolved against the 
respective child plan only
+      case cg: CoGroup if cg.leftOrder.exists(!_.resolved) || 
cg.rightOrder.exists(!_.resolved) =>
+        // Resolve against `AppendColumns`'s children, instead of 
`AppendColumns`,
+        // because `AppendColumns`'s serializer might produce conflict 
attribute
+        // names leading to ambiguous references exception.
+        val (leftPlanForResolve, rightPlanForResolve) = Seq(cg.left, 
cg.right).map {
+          case appendColumns: AppendColumns => appendColumns.child
+          case plan => plan
+        } match {
+          case Seq(left, right) => (left, right)
+        }
+
+        val resolvedLeftOrder = cg.leftOrder
+          .map(resolveExpressionByPlanOutput(_, 
leftPlanForResolve).asInstanceOf[SortOrder])
+        val resolvedRightOrder = cg.rightOrder
+          .map(resolveExpressionByPlanOutput(_, 
rightPlanForResolve).asInstanceOf[SortOrder])
+
+        cg.copy(leftOrder = resolvedLeftOrder, rightOrder = resolvedRightOrder)
+
       // Skips plan which contains deserializer expressions, as they should be 
resolved by another
       // rule: ResolveDeserializer.
       case plan if containsDeserializer(plan.expressions) => plan
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 6dfbf12bbd7..e6cd0699468 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -28,7 +28,7 @@ import org.scalatest.matchers.must.Matchers
 
 import org.apache.spark.api.python.PythonEvalType
 import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier}
+import org.apache.spark.sql.catalyst.{AliasIdentifier, QueryPlanningTracker, 
TableIdentifier}
 import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
@@ -1329,4 +1329,48 @@ class AnalysisSuite extends AnalysisTest with Matchers {
         args = Map("param1" -> Literal(10), "param2" -> Literal(20))),
       parsePlan("SELECT c FROM a WHERE c < 20"))
   }
+
+  test("SPARK-38591: resolve left and right CoGroup sort order on respective 
side only") {
+    def func(k: Int, left: Iterator[Int], right: Iterator[Int]): Iterator[Int] 
= {
+      Iterator.empty
+    }
+
+    implicit val intEncoder = ExpressionEncoder[Int]
+
+    val left = testRelation2.select($"e").analyze
+    val right = testRelation3.select($"e").analyze
+    val leftWithKey = AppendColumns[Int, Int]((x: Int) => x, left)
+    val rightWithKey = AppendColumns[Int, Int]((x: Int) => x, right)
+    val order = SortOrder($"e", Ascending)
+
+    val cogroup = leftWithKey.cogroup[Int, Int, Int, Int](
+      rightWithKey,
+      func,
+      leftWithKey.newColumns,
+      rightWithKey.newColumns,
+      left.output,
+      right.output,
+      order :: Nil,
+      order :: Nil
+    )
+
+    // analyze the plan
+    val actualPlan = getAnalyzer.executeAndCheck(cogroup, new 
QueryPlanningTracker)
+    val cg = actualPlan.collectFirst {
+      case cg: CoGroup => cg
+    }
+    // assert sort order reference only their respective plan
+    assert(cg.isDefined)
+    cg.foreach { cg =>
+      assert(cg.leftOrder != cg.rightOrder)
+
+      assert(cg.leftOrder.flatMap(_.references).nonEmpty)
+      
assert(cg.leftOrder.flatMap(_.references).forall(cg.left.output.contains))
+      
assert(!cg.leftOrder.flatMap(_.references).exists(cg.right.output.contains))
+
+      assert(cg.rightOrder.flatMap(_.references).nonEmpty)
+      
assert(cg.rightOrder.flatMap(_.references).forall(cg.right.output.contains))
+      
assert(!cg.rightOrder.flatMap(_.references).exists(cg.left.output.contains))
+    }
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 8b48d7e7827..70db5c1a655 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -573,12 +573,40 @@ class DatasetSuite extends QueryTest
       "a", "30", "b", "3", "c", "1")
   }
 
+  test("groupBy, flatMapSorted") {
+    val ds = Seq(("a", 1, 10), ("a", 2, 20), ("b", 2, 1), ("b", 1, 2), ("c", 
1, 1))
+      .toDF("key", "seq", "value")
+    val grouped = ds.groupBy($"key").as[String, (String, Int, Int)]
+    val aggregated = grouped.flatMapSortedGroups($"seq", expr("length(key)"), 
$"value") {
+      (g, iter) => Iterator(g, iter.mkString(", "))
+    }
+
+    checkDatasetUnorderly(
+      aggregated,
+      "a", "(a,1,10), (a,2,20)",
+      "b", "(b,1,2), (b,2,1)",
+      "c", "(c,1,1)"
+    )
+
+    // Star is not allowed as group sort column
+    checkError(
+      exception = intercept[AnalysisException] {
+        grouped.flatMapSortedGroups($"*") {
+          (g, iter) => Iterator(g, iter.mkString(", "))
+        }
+      },
+      errorClass = "_LEGACY_ERROR_TEMP_1020",
+      parameters = Map("elem" -> "'*'", "prettyName" -> "MapGroups"))
+  }
+
   test("groupBy function, flatMapSorted") {
     val ds = Seq(("a", 1, 10), ("a", 2, 20), ("b", 2, 1), ("b", 1, 2), ("c", 
1, 1))
       .toDF("key", "seq", "value")
-    val grouped = ds.groupByKey(v => (v.getString(0), "word"))
-    val aggregated = grouped.flatMapSortedGroups($"seq", expr("length(key)")) {
-      (g, iter) => Iterator(g._1, iter.mkString(", "))
+    // groupByKey Row => String adds key columns `value` to the dataframe
+    val grouped = ds.groupByKey(v => v.getString(0))
+    // $"value" here is expected to not reference the key column
+    val aggregated = grouped.flatMapSortedGroups($"seq", expr("length(key)"), 
$"value") {
+      (g, iter) => Iterator(g, iter.mkString(", "))
     }
 
     checkDatasetUnorderly(
@@ -587,14 +615,42 @@ class DatasetSuite extends QueryTest
       "b", "[b,1,2], [b,2,1]",
       "c", "[c,1,1]"
     )
+
+    // Star is not allowed as group sort column
+    checkError(
+      exception = intercept[AnalysisException] {
+        grouped.flatMapSortedGroups($"*") {
+          (g, iter) => Iterator(g, iter.mkString(", "))
+        }
+      },
+      errorClass = "_LEGACY_ERROR_TEMP_1020",
+      parameters = Map("elem" -> "'*'", "prettyName" -> "MapGroups"))
+  }
+
+  test("groupBy, flatMapSorted desc") {
+    val ds = Seq(("a", 1, 10), ("a", 2, 20), ("b", 2, 1), ("b", 1, 2), ("c", 
1, 1))
+      .toDF("key", "seq", "value")
+    val grouped = ds.groupBy($"key").as[String, (String, Int, Int)]
+    val aggregated = grouped.flatMapSortedGroups($"seq".desc, 
expr("length(key)"), $"value") {
+      (g, iter) => Iterator(g, iter.mkString(", "))
+    }
+
+    checkDatasetUnorderly(
+      aggregated,
+      "a", "(a,2,20), (a,1,10)",
+      "b", "(b,2,1), (b,1,2)",
+      "c", "(c,1,1)"
+    )
   }
 
   test("groupBy function, flatMapSorted desc") {
     val ds = Seq(("a", 1, 10), ("a", 2, 20), ("b", 2, 1), ("b", 1, 2), ("c", 
1, 1))
       .toDF("key", "seq", "value")
-    val grouped = ds.groupByKey(v => (v.getString(0), "word"))
-    val aggregated = grouped.flatMapSortedGroups($"seq".desc, 
expr("length(key)")) {
-      (g, iter) => Iterator(g._1, iter.mkString(", "))
+    // groupByKey Row => String adds key columns `value` to the dataframe
+    val grouped = ds.groupByKey(v => v.getString(0))
+    // $"value" here is expected to not reference the key column
+    val aggregated = grouped.flatMapSortedGroups($"seq".desc, 
expr("length(key)"), $"value") {
+      (g, iter) => Iterator(g, iter.mkString(", "))
     }
 
     checkDatasetUnorderly(
@@ -759,11 +815,11 @@ class DatasetSuite extends QueryTest
       1 -> "a", 2 -> "bc", 3 -> "d")
   }
 
-  test("cogroup sorted") {
+  test("cogroup with groupBy and sorted") {
     val left = Seq(1 -> "a", 3 -> "xyz", 5 -> "hello", 3 -> "abc", 3 -> 
"ijk").toDS()
     val right = Seq(2 -> "q", 3 -> "w", 5 -> "x", 5 -> "z", 3 -> "a", 5 -> 
"y").toDS()
-    val groupedLeft = left.groupByKey(_._1)
-    val groupedRight = right.groupByKey(_._1)
+    val groupedLeft = left.groupBy($"_1").as[Int, (Int, String)]
+    val groupedRight = right.groupBy($"_1").as[Int, (Int, String)]
 
     val neitherSortedExpected = Seq(1 -> "a#", 2 -> "#q", 3 -> "xyzabcijk#wa", 
5 -> "hello#xzy")
     val leftSortedExpected = Seq(1 -> "a#", 2 -> "#q", 3 -> "abcijkxyz#wa", 5 
-> "hello#xzy")
@@ -771,18 +827,18 @@ class DatasetSuite extends QueryTest
     val bothSortedExpected = Seq(1 -> "a#", 2 -> "#q", 3 -> "abcijkxyz#aw", 5 
-> "hello#xyz")
     val bothDescSortedExpected = Seq(1 -> "a#", 2 -> "#q", 3 -> 
"xyzijkabc#wa", 5 -> "hello#zyx")
 
-    val leftOrder = Seq(left("_2"))
-    val rightOrder = Seq(right("_2"))
-    val leftDescOrder = Seq(left("_2").desc)
-    val rightDescOrder = Seq(right("_2").desc)
+    val ascOrder = Seq($"_2")
+    val descOrder = Seq($"_2".desc)
+    val exprOrder = Seq(substring($"_2", 0, 1))
     val none = Seq.empty
 
     Seq(
       ("neither", none, none, neitherSortedExpected),
-      ("left", leftOrder, none, leftSortedExpected),
-      ("right", none, rightOrder, rightSortedExpected),
-      ("both", leftOrder, rightOrder, bothSortedExpected),
-      ("both desc", leftDescOrder, rightDescOrder, bothDescSortedExpected)
+      ("left", ascOrder, none, leftSortedExpected),
+      ("right", none, ascOrder, rightSortedExpected),
+      ("both", ascOrder, ascOrder, bothSortedExpected),
+      ("expr", exprOrder, exprOrder, bothSortedExpected),
+      ("both desc", descOrder, descOrder, bothDescSortedExpected)
     ).foreach { case (label, leftOrder, rightOrder, expected) =>
       withClue(s"$label sorted") {
         val cogrouped = groupedLeft.cogroupSorted(groupedRight)(leftOrder: 
_*)(rightOrder: _*) {
@@ -795,6 +851,43 @@ class DatasetSuite extends QueryTest
     }
   }
 
+  test("cogroup with groupBy function and sorted") {
+    val left = Seq(1 -> "a", 3 -> "xyz", 5 -> "hello", 3 -> "abc", 3 -> 
"ijk").toDS()
+    val right = Seq(2 -> "q", 3 -> "w", 5 -> "x", 5 -> "z", 3 -> "a", 5 -> 
"y").toDS()
+    // this groupByKey produces conflicting _1 and _2 columns
+    // that should be ignored when resolving sort expressions
+    val groupedLeft = left.groupByKey(row => (row._1, row._1))
+    val groupedRight = right.groupByKey(row => (row._1, row._1))
+
+    val neitherSortedExpected = Seq(1 -> "a#", 2 -> "#q", 3 -> "xyzabcijk#wa", 
5 -> "hello#xzy")
+    val leftSortedExpected = Seq(1 -> "a#", 2 -> "#q", 3 -> "abcijkxyz#wa", 5 
-> "hello#xzy")
+    val rightSortedExpected = Seq(1 -> "a#", 2 -> "#q", 3 -> "xyzabcijk#aw", 5 
-> "hello#xyz")
+    val bothSortedExpected = Seq(1 -> "a#", 2 -> "#q", 3 -> "abcijkxyz#aw", 5 
-> "hello#xyz")
+    val bothDescSortedExpected = Seq(1 -> "a#", 2 -> "#q", 3 -> 
"xyzijkabc#wa", 5 -> "hello#zyx")
+
+    val ascOrder = Seq($"_2")
+    val descOrder = Seq($"_2".desc)
+    val exprOrder = Seq(substring($"_2", 0, 1))
+    val none = Seq.empty
+
+    Seq(
+      ("neither", none, none, neitherSortedExpected),
+      ("left", ascOrder, none, leftSortedExpected),
+      ("right", none, ascOrder, rightSortedExpected),
+      ("both", ascOrder, ascOrder, bothSortedExpected),
+      ("expr", exprOrder, exprOrder, bothSortedExpected),
+      ("both desc", descOrder, descOrder, bothDescSortedExpected)
+    ).foreach { case (label, leftOrder, rightOrder, expected) =>
+      withClue(s"$label sorted") {
+        val cogrouped = groupedLeft.cogroupSorted(groupedRight)(leftOrder: 
_*)(rightOrder: _*) {
+          (key, left, right) =>
+            Iterator(key._1 -> (left.map(_._2).mkString + "#" + 
right.map(_._2).mkString))
+        }
+        checkDatasetUnorderly(cogrouped, expected.toList: _*)
+      }
+    }
+  }
+
   test("SPARK-34806: observation on datasets") {
     val namedObservation = Observation("named")
     val unnamedObservation = Observation()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to