This is an automated email from the ASF dual-hosted git repository. kabhwan pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 37517df68c5 [SPARK-40571][SS][TESTS] Construct a new test case for applyInPandasWithState to verify fault-tolerance semantic with random python worker failures 37517df68c5 is described below commit 37517df68c5805a2dcff5c0c41ea273eae92ed0c Author: Jungtaek Lim <kabhwan.opensou...@gmail.com> AuthorDate: Tue Sep 27 17:41:28 2022 +0900 [SPARK-40571][SS][TESTS] Construct a new test case for applyInPandasWithState to verify fault-tolerance semantic with random python worker failures ### What changes were proposed in this pull request? This PR proposes a new test case for applyInPandasWithState to verify fault-tolerance semantic is not broken despite of random python worker failure. If the sink provides end-to-end exactly-once, the query should respect the guarantee. Otherwise, the query should respect stateful exactly-once, but at-least-once in terms of outputs. The test leverages file stream sink which is end-to-end exactly-once, but to make the verification simpler, we just verify whether the stateful exactly-once is guaranteed despite of python worker failures. ### Why are the changes needed? This strengthen the test coverage, especially the fault-tolerance semantic. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New test added. Manually ran `./python/run-tests --testnames 'pyspark.sql.tests.test_pandas_grouped_map_with_state'` 10 times and all succeeded. Closes #38008 from HeartSaVioR/SPARK-40571. Authored-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> --- .../tests/test_pandas_grouped_map_with_state.py | 149 ++++++++++++++++++++- 1 file changed, 147 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests/test_pandas_grouped_map_with_state.py b/python/pyspark/sql/tests/test_pandas_grouped_map_with_state.py index 7eb3bb92b84..8671cc8519c 100644 --- a/python/pyspark/sql/tests/test_pandas_grouped_map_with_state.py +++ b/python/pyspark/sql/tests/test_pandas_grouped_map_with_state.py @@ -15,9 +15,16 @@ # limitations under the License. # +import random +import shutil +import string +import sys +import tempfile + import unittest from typing import cast +from pyspark import SparkConf from pyspark.sql.streaming.state import GroupStateTimeout, GroupState from pyspark.sql.types import ( LongType, @@ -33,6 +40,7 @@ from pyspark.testing.sqlutils import ( pandas_requirement_message, pyarrow_requirement_message, ) +from pyspark.testing.utils import eventually if have_pandas: import pandas as pd @@ -46,8 +54,23 @@ if have_pyarrow: cast(str, pandas_requirement_message or pyarrow_requirement_message), ) class GroupedMapInPandasWithStateTests(ReusedSQLTestCase): + @classmethod + def conf(cls): + cfg = SparkConf() + cfg.set("spark.sql.shuffle.partitions", "5") + return cfg + def test_apply_in_pandas_with_state_basic(self): - df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") + input_path = tempfile.mkdtemp() + + def prepare_test_resource(): + with open(input_path + "/text-test.txt", "w") as fw: + fw.write("hello\n") + fw.write("this\n") + + prepare_test_resource() + + df = self.spark.readStream.format("text").load(input_path) for q in self.spark.streams.active: q.stop() @@ -71,7 +94,7 @@ class GroupedMapInPandasWithStateTests(ReusedSQLTestCase): def check_results(batch_df, _): self.assertEqual( - set(batch_df.collect()), + set(batch_df.sort("key").collect()), {Row(key="hello", countAsString="1"), Row(key="this", countAsString="1")}, ) @@ -90,6 +113,128 @@ class GroupedMapInPandasWithStateTests(ReusedSQLTestCase): self.assertTrue(q.isActive) q.processAllAvailable() + def test_apply_in_pandas_with_state_python_worker_random_failure(self): + input_path = tempfile.mkdtemp() + output_path = tempfile.mkdtemp() + checkpoint_loc = tempfile.mkdtemp() + + shutil.rmtree(output_path) + shutil.rmtree(checkpoint_loc) + + def prepare_test_resource(): + data_range = list(string.ascii_lowercase) + for i in range(5): + picked_data = [ + data_range[random.randrange(0, len(data_range) - 1)] for x in range(100) + ] + + with open(input_path + "/part-%i.txt" % i, "w") as fw: + for data in picked_data: + fw.write(data + "\n") + + def run_query(): + df = ( + self.spark.readStream.format("text") + .option("maxFilesPerTrigger", "1") + .load(input_path) + ) + + for q in self.spark.streams.active: + q.stop() + self.assertTrue(df.isStreaming) + + output_type = StructType( + [StructField("value", StringType()), StructField("count", LongType())] + ) + state_type = StructType([StructField("cnt", LongType())]) + + def func(key, pdf_iter, state): + assert isinstance(state, GroupState) + + # user function call will happen at most 26 times + # should be huge enough to not trigger kill in every batches + # but should be also reasonable to trigger kill multiple times across batches + if random.randrange(30) == 1: + sys.exit(1) + + count = state.getOption + if count is None: + count = 0 + else: + count = count[0] + + for pdf in pdf_iter: + count += len(pdf) + + state.update((count,)) + yield pd.DataFrame({"value": [key[0]], "count": [count]}) + + query = ( + df.groupBy(df["value"]) + .applyInPandasWithState( + func, output_type, state_type, "Append", GroupStateTimeout.NoTimeout + ) + .writeStream.queryName("this_query") + .format("json") + .outputMode("append") + .option("path", output_path) + .option("checkpointLocation", checkpoint_loc) + .start() + ) + + return query + + prepare_test_resource() + + expected = ( + self.spark.read.format("text") + .load(input_path) + .groupBy("value") + .count() + .sort("value") + .collect() + ) + + q = run_query() + self.assertEqual(q.name, "this_query") + self.assertTrue(q.isActive) + + def assert_test(): + nonlocal q + if not q.isActive: + print("query has been terminated, rerunning query...") + + # rerunning query as the query may have been killed by killed python worker + q = run_query() + + self.assertEqual(q.name, "this_query") + self.assertTrue(q.isActive) + + curr_status = q.status + if not curr_status["isDataAvailable"] and not curr_status["isTriggerActive"]: + # The query is active but not running due to no further data available + # Check the output now. + result = ( + self.spark.read.schema("value string, count int") + .format("json") + .load(output_path) + .groupBy("value") + .max("count") + .selectExpr("value", "`max(count)` AS count") + .sort("value") + .collect() + ) + + return result == expected + else: + # still processing the data, defer checking the output. + return False + + try: + eventually(assert_test, timeout=120) + finally: + q.stop() + if __name__ == "__main__": from pyspark.sql.tests.test_pandas_grouped_map_with_state import * # noqa: F401 --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org