This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 37517df68c5 [SPARK-40571][SS][TESTS] Construct a new test case for 
applyInPandasWithState to verify fault-tolerance semantic with random python 
worker failures
37517df68c5 is described below

commit 37517df68c5805a2dcff5c0c41ea273eae92ed0c
Author: Jungtaek Lim <kabhwan.opensou...@gmail.com>
AuthorDate: Tue Sep 27 17:41:28 2022 +0900

    [SPARK-40571][SS][TESTS] Construct a new test case for 
applyInPandasWithState to verify fault-tolerance semantic with random python 
worker failures
    
    ### What changes were proposed in this pull request?
    
    This PR proposes a new test case for applyInPandasWithState to verify 
fault-tolerance semantic is not broken despite of random python worker failure. 
If the sink provides end-to-end exactly-once, the query should respect the 
guarantee. Otherwise, the query should respect stateful exactly-once, but 
at-least-once in terms of outputs.
    
    The test leverages file stream sink which is end-to-end exactly-once, but 
to make the verification simpler, we just verify whether the stateful 
exactly-once is guaranteed despite of python worker failures.
    
    ### Why are the changes needed?
    
    This strengthen the test coverage, especially the fault-tolerance semantic.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    New test added.  Manually ran `./python/run-tests --testnames 
'pyspark.sql.tests.test_pandas_grouped_map_with_state'` 10 times and all 
succeeded.
    
    Closes #38008 from HeartSaVioR/SPARK-40571.
    
    Authored-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
---
 .../tests/test_pandas_grouped_map_with_state.py    | 149 ++++++++++++++++++++-
 1 file changed, 147 insertions(+), 2 deletions(-)

diff --git a/python/pyspark/sql/tests/test_pandas_grouped_map_with_state.py 
b/python/pyspark/sql/tests/test_pandas_grouped_map_with_state.py
index 7eb3bb92b84..8671cc8519c 100644
--- a/python/pyspark/sql/tests/test_pandas_grouped_map_with_state.py
+++ b/python/pyspark/sql/tests/test_pandas_grouped_map_with_state.py
@@ -15,9 +15,16 @@
 # limitations under the License.
 #
 
+import random
+import shutil
+import string
+import sys
+import tempfile
+
 import unittest
 from typing import cast
 
+from pyspark import SparkConf
 from pyspark.sql.streaming.state import GroupStateTimeout, GroupState
 from pyspark.sql.types import (
     LongType,
@@ -33,6 +40,7 @@ from pyspark.testing.sqlutils import (
     pandas_requirement_message,
     pyarrow_requirement_message,
 )
+from pyspark.testing.utils import eventually
 
 if have_pandas:
     import pandas as pd
@@ -46,8 +54,23 @@ if have_pyarrow:
     cast(str, pandas_requirement_message or pyarrow_requirement_message),
 )
 class GroupedMapInPandasWithStateTests(ReusedSQLTestCase):
+    @classmethod
+    def conf(cls):
+        cfg = SparkConf()
+        cfg.set("spark.sql.shuffle.partitions", "5")
+        return cfg
+
     def test_apply_in_pandas_with_state_basic(self):
-        df = 
self.spark.readStream.format("text").load("python/test_support/sql/streaming")
+        input_path = tempfile.mkdtemp()
+
+        def prepare_test_resource():
+            with open(input_path + "/text-test.txt", "w") as fw:
+                fw.write("hello\n")
+                fw.write("this\n")
+
+        prepare_test_resource()
+
+        df = self.spark.readStream.format("text").load(input_path)
 
         for q in self.spark.streams.active:
             q.stop()
@@ -71,7 +94,7 @@ class GroupedMapInPandasWithStateTests(ReusedSQLTestCase):
 
         def check_results(batch_df, _):
             self.assertEqual(
-                set(batch_df.collect()),
+                set(batch_df.sort("key").collect()),
                 {Row(key="hello", countAsString="1"), Row(key="this", 
countAsString="1")},
             )
 
@@ -90,6 +113,128 @@ class GroupedMapInPandasWithStateTests(ReusedSQLTestCase):
         self.assertTrue(q.isActive)
         q.processAllAvailable()
 
+    def test_apply_in_pandas_with_state_python_worker_random_failure(self):
+        input_path = tempfile.mkdtemp()
+        output_path = tempfile.mkdtemp()
+        checkpoint_loc = tempfile.mkdtemp()
+
+        shutil.rmtree(output_path)
+        shutil.rmtree(checkpoint_loc)
+
+        def prepare_test_resource():
+            data_range = list(string.ascii_lowercase)
+            for i in range(5):
+                picked_data = [
+                    data_range[random.randrange(0, len(data_range) - 1)] for x 
in range(100)
+                ]
+
+                with open(input_path + "/part-%i.txt" % i, "w") as fw:
+                    for data in picked_data:
+                        fw.write(data + "\n")
+
+        def run_query():
+            df = (
+                self.spark.readStream.format("text")
+                .option("maxFilesPerTrigger", "1")
+                .load(input_path)
+            )
+
+            for q in self.spark.streams.active:
+                q.stop()
+            self.assertTrue(df.isStreaming)
+
+            output_type = StructType(
+                [StructField("value", StringType()), StructField("count", 
LongType())]
+            )
+            state_type = StructType([StructField("cnt", LongType())])
+
+            def func(key, pdf_iter, state):
+                assert isinstance(state, GroupState)
+
+                # user function call will happen at most 26 times
+                # should be huge enough to not trigger kill in every batches
+                # but should be also reasonable to trigger kill multiple times 
across batches
+                if random.randrange(30) == 1:
+                    sys.exit(1)
+
+                count = state.getOption
+                if count is None:
+                    count = 0
+                else:
+                    count = count[0]
+
+                for pdf in pdf_iter:
+                    count += len(pdf)
+
+                state.update((count,))
+                yield pd.DataFrame({"value": [key[0]], "count": [count]})
+
+            query = (
+                df.groupBy(df["value"])
+                .applyInPandasWithState(
+                    func, output_type, state_type, "Append", 
GroupStateTimeout.NoTimeout
+                )
+                .writeStream.queryName("this_query")
+                .format("json")
+                .outputMode("append")
+                .option("path", output_path)
+                .option("checkpointLocation", checkpoint_loc)
+                .start()
+            )
+
+            return query
+
+        prepare_test_resource()
+
+        expected = (
+            self.spark.read.format("text")
+            .load(input_path)
+            .groupBy("value")
+            .count()
+            .sort("value")
+            .collect()
+        )
+
+        q = run_query()
+        self.assertEqual(q.name, "this_query")
+        self.assertTrue(q.isActive)
+
+        def assert_test():
+            nonlocal q
+            if not q.isActive:
+                print("query has been terminated, rerunning query...")
+
+                # rerunning query as the query may have been killed by killed 
python worker
+                q = run_query()
+
+                self.assertEqual(q.name, "this_query")
+                self.assertTrue(q.isActive)
+
+            curr_status = q.status
+            if not curr_status["isDataAvailable"] and not 
curr_status["isTriggerActive"]:
+                # The query is active but not running due to no further data 
available
+                # Check the output now.
+                result = (
+                    self.spark.read.schema("value string, count int")
+                    .format("json")
+                    .load(output_path)
+                    .groupBy("value")
+                    .max("count")
+                    .selectExpr("value", "`max(count)` AS count")
+                    .sort("value")
+                    .collect()
+                )
+
+                return result == expected
+            else:
+                # still processing the data, defer checking the output.
+                return False
+
+        try:
+            eventually(assert_test, timeout=120)
+        finally:
+            q.stop()
+
 
 if __name__ == "__main__":
     from pyspark.sql.tests.test_pandas_grouped_map_with_state import *  # 
noqa: F401


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to