This is an automated email from the ASF dual-hosted git repository.
timsaucer pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git
The following commit(s) were added to refs/heads/main by this push:
new d54dc4ae Support string column identifiers for sort/aggregate/window
and stricter Expr validation (#1221)
d54dc4ae is described below
commit d54dc4aeeeab41b5a845d383a672630ceb880253
Author: kosiew <[email protected]>
AuthorDate: Wed Sep 17 04:49:57 2025 +0800
Support string column identifiers for sort/aggregate/window and stricter
Expr validation (#1221)
* refactor: improve DataFrame expression handling, type checking, and docs
- Refactor expression handling and `_simplify_expression` for stronger
type checking and clearer error handling
- Improve type annotations for `file_sort_order` and `order_by` to
support string inputs
- Refactor DataFrame `filter` method to better validate predicates
- Replace internal error message variable with public constant
- Clarify usage of `col()` and `column()` in DataFrame examples
* refactor: unify expression and sorting logic; improve docs and error
handling
- Update `order_by` handling in Window class for better type support
- Improve type checking in DataFrame expression handling
- Replace `Expr`/`SortExpr` with `SortKey` in file_sort_order and
related functions
- Simplify file_sort_order handling in SessionContext
- Rename `_EXPR_TYPE_ERROR` → `EXPR_TYPE_ERROR` for consistency
- Clarify usage of `col()` vs `column()` in DataFrame examples
- Enhance documentation for file_sort_order in SessionContext
* feat: add ensure_expr helper for validation; refine expression handling,
sorting, and docs
- Introduce `ensure_expr` helper and improve internal expression
validation
- Update error messages and tests to consistently use `EXPR_TYPE_ERROR`
- Refactor expression handling with `_to_raw_expr`, `_ensure_expr`, and
`SortKey`
- Improve type safety and consistency in sort key definitions and file
sort order
- Add parameterized parquet sorting tests
- Enhance DataFrame docstrings with clearer guidance and usage examples
- Fix minor typos and error message clarity
* Refactor and enhance expression handling, test coverage, and documentation
- Introduced `ensure_expr_list` to validate and flatten nested
expressions, treating strings as atomic
- Updated expression utilities to improve consistency across aggregation
and window functions
- Consolidated and expanded parameterized tests for string equivalence
in ranking and window functions
- Exposed `EXPR_TYPE_ERROR` for consistent error messaging across
modules and tests
- Improved internal sort logic using `expr_internal.SortExpr`
- Clarified expectations for `join_on` expressions in documentation
- Standardized imports and improved test clarity for maintainability
* refactor: update docstring for sort_or_default function to clarify its
purpose
* fix Ruff errors
* refactor: update type hints to use typing.Union for better clarity and
consistency
* fix Ruff errors
* refactor: simplify type hints by removing unnecessary imports for type
checking
* refactor: update type hints for rex_type and types methods to improve
clarity
* refactor: remove unnecessary type ignore comments from rex_type and types
methods
* docs: update section title for clarity on DataFrame method arguments
* docs: clarify description of DataFrame methods accepting column names
* docs: add note to clarify function documentation reference for DataFrame
methods
* docs: remove outdated information about predicate acceptance in DataFrame
class
* refactor: simplify type hint for expr_list parameter in
expr_list_to_raw_expr_list function
* docs: clarify usage of datafusion.col and datafusion.lit in DataFrame
method documentation
* docs: clarify usage of col() and lit() in DataFrame filter examples
* Fix ruff errors
---
docs/source/user-guide/dataframe/index.rst | 50 +++++++
python/datafusion/context.py | 63 +++++----
python/datafusion/dataframe.py | 130 ++++++++++-------
python/datafusion/expr.py | 147 ++++++++++++++++---
python/datafusion/functions.py | 123 ++++++++++++----
python/tests/test_dataframe.py | 217 ++++++++++++++++++++++++++++-
python/tests/test_expr.py | 26 ++++
7 files changed, 633 insertions(+), 123 deletions(-)
diff --git a/docs/source/user-guide/dataframe/index.rst
b/docs/source/user-guide/dataframe/index.rst
index f69485af..1387db0b 100644
--- a/docs/source/user-guide/dataframe/index.rst
+++ b/docs/source/user-guide/dataframe/index.rst
@@ -126,6 +126,56 @@ DataFusion's DataFrame API offers a wide range of
operations:
# Drop columns
df = df.drop("temporary_column")
+Column Names as Function Arguments
+----------------------------------
+
+Some ``DataFrame`` methods accept column names when an argument refers to an
+existing column. These include:
+
+* :py:meth:`~datafusion.DataFrame.select`
+* :py:meth:`~datafusion.DataFrame.sort`
+* :py:meth:`~datafusion.DataFrame.drop`
+* :py:meth:`~datafusion.DataFrame.join` (``on`` argument)
+* :py:meth:`~datafusion.DataFrame.aggregate` (grouping columns)
+
+See the full function documentation for details on any specific function.
+
+Note that :py:meth:`~datafusion.DataFrame.join_on` expects
``col()``/``column()`` expressions rather than plain strings.
+
+For such methods, you can pass column names directly:
+
+.. code-block:: python
+
+ from datafusion import col, functions as f
+
+ df.sort('id')
+ df.aggregate('id', [f.count(col('value'))])
+
+The same operation can also be written with explicit column expressions, using
either ``col()`` or ``column()``:
+
+.. code-block:: python
+
+ from datafusion import col, column, functions as f
+
+ df.sort(col('id'))
+ df.aggregate(column('id'), [f.count(col('value'))])
+
+Note that ``column()`` is an alias of ``col()``, so you can use either name;
the example above shows both in action.
+
+Whenever an argument represents an expression—such as in
+:py:meth:`~datafusion.DataFrame.filter` or
+:py:meth:`~datafusion.DataFrame.with_column`—use ``col()`` to reference
+columns. The comparison and arithmetic operators on ``Expr`` will automatically
+convert any non-``Expr`` value into a literal expression, so writing
+
+.. code-block:: python
+
+ from datafusion import col
+ df.filter(col("age") > 21)
+
+is equivalent to using ``lit(21)`` explicitly. Use ``lit()`` (also available
+as ``literal()``) when you need to construct a literal expression directly.
+
Terminal Operations
-------------------
diff --git a/python/datafusion/context.py b/python/datafusion/context.py
index bce51d64..b6e728b5 100644
--- a/python/datafusion/context.py
+++ b/python/datafusion/context.py
@@ -22,16 +22,16 @@ from __future__ import annotations
import warnings
from typing import TYPE_CHECKING, Any, Protocol
-import pyarrow as pa
-
try:
from warnings import deprecated # Python 3.13+
except ImportError:
from typing_extensions import deprecated # Python 3.12
+import pyarrow as pa
+
from datafusion.catalog import Catalog, CatalogProvider, Table
from datafusion.dataframe import DataFrame
-from datafusion.expr import Expr, SortExpr, sort_list_to_raw_sort_list
+from datafusion.expr import SortKey, sort_list_to_raw_sort_list
from datafusion.record_batch import RecordBatchStream
from datafusion.user_defined import AggregateUDF, ScalarUDF, TableFunction,
WindowUDF
@@ -39,12 +39,14 @@ from ._internal import RuntimeEnvBuilder as
RuntimeEnvBuilderInternal
from ._internal import SessionConfig as SessionConfigInternal
from ._internal import SessionContext as SessionContextInternal
from ._internal import SQLOptions as SQLOptionsInternal
+from ._internal import expr as expr_internal
if TYPE_CHECKING:
import pathlib
+ from collections.abc import Sequence
import pandas as pd
- import polars as pl
+ import polars as pl # type: ignore[import]
from datafusion.plan import ExecutionPlan, LogicalPlan
@@ -553,7 +555,7 @@ class SessionContext:
table_partition_cols: list[tuple[str, str | pa.DataType]] | None =
None,
file_extension: str = ".parquet",
schema: pa.Schema | None = None,
- file_sort_order: list[list[Expr | SortExpr]] | None = None,
+ file_sort_order: Sequence[Sequence[SortKey]] | None = None,
) -> None:
"""Register multiple files as a single table.
@@ -567,23 +569,20 @@ class SessionContext:
table_partition_cols: Partition columns.
file_extension: File extension of the provided table.
schema: The data source schema.
- file_sort_order: Sort order for the file.
+ file_sort_order: Sort order for the file. Each sort key can be
+ specified as a column name (``str``), an expression
+ (``Expr``), or a ``SortExpr``.
"""
if table_partition_cols is None:
table_partition_cols = []
table_partition_cols =
self._convert_table_partition_cols(table_partition_cols)
- file_sort_order_raw = (
- [sort_list_to_raw_sort_list(f) for f in file_sort_order]
- if file_sort_order is not None
- else None
- )
self.ctx.register_listing_table(
name,
str(path),
table_partition_cols,
file_extension,
schema,
- file_sort_order_raw,
+ self._convert_file_sort_order(file_sort_order),
)
def sql(self, query: str, options: SQLOptions | None = None) -> DataFrame:
@@ -808,7 +807,7 @@ class SessionContext:
file_extension: str = ".parquet",
skip_metadata: bool = True,
schema: pa.Schema | None = None,
- file_sort_order: list[list[SortExpr]] | None = None,
+ file_sort_order: Sequence[Sequence[SortKey]] | None = None,
) -> None:
"""Register a Parquet file as a table.
@@ -827,7 +826,9 @@ class SessionContext:
that may be in the file schema. This can help avoid schema
conflicts due to metadata.
schema: The data source schema.
- file_sort_order: Sort order for the file.
+ file_sort_order: Sort order for the file. Each sort key can be
+ specified as a column name (``str``), an expression
+ (``Expr``), or a ``SortExpr``.
"""
if table_partition_cols is None:
table_partition_cols = []
@@ -840,9 +841,7 @@ class SessionContext:
file_extension,
skip_metadata,
schema,
- [sort_list_to_raw_sort_list(exprs) for exprs in file_sort_order]
- if file_sort_order is not None
- else None,
+ self._convert_file_sort_order(file_sort_order),
)
def register_csv(
@@ -1099,7 +1098,7 @@ class SessionContext:
file_extension: str = ".parquet",
skip_metadata: bool = True,
schema: pa.Schema | None = None,
- file_sort_order: list[list[Expr | SortExpr]] | None = None,
+ file_sort_order: Sequence[Sequence[SortKey]] | None = None,
) -> DataFrame:
"""Read a Parquet source into a
:py:class:`~datafusion.dataframe.Dataframe`.
@@ -1116,7 +1115,9 @@ class SessionContext:
schema: An optional schema representing the parquet files. If None,
the parquet reader will try to infer it based on data in the
file.
- file_sort_order: Sort order for the file.
+ file_sort_order: Sort order for the file. Each sort key can be
+ specified as a column name (``str``), an expression
+ (``Expr``), or a ``SortExpr``.
Returns:
DataFrame representation of the read Parquet files
@@ -1124,11 +1125,7 @@ class SessionContext:
if table_partition_cols is None:
table_partition_cols = []
table_partition_cols =
self._convert_table_partition_cols(table_partition_cols)
- file_sort_order = (
- [sort_list_to_raw_sort_list(f) for f in file_sort_order]
- if file_sort_order is not None
- else None
- )
+ file_sort_order = self._convert_file_sort_order(file_sort_order)
return DataFrame(
self.ctx.read_parquet(
str(path),
@@ -1179,6 +1176,24 @@ class SessionContext:
"""Execute the ``plan`` and return the results."""
return RecordBatchStream(self.ctx.execute(plan._raw_plan, partitions))
+ @staticmethod
+ def _convert_file_sort_order(
+ file_sort_order: Sequence[Sequence[SortKey]] | None,
+ ) -> list[list[expr_internal.SortExpr]] | None:
+ """Convert nested ``SortKey`` sequences into raw sort expressions.
+
+ Each ``SortKey`` can be a column name string, an ``Expr``, or a
+ ``SortExpr`` and will be converted using
+ :func:`datafusion.expr.sort_list_to_raw_sort_list`.
+ """
+ # Convert each ``SortKey`` in the provided sort order to the low-level
+ # representation expected by the Rust bindings.
+ return (
+ [sort_list_to_raw_sort_list(f) for f in file_sort_order]
+ if file_sort_order is not None
+ else None
+ )
+
@staticmethod
def _convert_table_partition_cols(
table_partition_cols: list[tuple[str, str | pa.DataType]],
diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py
index 181c29db..68e6fe5a 100644
--- a/python/datafusion/dataframe.py
+++ b/python/datafusion/dataframe.py
@@ -22,6 +22,7 @@ See :ref:`user_guide_concepts` in the online documentation
for more information.
from __future__ import annotations
import warnings
+from collections.abc import Sequence
from typing import (
TYPE_CHECKING,
Any,
@@ -40,20 +41,25 @@ except ImportError:
from datafusion._internal import DataFrame as DataFrameInternal
from datafusion._internal import ParquetColumnOptions as
ParquetColumnOptionsInternal
from datafusion._internal import ParquetWriterOptions as
ParquetWriterOptionsInternal
-from datafusion.expr import Expr, SortExpr, sort_or_default
+from datafusion.expr import (
+ Expr,
+ SortKey,
+ ensure_expr,
+ ensure_expr_list,
+ expr_list_to_raw_expr_list,
+ sort_list_to_raw_sort_list,
+)
from datafusion.plan import ExecutionPlan, LogicalPlan
from datafusion.record_batch import RecordBatchStream
if TYPE_CHECKING:
import pathlib
- from typing import Callable, Sequence
+ from typing import Callable
import pandas as pd
import polars as pl
import pyarrow as pa
- from datafusion._internal import expr as expr_internal
-
from enum import Enum
@@ -401,9 +407,7 @@ class DataFrame:
df = df.select("a", col("b"), col("a").alias("alternate_a"))
"""
- exprs_internal = [
- Expr.column(arg).expr if isinstance(arg, str) else arg.expr for
arg in exprs
- ]
+ exprs_internal = expr_list_to_raw_expr_list(exprs)
return DataFrame(self.df.select(*exprs_internal))
def drop(self, *columns: str) -> DataFrame:
@@ -421,9 +425,17 @@ class DataFrame:
"""Return a DataFrame for which ``predicate`` evaluates to ``True``.
Rows for which ``predicate`` evaluates to ``False`` or ``None`` are
filtered
- out. If more than one predicate is provided, these predicates will be
- combined as a logical AND. If more complex logic is required, see the
- logical operations in :py:mod:`~datafusion.functions`.
+ out. If more than one predicate is provided, these predicates will be
+ combined as a logical AND. Each ``predicate`` must be an
+ :class:`~datafusion.expr.Expr` created using helper functions such as
+ :func:`datafusion.col` or :func:`datafusion.lit`.
+ If more complex logic is required, see the logical operations in
+ :py:mod:`~datafusion.functions`.
+
+ Example::
+
+ from datafusion import col, lit
+ df.filter(col("a") > lit(1))
Args:
predicates: Predicate expression(s) to filter the DataFrame.
@@ -433,12 +445,20 @@ class DataFrame:
"""
df = self.df
for p in predicates:
- df = df.filter(p.expr)
+ df = df.filter(ensure_expr(p))
return DataFrame(df)
def with_column(self, name: str, expr: Expr) -> DataFrame:
"""Add an additional column to the DataFrame.
+ The ``expr`` must be an :class:`~datafusion.expr.Expr` constructed with
+ :func:`datafusion.col` or :func:`datafusion.lit`.
+
+ Example::
+
+ from datafusion import col, lit
+ df.with_column("b", col("a") + lit(1))
+
Args:
name: Name of the column to add.
expr: Expression to compute the column.
@@ -446,23 +466,27 @@ class DataFrame:
Returns:
DataFrame with the new column.
"""
- return DataFrame(self.df.with_column(name, expr.expr))
+ return DataFrame(self.df.with_column(name, ensure_expr(expr)))
def with_columns(
self, *exprs: Expr | Iterable[Expr], **named_exprs: Expr
) -> DataFrame:
"""Add columns to the DataFrame.
- By passing expressions, iteratables of expressions, or named
expressions. To
- pass named expressions use the form name=Expr.
+ By passing expressions, iterables of expressions, or named expressions.
+ All expressions must be :class:`~datafusion.expr.Expr` objects created
via
+ :func:`datafusion.col` or :func:`datafusion.lit`.
+ To pass named expressions use the form ``name=Expr``.
- Example usage: The following will add 4 columns labeled a, b, c, and
d::
+ Example usage: The following will add 4 columns labeled ``a``, ``b``,
``c``,
+ and ``d``::
+ from datafusion import col, lit
df = df.with_columns(
- lit(0).alias('a'),
- [lit(1).alias('b'), lit(2).alias('c')],
+ col("x").alias("a"),
+ [lit(1).alias("b"), col("y").alias("c")],
d=lit(3)
- )
+ )
Args:
exprs: Either a single expression or an iterable of expressions to
add.
@@ -471,24 +495,10 @@ class DataFrame:
Returns:
DataFrame with the new columns added.
"""
-
- def _simplify_expression(
- *exprs: Expr | Iterable[Expr], **named_exprs: Expr
- ) -> list[expr_internal.Expr]:
- expr_list = []
- for expr in exprs:
- if isinstance(expr, Expr):
- expr_list.append(expr.expr)
- elif isinstance(expr, Iterable):
- expr_list.extend(inner_expr.expr for inner_expr in expr)
- else:
- raise NotImplementedError
- if named_exprs:
- for alias, expr in named_exprs.items():
- expr_list.append(expr.alias(alias).expr)
- return expr_list
-
- expressions = _simplify_expression(*exprs, **named_exprs)
+ expressions = ensure_expr_list(exprs)
+ for alias, expr in named_exprs.items():
+ ensure_expr(expr)
+ expressions.append(expr.alias(alias).expr)
return DataFrame(self.df.with_columns(expressions))
@@ -510,37 +520,47 @@ class DataFrame:
return DataFrame(self.df.with_column_renamed(old_name, new_name))
def aggregate(
- self, group_by: list[Expr] | Expr, aggs: list[Expr] | Expr
+ self,
+ group_by: Sequence[Expr | str] | Expr | str,
+ aggs: Sequence[Expr] | Expr,
) -> DataFrame:
"""Aggregates the rows of the current DataFrame.
Args:
- group_by: List of expressions to group by.
- aggs: List of expressions to aggregate.
+ group_by: Sequence of expressions or column names to group by.
+ aggs: Sequence of expressions to aggregate.
Returns:
DataFrame after aggregation.
"""
- group_by = group_by if isinstance(group_by, list) else [group_by]
- aggs = aggs if isinstance(aggs, list) else [aggs]
+ group_by_list = (
+ list(group_by)
+ if isinstance(group_by, Sequence) and not isinstance(group_by,
(Expr, str))
+ else [group_by]
+ )
+ aggs_list = (
+ list(aggs)
+ if isinstance(aggs, Sequence) and not isinstance(aggs, Expr)
+ else [aggs]
+ )
- group_by = [e.expr for e in group_by]
- aggs = [e.expr for e in aggs]
- return DataFrame(self.df.aggregate(group_by, aggs))
+ group_by_exprs = expr_list_to_raw_expr_list(group_by_list)
+ aggs_exprs = ensure_expr_list(aggs_list)
+ return DataFrame(self.df.aggregate(group_by_exprs, aggs_exprs))
- def sort(self, *exprs: Expr | SortExpr) -> DataFrame:
- """Sort the DataFrame by the specified sorting expressions.
+ def sort(self, *exprs: SortKey) -> DataFrame:
+ """Sort the DataFrame by the specified sorting expressions or column
names.
Note that any expression can be turned into a sort expression by
- calling its` ``sort`` method.
+ calling its ``sort`` method.
Args:
- exprs: Sort expressions, applied in order.
+ exprs: Sort expressions or column names, applied in order.
Returns:
DataFrame after sorting.
"""
- exprs_raw = [sort_or_default(expr) for expr in exprs]
+ exprs_raw = sort_list_to_raw_sort_list(exprs)
return DataFrame(self.df.sort(*exprs_raw))
def cast(self, mapping: dict[str, pa.DataType[Any]]) -> DataFrame:
@@ -752,8 +772,14 @@ class DataFrame:
) -> DataFrame:
"""Join two :py:class:`DataFrame` using the specified expressions.
- On expressions are used to support in-equality predicates. Equality
- predicates are correctly optimized
+ Join predicates must be :class:`~datafusion.expr.Expr` objects,
typically
+ built with :func:`datafusion.col`. On expressions are used to support
+ in-equality predicates. Equality predicates are correctly optimized.
+
+ Example::
+
+ from datafusion import col
+ df.join_on(other_df, col("id") == col("other_id"))
Args:
right: Other DataFrame to join with.
@@ -764,7 +790,7 @@ class DataFrame:
Returns:
DataFrame after join.
"""
- exprs = [expr.expr for expr in on_exprs]
+ exprs = [ensure_expr(expr) for expr in on_exprs]
return DataFrame(self.df.join_on(right.df, exprs, how))
def explain(self, verbose: bool = False, analyze: bool = False) -> None:
diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py
index b5156040..5d1180bd 100644
--- a/python/datafusion/expr.py
+++ b/python/datafusion/expr.py
@@ -22,7 +22,8 @@ See :ref:`Expressions` in the online documentation for more
details.
from __future__ import annotations
-from typing import TYPE_CHECKING, Any, ClassVar, Optional
+import typing as _typing
+from typing import TYPE_CHECKING, Any, ClassVar, Iterable, Optional, Sequence
import pyarrow as pa
@@ -31,14 +32,23 @@ try:
except ImportError:
from typing_extensions import deprecated # Python 3.12
-from datafusion.common import DataTypeMap, NullTreatment, RexType
+from datafusion.common import NullTreatment
from ._internal import expr as expr_internal
from ._internal import functions as functions_internal
if TYPE_CHECKING:
+ from collections.abc import Sequence
+
+ # Type-only imports
+ from datafusion.common import DataTypeMap, RexType
from datafusion.plan import LogicalPlan
+
+# Standard error message for invalid expression types
+# Mention both alias forms of column and literal helpers
+EXPR_TYPE_ERROR = "Use col()/column() or lit()/literal() to construct
expressions"
+
# The following are imported from the internal representation. We may choose to
# give these all proper wrappers, or to simply leave as is. These were added
# in order to support passing the `test_imports` unit test.
@@ -126,6 +136,7 @@ Values = expr_internal.Values
WindowExpr = expr_internal.WindowExpr
__all__ = [
+ "EXPR_TYPE_ERROR",
"Aggregate",
"AggregateFunction",
"Alias",
@@ -195,6 +206,7 @@ __all__ = [
"SimilarTo",
"Sort",
"SortExpr",
+ "SortKey",
"Subquery",
"SubqueryAlias",
"TableScan",
@@ -212,19 +224,97 @@ __all__ = [
"WindowExpr",
"WindowFrame",
"WindowFrameBound",
+ "ensure_expr",
+ "ensure_expr_list",
]
+def ensure_expr(value: _typing.Union[Expr, Any]) -> expr_internal.Expr:
+ """Return the internal expression from ``Expr`` or raise ``TypeError``.
+
+ This helper rejects plain strings and other non-:class:`Expr` values so
+ higher level APIs consistently require explicit :func:`~datafusion.col` or
+ :func:`~datafusion.lit` expressions.
+
+ Args:
+ value: Candidate expression or other object.
+
+ Returns:
+ The internal expression representation.
+
+ Raises:
+ TypeError: If ``value`` is not an instance of :class:`Expr`.
+ """
+ if not isinstance(value, Expr):
+ raise TypeError(EXPR_TYPE_ERROR)
+ return value.expr
+
+
+def ensure_expr_list(
+ exprs: Iterable[_typing.Union[Expr, Iterable[Expr]]],
+) -> list[expr_internal.Expr]:
+ """Flatten an iterable of expressions, validating each via ``ensure_expr``.
+
+ Args:
+ exprs: Possibly nested iterable containing expressions.
+
+ Returns:
+ A flat list of raw expressions.
+
+ Raises:
+ TypeError: If any item is not an instance of :class:`Expr`.
+ """
+
+ def _iter(
+ items: Iterable[_typing.Union[Expr, Iterable[Expr]]],
+ ) -> Iterable[expr_internal.Expr]:
+ for expr in items:
+ if isinstance(expr, Iterable) and not isinstance(
+ expr, (Expr, str, bytes, bytearray)
+ ):
+ # Treat string-like objects as atomic to surface standard
errors
+ yield from _iter(expr)
+ else:
+ yield ensure_expr(expr)
+
+ return list(_iter(exprs))
+
+
+def _to_raw_expr(value: _typing.Union[Expr, str]) -> expr_internal.Expr:
+ """Convert a Python expression or column name to its raw variant.
+
+ Args:
+ value: Candidate expression or column name.
+
+ Returns:
+ The internal :class:`~datafusion._internal.expr.Expr` representation.
+
+ Raises:
+ TypeError: If ``value`` is neither an :class:`Expr` nor ``str``.
+ """
+ if isinstance(value, str):
+ return Expr.column(value).expr
+ if isinstance(value, Expr):
+ return value.expr
+ error = (
+ "Expected Expr or column name, found:"
+ f" {type(value).__name__}. {EXPR_TYPE_ERROR}."
+ )
+ raise TypeError(error)
+
+
def expr_list_to_raw_expr_list(
expr_list: Optional[list[Expr] | Expr],
) -> Optional[list[expr_internal.Expr]]:
- """Helper function to convert an optional list to raw expressions."""
- if isinstance(expr_list, Expr):
+ """Convert a sequence of expressions or column names to raw expressions."""
+ if isinstance(expr_list, (Expr, str)):
expr_list = [expr_list]
- return [e.expr for e in expr_list] if expr_list is not None else None
+ if expr_list is None:
+ return None
+ return [_to_raw_expr(e) for e in expr_list]
-def sort_or_default(e: Expr | SortExpr) -> expr_internal.SortExpr:
+def sort_or_default(e: _typing.Union[Expr, SortExpr]) ->
expr_internal.SortExpr:
"""Helper function to return a default Sort if an Expr is provided."""
if isinstance(e, SortExpr):
return e.raw_sort
@@ -232,12 +322,21 @@ def sort_or_default(e: Expr | SortExpr) ->
expr_internal.SortExpr:
def sort_list_to_raw_sort_list(
- sort_list: Optional[list[Expr | SortExpr] | Expr | SortExpr],
+ sort_list: Optional[_typing.Union[Sequence[SortKey], SortKey]],
) -> Optional[list[expr_internal.SortExpr]]:
"""Helper function to return an optional sort list to raw variant."""
- if isinstance(sort_list, (Expr, SortExpr)):
+ if isinstance(sort_list, (Expr, SortExpr, str)):
sort_list = [sort_list]
- return [sort_or_default(e) for e in sort_list] if sort_list is not None
else None
+ if sort_list is None:
+ return None
+ raw_sort_list = []
+ for item in sort_list:
+ if isinstance(item, SortExpr):
+ raw_sort_list.append(sort_or_default(item))
+ else:
+ raw_expr = _to_raw_expr(item) # may raise ``TypeError``
+ raw_sort_list.append(sort_or_default(Expr(raw_expr)))
+ return raw_sort_list
class Expr:
@@ -352,7 +451,7 @@ class Expr:
"""Binary not (~)."""
return Expr(self.expr.__invert__())
- def __getitem__(self, key: str | int | slice) -> Expr:
+ def __getitem__(self, key: str | int) -> Expr:
"""Retrieve sub-object.
If ``key`` is a string, returns the subfield of the struct.
@@ -530,13 +629,13 @@ class Expr:
"""Returns ``True`` if this expression is not null."""
return Expr(self.expr.is_not_null())
- def fill_nan(self, value: Any | Expr | None = None) -> Expr:
+ def fill_nan(self, value: Optional[_typing.Union[Any, Expr]] = None) ->
Expr:
"""Fill NaN values with a provided value."""
if not isinstance(value, Expr):
value = Expr.literal(value)
return Expr(functions_internal.nanvl(self.expr, value.expr))
- def fill_null(self, value: Any | Expr | None = None) -> Expr:
+ def fill_null(self, value: Optional[_typing.Union[Any, Expr]] = None) ->
Expr:
"""Fill NULL values with a provided value."""
if not isinstance(value, Expr):
value = Expr.literal(value)
@@ -549,7 +648,7 @@ class Expr:
bool: pa.bool_(),
}
- def cast(self, to: pa.DataType[Any] | type[float | int | str | bool]) ->
Expr:
+ def cast(self, to: _typing.Union[pa.DataType[Any], type]) -> Expr:
"""Cast to a new data type."""
if not isinstance(to, pa.DataType):
try:
@@ -622,7 +721,7 @@ class Expr:
"""Compute the output column name based on the provided logical
plan."""
return self.expr.column_name(plan._raw_plan)
- def order_by(self, *exprs: Expr | SortExpr) -> ExprFuncBuilder:
+ def order_by(self, *exprs: _typing.Union[Expr, SortExpr]) ->
ExprFuncBuilder:
"""Set the ordering for a window or aggregate function.
This function will create an :py:class:`ExprFuncBuilder` that can be
used to
@@ -687,7 +786,7 @@ class Expr:
window: Window definition
"""
partition_by_raw = expr_list_to_raw_expr_list(window._partition_by)
- order_by_raw = sort_list_to_raw_sort_list(window._order_by)
+ order_by_raw = window._order_by
window_frame_raw = (
window._window_frame.window_frame
if window._window_frame is not None
@@ -1171,9 +1270,16 @@ class Window:
def __init__(
self,
- partition_by: Optional[list[Expr] | Expr] = None,
+ partition_by: Optional[_typing.Union[list[Expr], Expr]] = None,
window_frame: Optional[WindowFrame] = None,
- order_by: Optional[list[SortExpr | Expr] | Expr | SortExpr] = None,
+ order_by: Optional[
+ _typing.Union[
+ list[_typing.Union[SortExpr, Expr, str]],
+ Expr,
+ SortExpr,
+ str,
+ ]
+ ] = None,
null_treatment: Optional[NullTreatment] = None,
) -> None:
"""Construct a window definition.
@@ -1186,7 +1292,7 @@ class Window:
"""
self._partition_by = partition_by
self._window_frame = window_frame
- self._order_by = order_by
+ self._order_by = sort_list_to_raw_sort_list(order_by)
self._null_treatment = null_treatment
@@ -1244,7 +1350,7 @@ class WindowFrameBound:
"""Constructs a window frame bound."""
self.frame_bound = frame_bound
- def get_offset(self) -> int | None:
+ def get_offset(self) -> Optional[int]:
"""Returns the offset of the window frame."""
return self.frame_bound.get_offset()
@@ -1326,3 +1432,6 @@ class SortExpr:
def __repr__(self) -> str:
"""Generate a string representation of this expression."""
return self.raw_sort.__repr__()
+
+
+SortKey = _typing.Union[Expr, SortExpr, str]
diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py
index 7ee4929a..648efef7 100644
--- a/python/datafusion/functions.py
+++ b/python/datafusion/functions.py
@@ -28,6 +28,7 @@ from datafusion.expr import (
CaseBuilder,
Expr,
SortExpr,
+ SortKey,
WindowFrame,
expr_list_to_raw_expr_list,
sort_list_to_raw_sort_list,
@@ -429,7 +430,7 @@ def window(
name: str,
args: list[Expr],
partition_by: list[Expr] | Expr | None = None,
- order_by: list[Expr | SortExpr] | Expr | SortExpr | None = None,
+ order_by: list[SortKey] | SortKey | None = None,
window_frame: WindowFrame | None = None,
ctx: SessionContext | None = None,
) -> Expr:
@@ -440,6 +441,10 @@ def window(
lag use::
df.select(functions.lag(col("a")).partition_by(col("b")).build())
+
+ The ``order_by`` parameter accepts column names or expressions, e.g.::
+
+ window("lag", [col("a")], order_by="ts")
"""
args = [a.expr for a in args]
partition_by_raw = expr_list_to_raw_expr_list(partition_by)
@@ -1723,7 +1728,7 @@ def array_agg(
expression: Expr,
distinct: bool = False,
filter: Optional[Expr] = None,
- order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
+ order_by: Optional[list[SortKey] | SortKey] = None,
) -> Expr:
"""Aggregate values into an array.
@@ -1738,7 +1743,11 @@ def array_agg(
expression: Values to combine into an array
distinct: If True, a single entry for each distinct value will be in
the result
filter: If provided, only compute against rows for which the filter is
True
- order_by: Order the resultant array values
+ order_by: Order the resultant array values. Accepts column names or
expressions.
+
+ For example::
+
+ df.select(array_agg(col("a"), order_by="b"))
"""
order_by_raw = sort_list_to_raw_sort_list(order_by)
filter_raw = filter.expr if filter is not None else None
@@ -2222,7 +2231,7 @@ def regr_syy(
def first_value(
expression: Expr,
filter: Optional[Expr] = None,
- order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
+ order_by: Optional[list[SortKey] | SortKey] = None,
null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS,
) -> Expr:
"""Returns the first value in a group of values.
@@ -2235,8 +2244,13 @@ def first_value(
Args:
expression: Argument to perform bitwise calculation on
filter: If provided, only compute against rows for which the filter is
True
- order_by: Set the ordering of the expression to evaluate
+ order_by: Set the ordering of the expression to evaluate. Accepts
+ column names or expressions.
null_treatment: Assign whether to respect or ignore null values.
+
+ For example::
+
+ df.select(first_value(col("a"), order_by="ts"))
"""
order_by_raw = sort_list_to_raw_sort_list(order_by)
filter_raw = filter.expr if filter is not None else None
@@ -2254,7 +2268,7 @@ def first_value(
def last_value(
expression: Expr,
filter: Optional[Expr] = None,
- order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
+ order_by: Optional[list[SortKey] | SortKey] = None,
null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS,
) -> Expr:
"""Returns the last value in a group of values.
@@ -2267,8 +2281,13 @@ def last_value(
Args:
expression: Argument to perform bitwise calculation on
filter: If provided, only compute against rows for which the filter is
True
- order_by: Set the ordering of the expression to evaluate
+ order_by: Set the ordering of the expression to evaluate. Accepts
+ column names or expressions.
null_treatment: Assign whether to respect or ignore null values.
+
+ For example::
+
+ df.select(last_value(col("a"), order_by="ts"))
"""
order_by_raw = sort_list_to_raw_sort_list(order_by)
filter_raw = filter.expr if filter is not None else None
@@ -2287,7 +2306,7 @@ def nth_value(
expression: Expr,
n: int,
filter: Optional[Expr] = None,
- order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
+ order_by: Optional[list[SortKey] | SortKey] = None,
null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS,
) -> Expr:
"""Returns the n-th value in a group of values.
@@ -2301,8 +2320,13 @@ def nth_value(
expression: Argument to perform bitwise calculation on
n: Index of value to return. Starts at 1.
filter: If provided, only compute against rows for which the filter is
True
- order_by: Set the ordering of the expression to evaluate
+ order_by: Set the ordering of the expression to evaluate. Accepts
+ column names or expressions.
null_treatment: Assign whether to respect or ignore null values.
+
+ For example::
+
+ df.select(nth_value(col("a"), 2, order_by="ts"))
"""
order_by_raw = sort_list_to_raw_sort_list(order_by)
filter_raw = filter.expr if filter is not None else None
@@ -2408,7 +2432,7 @@ def lead(
shift_offset: int = 1,
default_value: Optional[Any] = None,
partition_by: Optional[list[Expr] | Expr] = None,
- order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
+ order_by: Optional[list[SortKey] | SortKey] = None,
) -> Expr:
"""Create a lead window function.
@@ -2437,7 +2461,12 @@ def lead(
shift_offset: Number of rows following the current row.
default_value: Value to return if shift_offet row does not exist.
partition_by: Expressions to partition the window frame on.
- order_by: Set ordering within the window frame.
+ order_by: Set ordering within the window frame. Accepts
+ column names or expressions.
+
+ For example::
+
+ lead(col("b"), order_by="ts")
"""
if not isinstance(default_value, pa.Scalar) and default_value is not None:
default_value = pa.scalar(default_value)
@@ -2461,7 +2490,7 @@ def lag(
shift_offset: int = 1,
default_value: Optional[Any] = None,
partition_by: Optional[list[Expr] | Expr] = None,
- order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
+ order_by: Optional[list[SortKey] | SortKey] = None,
) -> Expr:
"""Create a lag window function.
@@ -2487,7 +2516,12 @@ def lag(
shift_offset: Number of rows before the current row.
default_value: Value to return if shift_offet row does not exist.
partition_by: Expressions to partition the window frame on.
- order_by: Set ordering within the window frame.
+ order_by: Set ordering within the window frame. Accepts
+ column names or expressions.
+
+ For example::
+
+ lag(col("b"), order_by="ts")
"""
if not isinstance(default_value, pa.Scalar):
default_value = pa.scalar(default_value)
@@ -2508,7 +2542,7 @@ def lag(
def row_number(
partition_by: Optional[list[Expr] | Expr] = None,
- order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
+ order_by: Optional[list[SortKey] | SortKey] = None,
) -> Expr:
"""Create a row number window function.
@@ -2527,7 +2561,12 @@ def row_number(
Args:
partition_by: Expressions to partition the window frame on.
- order_by: Set ordering within the window frame.
+ order_by: Set ordering within the window frame. Accepts
+ column names or expressions.
+
+ For example::
+
+ row_number(order_by="points")
"""
partition_by_raw = expr_list_to_raw_expr_list(partition_by)
order_by_raw = sort_list_to_raw_sort_list(order_by)
@@ -2542,7 +2581,7 @@ def row_number(
def rank(
partition_by: Optional[list[Expr] | Expr] = None,
- order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
+ order_by: Optional[list[SortKey] | SortKey] = None,
) -> Expr:
"""Create a rank window function.
@@ -2566,7 +2605,12 @@ def rank(
Args:
partition_by: Expressions to partition the window frame on.
- order_by: Set ordering within the window frame.
+ order_by: Set ordering within the window frame. Accepts
+ column names or expressions.
+
+ For example::
+
+ rank(order_by="points")
"""
partition_by_raw = expr_list_to_raw_expr_list(partition_by)
order_by_raw = sort_list_to_raw_sort_list(order_by)
@@ -2581,7 +2625,7 @@ def rank(
def dense_rank(
partition_by: Optional[list[Expr] | Expr] = None,
- order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
+ order_by: Optional[list[SortKey] | SortKey] = None,
) -> Expr:
"""Create a dense_rank window function.
@@ -2600,7 +2644,12 @@ def dense_rank(
Args:
partition_by: Expressions to partition the window frame on.
- order_by: Set ordering within the window frame.
+ order_by: Set ordering within the window frame. Accepts
+ column names or expressions.
+
+ For example::
+
+ dense_rank(order_by="points")
"""
partition_by_raw = expr_list_to_raw_expr_list(partition_by)
order_by_raw = sort_list_to_raw_sort_list(order_by)
@@ -2615,7 +2664,7 @@ def dense_rank(
def percent_rank(
partition_by: Optional[list[Expr] | Expr] = None,
- order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
+ order_by: Optional[list[SortKey] | SortKey] = None,
) -> Expr:
"""Create a percent_rank window function.
@@ -2635,7 +2684,12 @@ def percent_rank(
Args:
partition_by: Expressions to partition the window frame on.
- order_by: Set ordering within the window frame.
+ order_by: Set ordering within the window frame. Accepts
+ column names or expressions.
+
+ For example::
+
+ percent_rank(order_by="points")
"""
partition_by_raw = expr_list_to_raw_expr_list(partition_by)
order_by_raw = sort_list_to_raw_sort_list(order_by)
@@ -2650,7 +2704,7 @@ def percent_rank(
def cume_dist(
partition_by: Optional[list[Expr] | Expr] = None,
- order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
+ order_by: Optional[list[SortKey] | SortKey] = None,
) -> Expr:
"""Create a cumulative distribution window function.
@@ -2670,7 +2724,12 @@ def cume_dist(
Args:
partition_by: Expressions to partition the window frame on.
- order_by: Set ordering within the window frame.
+ order_by: Set ordering within the window frame. Accepts
+ column names or expressions.
+
+ For example::
+
+ cume_dist(order_by="points")
"""
partition_by_raw = expr_list_to_raw_expr_list(partition_by)
order_by_raw = sort_list_to_raw_sort_list(order_by)
@@ -2686,7 +2745,7 @@ def cume_dist(
def ntile(
groups: int,
partition_by: Optional[list[Expr] | Expr] = None,
- order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
+ order_by: Optional[list[SortKey] | SortKey] = None,
) -> Expr:
"""Create a n-tile window function.
@@ -2709,7 +2768,12 @@ def ntile(
Args:
groups: Number of groups for the n-tile to be divided into.
partition_by: Expressions to partition the window frame on.
- order_by: Set ordering within the window frame.
+ order_by: Set ordering within the window frame. Accepts
+ column names or expressions.
+
+ For example::
+
+ ntile(3, order_by="points")
"""
partition_by_raw = expr_list_to_raw_expr_list(partition_by)
order_by_raw = sort_list_to_raw_sort_list(order_by)
@@ -2727,7 +2791,7 @@ def string_agg(
expression: Expr,
delimiter: str,
filter: Optional[Expr] = None,
- order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
+ order_by: Optional[list[SortKey] | SortKey] = None,
) -> Expr:
"""Concatenates the input strings.
@@ -2742,7 +2806,12 @@ def string_agg(
expression: Argument to perform bitwise calculation on
delimiter: Text to place between each value of expression
filter: If provided, only compute against rows for which the filter is
True
- order_by: Set the ordering of the expression to evaluate
+ order_by: Set the ordering of the expression to evaluate. Accepts
+ column names or expressions.
+
+ For example::
+
+ df.select(string_agg(col("a"), ",", order_by="b"))
"""
order_by_raw = sort_list_to_raw_sort_list(order_by)
filter_raw = filter.expr if filter is not None else None
diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py
index d00dc9c6..1cf48ec7 100644
--- a/python/tests/test_dataframe.py
+++ b/python/tests/test_dataframe.py
@@ -34,6 +34,9 @@ from datafusion import (
column,
literal,
)
+from datafusion import (
+ col as df_col,
+)
from datafusion import (
functions as f,
)
@@ -43,7 +46,7 @@ from datafusion.dataframe_formatter import (
get_formatter,
reset_formatter,
)
-from datafusion.expr import Window
+from datafusion.expr import EXPR_TYPE_ERROR, Window
from pyarrow.csv import write_csv
MB = 1024 * 1024
@@ -227,6 +230,14 @@ def test_select_mixed_expr_string(df):
assert result.column(1) == pa.array([1, 2, 3])
+def test_select_unsupported(df):
+ with pytest.raises(
+ TypeError,
+ match=f"Expected Expr or column name.*{re.escape(EXPR_TYPE_ERROR)}",
+ ):
+ df.select(1)
+
+
def test_filter(df):
df1 = df.filter(column("a") > literal(2)).select(
column("a") + column("b"),
@@ -268,6 +279,47 @@ def test_sort(df):
assert table.to_pydict() == expected
+def test_sort_string_and_expression_equivalent(df):
+ from datafusion import col
+
+ result_str = df.sort("a").to_pydict()
+ result_expr = df.sort(col("a")).to_pydict()
+ assert result_str == result_expr
+
+
+def test_sort_unsupported(df):
+ with pytest.raises(
+ TypeError,
+ match=f"Expected Expr or column name.*{re.escape(EXPR_TYPE_ERROR)}",
+ ):
+ df.sort(1)
+
+
+def test_aggregate_string_and_expression_equivalent(df):
+ from datafusion import col
+
+ result_str = df.aggregate("a", [f.count()]).sort("a").to_pydict()
+ result_expr = df.aggregate(col("a"), [f.count()]).sort("a").to_pydict()
+ assert result_str == result_expr
+
+
+def test_aggregate_tuple_group_by(df):
+ result_list = df.aggregate(["a"], [f.count()]).sort("a").to_pydict()
+ result_tuple = df.aggregate(("a",), [f.count()]).sort("a").to_pydict()
+ assert result_tuple == result_list
+
+
+def test_aggregate_tuple_aggs(df):
+ result_list = df.aggregate("a", [f.count()]).sort("a").to_pydict()
+ result_tuple = df.aggregate("a", (f.count(),)).sort("a").to_pydict()
+ assert result_tuple == result_list
+
+
+def test_filter_string_unsupported(df):
+ with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)):
+ df.filter("a > 1")
+
+
def test_drop(df):
df = df.drop("c")
@@ -337,6 +389,13 @@ def test_with_column(df):
assert result.column(2) == pa.array([5, 7, 9])
+def test_with_column_invalid_expr(df):
+ with pytest.raises(
+ TypeError, match=r"Use col\(\)/column\(\) or lit\(\)/literal\(\)"
+ ):
+ df.with_column("c", "a")
+
+
def test_with_columns(df):
df = df.with_columns(
(column("a") + column("b")).alias("c"),
@@ -368,6 +427,17 @@ def test_with_columns(df):
assert result.column(6) == pa.array([5, 7, 9])
+def test_with_columns_invalid_expr(df):
+ with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)):
+ df.with_columns("a")
+ with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)):
+ df.with_columns(c="a")
+ with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)):
+ df.with_columns(["a"])
+ with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)):
+ df.with_columns(c=["a"])
+
+
def test_cast(df):
df = df.cast({"a": pa.float16(), "b": pa.list_(pa.uint32())})
expected = pa.schema(
@@ -526,6 +596,29 @@ def test_join_on():
assert table.to_pydict() == expected
+def test_join_on_invalid_expr():
+ ctx = SessionContext()
+
+ batch = pa.RecordBatch.from_arrays(
+ [pa.array([1, 2]), pa.array([4, 5])],
+ names=["a", "b"],
+ )
+ df = ctx.create_dataframe([[batch]], "l")
+ df1 = ctx.create_dataframe([[batch]], "r")
+
+ with pytest.raises(
+ TypeError, match=r"Use col\(\)/column\(\) or lit\(\)/literal\(\)"
+ ):
+ df.join_on(df1, "a")
+
+
+def test_aggregate_invalid_aggs(df):
+ with pytest.raises(
+ TypeError, match=r"Use col\(\)/column\(\) or lit\(\)/literal\(\)"
+ ):
+ df.aggregate([], "a")
+
+
def test_distinct():
ctx = SessionContext()
@@ -713,6 +806,13 @@ data_test_window_functions = [
),
[1, 1, 1, 1, 5, 5, 5],
),
+ (
+ "first_value_order_by_string",
+ f.first_value(column("a")).over(
+ Window(partition_by=[column("c")], order_by="b")
+ ),
+ [1, 1, 1, 1, 5, 5, 5],
+ ),
(
"last_value",
f.last_value(column("a")).over(
@@ -755,6 +855,27 @@ def test_window_functions(partitioned_df, name, expr,
result):
assert table.sort_by("a").to_pydict() == expected
[email protected]("partition", ["c", df_col("c")])
+def test_rank_partition_by_accepts_string(partitioned_df, partition):
+ """Passing a string to partition_by should match using col()."""
+ df = partitioned_df.select(
+ f.rank(order_by=column("a"), partition_by=partition).alias("r")
+ )
+ table = pa.Table.from_batches(df.sort(column("a")).collect())
+ assert table.column("r").to_pylist() == [1, 2, 3, 4, 1, 2, 3]
+
+
[email protected]("partition", ["c", df_col("c")])
+def test_window_partition_by_accepts_string(partitioned_df, partition):
+ """Window.partition_by accepts string identifiers."""
+ expr = f.first_value(column("a")).over(
+ Window(partition_by=partition, order_by=column("b"))
+ )
+ df = partitioned_df.select(expr.alias("fv"))
+ table = pa.Table.from_batches(df.sort(column("a")).collect())
+ assert table.column("fv").to_pylist() == [1, 1, 1, 1, 5, 5, 5]
+
+
@pytest.mark.parametrize(
("units", "start_bound", "end_bound"),
[
@@ -825,6 +946,69 @@ def
test_window_frame_defaults_match_postgres(partitioned_df):
assert df_2.sort(col_a).to_pydict() == expected
+def _build_last_value_df(df):
+ return df.select(
+ f.last_value(column("a"))
+ .over(
+ Window(
+ partition_by=[column("c")],
+ order_by=[column("b")],
+ window_frame=WindowFrame("rows", None, None),
+ )
+ )
+ .alias("expr"),
+ f.last_value(column("a"))
+ .over(
+ Window(
+ partition_by=[column("c")],
+ order_by="b",
+ window_frame=WindowFrame("rows", None, None),
+ )
+ )
+ .alias("str"),
+ )
+
+
+def _build_nth_value_df(df):
+ return df.select(
+ f.nth_value(column("b"),
3).over(Window(order_by=[column("a")])).alias("expr"),
+ f.nth_value(column("b"), 3).over(Window(order_by="a")).alias("str"),
+ )
+
+
+def _build_rank_df(df):
+ return df.select(
+ f.rank(order_by=[column("b")]).alias("expr"),
+ f.rank(order_by="b").alias("str"),
+ )
+
+
+def _build_array_agg_df(df):
+ return df.aggregate(
+ [column("c")],
+ [
+ f.array_agg(column("a"), order_by=[column("a")]).alias("expr"),
+ f.array_agg(column("a"), order_by="a").alias("str"),
+ ],
+ ).sort(column("c"))
+
+
[email protected](
+ ("builder", "expected"),
+ [
+ pytest.param(_build_last_value_df, [3, 3, 3, 3, 6, 6, 6],
id="last_value"),
+ pytest.param(_build_nth_value_df, [None, None, 7, 7, 7, 7, 7],
id="nth_value"),
+ pytest.param(_build_rank_df, [1, 1, 3, 3, 5, 6, 6], id="rank"),
+ pytest.param(_build_array_agg_df, [[0, 1, 2, 3], [4, 5, 6]],
id="array_agg"),
+ ],
+)
+def test_order_by_string_equivalence(partitioned_df, builder, expected):
+ df = builder(partitioned_df)
+ table = pa.Table.from_batches(df.collect())
+ assert table.column("expr").to_pylist() == expected
+ assert table.column("expr").to_pylist() == table.column("str").to_pylist()
+
+
def test_html_formatter_cell_dimension(df, clean_formatter_state):
"""Test configuring the HTML formatter with different options."""
# Configure with custom settings
@@ -2680,3 +2864,34 @@ def test_show_from_empty_batch(capsys) -> None:
ctx.create_dataframe([[batch]]).show()
out = capsys.readouterr().out
assert "| a |" in out
+
+
[email protected]("file_sort_order", [[["a"]], [[df_col("a")]]])
+def test_register_parquet_file_sort_order(ctx, tmp_path, file_sort_order):
+ table = pa.table({"a": [1, 2]})
+ path = tmp_path / "file.parquet"
+ pa.parquet.write_table(table, path)
+ ctx.register_parquet("t", path, file_sort_order=file_sort_order)
+ assert "t" in ctx.catalog().schema().names()
+
+
[email protected]("file_sort_order", [[["a"]], [[df_col("a")]]])
+def test_register_listing_table_file_sort_order(ctx, tmp_path,
file_sort_order):
+ table = pa.table({"a": [1, 2]})
+ dir_path = tmp_path / "dir"
+ dir_path.mkdir()
+ pa.parquet.write_table(table, dir_path / "file.parquet")
+ ctx.register_listing_table(
+ "t", dir_path, schema=table.schema, file_sort_order=file_sort_order
+ )
+ assert "t" in ctx.catalog().schema().names()
+
+
[email protected]("file_sort_order", [[["a"]], [[df_col("a")]]])
+def test_read_parquet_file_sort_order(tmp_path, file_sort_order):
+ ctx = SessionContext()
+ table = pa.table({"a": [1, 2]})
+ path = tmp_path / "data.parquet"
+ pa.parquet.write_table(table, path)
+ df = ctx.read_parquet(path, file_sort_order=file_sort_order)
+ assert df.collect()[0].column(0).to_pylist() == [1, 2]
diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py
index cfeb07c1..810d419c 100644
--- a/python/tests/test_expr.py
+++ b/python/tests/test_expr.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
+import re
from datetime import datetime, timezone
import pyarrow as pa
@@ -28,6 +29,7 @@ from datafusion import (
literal_with_metadata,
)
from datafusion.expr import (
+ EXPR_TYPE_ERROR,
Aggregate,
AggregateFunction,
BinaryExpr,
@@ -47,6 +49,8 @@ from datafusion.expr import (
TransactionEnd,
TransactionStart,
Values,
+ ensure_expr,
+ ensure_expr_list,
)
@@ -880,3 +884,25 @@ def test_literal_metadata(ctx):
for expected_field in expected_schema:
actual_field = result[0].schema.field(expected_field.name)
assert expected_field.metadata == actual_field.metadata
+
+
+def test_ensure_expr():
+ e = col("a")
+ assert ensure_expr(e) is e.expr
+ with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)):
+ ensure_expr("a")
+
+
+def test_ensure_expr_list_string():
+ with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)):
+ ensure_expr_list("a")
+
+
+def test_ensure_expr_list_bytes():
+ with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)):
+ ensure_expr_list(b"a")
+
+
+def test_ensure_expr_list_bytearray():
+ with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)):
+ ensure_expr_list(bytearray(b"a"))
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]