This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-3.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push: new 0cb0fa31397 [SPARK-38591][SQL][FOLLOW-UP] Fix ambiguous references for sorted cogroups 0cb0fa31397 is described below commit 0cb0fa313979e1b82ddd711a05d8c4e78cf6c9f5 Author: Enrico Minack <git...@enrico.minack.dev> AuthorDate: Mon Jan 30 10:06:07 2023 +0800 [SPARK-38591][SQL][FOLLOW-UP] Fix ambiguous references for sorted cogroups Sort order for left and right cogroups must be resolved against left and right plan, respectively. Otherwise, ambiguous reference exception can be thrown. ```Scala leftGroupedDf.cogroup(rightGroupedDf)($"time")($"time") { ... } ``` Grouped DataFrames `leftGroupedDf` and `rightGroupedDf` both contain column `"time"`. Left and right sort order `$"time"` is ambiguous when resolved against all children. They must be resolved against left or right child, exclusively. This fixes errors like [AMBIGUOUS_REFERENCE] Reference `time` is ambiguous, could be: [`time`, `time`]. Tested in `AnalysisSuite` on `Analyzer` level, and E2E in `DatasetSuite`. Closes #39744 from EnricoMi/branch-sorted-groups-ambiguous-reference. Authored-by: Enrico Minack <git...@enrico.minack.dev> Signed-off-by: Wenchen Fan <wenc...@databricks.com> (cherry picked from commit 607e753f9cf390cce293cef22a682e8a2d63e86b) Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../spark/sql/catalyst/analysis/Analyzer.scala | 31 +++++ .../sql/catalyst/analysis/AnalysisSuite.scala | 46 +++++++- .../scala/org/apache/spark/sql/DatasetSuite.scala | 127 ++++++++++++++++++--- 3 files changed, 186 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index dd1743f4554..48ea0460725 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1593,6 +1593,37 @@ class Analyzer(override val catalogManager: CatalogManager) Generate(newG.asInstanceOf[Generator], join, outer, qualifier, output, child) } + case mg: MapGroups if mg.dataOrder.exists(!_.resolved) => + // Resolve against `AppendColumns`'s children, instead of `AppendColumns`, + // because `AppendColumns`'s serializer might produce conflict attribute + // names leading to ambiguous references exception. + val planForResolve = mg.child match { + case appendColumns: AppendColumns => appendColumns.child + case plan => plan + } + val resolvedOrder = mg.dataOrder + .map(resolveExpressionByPlanOutput(_, planForResolve).asInstanceOf[SortOrder]) + mg.copy(dataOrder = resolvedOrder) + + // Left and right sort expression have to be resolved against the respective child plan only + case cg: CoGroup if cg.leftOrder.exists(!_.resolved) || cg.rightOrder.exists(!_.resolved) => + // Resolve against `AppendColumns`'s children, instead of `AppendColumns`, + // because `AppendColumns`'s serializer might produce conflict attribute + // names leading to ambiguous references exception. + val (leftPlanForResolve, rightPlanForResolve) = Seq(cg.left, cg.right).map { + case appendColumns: AppendColumns => appendColumns.child + case plan => plan + } match { + case Seq(left, right) => (left, right) + } + + val resolvedLeftOrder = cg.leftOrder + .map(resolveExpressionByPlanOutput(_, leftPlanForResolve).asInstanceOf[SortOrder]) + val resolvedRightOrder = cg.rightOrder + .map(resolveExpressionByPlanOutput(_, rightPlanForResolve).asInstanceOf[SortOrder]) + + cg.copy(leftOrder = resolvedLeftOrder, rightOrder = resolvedRightOrder) + // Skips plan which contains deserializer expressions, as they should be resolved by another // rule: ResolveDeserializer. case plan if containsDeserializer(plan.expressions) => plan diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 6dfbf12bbd7..e6cd0699468 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -28,7 +28,7 @@ import org.scalatest.matchers.must.Matchers import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.{AliasIdentifier, QueryPlanningTracker, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -1329,4 +1329,48 @@ class AnalysisSuite extends AnalysisTest with Matchers { args = Map("param1" -> Literal(10), "param2" -> Literal(20))), parsePlan("SELECT c FROM a WHERE c < 20")) } + + test("SPARK-38591: resolve left and right CoGroup sort order on respective side only") { + def func(k: Int, left: Iterator[Int], right: Iterator[Int]): Iterator[Int] = { + Iterator.empty + } + + implicit val intEncoder = ExpressionEncoder[Int] + + val left = testRelation2.select($"e").analyze + val right = testRelation3.select($"e").analyze + val leftWithKey = AppendColumns[Int, Int]((x: Int) => x, left) + val rightWithKey = AppendColumns[Int, Int]((x: Int) => x, right) + val order = SortOrder($"e", Ascending) + + val cogroup = leftWithKey.cogroup[Int, Int, Int, Int]( + rightWithKey, + func, + leftWithKey.newColumns, + rightWithKey.newColumns, + left.output, + right.output, + order :: Nil, + order :: Nil + ) + + // analyze the plan + val actualPlan = getAnalyzer.executeAndCheck(cogroup, new QueryPlanningTracker) + val cg = actualPlan.collectFirst { + case cg: CoGroup => cg + } + // assert sort order reference only their respective plan + assert(cg.isDefined) + cg.foreach { cg => + assert(cg.leftOrder != cg.rightOrder) + + assert(cg.leftOrder.flatMap(_.references).nonEmpty) + assert(cg.leftOrder.flatMap(_.references).forall(cg.left.output.contains)) + assert(!cg.leftOrder.flatMap(_.references).exists(cg.right.output.contains)) + + assert(cg.rightOrder.flatMap(_.references).nonEmpty) + assert(cg.rightOrder.flatMap(_.references).forall(cg.right.output.contains)) + assert(!cg.rightOrder.flatMap(_.references).exists(cg.left.output.contains)) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 8b48d7e7827..70db5c1a655 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -573,12 +573,40 @@ class DatasetSuite extends QueryTest "a", "30", "b", "3", "c", "1") } + test("groupBy, flatMapSorted") { + val ds = Seq(("a", 1, 10), ("a", 2, 20), ("b", 2, 1), ("b", 1, 2), ("c", 1, 1)) + .toDF("key", "seq", "value") + val grouped = ds.groupBy($"key").as[String, (String, Int, Int)] + val aggregated = grouped.flatMapSortedGroups($"seq", expr("length(key)"), $"value") { + (g, iter) => Iterator(g, iter.mkString(", ")) + } + + checkDatasetUnorderly( + aggregated, + "a", "(a,1,10), (a,2,20)", + "b", "(b,1,2), (b,2,1)", + "c", "(c,1,1)" + ) + + // Star is not allowed as group sort column + checkError( + exception = intercept[AnalysisException] { + grouped.flatMapSortedGroups($"*") { + (g, iter) => Iterator(g, iter.mkString(", ")) + } + }, + errorClass = "_LEGACY_ERROR_TEMP_1020", + parameters = Map("elem" -> "'*'", "prettyName" -> "MapGroups")) + } + test("groupBy function, flatMapSorted") { val ds = Seq(("a", 1, 10), ("a", 2, 20), ("b", 2, 1), ("b", 1, 2), ("c", 1, 1)) .toDF("key", "seq", "value") - val grouped = ds.groupByKey(v => (v.getString(0), "word")) - val aggregated = grouped.flatMapSortedGroups($"seq", expr("length(key)")) { - (g, iter) => Iterator(g._1, iter.mkString(", ")) + // groupByKey Row => String adds key columns `value` to the dataframe + val grouped = ds.groupByKey(v => v.getString(0)) + // $"value" here is expected to not reference the key column + val aggregated = grouped.flatMapSortedGroups($"seq", expr("length(key)"), $"value") { + (g, iter) => Iterator(g, iter.mkString(", ")) } checkDatasetUnorderly( @@ -587,14 +615,42 @@ class DatasetSuite extends QueryTest "b", "[b,1,2], [b,2,1]", "c", "[c,1,1]" ) + + // Star is not allowed as group sort column + checkError( + exception = intercept[AnalysisException] { + grouped.flatMapSortedGroups($"*") { + (g, iter) => Iterator(g, iter.mkString(", ")) + } + }, + errorClass = "_LEGACY_ERROR_TEMP_1020", + parameters = Map("elem" -> "'*'", "prettyName" -> "MapGroups")) + } + + test("groupBy, flatMapSorted desc") { + val ds = Seq(("a", 1, 10), ("a", 2, 20), ("b", 2, 1), ("b", 1, 2), ("c", 1, 1)) + .toDF("key", "seq", "value") + val grouped = ds.groupBy($"key").as[String, (String, Int, Int)] + val aggregated = grouped.flatMapSortedGroups($"seq".desc, expr("length(key)"), $"value") { + (g, iter) => Iterator(g, iter.mkString(", ")) + } + + checkDatasetUnorderly( + aggregated, + "a", "(a,2,20), (a,1,10)", + "b", "(b,2,1), (b,1,2)", + "c", "(c,1,1)" + ) } test("groupBy function, flatMapSorted desc") { val ds = Seq(("a", 1, 10), ("a", 2, 20), ("b", 2, 1), ("b", 1, 2), ("c", 1, 1)) .toDF("key", "seq", "value") - val grouped = ds.groupByKey(v => (v.getString(0), "word")) - val aggregated = grouped.flatMapSortedGroups($"seq".desc, expr("length(key)")) { - (g, iter) => Iterator(g._1, iter.mkString(", ")) + // groupByKey Row => String adds key columns `value` to the dataframe + val grouped = ds.groupByKey(v => v.getString(0)) + // $"value" here is expected to not reference the key column + val aggregated = grouped.flatMapSortedGroups($"seq".desc, expr("length(key)"), $"value") { + (g, iter) => Iterator(g, iter.mkString(", ")) } checkDatasetUnorderly( @@ -759,11 +815,11 @@ class DatasetSuite extends QueryTest 1 -> "a", 2 -> "bc", 3 -> "d") } - test("cogroup sorted") { + test("cogroup with groupBy and sorted") { val left = Seq(1 -> "a", 3 -> "xyz", 5 -> "hello", 3 -> "abc", 3 -> "ijk").toDS() val right = Seq(2 -> "q", 3 -> "w", 5 -> "x", 5 -> "z", 3 -> "a", 5 -> "y").toDS() - val groupedLeft = left.groupByKey(_._1) - val groupedRight = right.groupByKey(_._1) + val groupedLeft = left.groupBy($"_1").as[Int, (Int, String)] + val groupedRight = right.groupBy($"_1").as[Int, (Int, String)] val neitherSortedExpected = Seq(1 -> "a#", 2 -> "#q", 3 -> "xyzabcijk#wa", 5 -> "hello#xzy") val leftSortedExpected = Seq(1 -> "a#", 2 -> "#q", 3 -> "abcijkxyz#wa", 5 -> "hello#xzy") @@ -771,18 +827,18 @@ class DatasetSuite extends QueryTest val bothSortedExpected = Seq(1 -> "a#", 2 -> "#q", 3 -> "abcijkxyz#aw", 5 -> "hello#xyz") val bothDescSortedExpected = Seq(1 -> "a#", 2 -> "#q", 3 -> "xyzijkabc#wa", 5 -> "hello#zyx") - val leftOrder = Seq(left("_2")) - val rightOrder = Seq(right("_2")) - val leftDescOrder = Seq(left("_2").desc) - val rightDescOrder = Seq(right("_2").desc) + val ascOrder = Seq($"_2") + val descOrder = Seq($"_2".desc) + val exprOrder = Seq(substring($"_2", 0, 1)) val none = Seq.empty Seq( ("neither", none, none, neitherSortedExpected), - ("left", leftOrder, none, leftSortedExpected), - ("right", none, rightOrder, rightSortedExpected), - ("both", leftOrder, rightOrder, bothSortedExpected), - ("both desc", leftDescOrder, rightDescOrder, bothDescSortedExpected) + ("left", ascOrder, none, leftSortedExpected), + ("right", none, ascOrder, rightSortedExpected), + ("both", ascOrder, ascOrder, bothSortedExpected), + ("expr", exprOrder, exprOrder, bothSortedExpected), + ("both desc", descOrder, descOrder, bothDescSortedExpected) ).foreach { case (label, leftOrder, rightOrder, expected) => withClue(s"$label sorted") { val cogrouped = groupedLeft.cogroupSorted(groupedRight)(leftOrder: _*)(rightOrder: _*) { @@ -795,6 +851,43 @@ class DatasetSuite extends QueryTest } } + test("cogroup with groupBy function and sorted") { + val left = Seq(1 -> "a", 3 -> "xyz", 5 -> "hello", 3 -> "abc", 3 -> "ijk").toDS() + val right = Seq(2 -> "q", 3 -> "w", 5 -> "x", 5 -> "z", 3 -> "a", 5 -> "y").toDS() + // this groupByKey produces conflicting _1 and _2 columns + // that should be ignored when resolving sort expressions + val groupedLeft = left.groupByKey(row => (row._1, row._1)) + val groupedRight = right.groupByKey(row => (row._1, row._1)) + + val neitherSortedExpected = Seq(1 -> "a#", 2 -> "#q", 3 -> "xyzabcijk#wa", 5 -> "hello#xzy") + val leftSortedExpected = Seq(1 -> "a#", 2 -> "#q", 3 -> "abcijkxyz#wa", 5 -> "hello#xzy") + val rightSortedExpected = Seq(1 -> "a#", 2 -> "#q", 3 -> "xyzabcijk#aw", 5 -> "hello#xyz") + val bothSortedExpected = Seq(1 -> "a#", 2 -> "#q", 3 -> "abcijkxyz#aw", 5 -> "hello#xyz") + val bothDescSortedExpected = Seq(1 -> "a#", 2 -> "#q", 3 -> "xyzijkabc#wa", 5 -> "hello#zyx") + + val ascOrder = Seq($"_2") + val descOrder = Seq($"_2".desc) + val exprOrder = Seq(substring($"_2", 0, 1)) + val none = Seq.empty + + Seq( + ("neither", none, none, neitherSortedExpected), + ("left", ascOrder, none, leftSortedExpected), + ("right", none, ascOrder, rightSortedExpected), + ("both", ascOrder, ascOrder, bothSortedExpected), + ("expr", exprOrder, exprOrder, bothSortedExpected), + ("both desc", descOrder, descOrder, bothDescSortedExpected) + ).foreach { case (label, leftOrder, rightOrder, expected) => + withClue(s"$label sorted") { + val cogrouped = groupedLeft.cogroupSorted(groupedRight)(leftOrder: _*)(rightOrder: _*) { + (key, left, right) => + Iterator(key._1 -> (left.map(_._2).mkString + "#" + right.map(_._2).mkString)) + } + checkDatasetUnorderly(cogrouped, expected.toList: _*) + } + } + } + test("SPARK-34806: observation on datasets") { val namedObservation = Observation("named") val unnamedObservation = Observation() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org