Repository: spark Updated Branches: refs/heads/master f53818d35 -> 9786ce66c
[SPARK-22239][SQL][PYTHON] Enable grouped aggregate pandas UDFs as window functions with unbounded window frames ## What changes were proposed in this pull request? This PR enables using a grouped aggregate pandas UDFs as window functions. The semantics is the same as using SQL aggregation function as window functions. ``` >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> from pyspark.sql import Window >>> df = spark.createDataFrame( ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ... ("id", "v")) >>> pandas_udf("double", PandasUDFType.GROUPED_AGG) ... def mean_udf(v): ... return v.mean() >>> w = Window.partitionBy('id') >>> df.withColumn('mean_v', mean_udf(df['v']).over(w)).show() +---+----+------+ | id| v|mean_v| +---+----+------+ | 1| 1.0| 1.5| | 1| 2.0| 1.5| | 2| 3.0| 6.0| | 2| 5.0| 6.0| | 2|10.0| 6.0| +---+----+------+ ``` The scope of this PR is somewhat limited in terms of: (1) Only supports unbounded window, which acts essentially as group by. (2) Only supports aggregation functions, not "transform" like window functions (n -> n mapping) Both of these are left as future work. Especially, (1) needs careful thinking w.r.t. how to pass rolling window data to python efficiently. (2) is a bit easier but does require more changes therefore I think it's better to leave it as a separate PR. ## How was this patch tested? WindowPandasUDFTests Author: Li Jin <ice.xell...@gmail.com> Closes #21082 from icexelloss/SPARK-22239-window-udf. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9786ce66 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9786ce66 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9786ce66 Branch: refs/heads/master Commit: 9786ce66c52d41b1d58ddedb3a984f561fd09ff3 Parents: f53818d Author: Li Jin <ice.xell...@gmail.com> Authored: Wed Jun 13 09:10:52 2018 +0800 Committer: hyukjinkwon <gurwls...@apache.org> Committed: Wed Jun 13 09:10:52 2018 +0800 ---------------------------------------------------------------------- .../apache/spark/api/python/PythonRunner.scala | 2 + python/pyspark/rdd.py | 1 + python/pyspark/sql/functions.py | 34 ++- python/pyspark/sql/tests.py | 238 +++++++++++++++++++ python/pyspark/worker.py | 20 +- .../spark/sql/catalyst/analysis/Analyzer.scala | 11 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 12 +- .../sql/catalyst/expressions/PythonUDF.scala | 6 +- .../expressions/windowExpressions.scala | 33 ++- .../sql/catalyst/optimizer/Optimizer.scala | 7 +- .../spark/sql/catalyst/planning/patterns.scala | 42 +++- .../spark/sql/execution/SparkPlanner.scala | 1 + .../spark/sql/execution/SparkStrategies.scala | 20 +- .../execution/python/ExtractPythonUDFs.scala | 2 +- .../execution/python/WindowInPandasExec.scala | 173 ++++++++++++++ 15 files changed, 580 insertions(+), 22 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/9786ce66/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 41eac10..ebabedf 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -40,6 +40,7 @@ private[spark] object PythonEvalType { val SQL_SCALAR_PANDAS_UDF = 200 val SQL_GROUPED_MAP_PANDAS_UDF = 201 val SQL_GROUPED_AGG_PANDAS_UDF = 202 + val SQL_WINDOW_AGG_PANDAS_UDF = 203 def toString(pythonEvalType: Int): String = pythonEvalType match { case NON_UDF => "NON_UDF" @@ -47,6 +48,7 @@ private[spark] object PythonEvalType { case SQL_SCALAR_PANDAS_UDF => "SQL_SCALAR_PANDAS_UDF" case SQL_GROUPED_MAP_PANDAS_UDF => "SQL_GROUPED_MAP_PANDAS_UDF" case SQL_GROUPED_AGG_PANDAS_UDF => "SQL_GROUPED_AGG_PANDAS_UDF" + case SQL_WINDOW_AGG_PANDAS_UDF => "SQL_WINDOW_AGG_PANDAS_UDF" } } http://git-wip-us.apache.org/repos/asf/spark/blob/9786ce66/python/pyspark/rdd.py ---------------------------------------------------------------------- diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 14d9128..7e7e582 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -74,6 +74,7 @@ class PythonEvalType(object): SQL_SCALAR_PANDAS_UDF = 200 SQL_GROUPED_MAP_PANDAS_UDF = 201 SQL_GROUPED_AGG_PANDAS_UDF = 202 + SQL_WINDOW_AGG_PANDAS_UDF = 203 def portable_hash(x): http://git-wip-us.apache.org/repos/asf/spark/blob/9786ce66/python/pyspark/sql/functions.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 1cdbb8a..a5e3384 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2616,10 +2616,12 @@ def pandas_udf(f=None, returnType=None, functionType=None): The returned scalar can be either a python primitive type, e.g., `int` or `float` or a numpy data type, e.g., `numpy.int64` or `numpy.float64`. - :class:`ArrayType`, :class:`MapType` and :class:`StructType` are currently not supported as - output types. + :class:`MapType` and :class:`StructType` are currently not supported as output types. - Group aggregate UDFs are used with :meth:`pyspark.sql.GroupedData.agg` + Group aggregate UDFs are used with :meth:`pyspark.sql.GroupedData.agg` and + :class:`pyspark.sql.Window` + + This example shows using grouped aggregated UDFs with groupby: >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> df = spark.createDataFrame( @@ -2636,7 +2638,31 @@ def pandas_udf(f=None, returnType=None, functionType=None): | 2| 6.0| +---+-----------+ - .. seealso:: :meth:`pyspark.sql.GroupedData.agg` + This example shows using grouped aggregated UDFs as window functions. Note that only + unbounded window frame is supported at the moment: + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> from pyspark.sql import Window + >>> df = spark.createDataFrame( + ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ... ("id", "v")) + >>> @pandas_udf("double", PandasUDFType.GROUPED_AGG) # doctest: +SKIP + ... def mean_udf(v): + ... return v.mean() + >>> w = Window.partitionBy('id') \\ + ... .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) + >>> df.withColumn('mean_v', mean_udf(df['v']).over(w)).show() # doctest: +SKIP + +---+----+------+ + | id| v|mean_v| + +---+----+------+ + | 1| 1.0| 1.5| + | 1| 2.0| 1.5| + | 2| 3.0| 6.0| + | 2| 5.0| 6.0| + | 2|10.0| 6.0| + +---+----+------+ + + .. seealso:: :meth:`pyspark.sql.GroupedData.agg` and :class:`pyspark.sql.Window` .. note:: The user-defined functions are considered deterministic by default. Due to optimization, duplicate invocations may be eliminated or the function may even be invoked http://git-wip-us.apache.org/repos/asf/spark/blob/9786ce66/python/pyspark/sql/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 4a3941d..2d7a4f6 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -5454,6 +5454,15 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase): expected1 = df.groupby(df.id).agg(sum(df.v)) self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + def test_array_type(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + df = self.data + + array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array<double>', PandasUDFType.GROUPED_AGG) + result1 = df.groupby('id').agg(array_udf(df['v']).alias('v2')) + self.assertEquals(result1.first()['v2'], [1.0, 2.0]) + def test_invalid_args(self): from pyspark.sql.functions import mean @@ -5479,6 +5488,235 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase): 'mixture.*aggregate function.*group aggregate pandas UDF'): df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect() + +@unittest.skipIf( + not _have_pandas or not _have_pyarrow, + _pandas_requirement_message or _pyarrow_requirement_message) +class WindowPandasUDFTests(ReusedSQLTestCase): + @property + def data(self): + from pyspark.sql.functions import array, explode, col, lit + return self.spark.range(10).toDF('id') \ + .withColumn("vs", array([lit(i * 1.0) + col('id') for i in range(20, 30)])) \ + .withColumn("v", explode(col('vs'))) \ + .drop('vs') \ + .withColumn('w', lit(1.0)) + + @property + def python_plus_one(self): + from pyspark.sql.functions import udf + return udf(lambda v: v + 1, 'double') + + @property + def pandas_scalar_time_two(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + return pandas_udf(lambda v: v * 2, 'double') + + @property + def pandas_agg_mean_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.GROUPED_AGG) + def avg(v): + return v.mean() + return avg + + @property + def pandas_agg_max_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.GROUPED_AGG) + def max(v): + return v.max() + return max + + @property + def pandas_agg_min_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.GROUPED_AGG) + def min(v): + return v.min() + return min + + @property + def unbounded_window(self): + return Window.partitionBy('id') \ + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) + + @property + def ordered_window(self): + return Window.partitionBy('id').orderBy('v') + + @property + def unpartitioned_window(self): + return Window.partitionBy() + + def test_simple(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType, percent_rank, mean, max + + df = self.data + w = self.unbounded_window + + mean_udf = self.pandas_agg_mean_udf + + result1 = df.withColumn('mean_v', mean_udf(df['v']).over(w)) + expected1 = df.withColumn('mean_v', mean(df['v']).over(w)) + + result2 = df.select(mean_udf(df['v']).over(w)) + expected2 = df.select(mean(df['v']).over(w)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + + def test_multiple_udfs(self): + from pyspark.sql.functions import max, min, mean + + df = self.data + w = self.unbounded_window + + result1 = df.withColumn('mean_v', self.pandas_agg_mean_udf(df['v']).over(w)) \ + .withColumn('max_v', self.pandas_agg_max_udf(df['v']).over(w)) \ + .withColumn('min_w', self.pandas_agg_min_udf(df['w']).over(w)) + + expected1 = df.withColumn('mean_v', mean(df['v']).over(w)) \ + .withColumn('max_v', max(df['v']).over(w)) \ + .withColumn('min_w', min(df['w']).over(w)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_replace_existing(self): + from pyspark.sql.functions import mean + + df = self.data + w = self.unbounded_window + + result1 = df.withColumn('v', self.pandas_agg_mean_udf(df['v']).over(w)) + expected1 = df.withColumn('v', mean(df['v']).over(w)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_mixed_sql(self): + from pyspark.sql.functions import mean + + df = self.data + w = self.unbounded_window + mean_udf = self.pandas_agg_mean_udf + + result1 = df.withColumn('v', mean_udf(df['v'] * 2).over(w) + 1) + expected1 = df.withColumn('v', mean(df['v'] * 2).over(w) + 1) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_mixed_udf(self): + from pyspark.sql.functions import mean + + df = self.data + w = self.unbounded_window + + plus_one = self.python_plus_one + time_two = self.pandas_scalar_time_two + mean_udf = self.pandas_agg_mean_udf + + result1 = df.withColumn( + 'v2', + plus_one(mean_udf(plus_one(df['v'])).over(w))) + expected1 = df.withColumn( + 'v2', + plus_one(mean(plus_one(df['v'])).over(w))) + + result2 = df.withColumn( + 'v2', + time_two(mean_udf(time_two(df['v'])).over(w))) + expected2 = df.withColumn( + 'v2', + time_two(mean(time_two(df['v'])).over(w))) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + + def test_without_partitionBy(self): + from pyspark.sql.functions import mean + + df = self.data + w = self.unpartitioned_window + mean_udf = self.pandas_agg_mean_udf + + result1 = df.withColumn('v2', mean_udf(df['v']).over(w)) + expected1 = df.withColumn('v2', mean(df['v']).over(w)) + + result2 = df.select(mean_udf(df['v']).over(w)) + expected2 = df.select(mean(df['v']).over(w)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + + def test_mixed_sql_and_udf(self): + from pyspark.sql.functions import max, min, rank, col + + df = self.data + w = self.unbounded_window + ow = self.ordered_window + max_udf = self.pandas_agg_max_udf + min_udf = self.pandas_agg_min_udf + + result1 = df.withColumn('v_diff', max_udf(df['v']).over(w) - min_udf(df['v']).over(w)) + expected1 = df.withColumn('v_diff', max(df['v']).over(w) - min(df['v']).over(w)) + + # Test mixing sql window function and window udf in the same expression + result2 = df.withColumn('v_diff', max_udf(df['v']).over(w) - min(df['v']).over(w)) + expected2 = expected1 + + # Test chaining sql aggregate function and udf + result3 = df.withColumn('max_v', max_udf(df['v']).over(w)) \ + .withColumn('min_v', min(df['v']).over(w)) \ + .withColumn('v_diff', col('max_v') - col('min_v')) \ + .drop('max_v', 'min_v') + expected3 = expected1 + + # Test mixing sql window function and udf + result4 = df.withColumn('max_v', max_udf(df['v']).over(w)) \ + .withColumn('rank', rank().over(ow)) + expected4 = df.withColumn('max_v', max(df['v']).over(w)) \ + .withColumn('rank', rank().over(ow)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) + self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) + + def test_array_type(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + df = self.data + w = self.unbounded_window + + array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array<double>', PandasUDFType.GROUPED_AGG) + result1 = df.withColumn('v2', array_udf(df['v']).over(w)) + self.assertEquals(result1.first()['v2'], [1.0, 2.0]) + + def test_invalid_args(self): + from pyspark.sql.functions import mean, pandas_udf, PandasUDFType + + df = self.data + w = self.unbounded_window + ow = self.ordered_window + mean_udf = self.pandas_agg_mean_udf + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + AnalysisException, + '.*not supported within a window function'): + foo_udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUPED_MAP) + df.withColumn('v2', foo_udf(df['v']).over(w)) + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + AnalysisException, + '.*Only unbounded window frame is supported.*'): + df.withColumn('mean_v', mean_udf(df['v']).over(ow)) + + if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: http://git-wip-us.apache.org/repos/asf/spark/blob/9786ce66/python/pyspark/worker.py ---------------------------------------------------------------------- diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index a30d6bf..38fe2ef 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -128,6 +128,21 @@ def wrap_grouped_agg_pandas_udf(f, return_type): return lambda *a: (wrapped(*a), arrow_return_type) +def wrap_window_agg_pandas_udf(f, return_type): + # This is similar to grouped_agg_pandas_udf, the only difference + # is that window_agg_pandas_udf needs to repeat the return value + # to match window length, where grouped_agg_pandas_udf just returns + # the scalar value. + arrow_return_type = to_arrow_type(return_type) + + def wrapped(*series): + import pandas as pd + result = f(*series) + return pd.Series([result]).repeat(len(series[0])) + + return lambda *a: (wrapped(*a), arrow_return_type) + + def read_single_udf(pickleSer, infile, eval_type): num_arg = read_int(infile) arg_offsets = [read_int(infile) for i in range(num_arg)] @@ -151,6 +166,8 @@ def read_single_udf(pickleSer, infile, eval_type): return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec) elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type) + elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF: + return arg_offsets, wrap_window_agg_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_BATCHED_UDF: return arg_offsets, wrap_udf(func, return_type) else: @@ -195,7 +212,8 @@ def read_udfs(pickleSer, infile, eval_type): if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, - PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF): + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, + PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF): timezone = utf8_deserializer.loads(infile) ser = ArrowStreamPandasSerializer(timezone) else: http://git-wip-us.apache.org/repos/asf/spark/blob/9786ce66/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index f9947d1..6e3107f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1739,9 +1739,10 @@ class Analyzer( * 1. For a list of [[Expression]]s (a projectList or an aggregateExpressions), partitions * it two lists of [[Expression]]s, one for all [[WindowExpression]]s and another for * all regular expressions. - * 2. For all [[WindowExpression]]s, groups them based on their [[WindowSpecDefinition]]s. - * 3. For every distinct [[WindowSpecDefinition]], creates a [[Window]] operator and inserts - * it into the plan tree. + * 2. For all [[WindowExpression]]s, groups them based on their [[WindowSpecDefinition]]s + * and [[WindowFunctionType]]s. + * 3. For every distinct [[WindowSpecDefinition]] and [[WindowFunctionType]], creates a + * [[Window]] operator and inserts it into the plan tree. */ object ExtractWindowExpressions extends Rule[LogicalPlan] { private def hasWindowFunction(exprs: Seq[Expression]): Boolean = @@ -1901,7 +1902,7 @@ class Analyzer( s"Please file a bug report with this error message, stack trace, and the query.") } else { val spec = distinctWindowSpec.head - (spec.partitionSpec, spec.orderSpec) + (spec.partitionSpec, spec.orderSpec, WindowFunctionType.functionType(expr)) } }.toSeq @@ -1909,7 +1910,7 @@ class Analyzer( // setting this to the child of the next Window operator. val windowOps = groupedWindowExpressions.foldLeft(child) { - case (last, ((partitionSpec, orderSpec), windowExpressions)) => + case (last, ((partitionSpec, orderSpec, _), windowExpressions)) => Window(windowExpressions, partitionSpec, orderSpec, last) } http://git-wip-us.apache.org/repos/asf/spark/blob/9786ce66/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 90bda2a..af256b9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.analysis +import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ @@ -112,12 +113,19 @@ trait CheckAnalysis extends PredicateHelper { failAnalysis("An offset window function can only be evaluated in an ordered " + s"row-based window frame with a single offset: $w") + case _ @ WindowExpression(_: PythonUDF, + WindowSpecDefinition(_, _, frame: SpecifiedWindowFrame)) + if !frame.isUnbounded => + failAnalysis("Only unbounded window frame is supported with Pandas UDFs.") + case w @ WindowExpression(e, s) => // Only allow window functions with an aggregate expression or an offset window - // function. + // function or a Pandas window UDF. e match { case _: AggregateExpression | _: OffsetWindowFunction | _: AggregateWindowFunction => w + case f: PythonUDF if PythonUDF.isWindowPandasUDF(f) => + w case _ => failAnalysis(s"Expression '$e' not supported within a window function.") } @@ -154,7 +162,7 @@ trait CheckAnalysis extends PredicateHelper { case Aggregate(groupingExprs, aggregateExprs, child) => def isAggregateExpression(expr: Expression) = { - expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupAggPandasUDF(expr) + expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupedAggPandasUDF(expr) } def checkValidAggregateExpression(expr: Expression): Unit = expr match { http://git-wip-us.apache.org/repos/asf/spark/blob/9786ce66/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala index efd664d..6530b17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala @@ -34,10 +34,14 @@ object PythonUDF { e.isInstanceOf[PythonUDF] && SCALAR_TYPES.contains(e.asInstanceOf[PythonUDF].evalType) } - def isGroupAggPandasUDF(e: Expression): Boolean = { + def isGroupedAggPandasUDF(e: Expression): Boolean = { e.isInstanceOf[PythonUDF] && e.asInstanceOf[PythonUDF].evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF } + + // This is currently same as GroupedAggPandasUDF, but we might support new types in the future, + // e.g, N -> N transform. + def isWindowPandasUDF(e: Expression): Boolean = isGroupedAggPandasUDF(e) } /** http://git-wip-us.apache.org/repos/asf/spark/blob/9786ce66/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 9fe2fb2..f957aaa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -21,7 +21,7 @@ import java.util.Locale import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedException} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} -import org.apache.spark.sql.catalyst.expressions.aggregate.{DeclarativeAggregate, NoOp} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, DeclarativeAggregate, NoOp} import org.apache.spark.sql.types._ /** @@ -298,6 +298,37 @@ trait WindowFunction extends Expression { } /** + * Case objects that describe whether a window function is a SQL window function or a Python + * user-defined window function. + */ +sealed trait WindowFunctionType + +object WindowFunctionType { + case object SQL extends WindowFunctionType + case object Python extends WindowFunctionType + + def functionType(windowExpression: NamedExpression): WindowFunctionType = { + val t = windowExpression.collectFirst { + case _: WindowFunction | _: AggregateFunction => SQL + case udf: PythonUDF if PythonUDF.isWindowPandasUDF(udf) => Python + } + + // Normally a window expression would either have a SQL window function, a SQL + // aggregate function or a python window UDF. However, sometimes the optimizer will replace + // the window function if the value of the window function can be predetermined. + // For example, for query: + // + // select count(NULL) over () from values 1.0, 2.0, 3.0 T(a) + // + // The window function will be replaced by expression literal(0) + // To handle this case, if a window expression doesn't have a regular window function, we + // consider its type to be SQL as literal(0) is also a SQL expression. + t.getOrElse(SQL) + } +} + + +/** * An offset window function is a window function that returns the value of the input column offset * by a number of rows within the partition. For instance: an OffsetWindowfunction for value x with * offset -2, will get the value of x 2 rows back in the partition. http://git-wip-us.apache.org/repos/asf/spark/blob/9786ce66/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index bfa61116..aa992de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -621,12 +621,15 @@ object CollapseRepartition extends Rule[LogicalPlan] { /** * Collapse Adjacent Window Expression. * - If the partition specs and order specs are the same and the window expression are - * independent, collapse into the parent. + * independent and are of the same window function type, collapse into the parent. */ object CollapseWindow extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case w1 @ Window(we1, ps1, os1, w2 @ Window(we2, ps2, os2, grandChild)) - if ps1 == ps2 && os1 == os2 && w1.references.intersect(w2.windowOutputSet).isEmpty => + if ps1 == ps2 && os1 == os2 && w1.references.intersect(w2.windowOutputSet).isEmpty && + // This assumes Window contains the same type of window expressions. This is ensured + // by ExtractWindowFunctions. + WindowFunctionType.functionType(we1.head) == WindowFunctionType.functionType(we2.head) => w1.copy(windowExpressions = we2 ++ we1, child = grandChild) } } http://git-wip-us.apache.org/repos/asf/spark/blob/9786ce66/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 626f905..84be677 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.planning import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ @@ -215,7 +216,7 @@ object PhysicalAggregation { case agg: AggregateExpression if !equivalentAggregateExpressions.addExpr(agg) => agg case udf: PythonUDF - if PythonUDF.isGroupAggPandasUDF(udf) && + if PythonUDF.isGroupedAggPandasUDF(udf) && !equivalentAggregateExpressions.addExpr(udf) => udf } } @@ -245,7 +246,7 @@ object PhysicalAggregation { equivalentAggregateExpressions.getEquivalentExprs(ae).headOption .getOrElse(ae).asInstanceOf[AggregateExpression].resultAttribute // Similar to AggregateExpression - case ue: PythonUDF if PythonUDF.isGroupAggPandasUDF(ue) => + case ue: PythonUDF if PythonUDF.isGroupedAggPandasUDF(ue) => equivalentAggregateExpressions.getEquivalentExprs(ue).headOption .getOrElse(ue).asInstanceOf[PythonUDF].resultAttribute case expression => @@ -268,3 +269,40 @@ object PhysicalAggregation { case _ => None } } + +/** + * An extractor used when planning physical execution of a window. This extractor outputs + * the window function type of the logical window. + * + * The input logical window must contain same type of window functions, which is ensured by + * the rule ExtractWindowExpressions in the analyzer. + */ +object PhysicalWindow { + // windowFunctionType, windowExpression, partitionSpec, orderSpec, child + private type ReturnType = + (WindowFunctionType, Seq[NamedExpression], Seq[Expression], Seq[SortOrder], LogicalPlan) + + def unapply(a: Any): Option[ReturnType] = a match { + case expr @ logical.Window(windowExpressions, partitionSpec, orderSpec, child) => + + // The window expression should not be empty here, otherwise it's a bug. + if (windowExpressions.isEmpty) { + throw new AnalysisException(s"Window expression is empty in $expr") + } + + val windowFunctionType = windowExpressions.map(WindowFunctionType.functionType) + .reduceLeft { (t1: WindowFunctionType, t2: WindowFunctionType) => + if (t1 != t2) { + // We shouldn't have different window function type here, otherwise it's a bug. + throw new AnalysisException( + s"Found different window function type in $windowExpressions") + } else { + t1 + } + } + + Some((windowFunctionType, windowExpressions, partitionSpec, orderSpec, child)) + + case _ => None + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/9786ce66/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 7404887..75f5ec0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -41,6 +41,7 @@ class SparkPlanner( DataSourceStrategy(conf) :: SpecialLimits :: Aggregation :: + Window :: JoinSelection :: InMemoryScans :: BasicOperators :: Nil) http://git-wip-us.apache.org/repos/asf/spark/blob/9786ce66/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index be34387..d6951ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -327,7 +327,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case PhysicalAggregation( namedGroupingExpressions, aggregateExpressions, rewrittenResultExpressions, child) => - if (aggregateExpressions.exists(PythonUDF.isGroupAggPandasUDF)) { + if (aggregateExpressions.exists(PythonUDF.isGroupedAggPandasUDF)) { throw new AnalysisException( "Streaming aggregation doesn't support group aggregate pandas UDF") } @@ -428,6 +428,22 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + object Window extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case PhysicalWindow( + WindowFunctionType.SQL, windowExprs, partitionSpec, orderSpec, child) => + execution.window.WindowExec( + windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil + + case PhysicalWindow( + WindowFunctionType.Python, windowExprs, partitionSpec, orderSpec, child) => + execution.python.WindowInPandasExec( + windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil + + case _ => Nil + } + } + protected lazy val singleRowRdd = sparkContext.parallelize(Seq(InternalRow()), 1) object InMemoryScans extends Strategy { @@ -548,8 +564,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.FilterExec(f.typedCondition(f.deserializer), planLater(f.child)) :: Nil case e @ logical.Expand(_, _, child) => execution.ExpandExec(e.projections, e.output, planLater(child)) :: Nil - case logical.Window(windowExprs, partitionSpec, orderSpec, child) => - execution.window.WindowExec(windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil case logical.Sample(lb, ub, withReplacement, seed, child) => execution.SampleExec(lb, ub, withReplacement, seed, planLater(child)) :: Nil case logical.LocalRelation(output, data, _) => http://git-wip-us.apache.org/repos/asf/spark/blob/9786ce66/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 9d56f48..1e09610 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -39,7 +39,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { */ private def belongAggregate(e: Expression, agg: Aggregate): Boolean = { e.isInstanceOf[AggregateExpression] || - PythonUDF.isGroupAggPandasUDF(e) || + PythonUDF.isGroupedAggPandasUDF(e) || agg.groupingExpressions.exists(_.semanticEquals(e)) } http://git-wip-us.apache.org/repos/asf/spark/blob/9786ce66/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala new file mode 100644 index 0000000..c76832a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.File + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.types.{DataType, StructField, StructType} +import org.apache.spark.util.Utils + +case class WindowInPandasExec( + windowExpression: Seq[NamedExpression], + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], + child: SparkPlan) extends UnaryExecNode { + + override def output: Seq[Attribute] = + child.output ++ windowExpression.map(_.toAttribute) + + override def requiredChildDistribution: Seq[Distribution] = { + if (partitionSpec.isEmpty) { + // Only show warning when the number of bytes is larger than 100 MB? + logWarning("No Partition Defined for Window operation! Moving all data to a single " + + "partition, this can cause serious performance degradation.") + AllTuples :: Nil + } else { + ClusteredDistribution(partitionSpec) :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec) + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def outputPartitioning: Partitioning = child.outputPartitioning + + private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { + udf.children match { + case Seq(u: PythonUDF) => + val (chained, children) = collectFunctions(u) + (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children) + case children => + // There should not be any other UDFs, or the children can't be evaluated directly. + assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty)) + (ChainedPythonFunctions(Seq(udf.func)), udf.children) + } + } + + /** + * Create the resulting projection. + * + * This method uses Code Generation. It can only be used on the executor side. + * + * @param expressions unbound ordered function expressions. + * @return the final resulting projection. + */ + private[this] def createResultProjection(expressions: Seq[Expression]): UnsafeProjection = { + val references = expressions.zipWithIndex.map { case (e, i) => + // Results of window expressions will be on the right side of child's output + BoundReference(child.output.size + i, e.dataType, e.nullable) + } + val unboundToRefMap = expressions.zip(references).toMap + val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap)) + UnsafeProjection.create( + child.output ++ patchedWindowExpression, + child.output) + } + + protected override def doExecute(): RDD[InternalRow] = { + val inputRDD = child.execute() + + val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) + val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) + val sessionLocalTimeZone = conf.sessionLocalTimeZone + val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone + + // Extract window expressions and window functions + val expressions = windowExpression.flatMap(_.collect { case e: WindowExpression => e }) + + val udfExpressions = expressions.map(_.windowFunction.asInstanceOf[PythonUDF]) + + val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip + + // Filter child output attributes down to only those that are UDF inputs. + // Also eliminate duplicate UDF inputs. + val allInputs = new ArrayBuffer[Expression] + val dataTypes = new ArrayBuffer[DataType] + val argOffsets = inputs.map { input => + input.map { e => + if (allInputs.exists(_.semanticEquals(e))) { + allInputs.indexWhere(_.semanticEquals(e)) + } else { + allInputs += e + dataTypes += e.dataType + allInputs.length - 1 + } + }.toArray + }.toArray + + // Schema of input rows to the python runner + val windowInputSchema = StructType(dataTypes.zipWithIndex.map { case (dt, i) => + StructField(s"_$i", dt) + }) + + inputRDD.mapPartitionsInternal { iter => + val context = TaskContext.get() + + val grouped = if (partitionSpec.isEmpty) { + // Use an empty unsafe row as a place holder for the grouping key + Iterator((new UnsafeRow(), iter)) + } else { + GroupedIterator(iter, partitionSpec, child.output) + } + + // The queue used to buffer input rows so we can drain it to + // combine input with output from Python. + val queue = HybridRowQueue(context.taskMemoryManager(), + new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length) + context.addTaskCompletionListener { _ => + queue.close() + } + + val inputProj = UnsafeProjection.create(allInputs, child.output) + val pythonInput = grouped.map { case (_, rows) => + rows.map { row => + queue.add(row.asInstanceOf[UnsafeRow]) + inputProj(row) + } + } + + val windowFunctionResult = new ArrowPythonRunner( + pyFuncs, bufferSize, reuseWorker, + PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF, + argOffsets, windowInputSchema, + sessionLocalTimeZone, pandasRespectSessionTimeZone) + .compute(pythonInput, context.partitionId(), context) + + val joined = new JoinedRow + val resultProj = createResultProjection(expressions) + + windowFunctionResult.flatMap(_.rowIterator.asScala).map { windowOutput => + val leftRow = queue.remove() + val joinedRow = joined(leftRow, windowOutput) + resultProj(joinedRow) + } + } + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org