This is an automated email from the ASF dual-hosted git repository. lixiao pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 9c4eb99 [SPARK-27870][SQL][PYSPARK] Flush batch timely for pandas UDF (for improving pandas UDFs pipeline) 9c4eb99 is described below commit 9c4eb99c52803f2488ac3787672aa8d3e4d1544e Author: WeichenXu <weichen...@databricks.com> AuthorDate: Fri Jun 7 14:02:43 2019 -0700 [SPARK-27870][SQL][PYSPARK] Flush batch timely for pandas UDF (for improving pandas UDFs pipeline) ## What changes were proposed in this pull request? Flush batch timely for pandas UDF. This could improve performance when multiple pandas UDF plans are pipelined. When batch being flushed in time, downstream pandas UDFs will get pipelined as soon as possible, and pipeline will help hide the donwstream UDFs computation time. For example: When the first UDF start computing on batch-3, the second pipelined UDF can start computing on batch-2, and the third pipelined UDF can start computing on batch-1. If we do not flush each batch in time, the donwstream UDF's pipeline will lag behind too much, which may increase the total processing time. I add flush at two places: * JVM process feed data into python worker. In jvm side, when write one batch, flush it * VM process read data from python worker output, In python worker side, when write one batch, flush it If no flush, the default buffer size for them are both 65536. Especially in the ML case, in order to make realtime prediction, we will make batch size very small. The buffer size is too large for the case, which cause downstream pandas UDF pipeline lag behind too much. ### Note * This is only applied to pandas scalar UDF. * Do not flush for each batch. The minimum interval between two flush is 0.1 second. This avoid too frequent flushing when batch size is small. It works like: ``` last_flush_time = time.time() for batch in iterator: writer.write_batch(batch) flush_time = time.time() if self.flush_timely and (flush_time - last_flush_time > 0.1): stream.flush() last_flush_time = flush_time ``` ## How was this patch tested? ### Benchmark to make sure the flush do not cause performance regression #### Test code: ``` numRows = ... batchSize = ... spark.conf.set('spark.sql.execution.arrow.maxRecordsPerBatch', str(batchSize)) df = spark.range(1, numRows + 1, numPartitions=1).select(col('id').alias('a')) pandas_udf("int", PandasUDFType.SCALAR) def fp1(x): return x + 10 beg_time = time.time() result = df.select(sum(fp1('a'))).head() print("result: " + str(result[0])) print("consume time: " + str(time.time() - beg_time)) ``` #### Test Result: params | Consume time (Before) | Consume time (After) ------------ | ----------------------- | ---------------------- numRows=100000000, batchSize=10000 | 23.43s | 24.64s numRows=100000000, batchSize=1000 | 36.73s | 34.50s numRows=10000000, batchSize=100 | 35.67s | 32.64s numRows=1000000, batchSize=10 | 33.60s | 32.11s numRows=100000, batchSize=1 | 33.36s | 31.82s ### Benchmark pipelined pandas UDF #### Test code: ``` spark.conf.set('spark.sql.execution.arrow.maxRecordsPerBatch', '1') df = spark.range(1, 31, numPartitions=1).select(col('id').alias('a')) pandas_udf("int", PandasUDFType.SCALAR) def fp1(x): print("run fp1") time.sleep(1) return x + 100 pandas_udf("int", PandasUDFType.SCALAR) def fp2(x, y): print("run fp2") time.sleep(1) return x + y beg_time = time.time() result = df.select(sum(fp2(fp1('a'), col('a')))).head() print("result: " + str(result[0])) print("consume time: " + str(time.time() - beg_time)) ``` #### Test Result: **Before**: consume time: 63.57s **After**: consume time: 32.43s **So the PR improve performance by make downstream UDF get pipelined early.** Please review https://spark.apache.org/contributing.html before opening a pull request. Closes #24734 from WeichenXu123/improve_pandas_udf_pipeline. Lead-authored-by: WeichenXu <weichen...@databricks.com> Co-authored-by: Xiangrui Meng <m...@databricks.com> Signed-off-by: gatorsmile <gatorsm...@gmail.com> --- python/pyspark/serializers.py | 18 ++++++++++++++++-- python/pyspark/testing/utils.py | 3 +++ python/pyspark/tests/test_serializers.py | 10 ++++++++++ .../sql/execution/python/ArrowPythonRunner.scala | 19 ++++++++++++------- 4 files changed, 41 insertions(+), 9 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 516ee7e..1b17e60 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -58,6 +58,7 @@ import types import collections import zlib import itertools +import time if sys.version < '3': import cPickle as pickle @@ -230,11 +231,19 @@ class ArrowStreamSerializer(Serializer): def dump_stream(self, iterator, stream): import pyarrow as pa writer = None + last_flush_time = time.time() try: for batch in iterator: if writer is None: writer = pa.RecordBatchStreamWriter(stream, batch.schema) writer.write_batch(batch) + current_time = time.time() + # If it takes time to compute each input batch but per-batch data is very small, + # the data might stay in the buffer for long and downstream reader cannot read it. + # We want to flush timely in this case. + if current_time - last_flush_time > 0.1: + stream.flush() + last_flush_time = current_time finally: if writer is not None: writer.close() @@ -872,11 +881,16 @@ class ChunkedStream(object): byte_pos = new_byte_pos self.current_pos = 0 - def close(self): - # if there is anything left in the buffer, write it out first + def flush(self): if self.current_pos > 0: write_int(self.current_pos, self.wrapped) self.wrapped.write(self.buffer[:self.current_pos]) + self.current_pos = 0 + self.wrapped.flush() + + def close(self): + # If there is anything left in the buffer, write it out first. + self.flush() # -1 length indicates to the receiving end that we're done. write_int(-1, self.wrapped) self.wrapped.close() diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py index 2b42b89..61c342b 100644 --- a/python/pyspark/testing/utils.py +++ b/python/pyspark/testing/utils.py @@ -99,6 +99,9 @@ class ByteArrayOutput(object): def write(self, b): self.buffer += b + def flush(self): + pass + def close(self): pass diff --git a/python/pyspark/tests/test_serializers.py b/python/pyspark/tests/test_serializers.py index bce9406..498076d 100644 --- a/python/pyspark/tests/test_serializers.py +++ b/python/pyspark/tests/test_serializers.py @@ -225,6 +225,16 @@ class SerializersTest(unittest.TestCase): # ends with a -1 self.assertEqual(dest.buffer[-4:], write_int(-1)) + def test_chunked_stream_flush(self): + wrapped = ByteArrayOutput() + stream = serializers.ChunkedStream(wrapped, 10) + stream.write(bytearray([0])) + self.assertEqual(len(wrapped.buffer), 0, "small write should be buffered") + stream.flush() + # Expect buffer size 4 bytes + buffer data 1 byte. + self.assertEqual(len(wrapped.buffer), 5, "flush should work") + stream.close() + if __name__ == "__main__": from pyspark.tests.test_serializers import * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 3710218..ddb65a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -78,16 +78,21 @@ class ArrowPythonRunner( val arrowWriter = ArrowWriter.create(root) val writer = new ArrowStreamWriter(root, null, dataOut) writer.start() - - while (inputIterator.hasNext) { - val nextBatch = inputIterator.next() - - while (nextBatch.hasNext) { - arrowWriter.write(nextBatch.next()) + var lastFlushTime = System.currentTimeMillis() + inputIterator.foreach { batch => + batch.foreach { row => + arrowWriter.write(row) } - arrowWriter.finish() writer.writeBatch() + val currentTime = System.currentTimeMillis() + // If it takes time to compute each input batch but per-batch data is very small, + // the data might stay in the buffer for long and downstream reader cannot read it. + // We want to flush timely in this case. + if (currentTime - lastFlushTime > 100) { + dataOut.flush() + lastFlushTime = currentTime + } arrowWriter.reset() } // end writes footer to the output stream and doesn't clean any resources. --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org