Repository: spark Updated Branches: refs/heads/master 52738d4e0 -> 6d06ff6f7
[SPARK-17514] df.take(1) and df.limit(1).collect() should perform the same in Python ## What changes were proposed in this pull request? In PySpark, `df.take(1)` runs a single-stage job which computes only one partition of the DataFrame, while `df.limit(1).collect()` computes all partitions and runs a two-stage job. This difference in performance is confusing. The reason why `limit(1).collect()` is so much slower is that `collect()` internally maps to `df.rdd.<some-pyspark-conversions>.toLocalIterator`, which causes Spark SQL to build a query where a global limit appears in the middle of the plan; this, in turn, ends up being executed inefficiently because limits in the middle of plans are now implemented by repartitioning to a single task rather than by running a `take()` job on the driver (this was done in #7334, a patch which was a prerequisite to allowing partition-local limits to be pushed beneath unions, etc.). In order to fix this performance problem I think that we should generalize the fix from SPARK-10731 / #8876 so that `DataFrame.collect()` also delegates to the Scala implementation and shares the same performance properties. This patch modifies `DataFrame.collect()` to first collect all results to the driver and then pass them to Python, allowing this query to be planned using Spark's `CollectLimit` optimizations. ## How was this patch tested? Added a regression test in `sql/tests.py` which asserts that the expected number of jobs, stages, and tasks are run for both queries. Author: Josh Rosen <joshro...@databricks.com> Closes #15068 from JoshRosen/pyspark-collect-limit. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6d06ff6f Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6d06ff6f Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6d06ff6f Branch: refs/heads/master Commit: 6d06ff6f7e2dd72ba8fe96cd875e83eda6ebb2a9 Parents: 52738d4 Author: Josh Rosen <joshro...@databricks.com> Authored: Wed Sep 14 10:10:01 2016 -0700 Committer: Davies Liu <davies....@gmail.com> Committed: Wed Sep 14 10:10:01 2016 -0700 ---------------------------------------------------------------------- python/pyspark/sql/dataframe.py | 5 +---- python/pyspark/sql/tests.py | 18 ++++++++++++++++++ .../main/scala/org/apache/spark/sql/Dataset.scala | 8 ++++++-- .../sql/execution/python/EvaluatePython.scala | 13 +------------ 4 files changed, 26 insertions(+), 18 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/6d06ff6f/python/pyspark/sql/dataframe.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index e5eac91..0f7d8fb 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -357,10 +357,7 @@ class DataFrame(object): >>> df.take(2) [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ - with SCCallSiteSync(self._sc) as css: - port = self._sc._jvm.org.apache.spark.sql.execution.python.EvaluatePython.takeAndServe( - self._jdf, num) - return list(_load_from_socket(port, BatchedSerializer(PickleSerializer()))) + return self.limit(num).collect() @since(1.3) def foreach(self, f): http://git-wip-us.apache.org/repos/asf/spark/blob/6d06ff6f/python/pyspark/sql/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 769e454..1be0b72 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1862,6 +1862,24 @@ class HiveContextSQLTests(ReusedPySparkTestCase): sorted(df.select(functions.collect_list(df.value).alias('r')).collect()[0].r), ["1", "2", "2", "2"]) + def test_limit_and_take(self): + df = self.spark.range(1, 1000, numPartitions=10) + + def assert_runs_only_one_job_stage_and_task(job_group_name, f): + tracker = self.sc.statusTracker() + self.sc.setJobGroup(job_group_name, description="") + f() + jobs = tracker.getJobIdsForGroup(job_group_name) + self.assertEqual(1, len(jobs)) + stages = tracker.getJobInfo(jobs[0]).stageIds + self.assertEqual(1, len(stages)) + self.assertEqual(1, tracker.getStageInfo(stages[0]).numTasks) + + # Regression test for SPARK-10731: take should delegate to Scala implementation + assert_runs_only_one_job_stage_and_task("take", lambda: df.take(1)) + # Regression test for SPARK-17514: limit(n).collect() should the perform same as take(n) + assert_runs_only_one_job_stage_and_task("collect_limit", lambda: df.limit(1).collect()) + if __name__ == "__main__": from pyspark.sql.tests import * http://git-wip-us.apache.org/repos/asf/spark/blob/6d06ff6f/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 3b3cb82..9cfbdff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -29,7 +29,7 @@ import org.apache.commons.lang3.StringUtils import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.function._ -import org.apache.spark.api.python.PythonRDD +import org.apache.spark.api.python.{PythonRDD, SerDeUtil} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst._ @@ -2567,8 +2567,12 @@ class Dataset[T] private[sql]( } private[sql] def collectToPython(): Int = { + EvaluatePython.registerPicklers() withNewExecutionId { - PythonRDD.collectAndServe(javaToPython.rdd) + val toJava: (Any) => Any = EvaluatePython.toJava(_, schema) + val iter = new SerDeUtil.AutoBatchedPickler( + queryExecution.executedPlan.executeCollect().iterator.map(toJava)) + PythonRDD.serveIterator(iter, "serve-DataFrame") } } http://git-wip-us.apache.org/repos/asf/spark/blob/6d06ff6f/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index cf68ed4..724025b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -24,9 +24,8 @@ import scala.collection.JavaConverters._ import net.razorvine.pickle.{IObjectPickler, Opcodes, Pickler} -import org.apache.spark.api.python.{PythonRDD, SerDeUtil} +import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} @@ -34,16 +33,6 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String object EvaluatePython { - def takeAndServe(df: DataFrame, n: Int): Int = { - registerPicklers() - df.withNewExecutionId { - val iter = new SerDeUtil.AutoBatchedPickler( - df.queryExecution.executedPlan.executeTake(n).iterator.map { row => - EvaluatePython.toJava(row, df.schema) - }) - PythonRDD.serveIterator(iter, s"serve-DataFrame") - } - } def needConversionInPython(dt: DataType): Boolean = dt match { case DateType | TimestampType => true --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org