This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 436ce5f3de3 [SPARK-41261][PYTHON][SS] Fix issue for 
applyInPandasWithState when the columns of grouping keys are not placed in 
order from earliest
436ce5f3de3 is described below

commit 436ce5f3de3321162694a08c377f2de24e4741c3
Author: Jungtaek Lim <kabhwan.opensou...@gmail.com>
AuthorDate: Sun Nov 27 11:01:32 2022 +0900

    [SPARK-41261][PYTHON][SS] Fix issue for applyInPandasWithState when the 
columns of grouping keys are not placed in order from earliest
    
    ### What changes were proposed in this pull request?
    
    This PR fixes the issue for applyInPandasWithState, which is triggered with 
the columns of grouping keys are not placed in order from earliest. If the 
condition is met, user function may get "incorrect" value of the key, including 
`None`.
    
    This is because the projection for the value is co-used between normal 
input row and row for timed-out state. The projection assumed that the schema 
for the row is same as output schema of the child node, whereas row for 
timed-out state is constructed via concatenating key row + null value row.
    
    This PR creates a separate projection for the row for timed-out state, so 
that the projection can pick up the values for grouping columns correctly.
    
    ### Why are the changes needed?
    
    Without this fix, user function may get "incorrect" value of the key, 
including `None`.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No. This feature is not released yet.
    
    ### How was this patch tested?
    
    New test case.
    
    Closes #38798 from HeartSaVioR/SPARK-41261.
    
    Authored-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../FlatMapGroupsInPandasWithStateExec.scala       | 17 +++--
 .../FlatMapGroupsInPandasWithStateSuite.scala      | 83 ++++++++++++++++++++++
 2 files changed, 95 insertions(+), 5 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala
index bc1a5ae17e4..1c7ccaf4fe2 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala
@@ -97,6 +97,15 @@ case class FlatMapGroupsInPandasWithStateExec(
   private lazy val unsafeProj = 
UnsafeProjection.create(dedupAttributesWithNull,
     childOutputWithNull)
 
+  // See processTimedOutState: we create a row which contains the actual 
values for grouping key,
+  // but all nulls for value side by intention. The schema for this row is 
different from
+  // child.output, hence we should create another projection to deal with such 
schema.
+  private lazy val valueAttributesWithNull = childOutputWithNull.filterNot { 
attr =>
+    groupingAttributes.exists(_.withNullability(newNullability = true) == attr)
+  }
+  private lazy val unsafeProjForTimedOut = 
UnsafeProjection.create(dedupAttributesWithNull,
+    groupingAttributes ++ valueAttributesWithNull)
+
   override def requiredChildDistribution: Seq[Distribution] =
     StatefulOperatorPartitioning.getCompatibleDistribution(
       groupingAttributes, getStateInfo, conf) :: Nil
@@ -142,12 +151,10 @@ case class FlatMapGroupsInPandasWithStateExec(
           state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp < 
timeoutThreshold
         }
 
+        val emptyValueRow = new GenericInternalRow(
+          Array.fill(valueAttributesWithNull.length)(null: Any))
         val processIter = timingOutPairs.map { stateData =>
-          val joinedKeyRow = unsafeProj(
-            new JoinedRow(
-              stateData.keyRow,
-              new 
GenericInternalRow(Array.fill(dedupAttributesWithNull.length)(null: Any))))
-
+          val joinedKeyRow = unsafeProjForTimedOut(new 
JoinedRow(stateData.keyRow, emptyValueRow))
           (stateData.keyRow, stateData, Iterator.single(joinedKeyRow))
         }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala
index 785cbec9805..a83f7cce4c1 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala
@@ -803,4 +803,87 @@ class FlatMapGroupsInPandasWithStateSuite extends 
StateStoreMetricsTest {
         total = Seq(1), updated = Seq(1), droppedByWatermark = Seq(0), removed 
= Some(Seq(1)))
     )
   }
+
+  test("SPARK-41261: applyInPandasWithState - key in user function should be 
correct for " +
+    "timed out state despite of place for key columns") {
+    assume(shouldTestPandasUDFs)
+
+    // Function to maintain the count as state and set the proc. time timeout 
delay of 10 seconds.
+    // It returns the count if changed, or -1 if the state was removed by 
timeout.
+    val pythonScript =
+    """
+      |import pandas as pd
+      |from pyspark.sql.types import StructType, StructField, StringType
+      |
+      |tpe = StructType([
+      |    StructField("key1", StringType()),
+      |    StructField("key2", StringType()),
+      |    StructField("countAsStr", StringType())])
+      |
+      |def func(key, pdf_iter, state):
+      |    ret = None
+      |    if state.hasTimedOut:
+      |        state.remove()
+      |        yield pd.DataFrame({'key1': [key[0]], 'key2': [key[1]], 
'countAsStr': [str(-1)]})
+      |    else:
+      |        count = state.getOption
+      |        if count is None:
+      |            count = 0
+      |        else:
+      |            count = count[0]
+      |
+      |        for pdf in pdf_iter:
+      |            count += len(pdf)
+      |
+      |        state.update((count, ))
+      |        state.setTimeoutDuration(10000)
+      |        yield pd.DataFrame({'key1': [key[0]], 'key2': [key[1]], 
'countAsStr': [str(count)]})
+      |""".stripMargin
+    val pythonFunc = TestGroupedMapPandasUDFWithState(
+      name = "pandas_grouped_map_with_state", pythonScript = pythonScript)
+
+    val clock = new StreamManualClock
+    val inputData = MemoryStream[String]
+    // schema: val1, key2, val2, key1, val3
+    val inputDataDS = inputData.toDS
+      .withColumnRenamed("value", "val1")
+      .withColumn("key2", $"val1")
+      // the type of columns with string literal will be non-nullable
+      .withColumn("val2", lit("__FAKE__"))
+      .withColumn("key1", lit("__FAKE__"))
+      .withColumn("val3", lit("__FAKE__"))
+    val outputStructType = StructType(
+      Seq(
+        StructField("key1", StringType),
+        StructField("key2", StringType),
+        StructField("countAsStr", StringType)))
+    val stateStructType = StructType(Seq(StructField("count", LongType)))
+    val result =
+      inputDataDS
+        // grouping columns: key1, key2 (swapped order)
+        .groupBy("key1", "key2")
+        .applyInPandasWithState(
+          pythonFunc(
+            inputDataDS("val1"), inputDataDS("key2"), inputDataDS("val2"), 
inputDataDS("key1"),
+            inputDataDS("val3")
+          ).expr.asInstanceOf[PythonUDF],
+          outputStructType,
+          stateStructType,
+          "Update",
+          "ProcessingTimeTimeout")
+
+    testStream(result, Update)(
+      StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock),
+      AddData(inputData, "a"),
+      AdvanceManualClock(1 * 1000),
+      CheckNewAnswer(("__FAKE__", "a", "1")),
+      assertNumStateRows(total = 1, updated = 1),
+
+      AddData(inputData, "b"),
+      AdvanceManualClock(11 * 1000),
+      CheckNewAnswer(("__FAKE__", "a", "-1"), ("__FAKE__", "b", "1")),
+      assertNumStateRows(
+        total = Seq(1), updated = Seq(1), droppedByWatermark = Seq(0), removed 
= Some(Seq(1)))
+    )
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to