Yicong-Huang commented on code in PR #55552: URL: https://github.com/apache/spark/pull/55552#discussion_r3164738062
########## python/pyspark/sql/tests/pandas/test_pipelined_udf.py: ########## @@ -0,0 +1,247 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Tests for pipelined Python UDF execution mode. + +These tests run with spark.python.udf.pipelined.enabled=true to verify +correctness of the pipelined data transfer path for various UDF types. +""" + +import os +import unittest + +from pyspark import SparkConf +from pyspark.sql.functions import col, pandas_udf, udf +from pyspark.sql.types import ( + DoubleType, + LongType, + StringType, + StructType, + StructField, +) +from pyspark.testing.sqlutils import ReusedSQLTestCase +from pyspark.testing.utils import ( + have_pandas, + have_pyarrow, + pandas_requirement_message, + pyarrow_requirement_message, +) + +if have_pandas: + import pandas as pd + + [email protected]( + not have_pandas or not have_pyarrow, + pandas_requirement_message or pyarrow_requirement_message, +) +class PipelinedUDFTests(ReusedSQLTestCase): + """Tests that run with pipelined mode enabled.""" + + @classmethod + def conf(cls): + return ( + SparkConf() + .set("spark.python.udf.pipelined.enabled", "true") + .set("spark.sql.execution.arrow.pyspark.enabled", "true") + ) + + def test_pipelined_mode_is_active(self): + """Verify the pipelined code path is actually being used.""" + + @pandas_udf(StringType()) + def check_env(x: pd.Series) -> pd.Series: + flag = os.environ.get("SPARK_PIPELINED_UDF_ACTIVE", "not_set") + return pd.Series([flag] * len(x)) + + result = self.spark.range(1).select(check_env(col("id"))).first()[0] + self.assertEqual(result, "1", "pipelined_process() should set SPARK_PIPELINED_UDF_ACTIVE=1") Review Comment: Brittle: relies on the leaky `os.environ` set in `pipelined_process()` (see worker.py:3598). With worker reuse, a stale env can pass this even when the current task didn't use the pipelined path. Prefer a JVM-side metric/accumulator. ########## core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala: ########## @@ -359,10 +378,115 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( } // Return an iterator that read lines from the process's stdout - val dataIn = new DataInputStream(new BufferedInputStream( - new ReaderInputStream(worker, writer, handle, - faultHandlerEnabled, idleTimeoutSeconds, killOnIdleTimeout, context), - bufferSize)) + val dataIn = if (usePipelined) { Review Comment: nit: this `if (usePipelined) { ... }` block is ~100 lines inline. Extracting to a private `createPipelinedDataIn(...)` would help readability. ########## python/pyspark/worker.py: ########## @@ -3588,12 +3588,93 @@ def process(): if hasattr(out_iter, "close"): out_iter.close() + def pipelined_process(): + """ + Pipelined variant of process() that pre-fetches input batches in a background + reader thread while the main thread computes the UDF and writes output. + This allows input deserialization to overlap with UDF computation. + """ + # Mark that pipelined mode is active so UDFs can verify the code path. + os.environ["SPARK_PIPELINED_UDF_ACTIVE"] = "1" Review Comment: Env var leaks across tasks under `spark.python.worker.reuse=true`: this is set but never cleared. A worker first used for pipelined keeps `=1` for subsequent tasks, defeating the `test_pipelined_mode_is_active` check. Clear in `finally`, or use a Python module-level flag instead of `os.environ`. ########## core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala: ########## @@ -121,6 +121,17 @@ private[spark] object PythonEvalType { private[spark] object BasePythonRunner extends Logging { + /** + * Shared thread pool for pipelined writer tasks. Using a cached thread pool ensures that + * writer threads are reused across tasks, which keeps JIT-compiled code, branch prediction + * history, and CPU caches warm. + * Bounded by executor cores since each task uses at most one writer thread. + */ + private[python] lazy val pipelinedWriterThreadPool = { + val maxThreads = SparkEnv.get.conf.get(EXECUTOR_CORES) Review Comment: `EXECUTOR_CORES` defaults to `1` in `local[*]` mode, so the cached pool queues all writers serially behind one thread. would it be a problem? ########## python/pyspark/worker.py: ########## @@ -3588,12 +3588,93 @@ def process(): if hasattr(out_iter, "close"): out_iter.close() + def pipelined_process(): + """ + Pipelined variant of process() that pre-fetches input batches in a background + reader thread while the main thread computes the UDF and writes output. + This allows input deserialization to overlap with UDF computation. + """ + # Mark that pipelined mode is active so UDFs can verify the code path. + os.environ["SPARK_PIPELINED_UDF_ACTIVE"] = "1" + import queue + import threading + + queue_depth = int(os.environ.get("SPARK_PIPELINED_UDF_QUEUE_DEPTH", "2")) + _SENTINEL = object() + input_queue = queue.Queue(maxsize=queue_depth) + reader_error = [None] + stop_event = threading.Event() + + def _reader_thread(): + try: + for batch in deserializer.load_stream(infile): + # Some serializers (e.g., ArrowStreamGroupSerializer, + # ArrowStreamAggPandasUDFSerializer) yield lazy iterators + # that still read from infile. Materialize them here so the + # main thread can consume them without touching infile. + if hasattr(batch, "__next__"): + batch = list(batch) Review Comment: Eagerly materializing lazy iterators (grouped/aggregate UDFs) and queueing up to `queue_depth` of them risks OOM on large groups where sync mode wouldn't. I wonder how come this was not reflected on the ASV peakmem bench... ########## python/pyspark/sql/pandas/serializers.py: ########## @@ -136,17 +136,24 @@ def __init__(self, write_start_stream: bool = False) -> None: def dump_stream(self, iterator: Iterable["pa.RecordBatch"], stream: IO[bytes]) -> None: """Optionally prepend START_ARROW_STREAM, then write batches.""" + import os Review Comment: nit: `import os` at function scope is unusual. the rest of this file imports at module top. ########## python/pyspark/sql/tests/pandas/bench_pipelined_udf.py: ########## @@ -0,0 +1,286 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Benchmark: Pipelined vs synchronous JVM-Python UDF data transfer. + +Compares end-to-end execution time of Python UDFs with +spark.python.udf.pipelined.enabled = true vs false. + +Because spark.python.udf.pipelined.enabled is a SparkConf-level config (read at +SparkContext startup), each benchmark scenario runs in a separate subprocess with +its own SparkSession to ensure the config takes effect. + +Note: In local[1] mode (single core), pipelined mode may show overhead because +the writer thread and selector thread compete for the same CPU. The benefit of +pipeline parallelism is expected on multi-core executors where serialization can +overlap with output reading. + +Usage: + cd $SPARK_HOME + # Build Spark first (needed for PySpark to find JVM jars): + # build/sbt -Phive package + # cd python && zip -r lib/pyspark.zip pyspark && cd .. + python python/pyspark/sql/tests/pandas/bench_pipelined_udf.py \ + [--rows N] [--iterations N] [--partitions N] [--sleep-ms N] +""" + +import argparse +import json +import os +import subprocess +import sys + + +SPARK_HOME = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../../..") +PIPELINED_CONF = "spark.python.udf.pipelined.enabled" +QUEUE_DEPTH_CONF = "spark.python.udf.pipelined.queueDepth" + + +# ---- Subprocess worker script template ---- +# Each benchmark scenario is run in a fresh Python process to get a fresh SparkContext. +WORKER_TEMPLATE = """ +import os, sys, time, json +sys.path.insert(0, "{spark_home}") + +import pandas as pd +from pyspark.sql import SparkSession +from pyspark.sql.functions import pandas_udf, col +from pyspark.sql.types import LongType + +spark = ( + SparkSession.builder.master("{master}") + .appName("PipelinedUDFBench") + .config("spark.sql.execution.arrow.pyspark.enabled", "true") + .config("spark.python.worker.reuse", "true") + .config("spark.ui.enabled", "false") + .config("spark.sql.shuffle.partitions", "1") + .config("{pipelined_conf}", "{pipelined}") + .config("{queue_depth_conf}", "{queue_depth}") + .getOrCreate() +) + +{udf_code} + +df = {make_df_code} + +# Warmup +for _ in range({warmup}): + df.write.format("noop").mode("overwrite").save() + +# Timed runs +times = [] +for _ in range({iterations}): + start = time.perf_counter() + df.write.format("noop").mode("overwrite").save() + elapsed = time.perf_counter() - start + times.append(elapsed) + +# Output results as JSON to stdout +print("BENCH_RESULT:" + json.dumps(times)) +spark.stop() +""" + + +def run_subprocess(pipelined, udf_code, make_df_code, args): + """Run a benchmark in a fresh subprocess, return list of timing results.""" + script = WORKER_TEMPLATE.format( + spark_home=os.path.abspath(SPARK_HOME), + master=args.master, + pipelined_conf=PIPELINED_CONF, + pipelined="true" if pipelined else "false", + queue_depth_conf=QUEUE_DEPTH_CONF, + queue_depth=args.queue_depth, + udf_code=udf_code, + make_df_code=make_df_code, + warmup=args.warmup, + iterations=args.iterations, + ) + env = os.environ.copy() + env["SPARK_HOME"] = os.path.abspath(SPARK_HOME) + py4j_zip = os.path.join(os.path.abspath(SPARK_HOME), "python/lib/py4j-0.10.9.9-src.zip") + pyspark_path = os.path.join(os.path.abspath(SPARK_HOME), "python") + env["PYTHONPATH"] = f"{pyspark_path}:{py4j_zip}:" + env.get("PYTHONPATH", "") + + result = subprocess.run( + [sys.executable, "-c", script], capture_output=True, text=True, env=env, timeout=600 + ) + + for line in result.stdout.splitlines(): + if line.startswith("BENCH_RESULT:"): + return json.loads(line[len("BENCH_RESULT:") :]) + + print(" ERROR: no BENCH_RESULT in output") + print(" STDERR (last 500 chars):", result.stderr[-500:] if result.stderr else "<empty>") + return None + + +def print_stats(label, times): + if not times: + print(f" {label:40s} FAILED") + return 0.0 + avg = sum(times) / len(times) + mn = min(times) + mx = max(times) + print( + f" {label:40s} " + f"avg = {avg * 1000:8.1f} ms " + f"min = {mn * 1000:8.1f} ms " + f"max = {mx * 1000:8.1f} ms " + f"({len(times)} iters)" + ) + return avg + + +def run_benchmark(label, udf_code, make_df_code, args): + """Run sync and pipelined in separate subprocesses, print comparison.""" + print(f" [{label}]") + + sync_times = run_subprocess(False, udf_code, make_df_code, args) + sync_avg = print_stats("sync (pipelined=false)", sync_times) + + pipe_times = run_subprocess(True, udf_code, make_df_code, args) + pipe_avg = print_stats("pipelined (pipelined=true)", pipe_times) + + if pipe_avg > 0 and sync_avg > 0: + speedup = sync_avg / pipe_avg + diff_ms = (sync_avg - pipe_avg) * 1000 + marker = "faster" if speedup > 1.0 else "slower" + print(f" --> pipelined is {speedup:.2f}x {marker} ({diff_ms:+.1f} ms)") + print() + return sync_avg, pipe_avg + + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark pipelined vs synchronous Python UDF data transfer" + ) + parser.add_argument("--rows", type=int, default=1_000_000, + help="Rows for standard benchmarks (default: 1000000)") + parser.add_argument("--large-rows", type=int, default=5_000_000, + help="Rows for large data benchmark (default: 5000000)") + parser.add_argument("--iterations", type=int, default=5, + help="Timed iterations per scenario (default: 5)") + parser.add_argument("--warmup", type=int, default=2, + help="Warmup iterations (default: 2)") + parser.add_argument("--partitions", type=int, default=1, + help="Number of partitions (default: 1)") + parser.add_argument("--sleep-ms", type=float, default=10.0, + help="Sleep time in ms per batch for heavy UDF (default: 10.0)") + parser.add_argument("--queue-depth", type=int, default=2, + help="Pipelined queue depth (default: 2)") + parser.add_argument("--master", type=str, default="local[1]", Review Comment: nit: default `--master local[1]` mismatches the PR description's reported `local[*]` numbers. Either change the default or update the description to clarify how the table was produced. ########## python/pyspark/sql/pandas/serializers.py: ########## @@ -136,17 +136,24 @@ def __init__(self, write_start_stream: bool = False) -> None: def dump_stream(self, iterator: Iterable["pa.RecordBatch"], stream: IO[bytes]) -> None: """Optionally prepend START_ARROW_STREAM, then write batches.""" + import os + iterator = iter(iterator) if self._write_start_stream: iterator = self._write_stream_start(iterator, stream) import pyarrow as pa + pipelined = os.environ.get("SPARK_PIPELINED_UDF") == "1" writer = None try: for batch in iterator: if writer is None: writer = pa.RecordBatchStreamWriter(stream, batch.schema) writer.write_batch(batch) + # In pipelined mode, flush after each batch so the JVM can read output + # while still sending input, rather than buffering all output. + if pipelined: + stream.flush() Review Comment: Gating per-batch flush on a global env var means *any* sync-mode process where `SPARK_PIPELINED_UDF=1` happens to be set pays a syscall per batch. Consider passing as a serializer ctor arg so this can't accidentally affect sync mode. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
