This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-python.git
The following commit(s) were added to refs/heads/main by this push:
new 399fa75 feat: expose PyWindowFrame (#509)
399fa75 is described below
commit 399fa758ccb1dc785929123350144902a9b6c502
Author: Dan Lovell <[email protected]>
AuthorDate: Tue Oct 17 16:59:41 2023 -0400
feat: expose PyWindowFrame (#509)
* feat: expose PyWindowFrame
* fix: PyWindowFrame: return Err instead of panicking
* test: test PyWindowFrame creation
---
datafusion/__init__.py | 2 +
datafusion/tests/test_dataframe.py | 41 +++++++++++++-
src/functions.rs | 20 ++++++-
src/lib.rs | 2 +
src/udaf.rs | 30 +++++++++-
src/window_frame.rs | 110 +++++++++++++++++++++++++++++++++++++
6 files changed, 200 insertions(+), 5 deletions(-)
diff --git a/datafusion/__init__.py b/datafusion/__init__.py
index bb1beac..4a495b4 100644
--- a/datafusion/__init__.py
+++ b/datafusion/__init__.py
@@ -33,6 +33,7 @@ from ._internal import (
SessionConfig,
RuntimeConfig,
ScalarUDF,
+ WindowFrame,
)
from .common import (
@@ -98,6 +99,7 @@ __all__ = [
"Expr",
"AggregateUDF",
"ScalarUDF",
+ "WindowFrame",
"column",
"literal",
"TableScan",
diff --git a/datafusion/tests/test_dataframe.py
b/datafusion/tests/test_dataframe.py
index ce7d89e..c9b0f07 100644
--- a/datafusion/tests/test_dataframe.py
+++ b/datafusion/tests/test_dataframe.py
@@ -21,7 +21,14 @@ import pyarrow.parquet as pq
import pytest
from datafusion import functions as f
-from datafusion import DataFrame, SessionContext, column, literal, udf
+from datafusion import (
+ DataFrame,
+ SessionContext,
+ WindowFrame,
+ column,
+ literal,
+ udf,
+)
@pytest.fixture
@@ -304,6 +311,38 @@ def test_window_functions(df):
assert table.sort_by("a").to_pydict() == expected
[email protected](
+ ("units", "start_bound", "end_bound"),
+ [
+ (units, start_bound, end_bound)
+ for units in ("rows", "range")
+ for start_bound in (None, 0, 1)
+ for end_bound in (None, 0, 1)
+ ]
+ + [
+ ("groups", 0, 0),
+ ],
+)
+def test_valid_window_frame(units, start_bound, end_bound):
+ WindowFrame(units, start_bound, end_bound)
+
+
[email protected](
+ ("units", "start_bound", "end_bound"),
+ [
+ ("invalid-units", 0, None),
+ ("invalid-units", None, 0),
+ ("invalid-units", None, None),
+ ("groups", None, 0),
+ ("groups", 0, None),
+ ("groups", None, None),
+ ],
+)
+def test_invalid_window_frame(units, start_bound, end_bound):
+ with pytest.raises(RuntimeError):
+ WindowFrame(units, start_bound, end_bound)
+
+
def test_get_dataframe(tmp_path):
ctx = SessionContext()
diff --git a/src/functions.rs b/src/functions.rs
index e509aff..42203d7 100644
--- a/src/functions.rs
+++ b/src/functions.rs
@@ -17,9 +17,12 @@
use pyo3::{prelude::*, wrap_pyfunction};
+use crate::context::PySessionContext;
use crate::errors::DataFusionError;
use crate::expr::conditional_expr::PyCaseBuilder;
use crate::expr::PyExpr;
+use crate::window_frame::PyWindowFrame;
+use datafusion::execution::FunctionRegistry;
use datafusion_common::Column;
use datafusion_expr::expr::Alias;
use datafusion_expr::{
@@ -27,7 +30,7 @@ use datafusion_expr::{
expr::{AggregateFunction, ScalarFunction, Sort, WindowFunction},
lit,
window_function::find_df_window_func,
- BuiltinScalarFunction, Expr, WindowFrame,
+ BuiltinScalarFunction, Expr,
};
#[pyfunction]
@@ -130,13 +133,24 @@ fn window(
args: Vec<PyExpr>,
partition_by: Option<Vec<PyExpr>>,
order_by: Option<Vec<PyExpr>>,
+ window_frame: Option<PyWindowFrame>,
+ ctx: Option<PySessionContext>,
) -> PyResult<PyExpr> {
- let fun = find_df_window_func(name);
+ let fun = find_df_window_func(name).or_else(|| {
+ ctx.and_then(|ctx| {
+ ctx.ctx
+ .udaf(name)
+ .map(|fun| datafusion_expr::WindowFunction::AggregateUDF(fun))
+ .ok()
+ })
+ });
if fun.is_none() {
return Err(DataFusionError::Common("window function not
found".to_string()).into());
}
let fun = fun.unwrap();
- let window_frame = WindowFrame::new(order_by.is_some());
+ let window_frame = window_frame
+ .unwrap_or_else(|| PyWindowFrame::new("rows", None, Some(0)).unwrap())
+ .into();
Ok(PyExpr {
expr: datafusion_expr::Expr::WindowFunction(WindowFunction {
fun,
diff --git a/src/lib.rs b/src/lib.rs
index 2512aef..b9bd576 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -54,6 +54,7 @@ mod udaf;
#[allow(clippy::borrow_deref_ref)]
mod udf;
pub mod utils;
+mod window_frame;
#[cfg(feature = "mimalloc")]
#[global_allocator]
@@ -83,6 +84,7 @@ fn _internal(py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<context::PySessionContext>()?;
m.add_class::<dataframe::PyDataFrame>()?;
m.add_class::<udf::PyScalarUDF>()?;
+ m.add_class::<window_frame::PyWindowFrame>()?;
m.add_class::<udaf::PyAggregateUDF>()?;
m.add_class::<config::PyConfig>()?;
m.add_class::<sql::logical::PyLogicalPlan>()?;
diff --git a/src/udaf.rs b/src/udaf.rs
index 3b70aeb..6450f03 100644
--- a/src/udaf.rs
+++ b/src/udaf.rs
@@ -17,7 +17,7 @@
use std::sync::Arc;
-use pyo3::{prelude::*, types::PyTuple};
+use pyo3::{prelude::*, types::PyBool, types::PyTuple};
use datafusion::arrow::array::{Array, ArrayRef};
use datafusion::arrow::datatypes::DataType;
@@ -93,6 +93,34 @@ impl Accumulator for RustAccumulator {
fn size(&self) -> usize {
std::mem::size_of_val(self)
}
+
+ fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ Python::with_gil(|py| {
+ // 1. cast args to Pyarrow array
+ let py_args = values
+ .iter()
+ .map(|arg| arg.into_data().to_pyarrow(py).unwrap())
+ .collect::<Vec<_>>();
+ let py_args = PyTuple::new(py, py_args);
+
+ // 2. call function
+ self.accum
+ .as_ref(py)
+ .call_method1("retract_batch", py_args)
+ .map_err(|e| DataFusionError::Execution(format!("{e}")))?;
+
+ Ok(())
+ })
+ }
+
+ fn supports_retract_batch(&self) -> bool {
+ Python::with_gil(|py| {
+ let x: Result<&PyAny, PyErr> =
+ self.accum.as_ref(py).call_method0("supports_retract_batch");
+ let x: &PyAny = x.unwrap_or(PyBool::new(py, false));
+ x.extract().unwrap_or(false)
+ })
+ }
}
pub fn to_rust_accumulator(accum: PyObject) -> AccumulatorFactoryFunction {
diff --git a/src/window_frame.rs b/src/window_frame.rs
new file mode 100644
index 0000000..b8f414e
--- /dev/null
+++ b/src/window_frame.rs
@@ -0,0 +1,110 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use datafusion_common::{DataFusionError, ScalarValue};
+use datafusion_expr::window_frame::{WindowFrame, WindowFrameBound,
WindowFrameUnits};
+use pyo3::prelude::*;
+use std::fmt::{Display, Formatter};
+
+use crate::errors::py_datafusion_err;
+
+#[pyclass(name = "WindowFrame", module = "datafusion", subclass)]
+#[derive(Clone)]
+pub struct PyWindowFrame {
+ frame: WindowFrame,
+}
+
+impl From<PyWindowFrame> for WindowFrame {
+ fn from(frame: PyWindowFrame) -> Self {
+ frame.frame
+ }
+}
+
+impl From<WindowFrame> for PyWindowFrame {
+ fn from(frame: WindowFrame) -> PyWindowFrame {
+ PyWindowFrame { frame }
+ }
+}
+
+impl Display for PyWindowFrame {
+ fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
+ write!(
+ f,
+ "OVER ({} BETWEEN {} AND {})",
+ self.frame.units, self.frame.start_bound, self.frame.end_bound
+ )
+ }
+}
+
+#[pymethods]
+impl PyWindowFrame {
+ #[new(unit, start_bound, end_bound)]
+ pub fn new(units: &str, start_bound: Option<u64>, end_bound: Option<u64>)
-> PyResult<Self> {
+ let units = units.to_ascii_lowercase();
+ let units = match units.as_str() {
+ "rows" => WindowFrameUnits::Rows,
+ "range" => WindowFrameUnits::Range,
+ "groups" => WindowFrameUnits::Groups,
+ _ => {
+ return
Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
+ "{:?}",
+ units,
+ ))));
+ }
+ };
+ let start_bound = match start_bound {
+ Some(start_bound) => {
+
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(start_bound)))
+ }
+ None => match units {
+ WindowFrameUnits::Range =>
WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
+ WindowFrameUnits::Rows =>
WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
+ WindowFrameUnits::Groups => {
+ return
Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
+ "{:?}",
+ units,
+ ))));
+ }
+ },
+ };
+ let end_bound = match end_bound {
+ Some(end_bound) =>
WindowFrameBound::Following(ScalarValue::UInt64(Some(end_bound))),
+ None => match units {
+ WindowFrameUnits::Rows =>
WindowFrameBound::Following(ScalarValue::UInt64(None)),
+ WindowFrameUnits::Range =>
WindowFrameBound::Following(ScalarValue::UInt64(None)),
+ WindowFrameUnits::Groups => {
+ return
Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
+ "{:?}",
+ units,
+ ))));
+ }
+ },
+ };
+ Ok(PyWindowFrame {
+ frame: WindowFrame {
+ units,
+ start_bound,
+ end_bound,
+ },
+ })
+ }
+
+ /// Get a String representation of this window frame
+ fn __repr__(&self) -> String {
+ format!("{}", self)
+ }
+}