This is an automated email from the ASF dual-hosted git repository.

sunchao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d9ca9820384 [SPARK-42168][SQL][PYTHON][FOLLOW-UP] Test 
FlatMapCoGroupsInPandas with Window function
d9ca9820384 is described below

commit d9ca9820384b84aa5004f4c407d72d3fbc6cbb97
Author: Enrico Minack <git...@enrico.minack.dev>
AuthorDate: Fri Jan 27 09:20:08 2023 -0800

    [SPARK-42168][SQL][PYTHON][FOLLOW-UP] Test FlatMapCoGroupsInPandas with 
Window function
    
    ### What changes were proposed in this pull request?
    This ports tests from #39717 in branch-3.2 to master.
    
    ### Why are the changes needed?
    To make sure this use case is tested.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    E2E test in `test_pandas_cogrouped_map.py` and analysis test in 
`EnsureRequirementsSuite.scala`.
    
    Closes #39752 from EnricoMi/branch-cogroup-window-bug-test.
    
    Authored-by: Enrico Minack <git...@enrico.minack.dev>
    Signed-off-by: Chao Sun <sunc...@apple.com>
---
 .../sql/tests/pandas/test_pandas_cogrouped_map.py  | 54 ++++++++++++++++++++-
 .../exchange/EnsureRequirementsSuite.scala         | 56 ++++++++++++++++++++++
 2 files changed, 109 insertions(+), 1 deletion(-)

diff --git a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py 
b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
index d92a105f5d4..5cbc9e1caa4 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
@@ -18,8 +18,9 @@
 import unittest
 from typing import cast
 
-from pyspark.sql.functions import array, explode, col, lit, udf, pandas_udf
+from pyspark.sql.functions import array, explode, col, lit, udf, pandas_udf, 
sum
 from pyspark.sql.types import DoubleType, StructType, StructField, Row
+from pyspark.sql.window import Window
 from pyspark.errors import IllegalArgumentException, PythonException
 from pyspark.testing.sqlutils import (
     ReusedSQLTestCase,
@@ -365,6 +366,57 @@ class CogroupedMapInPandasTests(ReusedSQLTestCase):
 
         self.assertEqual(row.asDict(), Row(column=2, value=2).asDict())
 
+    def test_with_window_function(self):
+        # SPARK-42168: a window function with same partition keys but 
differing key order
+        ids = 2
+        days = 100
+        vals = 10000
+        parts = 10
+
+        id_df = self.spark.range(ids)
+        day_df = self.spark.range(days).withColumnRenamed("id", "day")
+        vals_df = self.spark.range(vals).withColumnRenamed("id", "value")
+        df = id_df.join(day_df).join(vals_df)
+
+        left_df = df.withColumnRenamed("value", 
"left").repartition(parts).cache()
+        # SPARK-42132: this bug requires us to alias all columns from df here
+        right_df = (
+            df.select(col("id").alias("id"), col("day").alias("day"), 
col("value").alias("right"))
+            .repartition(parts)
+            .cache()
+        )
+
+        # note the column order is different to the groupBy("id", "day") 
column order below
+        window = Window.partitionBy("day", "id")
+
+        left_grouped_df = left_df.groupBy("id", "day")
+        right_grouped_df = right_df.withColumn("day_sum", 
sum(col("day")).over(window)).groupBy(
+            "id", "day"
+        )
+
+        def cogroup(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame:
+            return pd.DataFrame(
+                [
+                    {
+                        "id": left["id"][0]
+                        if not left.empty
+                        else (right["id"][0] if not right.empty else None),
+                        "day": left["day"][0]
+                        if not left.empty
+                        else (right["day"][0] if not right.empty else None),
+                        "lefts": len(left.index),
+                        "rights": len(right.index),
+                    }
+                ]
+            )
+
+        df = left_grouped_df.cogroup(right_grouped_df).applyInPandas(
+            cogroup, schema="id long, day long, lefts integer, rights integer"
+        )
+
+        actual = df.orderBy("id", "day").take(days)
+        self.assertEqual(actual, [Row(0, day, vals, vals) for day in 
range(days)])
+
     @staticmethod
     def _test_with_key(left, right, isLeft):
         def right_assign_key(key, lft, rgt):
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
index bc1fd7a5fa5..844037339ab 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
@@ -17,7 +17,9 @@
 
 package org.apache.spark.sql.execution.exchange
 
+import org.apache.spark.api.python.PythonEvalType
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
 import org.apache.spark.sql.catalyst.plans.Inner
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.catalyst.statsEstimation.StatsTestPlan
@@ -25,9 +27,12 @@ import org.apache.spark.sql.connector.catalog.functions._
 import org.apache.spark.sql.execution.{DummySparkPlan, SortExec}
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.execution.joins.SortMergeJoinExec
+import org.apache.spark.sql.execution.python.FlatMapCoGroupsInPandasExec
+import org.apache.spark.sql.execution.window.WindowExec
 import org.apache.spark.sql.internal.SQLConf
 import 
org.apache.spark.sql.internal.SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION
 import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
 
 class EnsureRequirementsSuite extends SharedSparkSession {
   private val exprA = Literal(1)
@@ -1104,6 +1109,57 @@ class EnsureRequirementsSuite extends SharedSparkSession 
{
     }
   }
 
+  test("SPARK-42168: FlatMapCoGroupInPandas and Window function with differing 
key order") {
+    val lKey = AttributeReference("key", IntegerType)()
+    val lKey2 = AttributeReference("key2", IntegerType)()
+
+    val rKey = AttributeReference("key", IntegerType)()
+    val rKey2 = AttributeReference("key2", IntegerType)()
+    val rValue = AttributeReference("value", IntegerType)()
+
+    val left = DummySparkPlan()
+    val right = WindowExec(
+      Alias(
+        WindowExpression(
+          Sum(rValue).toAggregateExpression(),
+          WindowSpecDefinition(
+            Seq(rKey2, rKey),
+            Nil,
+            SpecifiedWindowFrame(RowFrame, UnboundedPreceding, 
UnboundedFollowing)
+          )
+        ), "sum")() :: Nil,
+      Seq(rKey2, rKey),
+      Nil,
+      DummySparkPlan()
+    )
+
+    val pythonUdf = PythonUDF("pyUDF", null,
+      StructType(Seq(StructField("value", IntegerType))),
+      Seq.empty,
+      PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
+      true)
+
+    val flapMapCoGroup = FlatMapCoGroupsInPandasExec(
+      Seq(lKey, lKey2),
+      Seq(rKey, rKey2),
+      pythonUdf,
+      AttributeReference("value", IntegerType)() :: Nil,
+      left,
+      right
+    )
+
+    val result = EnsureRequirements.apply(flapMapCoGroup)
+    result match {
+      case FlatMapCoGroupsInPandasExec(leftKeys, rightKeys, _, _,
+        SortExec(leftOrder, false, _, _), SortExec(rightOrder, false, _, _)) =>
+        assert(leftKeys === Seq(lKey, lKey2))
+        assert(rightKeys === Seq(rKey, rKey2))
+        assert(leftKeys.map(k => SortOrder(k, Ascending)) === leftOrder)
+        assert(rightKeys.map(k => SortOrder(k, Ascending)) === rightOrder)
+      case other => fail(other.toString)
+    }
+  }
+
   def bucket(numBuckets: Int, expr: Expression): TransformExpression = {
     TransformExpression(BucketFunction, Seq(expr), Some(numBuckets))
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to