alamb commented on code in PR #6617: URL: https://github.com/apache/arrow-datafusion/pull/6617#discussion_r1228612672
########## datafusion-examples/examples/simple_udwf.rs: ########## @@ -0,0 +1,210 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::{ + array::{AsArray, Float64Array}, + datatypes::Float64Type, +}; +use arrow_schema::DataType; +use datafusion::datasource::file_format::options::CsvReadOptions; + +use datafusion::error::Result; +use datafusion::prelude::*; +use datafusion_common::DataFusionError; +use datafusion_expr::{ + partition_evaluator::PartitionEvaluator, Signature, Volatility, WindowUDF, +}; + +// create local execution context with `cars.csv` registered as a table named `cars` +async fn create_context() -> Result<SessionContext> { + // declare a new context. In spark API, this corresponds to a new spark SQLsession + let ctx = SessionContext::new(); + + // declare a table in memory. In spark API, this corresponds to createDataFrame(...). + println!("pwd: {}", std::env::current_dir().unwrap().display()); + let csv_path = format!("datafusion/core/tests/data/cars.csv"); + let read_options = CsvReadOptions::default().has_header(true); + + ctx.register_csv("cars", &csv_path, read_options).await?; + Ok(ctx) +} + +/// In this example we will declare a user defined window function that computes a moving average and then run it using SQL +#[tokio::main] +async fn main() -> Result<()> { + let ctx = create_context().await?; + + // register the window function with DataFusion so wecan call it + ctx.register_udwf(my_average()); + + // Use SQL to run the new window function + let df = ctx.sql("SELECT * from cars").await?; + // print the results + df.show().await?; + + // Use SQL to run the new window function + // `PARTITION BY car`:each distinct value of car (red, and green) should be treated separately + // `ORDER BY time`: within each group (greed or green) the values will be orderd by time + let df = ctx + .sql( + "SELECT car, \ + speed, \ + lag(speed, 1) OVER (PARTITION BY car ORDER BY time),\ + my_average(speed) OVER (PARTITION BY car ORDER BY time),\ + time \ + from cars", + ) + .await?; + // print the results + df.show().await?; + + // // ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING: Run the window functon so that each invocation only sees 5 rows: the 2 before and 2 after) using + // let df = ctx.sql("SELECT car, \ + // speed, \ + // lag(speed, 1) OVER (PARTITION BY car ORDER BY time ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING),\ + // time \ + // from cars").await?; + // // print the results + // df.show().await?; + + // todo show how to run dataframe API as well + + Ok(()) +} + +// TODO make a helper funciton like `crate_udf` that helps to make these signatures + +fn my_average() -> WindowUDF { + WindowUDF { + name: String::from("my_average"), + // it will take 2 arguments -- the column and the window size + signature: Signature::exact(vec![DataType::Int32], Volatility::Immutable), + return_type: Arc::new(return_type), + partition_evaluator: Arc::new(make_partition_evaluator), + } +} + +/// Compute the return type of the function given the argument types +fn return_type(arg_types: &[DataType]) -> Result<Arc<DataType>> { + if arg_types.len() != 1 { + return Err(DataFusionError::Plan(format!( + "my_udwf expects 1 argument, got {}: {:?}", + arg_types.len(), + arg_types + ))); + } + Ok(Arc::new(arg_types[0].clone())) +} + +/// Create a partition evaluator for this argument +fn make_partition_evaluator() -> Result<Box<dyn PartitionEvaluator>> { + Ok(Box::new(MyPartitionEvaluator::new())) +} Review Comment: @stuartcarnie I looked into adding the arguments. The primary issue I encountered is that a `WindowUDF` is specified in terms of structures in `datafusion-expr` ( aka it doesn't have access to `PhysicalExpr`s as those are defined in a different crate. Here are some possible signatures we could provide. Do you have any feedback on these possibilities? # Pass in the `Expr`s from the logical plan This is non ideal in my mind as the PartitionEvaluator is created during execution (where the `Expr`s are normally not around anymore) ```rust /// Factory that creates a PartitionEvaluator for the given window function. /// /// This function is passed its input arguments so that cases such as /// constants can be correctly handled. pub type PartitionEvaluatorFunctionFactory = Arc<dyn Fn(&[Expr]) -> Result<Box<dyn PartitionEvaluator>> + Send + Sync>; ``` # Pass in a `ArgType` enum This is also non ideal in my mind as it seemingly artificially limits what the user defined window function can special case (why not Column's for example??) ```rust enum ArgType { /// The argument was a single value Scalar(ScalarValue), /// the argument is something other than a single value Array } /// Factory that creates a PartitionEvaluator for the given window function. /// /// This function is passed its input arguments so that cases such as /// constants can be specially handled if desired. pub type PartitionEvaluatorFunctionFactory = Arc<dyn Fn(args: Vec<ArgType>) -> Result<Box<dyn PartitionEvaluator>> + Send + Sync>; ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org