This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git
The following commit(s) were added to refs/heads/main by this push:
new 90f5b5b Add Window Functions for use with function builder (#808)
90f5b5b is described below
commit 90f5b5b355d79b56ae3607b7b0cdeb09b67e5121
Author: Tim Saucer <[email protected]>
AuthorDate: Mon Sep 2 10:51:00 2024 -0400
Add Window Functions for use with function builder (#808)
* Add window function as template for others and function builder
* Adding docstrings
* Change last_value to use function builder instead of explicitly passing
values
* Allow any value for lead function default value and add unit test
* Add lead window function and unit tests
* Temporarily commenting out deprecated functions in documenation so
builder will pass
* Expose row_number window function
* Add rank window function
* Add percent rank and dense rank
* Add cume_dist
* Add ntile window function
* Add comment to update when upstream merges
* Window frame required calling inner value
* Add unit test for avg as window function
* Working on documentation for window functions
* Add pyo build config file to git ignore since this is user specific
* Add examples to docstring
* Optionally add window function parameters during function call
* Update sort and order_by to apply automatic ordering if any other
expression is given
* Update unit tests to be cleaner and use default sort on expressions
* Ignore vscode folder specific settings
* Window frames should only apply to aggregate functions used as window
functions. Also pass in scalar pyarrow values so we can set a range other than
a uint
* Remove deprecated warning until we actually have a way to use all
functions without calling window()
* Built in window functions do not have any impact by setting
null_treatment so remove from user facing
* Update user documentation on how to pass parameters for different window
functions and what their impacts are
* Make first_value and last_value identical in the interface
---
.gitignore | 4 +
.../user-guide/common-operations/aggregations.rst | 2 +
.../user-guide/common-operations/windows.rst | 187 ++++++++--
python/datafusion/dataframe.py | 7 +-
python/datafusion/expr.py | 113 +++++-
python/datafusion/functions.py | 390 ++++++++++++++++++++-
python/datafusion/tests/test_dataframe.py | 182 +++++++---
python/datafusion/tests/test_functions.py | 1 +
src/dataframe.rs | 3 +-
src/expr.rs | 110 +++++-
src/expr/window.rs | 12 +-
src/functions.rs | 176 ++++++++--
12 files changed, 1059 insertions(+), 128 deletions(-)
diff --git a/.gitignore b/.gitignore
index 0030b90..aaeaaa5 100644
--- a/.gitignore
+++ b/.gitignore
@@ -4,6 +4,7 @@ target
/docs/temp
/docs/build
.DS_Store
+.vscode
# Byte-compiled / optimized / DLL files
__pycache__/
@@ -31,3 +32,6 @@ apache-rat-*.jar
CHANGELOG.md.bak
docs/mdbook/book
+
+.pyo3_build_config
+
diff --git a/docs/source/user-guide/common-operations/aggregations.rst
b/docs/source/user-guide/common-operations/aggregations.rst
index b920212..7ad4022 100644
--- a/docs/source/user-guide/common-operations/aggregations.rst
+++ b/docs/source/user-guide/common-operations/aggregations.rst
@@ -15,6 +15,8 @@
.. specific language governing permissions and limitations
.. under the License.
+.. _aggregation:
+
Aggregation
============
diff --git a/docs/source/user-guide/common-operations/windows.rst
b/docs/source/user-guide/common-operations/windows.rst
index 5ef3c98..6091768 100644
--- a/docs/source/user-guide/common-operations/windows.rst
+++ b/docs/source/user-guide/common-operations/windows.rst
@@ -15,13 +15,16 @@
.. specific language governing permissions and limitations
.. under the License.
+.. _window_functions:
+
Window Functions
================
-In this section you will learn about window functions. A window function
utilizes values from one or multiple rows to
-produce a result for each individual row, unlike an aggregate function that
provides a single value for multiple rows.
+In this section you will learn about window functions. A window function
utilizes values from one or
+multiple rows to produce a result for each individual row, unlike an aggregate
function that
+provides a single value for multiple rows.
-The functionality of window functions in DataFusion is supported by the
dedicated :py:func:`~datafusion.functions.window` function.
+The window functions are availble in the :py:mod:`~datafusion.functions`
module.
We'll use the pokemon dataset (from Ritchie Vink) in the following examples.
@@ -40,20 +43,25 @@ We'll use the pokemon dataset (from Ritchie Vink) in the
following examples.
ctx = SessionContext()
df = ctx.read_csv("pokemon.csv")
-Here is an example that shows how to compare each pokemons’s attack power with
the average attack power in its ``"Type 1"``
+Here is an example that shows how you can compare each pokemon's speed to the
speed of the
+previous row in the DataFrame.
.. ipython:: python
df.select(
col('"Name"'),
- col('"Attack"'),
- f.alias(
- f.window("avg", [col('"Attack"')], partition_by=[col('"Type 1"')]),
- "Average Attack",
- )
+ col('"Speed"'),
+ f.lag(col('"Speed"')).alias("Previous Speed")
)
-You can also control the order in which rows are processed by window functions
by providing
+Setting Parameters
+------------------
+
+
+Ordering
+^^^^^^^^
+
+You can control the order in which rows are processed by window functions by
providing
a list of ``order_by`` functions for the ``order_by`` parameter.
.. ipython:: python
@@ -61,33 +69,150 @@ a list of ``order_by`` functions for the ``order_by``
parameter.
df.select(
col('"Name"'),
col('"Attack"'),
- f.alias(
- f.window(
- "rank",
- [],
- partition_by=[col('"Type 1"')],
- order_by=[f.order_by(col('"Attack"'))],
- ),
- "rank",
- ),
+ col('"Type 1"'),
+ f.rank(
+ partition_by=[col('"Type 1"')],
+ order_by=[col('"Attack"').sort(ascending=True)],
+ ).alias("rank"),
+ ).sort(col('"Type 1"'), col('"Attack"'))
+
+Partitions
+^^^^^^^^^^
+
+A window function can take a list of ``partition_by`` columns similar to an
+:ref:`Aggregation Function<aggregation>`. This will cause the window values to
be evaluated
+independently for each of the partitions. In the example above, we found the
rank of each
+Pokemon per ``Type 1`` partitions. We can see the first couple of each
partition if we do
+the following:
+
+.. ipython:: python
+
+ df.select(
+ col('"Name"'),
+ col('"Attack"'),
+ col('"Type 1"'),
+ f.rank(
+ partition_by=[col('"Type 1"')],
+ order_by=[col('"Attack"').sort(ascending=True)],
+ ).alias("rank"),
+ ).filter(col("rank") < lit(3)).sort(col('"Type 1"'), col("rank"))
+
+Window Frame
+^^^^^^^^^^^^
+
+When using aggregate functions, the Window Frame of defines the rows over
which it operates.
+If you do not specify a Window Frame, the frame will be set depending on the
following
+criteria.
+
+* If an ``order_by`` clause is set, the default window frame is defined as the
rows between
+ unbounded preceeding and the current row.
+* If an ``order_by`` is not set, the default frame is defined as the rows
betwene unbounded
+ and unbounded following (the entire partition).
+
+Window Frames are defined by three parameters: unit type, starting bound, and
ending bound.
+
+The unit types available are:
+
+* Rows: The starting and ending boundaries are defined by the number of rows
relative to the
+ current row.
+* Range: When using Range, the ``order_by`` clause must have exactly one term.
The boundaries
+ are defined bow how close the rows are to the value of the expression in the
``order_by``
+ parameter.
+* Groups: A "group" is the set of all rows that have equivalent values for all
terms in the
+ ``order_by`` clause.
+
+In this example we perform a "rolling average" of the speed of the current
Pokemon and the
+two preceeding rows.
+
+.. ipython:: python
+
+ from datafusion.expr import WindowFrame
+
+ df.select(
+ col('"Name"'),
+ col('"Speed"'),
+ f.window("avg",
+ [col('"Speed"')],
+ order_by=[col('"Speed"')],
+ window_frame=WindowFrame("rows", 2, 0)
+ ).alias("Previous Speed")
+ )
+
+Null Treatment
+^^^^^^^^^^^^^^
+
+When using aggregate functions as window functions, it is often useful to
specify how null values
+should be treated. In order to do this you need to use the builder function.
In future releases
+we expect this to be simplified in the interface.
+
+One common usage for handling nulls is the case where you want to find the
last value up to the
+current row. In the following example we demonstrate how setting the null
treatment to ignore
+nulls will fill in with the value of the most recent non-null row. To do this,
we also will set
+the window frame so that we only process up to the current row.
+
+In this example, we filter down to one specific type of Pokemon that does have
some entries in
+it's ``Type 2`` column that are null.
+
+.. ipython:: python
+
+ from datafusion.common import NullTreatment
+
+ df.filter(col('"Type 1"') == lit("Bug")).select(
+ '"Name"',
+ '"Type 2"',
+ f.window("last_value", [col('"Type 2"')])
+ .window_frame(WindowFrame("rows", None, 0))
+ .order_by(col('"Speed"'))
+ .null_treatment(NullTreatment.IGNORE_NULLS)
+ .build()
+ .alias("last_wo_null"),
+ f.window("last_value", [col('"Type 2"')])
+ .window_frame(WindowFrame("rows", None, 0))
+ .order_by(col('"Speed"'))
+ .null_treatment(NullTreatment.RESPECT_NULLS)
+ .build()
+ .alias("last_with_null")
+ )
+
+Aggregate Functions
+-------------------
+
+You can use any :ref:`Aggregation Function<aggregation>` as a window function.
Currently
+aggregate functions must use the deprecated
+:py:func:`datafusion.functions.window` API but this should be resolved in
+DataFusion 42.0 (`Issue Link
<https://github.com/apache/datafusion-python/issues/833>`_). Here
+is an example that shows how to compare each pokemons’s attack power with the
average attack
+power in its ``"Type 1"`` using the :py:func:`datafusion.functions.avg`
function.
+
+.. ipython:: python
+ :okwarning:
+
+ df.select(
+ col('"Name"'),
+ col('"Attack"'),
+ col('"Type 1"'),
+ f.window("avg", [col('"Attack"')])
+ .partition_by(col('"Type 1"'))
+ .build()
+ .alias("Average Attack"),
)
+Available Functions
+-------------------
+
The possible window functions are:
1. Rank Functions
- - rank
- - dense_rank
- - row_number
- - ntile
+ - :py:func:`datafusion.functions.rank`
+ - :py:func:`datafusion.functions.dense_rank`
+ - :py:func:`datafusion.functions.ntile`
+ - :py:func:`datafusion.functions.row_number`
2. Analytical Functions
- - cume_dist
- - percent_rank
- - lag
- - lead
- - first_value
- - last_value
- - nth_value
+ - :py:func:`datafusion.functions.cume_dist`
+ - :py:func:`datafusion.functions.percent_rank`
+ - :py:func:`datafusion.functions.lag`
+ - :py:func:`datafusion.functions.lead`
3. Aggregate Functions
- - All aggregate functions can be used as window functions.
+ - All :ref:`Aggregation Functions<aggregation>` can be used as window
functions.
diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py
index 4f17601..0e7d82e 100644
--- a/python/datafusion/dataframe.py
+++ b/python/datafusion/dataframe.py
@@ -123,11 +123,10 @@ class DataFrame:
df = df.select("a", col("b"), col("a").alias("alternate_a"))
"""
- exprs = [
- arg.expr if isinstance(arg, Expr) else Expr.column(arg).expr
- for arg in exprs
+ exprs_internal = [
+ Expr.column(arg).expr if isinstance(arg, str) else arg.expr for
arg in exprs
]
- return DataFrame(self.df.select(*exprs))
+ return DataFrame(self.df.select(*exprs_internal))
def filter(self, *predicates: Expr) -> DataFrame:
"""Return a DataFrame for which ``predicate`` evaluates to ``True``.
diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py
index 71fcf39..c7272bb 100644
--- a/python/datafusion/expr.py
+++ b/python/datafusion/expr.py
@@ -23,8 +23,8 @@ See :ref:`Expressions` in the online documentation for more
details.
from __future__ import annotations
from ._internal import expr as expr_internal, LogicalPlan
-from datafusion.common import RexType, DataTypeMap
-from typing import Any
+from datafusion.common import NullTreatment, RexType, DataTypeMap
+from typing import Any, Optional
import pyarrow as pa
# The following are imported from the internal representation. We may choose to
@@ -355,6 +355,10 @@ class Expr:
"""Returns ``True`` if this expression is null."""
return Expr(self.expr.is_null())
+ def is_not_null(self) -> Expr:
+ """Returns ``True`` if this expression is not null."""
+ return Expr(self.expr.is_not_null())
+
def cast(self, to: pa.DataType[Any]) -> Expr:
"""Cast to a new data type."""
return Expr(self.expr.cast(to))
@@ -405,12 +409,107 @@ class Expr:
"""Compute the output column name based on the provided logical
plan."""
return self.expr.column_name(plan)
+ def order_by(self, *exprs: Expr) -> ExprFuncBuilder:
+ """Set the ordering for a window or aggregate function.
+
+ This function will create an :py:class:`ExprFuncBuilder` that can be
used to
+ set parameters for either window or aggregate functions. If used on
any other
+ type of expression, an error will be generated when ``build()`` is
called.
+ """
+ return ExprFuncBuilder(self.expr.order_by(list(e.expr for e in exprs)))
+
+ def filter(self, filter: Expr) -> ExprFuncBuilder:
+ """Filter an aggregate function.
+
+ This function will create an :py:class:`ExprFuncBuilder` that can be
used to
+ set parameters for either window or aggregate functions. If used on
any other
+ type of expression, an error will be generated when ``build()`` is
called.
+ """
+ return ExprFuncBuilder(self.expr.filter(filter.expr))
+
+ def distinct(self) -> ExprFuncBuilder:
+ """Only evaluate distinct values for an aggregate function.
+
+ This function will create an :py:class:`ExprFuncBuilder` that can be
used to
+ set parameters for either window or aggregate functions. If used on
any other
+ type of expression, an error will be generated when ``build()`` is
called.
+ """
+ return ExprFuncBuilder(self.expr.distinct())
+
+ def null_treatment(self, null_treatment: NullTreatment) -> ExprFuncBuilder:
+ """Set the treatment for ``null`` values for a window or aggregate
function.
+
+ This function will create an :py:class:`ExprFuncBuilder` that can be
used to
+ set parameters for either window or aggregate functions. If used on
any other
+ type of expression, an error will be generated when ``build()`` is
called.
+ """
+ return ExprFuncBuilder(self.expr.null_treatment(null_treatment))
+
+ def partition_by(self, *partition_by: Expr) -> ExprFuncBuilder:
+ """Set the partitioning for a window function.
+
+ This function will create an :py:class:`ExprFuncBuilder` that can be
used to
+ set parameters for either window or aggregate functions. If used on
any other
+ type of expression, an error will be generated when ``build()`` is
called.
+ """
+ return ExprFuncBuilder(
+ self.expr.partition_by(list(e.expr for e in partition_by))
+ )
+
+ def window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder:
+ """Set the frame fora window function.
+
+ This function will create an :py:class:`ExprFuncBuilder` that can be
used to
+ set parameters for either window or aggregate functions. If used on
any other
+ type of expression, an error will be generated when ``build()`` is
called.
+ """
+ return
ExprFuncBuilder(self.expr.window_frame(window_frame.window_frame))
+
+
+class ExprFuncBuilder:
+ def __init__(self, builder: expr_internal.ExprFuncBuilder):
+ self.builder = builder
+
+ def order_by(self, *exprs: Expr) -> ExprFuncBuilder:
+ """Set the ordering for a window or aggregate function.
+
+ Values given in ``exprs`` must be sort expressions. You can convert
any other
+ expression to a sort expression using `.sort()`.
+ """
+ return ExprFuncBuilder(self.builder.order_by(list(e.expr for e in
exprs)))
+
+ def filter(self, filter: Expr) -> ExprFuncBuilder:
+ """Filter values during aggregation."""
+ return ExprFuncBuilder(self.builder.filter(filter.expr))
+
+ def distinct(self) -> ExprFuncBuilder:
+ """Only evaluate distinct values during aggregation."""
+ return ExprFuncBuilder(self.builder.distinct())
+
+ def null_treatment(self, null_treatment: NullTreatment) -> ExprFuncBuilder:
+ """Set how nulls are treated for either window or aggregate
functions."""
+ return ExprFuncBuilder(self.builder.null_treatment(null_treatment))
+
+ def partition_by(self, *partition_by: Expr) -> ExprFuncBuilder:
+ """Set partitioning for window functions."""
+ return ExprFuncBuilder(
+ self.builder.partition_by(list(e.expr for e in partition_by))
+ )
+
+ def window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder:
+ """Set window frame for window functions."""
+ return
ExprFuncBuilder(self.builder.window_frame(window_frame.window_frame))
+
+ def build(self) -> Expr:
+ """Create an expression from a Function Builder."""
+ return Expr(self.builder.build())
+
class WindowFrame:
"""Defines a window frame for performing window operations."""
def __init__(
- self, units: str, start_bound: int | None, end_bound: int | None
+ self, units: str, start_bound: Optional[Any], end_bound: Optional[Any]
) -> None:
"""Construct a window frame using the given parameters.
@@ -423,6 +522,14 @@ class WindowFrame:
will be set to unbounded. If unit type is ``groups``, this
parameter must be set.
"""
+ if not isinstance(start_bound, pa.Scalar) and start_bound is not None:
+ start_bound = pa.scalar(start_bound)
+ if units == "rows" or units == "groups":
+ start_bound = start_bound.cast(pa.uint64())
+ if not isinstance(end_bound, pa.Scalar) and end_bound is not None:
+ end_bound = pa.scalar(end_bound)
+ if units == "rows" or units == "groups":
+ end_bound = end_bound.cast(pa.uint64())
self.window_frame = expr_internal.WindowFrame(units, start_bound,
end_bound)
def get_frame_units(self) -> str:
diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py
index ec0c110..28201c1 100644
--- a/python/datafusion/functions.py
+++ b/python/datafusion/functions.py
@@ -27,6 +27,10 @@ from datafusion._internal import functions as f, common
from datafusion.expr import CaseBuilder, Expr, WindowFrame
from datafusion.context import SessionContext
+from typing import Any, Optional
+
+import pyarrow as pa
+
__all__ = [
"abs",
"acos",
@@ -246,7 +250,16 @@ __all__ = [
"var_pop",
"var_samp",
"when",
+ # Window Functions
"window",
+ "lead",
+ "lag",
+ "row_number",
+ "rank",
+ "dense_rank",
+ "percent_rank",
+ "cume_dist",
+ "ntile",
]
@@ -383,7 +396,14 @@ def window(
window_frame: WindowFrame | None = None,
ctx: SessionContext | None = None,
) -> Expr:
- """Creates a new Window function expression."""
+ """Creates a new Window function expression.
+
+ This interface will soon be deprecated. Instead of using this interface,
+ users should call the window functions directly. For example, to perform a
+ lag use::
+
+ df.select(functions.lag(col("a")).partition_by(col("b")).build())
+ """
args = [a.expr for a in args]
partition_by = [e.expr for e in partition_by] if partition_by is not None
else None
order_by = [o.expr for o in order_by] if order_by is not None else None
@@ -1022,12 +1042,12 @@ def struct(*args: Expr) -> Expr:
return Expr(f.struct(*args))
-def named_struct(name_pairs: list[(str, Expr)]) -> Expr:
+def named_struct(name_pairs: list[tuple[str, Expr]]) -> Expr:
"""Returns a struct with the given names and arguments pairs."""
- name_pairs = [[Expr.literal(pair[0]), pair[1]] for pair in name_pairs]
+ name_pair_exprs = [[Expr.literal(pair[0]), pair[1]] for pair in name_pairs]
# flatten
- name_pairs = [x.expr for xs in name_pairs for x in xs]
+ name_pairs = [x.expr for xs in name_pair_exprs for x in xs]
return Expr(f.named_struct(*name_pairs))
@@ -1690,17 +1710,19 @@ def regr_syy(y: Expr, x: Expr, distinct: bool = False)
-> Expr:
def first_value(
arg: Expr,
distinct: bool = False,
- filter: bool = None,
- order_by: Expr | None = None,
- null_treatment: common.NullTreatment | None = None,
+ filter: Optional[bool] = None,
+ order_by: Optional[list[Expr]] = None,
+ null_treatment: Optional[common.NullTreatment] = None,
) -> Expr:
"""Returns the first value in a group of values."""
+ order_by_cols = [e.expr for e in order_by] if order_by is not None else
None
+
return Expr(
f.first_value(
arg.expr,
distinct=distinct,
filter=filter,
- order_by=order_by,
+ order_by=order_by_cols,
null_treatment=null_treatment,
)
)
@@ -1709,17 +1731,23 @@ def first_value(
def last_value(
arg: Expr,
distinct: bool = False,
- filter: bool = None,
- order_by: Expr | None = None,
- null_treatment: common.NullTreatment | None = None,
+ filter: Optional[bool] = None,
+ order_by: Optional[list[Expr]] = None,
+ null_treatment: Optional[common.NullTreatment] = None,
) -> Expr:
- """Returns the last value in a group of values."""
+ """Returns the last value in a group of values.
+
+ To set parameters on this expression, use ``.order_by()``, ``.distinct()``,
+ ``.filter()``, or ``.null_treatment()``.
+ """
+ order_by_cols = [e.expr for e in order_by] if order_by is not None else
None
+
return Expr(
f.last_value(
arg.expr,
distinct=distinct,
filter=filter,
- order_by=order_by,
+ order_by=order_by_cols,
null_treatment=null_treatment,
)
)
@@ -1748,3 +1776,339 @@ def bool_and(arg: Expr, distinct: bool = False) -> Expr:
def bool_or(arg: Expr, distinct: bool = False) -> Expr:
"""Computes the boolean OR of the arguement."""
return Expr(f.bool_or(arg.expr, distinct=distinct))
+
+
+def lead(
+ arg: Expr,
+ shift_offset: int = 1,
+ default_value: Optional[Any] = None,
+ partition_by: Optional[list[Expr]] = None,
+ order_by: Optional[list[Expr]] = None,
+) -> Expr:
+ """Create a lead window function.
+
+ Lead operation will return the argument that is in the next
shift_offset-th row in
+ the partition. For example ``lead(col("b"), shift_offset=3,
default_value=5)`` will
+ return the 3rd following value in column ``b``. At the end of the
partition, where
+ no futher values can be returned it will return the default value of 5.
+
+ Here is an example of both the ``lead`` and
:py:func:`datafusion.functions.lag`
+ functions on a simple DataFrame::
+
+ +--------+------+-----+
+ | points | lead | lag |
+ +--------+------+-----+
+ | 100 | 100 | |
+ | 100 | 50 | 100 |
+ | 50 | 25 | 100 |
+ | 25 | | 50 |
+ +--------+------+-----+
+
+ To set window function parameters use the window builder approach
described in the
+ ref:`_window_functions` online documentation.
+
+ Args:
+ arg: Value to return
+ shift_offset: Number of rows following the current row.
+ default_value: Value to return if shift_offet row does not exist.
+ partition_by: Expressions to partition the window frame on.
+ order_by: Set ordering within the window frame.
+ """
+ if not isinstance(default_value, pa.Scalar) and default_value is not None:
+ default_value = pa.scalar(default_value)
+
+ partition_cols = (
+ [col.expr for col in partition_by] if partition_by is not None else
None
+ )
+ order_cols = [col.expr for col in order_by] if order_by is not None else
None
+
+ return Expr(
+ f.lead(
+ arg.expr,
+ shift_offset,
+ default_value,
+ partition_by=partition_cols,
+ order_by=order_cols,
+ )
+ )
+
+
+def lag(
+ arg: Expr,
+ shift_offset: int = 1,
+ default_value: Optional[Any] = None,
+ partition_by: Optional[list[Expr]] = None,
+ order_by: Optional[list[Expr]] = None,
+) -> Expr:
+ """Create a lag window function.
+
+ Lag operation will return the argument that is in the previous
shift_offset-th row
+ in the partition. For example ``lag(col("b"), shift_offset=3,
default_value=5)``
+ will return the 3rd previous value in column ``b``. At the beginnig of the
+ partition, where no values can be returned it will return the default
value of 5.
+
+ Here is an example of both the ``lag`` and
:py:func:`datafusion.functions.lead`
+ functions on a simple DataFrame::
+
+ +--------+------+-----+
+ | points | lead | lag |
+ +--------+------+-----+
+ | 100 | 100 | |
+ | 100 | 50 | 100 |
+ | 50 | 25 | 100 |
+ | 25 | | 50 |
+ +--------+------+-----+
+
+ Args:
+ arg: Value to return
+ shift_offset: Number of rows before the current row.
+ default_value: Value to return if shift_offet row does not exist.
+ partition_by: Expressions to partition the window frame on.
+ order_by: Set ordering within the window frame.
+ """
+ if not isinstance(default_value, pa.Scalar):
+ default_value = pa.scalar(default_value)
+
+ partition_cols = (
+ [col.expr for col in partition_by] if partition_by is not None else
None
+ )
+ order_cols = [col.expr for col in order_by] if order_by is not None else
None
+
+ return Expr(
+ f.lag(
+ arg.expr,
+ shift_offset,
+ default_value,
+ partition_by=partition_cols,
+ order_by=order_cols,
+ )
+ )
+
+
+def row_number(
+ partition_by: Optional[list[Expr]] = None,
+ order_by: Optional[list[Expr]] = None,
+) -> Expr:
+ """Create a row number window function.
+
+ Returns the row number of the window function.
+
+ Here is an example of the ``row_number`` on a simple DataFrame::
+
+ +--------+------------+
+ | points | row number |
+ +--------+------------+
+ | 100 | 1 |
+ | 100 | 2 |
+ | 50 | 3 |
+ | 25 | 4 |
+ +--------+------------+
+
+ Args:
+ partition_by: Expressions to partition the window frame on.
+ order_by: Set ordering within the window frame.
+ """
+ partition_cols = (
+ [col.expr for col in partition_by] if partition_by is not None else
None
+ )
+ order_cols = [col.expr for col in order_by] if order_by is not None else
None
+
+ return Expr(
+ f.row_number(
+ partition_by=partition_cols,
+ order_by=order_cols,
+ )
+ )
+
+
+def rank(
+ partition_by: Optional[list[Expr]] = None,
+ order_by: Optional[list[Expr]] = None,
+) -> Expr:
+ """Create a rank window function.
+
+ Returns the rank based upon the window order. Consecutive equal values
will receive
+ the same rank, but the next different value will not be consecutive but
rather the
+ number of rows that preceed it plus one. This is similar to Olympic
medals. If two
+ people tie for gold, the next place is bronze. There would be no silver
medal. Here
+ is an example of a dataframe with a window ordered by descending
``points`` and the
+ associated rank.
+
+ You should set ``order_by`` to produce meaningful results::
+
+ +--------+------+
+ | points | rank |
+ +--------+------+
+ | 100 | 1 |
+ | 100 | 1 |
+ | 50 | 3 |
+ | 25 | 4 |
+ +--------+------+
+
+ Args:
+ partition_by: Expressions to partition the window frame on.
+ order_by: Set ordering within the window frame.
+ """
+ partition_cols = (
+ [col.expr for col in partition_by] if partition_by is not None else
None
+ )
+ order_cols = [col.expr for col in order_by] if order_by is not None else
None
+
+ return Expr(
+ f.rank(
+ partition_by=partition_cols,
+ order_by=order_cols,
+ )
+ )
+
+
+def dense_rank(
+ partition_by: Optional[list[Expr]] = None,
+ order_by: Optional[list[Expr]] = None,
+) -> Expr:
+ """Create a dense_rank window function.
+
+ This window function is similar to :py:func:`rank` except that the
returned values
+ will be consecutive. Here is an example of a dataframe with a window
ordered by
+ descending ``points`` and the associated dense rank::
+
+ +--------+------------+
+ | points | dense_rank |
+ +--------+------------+
+ | 100 | 1 |
+ | 100 | 1 |
+ | 50 | 2 |
+ | 25 | 3 |
+ +--------+------------+
+
+ Args:
+ partition_by: Expressions to partition the window frame on.
+ order_by: Set ordering within the window frame.
+ """
+ partition_cols = (
+ [col.expr for col in partition_by] if partition_by is not None else
None
+ )
+ order_cols = [col.expr for col in order_by] if order_by is not None else
None
+
+ return Expr(
+ f.dense_rank(
+ partition_by=partition_cols,
+ order_by=order_cols,
+ )
+ )
+
+
+def percent_rank(
+ partition_by: Optional[list[Expr]] = None,
+ order_by: Optional[list[Expr]] = None,
+) -> Expr:
+ """Create a percent_rank window function.
+
+ This window function is similar to :py:func:`rank` except that the
returned values
+ are the percentage from 0.0 to 1.0 from first to last. Here is an example
of a
+ dataframe with a window ordered by descending ``points`` and the
associated percent
+ rank::
+
+ +--------+--------------+
+ | points | percent_rank |
+ +--------+--------------+
+ | 100 | 0.0 |
+ | 100 | 0.0 |
+ | 50 | 0.666667 |
+ | 25 | 1.0 |
+ +--------+--------------+
+
+ Args:
+ partition_by: Expressions to partition the window frame on.
+ order_by: Set ordering within the window frame.
+ """
+ partition_cols = (
+ [col.expr for col in partition_by] if partition_by is not None else
None
+ )
+ order_cols = [col.expr for col in order_by] if order_by is not None else
None
+
+ return Expr(
+ f.percent_rank(
+ partition_by=partition_cols,
+ order_by=order_cols,
+ )
+ )
+
+
+def cume_dist(
+ partition_by: Optional[list[Expr]] = None,
+ order_by: Optional[list[Expr]] = None,
+) -> Expr:
+ """Create a cumulative distribution window function.
+
+ This window function is similar to :py:func:`rank` except that the
returned values
+ are the ratio of the row number to the total numebr of rows. Here is an
example of a
+ dataframe with a window ordered by descending ``points`` and the associated
+ cumulative distribution::
+
+ +--------+-----------+
+ | points | cume_dist |
+ +--------+-----------+
+ | 100 | 0.5 |
+ | 100 | 0.5 |
+ | 50 | 0.75 |
+ | 25 | 1.0 |
+ +--------+-----------+
+
+ Args:
+ partition_by: Expressions to partition the window frame on.
+ order_by: Set ordering within the window frame.
+ """
+ partition_cols = (
+ [col.expr for col in partition_by] if partition_by is not None else
None
+ )
+ order_cols = [col.expr for col in order_by] if order_by is not None else
None
+
+ return Expr(
+ f.cume_dist(
+ partition_by=partition_cols,
+ order_by=order_cols,
+ )
+ )
+
+
+def ntile(
+ groups: int,
+ partition_by: Optional[list[Expr]] = None,
+ order_by: Optional[list[Expr]] = None,
+) -> Expr:
+ """Create a n-tile window function.
+
+ This window function orders the window frame into a give number of groups
based on
+ the ordering criteria. It then returns which group the current row is
assigned to.
+ Here is an example of a dataframe with a window ordered by descending
``points``
+ and the associated n-tile function::
+
+ +--------+-------+
+ | points | ntile |
+ +--------+-------+
+ | 120 | 1 |
+ | 100 | 1 |
+ | 80 | 2 |
+ | 60 | 2 |
+ | 40 | 3 |
+ | 20 | 3 |
+ +--------+-------+
+
+ Args:
+ groups: Number of groups for the n-tile to be divided into.
+ partition_by: Expressions to partition the window frame on.
+ order_by: Set ordering within the window frame.
+ """
+ partition_cols = (
+ [col.expr for col in partition_by] if partition_by is not None else
None
+ )
+ order_cols = [col.expr for col in order_by] if order_by is not None else
None
+
+ return Expr(
+ f.ntile(
+ Expr.literal(groups).expr,
+ partition_by=partition_cols,
+ order_by=order_cols,
+ )
+ )
diff --git a/python/datafusion/tests/test_dataframe.py
b/python/datafusion/tests/test_dataframe.py
index 477bc0f..c2a5f22 100644
--- a/python/datafusion/tests/test_dataframe.py
+++ b/python/datafusion/tests/test_dataframe.py
@@ -84,6 +84,23 @@ def aggregate_df():
return ctx.sql("select c1, sum(c2) from test group by c1")
[email protected]
+def partitioned_df():
+ ctx = SessionContext()
+
+ # create a RecordBatch and a new DataFrame from it
+ batch = pa.RecordBatch.from_arrays(
+ [
+ pa.array([0, 1, 2, 3, 4, 5, 6]),
+ pa.array([7, None, 7, 8, 9, None, 9]),
+ pa.array(["A", "A", "A", "A", "B", "B", "B"]),
+ ],
+ names=["a", "b", "c"],
+ )
+
+ return ctx.create_dataframe([[batch]])
+
+
def test_select(df):
df = df.select(
column("a") + column("b"),
@@ -249,7 +266,7 @@ def test_join():
df = df.join(df1, join_keys=(["a"], ["a"]), how="inner")
df.show()
- df = df.sort(column("l.a").sort(ascending=True))
+ df = df.sort(column("l.a"))
table = pa.Table.from_batches(df.collect())
expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]}
@@ -263,83 +280,162 @@ def test_distinct():
[pa.array([1, 2, 3, 1, 2, 3]), pa.array([4, 5, 6, 4, 5, 6])],
names=["a", "b"],
)
- df_a = (
- ctx.create_dataframe([[batch]])
- .distinct()
- .sort(column("a").sort(ascending=True))
- )
+ df_a = ctx.create_dataframe([[batch]]).distinct().sort(column("a"))
batch = pa.RecordBatch.from_arrays(
[pa.array([1, 2, 3]), pa.array([4, 5, 6])],
names=["a", "b"],
)
- df_b =
ctx.create_dataframe([[batch]]).sort(column("a").sort(ascending=True))
+ df_b = ctx.create_dataframe([[batch]]).sort(column("a"))
assert df_a.collect() == df_b.collect()
data_test_window_functions = [
- ("row", f.window("row_number", [], order_by=[f.order_by(column("c"))]),
[2, 1, 3]),
- ("rank", f.window("rank", [], order_by=[f.order_by(column("c"))]), [2, 1,
2]),
+ (
+ "row",
+ f.row_number(order_by=[column("b"),
column("a").sort(ascending=False)]),
+ [4, 2, 3, 5, 7, 1, 6],
+ ),
+ (
+ "row_w_params",
+ f.row_number(
+ order_by=[column("b"), column("a")],
+ partition_by=[column("c")],
+ ),
+ [2, 1, 3, 4, 2, 1, 3],
+ ),
+ ("rank", f.rank(order_by=[column("b")]), [3, 1, 3, 5, 6, 1, 6]),
+ (
+ "rank_w_params",
+ f.rank(order_by=[column("b"), column("a")],
partition_by=[column("c")]),
+ [2, 1, 3, 4, 2, 1, 3],
+ ),
(
"dense_rank",
- f.window("dense_rank", [], order_by=[f.order_by(column("c"))]),
- [2, 1, 2],
+ f.dense_rank(order_by=[column("b")]),
+ [2, 1, 2, 3, 4, 1, 4],
+ ),
+ (
+ "dense_rank_w_params",
+ f.dense_rank(order_by=[column("b"), column("a")],
partition_by=[column("c")]),
+ [2, 1, 3, 4, 2, 1, 3],
),
(
"percent_rank",
- f.window("percent_rank", [], order_by=[f.order_by(column("c"))]),
- [0.5, 0, 0.5],
+ f.round(f.percent_rank(order_by=[column("b")]), literal(3)),
+ [0.333, 0.0, 0.333, 0.667, 0.833, 0.0, 0.833],
+ ),
+ (
+ "percent_rank_w_params",
+ f.round(
+ f.percent_rank(
+ order_by=[column("b"), column("a")], partition_by=[column("c")]
+ ),
+ literal(3),
+ ),
+ [0.333, 0.0, 0.667, 1.0, 0.5, 0.0, 1.0],
),
(
"cume_dist",
- f.window("cume_dist", [], order_by=[f.order_by(column("b"))]),
- [0.3333333333333333, 0.6666666666666666, 1.0],
+ f.round(f.cume_dist(order_by=[column("b")]), literal(3)),
+ [0.571, 0.286, 0.571, 0.714, 1.0, 0.286, 1.0],
+ ),
+ (
+ "cume_dist_w_params",
+ f.round(
+ f.cume_dist(
+ order_by=[column("b"), column("a")], partition_by=[column("c")]
+ ),
+ literal(3),
+ ),
+ [0.5, 0.25, 0.75, 1.0, 0.667, 0.333, 1.0],
),
(
"ntile",
- f.window("ntile", [literal(2)], order_by=[f.order_by(column("c"))]),
- [1, 1, 2],
+ f.ntile(2, order_by=[column("b")]),
+ [1, 1, 1, 2, 2, 1, 2],
),
(
- "next",
- f.window("lead", [column("b")], order_by=[f.order_by(column("b"))]),
- [5, 6, None],
+ "ntile_w_params",
+ f.ntile(2, order_by=[column("b"), column("a")],
partition_by=[column("c")]),
+ [1, 1, 2, 2, 1, 1, 2],
),
+ ("lead", f.lead(column("b"), order_by=[column("b")]), [7, None, 8, 9, 9,
7, None]),
(
- "previous",
- f.window("lag", [column("b")], order_by=[f.order_by(column("b"))]),
- [None, 4, 5],
+ "lead_w_params",
+ f.lead(
+ column("b"),
+ shift_offset=2,
+ default_value=-1,
+ order_by=[column("b"), column("a")],
+ partition_by=[column("c")],
+ ),
+ [8, 7, -1, -1, -1, 9, -1],
),
+ ("lag", f.lag(column("b"), order_by=[column("b")]), [None, None, 7, 7, 8,
None, 9]),
+ (
+ "lag_w_params",
+ f.lag(
+ column("b"),
+ shift_offset=2,
+ default_value=-1,
+ order_by=[column("b"), column("a")],
+ partition_by=[column("c")],
+ ),
+ [-1, -1, None, 7, -1, -1, None],
+ ),
+ # TODO update all aggregate functions as windows once upstream merges
https://github.com/apache/datafusion-python/issues/833
pytest.param(
"first_value",
- f.window("first_value", [column("a")],
order_by=[f.order_by(column("b"))]),
- [1, 1, 1],
+ f.window(
+ "first_value",
+ [column("a")],
+ order_by=[f.order_by(column("b"))],
+ partition_by=[column("c")],
+ ),
+ [1, 1, 1, 1, 5, 5, 5],
),
pytest.param(
"last_value",
- f.window("last_value", [column("b")],
order_by=[f.order_by(column("b"))]),
- [4, 5, 6],
+ f.window("last_value", [column("a")])
+ .window_frame(WindowFrame("rows", 0, None))
+ .order_by(column("b"))
+ .partition_by(column("c"))
+ .build(),
+ [3, 3, 3, 3, 6, 6, 6],
),
pytest.param(
- "2nd_value",
+ "3rd_value",
f.window(
"nth_value",
- [column("b"), literal(2)],
- order_by=[f.order_by(column("b"))],
+ [column("b"), literal(3)],
+ order_by=[f.order_by(column("a"))],
),
- [None, 5, 5],
+ [None, None, 7, 7, 7, 7, 7],
+ ),
+ pytest.param(
+ "avg",
+ f.round(f.window("avg", [column("b")], order_by=[column("a")]),
literal(3)),
+ [7.0, 7.0, 7.0, 7.333, 7.75, 7.75, 8.0],
),
]
@pytest.mark.parametrize("name,expr,result", data_test_window_functions)
-def test_window_functions(df, name, expr, result):
- df = df.select(column("a"), column("b"), column("c"), f.alias(expr, name))
-
+def test_window_functions(partitioned_df, name, expr, result):
+ df = partitioned_df.select(
+ column("a"), column("b"), column("c"), f.alias(expr, name)
+ )
+ df.sort(column("a")).show()
table = pa.Table.from_batches(df.collect())
- expected = {"a": [1, 2, 3], "b": [4, 5, 6], "c": [8, 5, 8], name: result}
+ expected = {
+ "a": [0, 1, 2, 3, 4, 5, 6],
+ "b": [7, None, 7, 8, 9, None, 9],
+ "c": ["A", "A", "A", "A", "B", "B", "B"],
+ name: result,
+ }
assert table.sort_by("a").to_pydict() == expected
@@ -512,9 +608,9 @@ def test_intersect():
[pa.array([3]), pa.array([6])],
names=["a", "b"],
)
- df_c =
ctx.create_dataframe([[batch]]).sort(column("a").sort(ascending=True))
+ df_c = ctx.create_dataframe([[batch]]).sort(column("a"))
- df_a_i_b = df_a.intersect(df_b).sort(column("a").sort(ascending=True))
+ df_a_i_b = df_a.intersect(df_b).sort(column("a"))
assert df_c.collect() == df_a_i_b.collect()
@@ -538,9 +634,9 @@ def test_except_all():
[pa.array([1, 2]), pa.array([4, 5])],
names=["a", "b"],
)
- df_c =
ctx.create_dataframe([[batch]]).sort(column("a").sort(ascending=True))
+ df_c = ctx.create_dataframe([[batch]]).sort(column("a"))
- df_a_e_b = df_a.except_all(df_b).sort(column("a").sort(ascending=True))
+ df_a_e_b = df_a.except_all(df_b).sort(column("a"))
assert df_c.collect() == df_a_e_b.collect()
@@ -573,9 +669,9 @@ def test_union(ctx):
[pa.array([1, 2, 3, 3, 4, 5]), pa.array([4, 5, 6, 6, 7, 8])],
names=["a", "b"],
)
- df_c =
ctx.create_dataframe([[batch]]).sort(column("a").sort(ascending=True))
+ df_c = ctx.create_dataframe([[batch]]).sort(column("a"))
- df_a_u_b = df_a.union(df_b).sort(column("a").sort(ascending=True))
+ df_a_u_b = df_a.union(df_b).sort(column("a"))
assert df_c.collect() == df_a_u_b.collect()
@@ -597,9 +693,9 @@ def test_union_distinct(ctx):
[pa.array([1, 2, 3, 4, 5]), pa.array([4, 5, 6, 7, 8])],
names=["a", "b"],
)
- df_c =
ctx.create_dataframe([[batch]]).sort(column("a").sort(ascending=True))
+ df_c = ctx.create_dataframe([[batch]]).sort(column("a"))
- df_a_u_b = df_a.union(df_b, True).sort(column("a").sort(ascending=True))
+ df_a_u_b = df_a.union(df_b, True).sort(column("a"))
assert df_c.collect() == df_a_u_b.collect()
assert df_c.collect() == df_a_u_b.collect()
diff --git a/python/datafusion/tests/test_functions.py
b/python/datafusion/tests/test_functions.py
index e5429bd..fe092c4 100644
--- a/python/datafusion/tests/test_functions.py
+++ b/python/datafusion/tests/test_functions.py
@@ -963,6 +963,7 @@ def test_first_last_value(df):
assert result.column(3) == pa.array(["!"])
assert result.column(4) == pa.array([6])
assert result.column(5) == pa.array([datetime(2020, 7, 2)])
+ df.show()
def test_binary_string_functions(df):
diff --git a/src/dataframe.rs b/src/dataframe.rs
index 22b0522..d7abab4 100644
--- a/src/dataframe.rs
+++ b/src/dataframe.rs
@@ -39,6 +39,7 @@ use pyo3::types::{PyCapsule, PyTuple};
use tokio::task::JoinHandle;
use crate::errors::py_datafusion_err;
+use crate::expr::to_sort_expressions;
use crate::physical_plan::PyExecutionPlan;
use crate::record_batch::PyRecordBatchStream;
use crate::sql::logical::PyLogicalPlan;
@@ -150,7 +151,7 @@ impl PyDataFrame {
#[pyo3(signature = (*exprs))]
fn sort(&self, exprs: Vec<PyExpr>) -> PyResult<Self> {
- let exprs = exprs.into_iter().map(|e| e.into()).collect();
+ let exprs = to_sort_expressions(exprs);
let df = self.df.as_ref().clone().sort(exprs)?;
Ok(Self::new(df))
}
diff --git a/src/expr.rs b/src/expr.rs
index 04bfc85..697682d 100644
--- a/src/expr.rs
+++ b/src/expr.rs
@@ -16,10 +16,11 @@
// under the License.
use datafusion_expr::utils::exprlist_to_fields;
-use datafusion_expr::LogicalPlan;
+use datafusion_expr::{ExprFuncBuilder, ExprFunctionExt, LogicalPlan};
use pyo3::{basic::CompareOp, prelude::*};
use std::convert::{From, Into};
use std::sync::Arc;
+use window::PyWindowFrame;
use arrow::pyarrow::ToPyArrow;
use datafusion::arrow::datatypes::{DataType, Field};
@@ -32,7 +33,7 @@ use datafusion_expr::{
lit, Between, BinaryExpr, Case, Cast, Expr, Like, Operator, TryCast,
};
-use crate::common::data_type::{DataTypeMap, RexType};
+use crate::common::data_type::{DataTypeMap, NullTreatment, RexType};
use crate::errors::{py_runtime_err, py_type_err, py_unsupported_variant_err,
DataFusionError};
use crate::expr::aggregate_expr::PyAggregateFunction;
use crate::expr::binary_expr::PyBinaryExpr;
@@ -281,6 +282,10 @@ impl PyExpr {
self.expr.clone().is_null().into()
}
+ pub fn is_not_null(&self) -> PyExpr {
+ self.expr.clone().is_not_null().into()
+ }
+
pub fn cast(&self, to: PyArrowType<DataType>) -> PyExpr {
// self.expr.cast_to() requires DFSchema to validate that the cast
// is supported, omit that for now
@@ -510,6 +515,107 @@ impl PyExpr {
pub fn column_name(&self, plan: PyLogicalPlan) -> PyResult<String> {
self._column_name(&plan.plan()).map_err(py_runtime_err)
}
+
+ // Expression Function Builder functions
+
+ pub fn order_by(&self, order_by: Vec<PyExpr>) -> PyExprFuncBuilder {
+ self.expr
+ .clone()
+ .order_by(to_sort_expressions(order_by))
+ .into()
+ }
+
+ pub fn filter(&self, filter: PyExpr) -> PyExprFuncBuilder {
+ self.expr.clone().filter(filter.expr.clone()).into()
+ }
+
+ pub fn distinct(&self) -> PyExprFuncBuilder {
+ self.expr.clone().distinct().into()
+ }
+
+ pub fn null_treatment(&self, null_treatment: NullTreatment) ->
PyExprFuncBuilder {
+ self.expr
+ .clone()
+ .null_treatment(Some(null_treatment.into()))
+ .into()
+ }
+
+ pub fn partition_by(&self, partition_by: Vec<PyExpr>) -> PyExprFuncBuilder
{
+ let partition_by = partition_by.iter().map(|e|
e.expr.clone()).collect();
+ self.expr.clone().partition_by(partition_by).into()
+ }
+
+ pub fn window_frame(&self, window_frame: PyWindowFrame) ->
PyExprFuncBuilder {
+ self.expr.clone().window_frame(window_frame.into()).into()
+ }
+}
+
+#[pyclass(name = "ExprFuncBuilder", module = "datafusion.expr", subclass)]
+#[derive(Debug, Clone)]
+pub struct PyExprFuncBuilder {
+ pub builder: ExprFuncBuilder,
+}
+
+impl From<ExprFuncBuilder> for PyExprFuncBuilder {
+ fn from(builder: ExprFuncBuilder) -> Self {
+ Self { builder }
+ }
+}
+
+pub fn to_sort_expressions(order_by: Vec<PyExpr>) -> Vec<Expr> {
+ order_by
+ .iter()
+ .map(|e| e.expr.clone())
+ .map(|e| match e {
+ Expr::Sort(_) => e,
+ _ => e.sort(true, true),
+ })
+ .collect()
+}
+
+#[pymethods]
+impl PyExprFuncBuilder {
+ pub fn order_by(&self, order_by: Vec<PyExpr>) -> PyExprFuncBuilder {
+ self.builder
+ .clone()
+ .order_by(to_sort_expressions(order_by))
+ .into()
+ }
+
+ pub fn filter(&self, filter: PyExpr) -> PyExprFuncBuilder {
+ self.builder.clone().filter(filter.expr.clone()).into()
+ }
+
+ pub fn distinct(&self) -> PyExprFuncBuilder {
+ self.builder.clone().distinct().into()
+ }
+
+ pub fn null_treatment(&self, null_treatment: NullTreatment) ->
PyExprFuncBuilder {
+ self.builder
+ .clone()
+ .null_treatment(Some(null_treatment.into()))
+ .into()
+ }
+
+ pub fn partition_by(&self, partition_by: Vec<PyExpr>) -> PyExprFuncBuilder
{
+ let partition_by = partition_by.iter().map(|e|
e.expr.clone()).collect();
+ self.builder.clone().partition_by(partition_by).into()
+ }
+
+ pub fn window_frame(&self, window_frame: PyWindowFrame) ->
PyExprFuncBuilder {
+ self.builder
+ .clone()
+ .window_frame(window_frame.into())
+ .into()
+ }
+
+ pub fn build(&self) -> PyResult<PyExpr> {
+ self.builder
+ .clone()
+ .build()
+ .map(|expr| expr.into())
+ .map_err(|err| err.into())
+ }
}
impl PyExpr {
diff --git a/src/expr/window.rs b/src/expr/window.rs
index 7866511..7eb5860 100644
--- a/src/expr/window.rs
+++ b/src/expr/window.rs
@@ -168,7 +168,11 @@ fn not_window_function_err(expr: Expr) -> PyErr {
impl PyWindowFrame {
#[new]
#[pyo3(signature=(unit, start_bound, end_bound))]
- pub fn new(unit: &str, start_bound: Option<u64>, end_bound: Option<u64>)
-> PyResult<Self> {
+ pub fn new(
+ unit: &str,
+ start_bound: Option<ScalarValue>,
+ end_bound: Option<ScalarValue>,
+ ) -> PyResult<Self> {
let units = unit.to_ascii_lowercase();
let units = match units.as_str() {
"rows" => WindowFrameUnits::Rows,
@@ -182,9 +186,7 @@ impl PyWindowFrame {
}
};
let start_bound = match start_bound {
- Some(start_bound) => {
-
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(start_bound)))
- }
+ Some(start_bound) => WindowFrameBound::Preceding(start_bound),
None => match units {
WindowFrameUnits::Range =>
WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
WindowFrameUnits::Rows =>
WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
@@ -197,7 +199,7 @@ impl PyWindowFrame {
},
};
let end_bound = match end_bound {
- Some(end_bound) =>
WindowFrameBound::Following(ScalarValue::UInt64(Some(end_bound))),
+ Some(end_bound) => WindowFrameBound::Following(end_bound),
None => match units {
WindowFrameUnits::Rows =>
WindowFrameBound::Following(ScalarValue::UInt64(None)),
WindowFrameUnits::Range =>
WindowFrameBound::Following(ScalarValue::UInt64(None)),
diff --git a/src/functions.rs b/src/functions.rs
index 2525636..aed4de4 100644
--- a/src/functions.rs
+++ b/src/functions.rs
@@ -16,13 +16,16 @@
// under the License.
use datafusion::functions_aggregate::all_default_aggregate_functions;
+use datafusion_expr::window_function;
use datafusion_expr::ExprFunctionExt;
+use datafusion_expr::WindowFrame;
use pyo3::{prelude::*, wrap_pyfunction};
use crate::common::data_type::NullTreatment;
use crate::context::PySessionContext;
use crate::errors::DataFusionError;
use crate::expr::conditional_expr::PyCaseBuilder;
+use crate::expr::to_sort_expressions;
use crate::expr::window::PyWindowFrame;
use crate::expr::PyExpr;
use datafusion::execution::FunctionRegistry;
@@ -316,18 +319,15 @@ pub fn regr_syy(expr_y: PyExpr, expr_x: PyExpr, distinct:
bool) -> PyResult<PyEx
}
}
-#[pyfunction]
-pub fn first_value(
- expr: PyExpr,
+fn add_builder_fns_to_aggregate(
+ agg_fn: Expr,
distinct: bool,
filter: Option<PyExpr>,
order_by: Option<Vec<PyExpr>>,
null_treatment: Option<NullTreatment>,
) -> PyResult<PyExpr> {
- // If we initialize the UDAF with order_by directly, then it gets
over-written by the builder
- let agg_fn = functions_aggregate::expr_fn::first_value(expr.expr, None);
-
- // luckily, I can guarantee initializing a builder with an `order_by`
default of empty vec
+ // Since ExprFuncBuilder::new() is private, we can guarantee initializing
+ // a builder with an `order_by` default of empty vec
let order_by = order_by
.map(|x| x.into_iter().map(|x| x.expr).collect::<Vec<_>>())
.unwrap_or_default();
@@ -348,32 +348,30 @@ pub fn first_value(
}
#[pyfunction]
-pub fn last_value(
+pub fn first_value(
expr: PyExpr,
distinct: bool,
filter: Option<PyExpr>,
order_by: Option<Vec<PyExpr>>,
null_treatment: Option<NullTreatment>,
) -> PyResult<PyExpr> {
- let agg_fn = functions_aggregate::expr_fn::last_value(vec![expr.expr]);
-
- // luckily, I can guarantee initializing a builder with an `order_by`
default of empty vec
- let order_by = order_by
- .map(|x| x.into_iter().map(|x| x.expr).collect::<Vec<_>>())
- .unwrap_or_default();
- let mut builder = agg_fn.order_by(order_by);
-
- if distinct {
- builder = builder.distinct();
- }
+ // If we initialize the UDAF with order_by directly, then it gets
over-written by the builder
+ let agg_fn = functions_aggregate::expr_fn::first_value(expr.expr, None);
- if let Some(filter) = filter {
- builder = builder.filter(filter.expr);
- }
+ add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by,
null_treatment)
+}
- builder =
builder.null_treatment(null_treatment.map(DFNullTreatment::from));
+#[pyfunction]
+pub fn last_value(
+ expr: PyExpr,
+ distinct: bool,
+ filter: Option<PyExpr>,
+ order_by: Option<Vec<PyExpr>>,
+ null_treatment: Option<NullTreatment>,
+) -> PyResult<PyExpr> {
+ let agg_fn = functions_aggregate::expr_fn::last_value(vec![expr.expr]);
- Ok(builder.build()?.into())
+ add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by,
null_treatment)
}
#[pyfunction]
@@ -618,9 +616,11 @@ fn window(
ctx: Option<PySessionContext>,
) -> PyResult<PyExpr> {
let fun = find_window_fn(name, ctx)?;
+
let window_frame = window_frame
- .unwrap_or_else(|| PyWindowFrame::new("rows", None, Some(0)).unwrap())
- .into();
+ .map(|w| w.into())
+ .unwrap_or(WindowFrame::new(order_by.as_ref().map(|v| !v.is_empty())));
+
Ok(PyExpr {
expr: datafusion_expr::Expr::WindowFunction(WindowFunction {
fun,
@@ -634,6 +634,10 @@ fn window(
.unwrap_or_default()
.into_iter()
.map(|x| x.expr)
+ .map(|e| match e {
+ Expr::Sort(_) => e,
+ _ => e.sort(true, true),
+ })
.collect::<Vec<_>>(),
window_frame,
null_treatment: None,
@@ -890,6 +894,116 @@ aggregate_function!(array_agg,
functions_aggregate::array_agg::array_agg_udaf);
aggregate_function!(max, functions_aggregate::min_max::max_udaf);
aggregate_function!(min, functions_aggregate::min_max::min_udaf);
+fn add_builder_fns_to_window(
+ window_fn: Expr,
+ partition_by: Option<Vec<PyExpr>>,
+ order_by: Option<Vec<PyExpr>>,
+) -> PyResult<PyExpr> {
+ // Since ExprFuncBuilder::new() is private, set an empty partition and then
+ // override later if appropriate.
+ let mut builder = window_fn.partition_by(vec![]);
+
+ if let Some(partition_cols) = partition_by {
+ builder = builder.partition_by(
+ partition_cols
+ .into_iter()
+ .map(|col| col.clone().into())
+ .collect(),
+ );
+ }
+
+ if let Some(order_by_cols) = order_by {
+ let order_by_cols = to_sort_expressions(order_by_cols);
+ builder = builder.order_by(order_by_cols);
+ }
+
+ builder.build().map(|e| e.into()).map_err(|err| err.into())
+}
+
+#[pyfunction]
+pub fn lead(
+ arg: PyExpr,
+ shift_offset: i64,
+ default_value: Option<ScalarValue>,
+ partition_by: Option<Vec<PyExpr>>,
+ order_by: Option<Vec<PyExpr>>,
+) -> PyResult<PyExpr> {
+ let window_fn = window_function::lead(arg.expr, Some(shift_offset),
default_value);
+
+ add_builder_fns_to_window(window_fn, partition_by, order_by)
+}
+
+#[pyfunction]
+pub fn lag(
+ arg: PyExpr,
+ shift_offset: i64,
+ default_value: Option<ScalarValue>,
+ partition_by: Option<Vec<PyExpr>>,
+ order_by: Option<Vec<PyExpr>>,
+) -> PyResult<PyExpr> {
+ let window_fn = window_function::lag(arg.expr, Some(shift_offset),
default_value);
+
+ add_builder_fns_to_window(window_fn, partition_by, order_by)
+}
+
+#[pyfunction]
+pub fn row_number(
+ partition_by: Option<Vec<PyExpr>>,
+ order_by: Option<Vec<PyExpr>>,
+) -> PyResult<PyExpr> {
+ let window_fn = window_function::row_number();
+
+ add_builder_fns_to_window(window_fn, partition_by, order_by)
+}
+
+#[pyfunction]
+pub fn rank(partition_by: Option<Vec<PyExpr>>, order_by: Option<Vec<PyExpr>>)
-> PyResult<PyExpr> {
+ let window_fn = window_function::rank();
+
+ add_builder_fns_to_window(window_fn, partition_by, order_by)
+}
+
+#[pyfunction]
+pub fn dense_rank(
+ partition_by: Option<Vec<PyExpr>>,
+ order_by: Option<Vec<PyExpr>>,
+) -> PyResult<PyExpr> {
+ let window_fn = window_function::dense_rank();
+
+ add_builder_fns_to_window(window_fn, partition_by, order_by)
+}
+
+#[pyfunction]
+pub fn percent_rank(
+ partition_by: Option<Vec<PyExpr>>,
+ order_by: Option<Vec<PyExpr>>,
+) -> PyResult<PyExpr> {
+ let window_fn = window_function::percent_rank();
+
+ add_builder_fns_to_window(window_fn, partition_by, order_by)
+}
+
+#[pyfunction]
+pub fn cume_dist(
+ partition_by: Option<Vec<PyExpr>>,
+ order_by: Option<Vec<PyExpr>>,
+) -> PyResult<PyExpr> {
+ let window_fn = window_function::cume_dist();
+
+ add_builder_fns_to_window(window_fn, partition_by, order_by)
+}
+
+#[pyfunction]
+pub fn ntile(
+ arg: PyExpr,
+ partition_by: Option<Vec<PyExpr>>,
+ order_by: Option<Vec<PyExpr>>,
+) -> PyResult<PyExpr> {
+ let window_fn = window_function::ntile(arg.into());
+
+ add_builder_fns_to_window(window_fn, partition_by, order_by)
+}
+
pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(abs))?;
m.add_wrapped(wrap_pyfunction!(acos))?;
@@ -1075,5 +1189,15 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) ->
PyResult<()> {
m.add_wrapped(wrap_pyfunction!(array_slice))?;
m.add_wrapped(wrap_pyfunction!(flatten))?;
+ // Window Functions
+ m.add_wrapped(wrap_pyfunction!(lead))?;
+ m.add_wrapped(wrap_pyfunction!(lag))?;
+ m.add_wrapped(wrap_pyfunction!(row_number))?;
+ m.add_wrapped(wrap_pyfunction!(rank))?;
+ m.add_wrapped(wrap_pyfunction!(dense_rank))?;
+ m.add_wrapped(wrap_pyfunction!(percent_rank))?;
+ m.add_wrapped(wrap_pyfunction!(cume_dist))?;
+ m.add_wrapped(wrap_pyfunction!(ntile))?;
+
Ok(())
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]