This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new d553ffdff8 Improve async_udf example and docs (#16846)
d553ffdff8 is described below
commit d553ffdff88ff62fc0cd29d5bb924771e7c6c904
Author: Andrew Lamb <[email protected]>
AuthorDate: Thu Jul 24 11:59:03 2025 -0400
Improve async_udf example and docs (#16846)
* Improve async_udf example and docs
* tweak
* Remove random monospace async and version note
* Fix explain plan diff by hard coding parallelism
* rename arguments, use as_string_view_array
* request --> reqwest
---
datafusion-examples/README.md | 1 +
datafusion-examples/examples/async_udf.rs | 302 ++++++++++-----------
datafusion/core/src/execution/context/mod.rs | 2 +-
.../library-user-guide/functions/adding-udfs.md | 65 +++--
4 files changed, 178 insertions(+), 192 deletions(-)
diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md
index ef5687c867..02f83b9bd0 100644
--- a/datafusion-examples/README.md
+++ b/datafusion-examples/README.md
@@ -50,6 +50,7 @@ cargo run --example dataframe
- [`advanced_udf.rs`](examples/advanced_udf.rs): Define and invoke a more
complicated User Defined Scalar Function (UDF)
- [`advanced_udwf.rs`](examples/advanced_udwf.rs): Define and invoke a more
complicated User Defined Window Function (UDWF)
- [`advanced_parquet_index.rs`](examples/advanced_parquet_index.rs): Creates a
detailed secondary index that covers the contents of several parquet files
+- [`async_udf.rs`](examples/async_udf.rs): Define and invoke an asynchronous
User Defined Scalar Function (UDF)
- [`analyzer_rule.rs`](examples/analyzer_rule.rs): Use a custom AnalyzerRule
to change a query's semantics (row level access control)
- [`catalog.rs`](examples/catalog.rs): Register the table into a custom catalog
- [`composed_extension_codec`](examples/composed_extension_codec.rs): Example
of using multiple extension codecs for serialization / deserialization
diff --git a/datafusion-examples/examples/async_udf.rs
b/datafusion-examples/examples/async_udf.rs
index 3037a971df..f1fc3f8885 100644
--- a/datafusion-examples/examples/async_udf.rs
+++ b/datafusion-examples/examples/async_udf.rs
@@ -15,104 +15,104 @@
// specific language governing permissions and limitations
// under the License.
-use arrow::array::{ArrayIter, ArrayRef, AsArray, Int64Array, RecordBatch,
StringArray};
-use arrow::compute::kernels::cmp::eq;
+//! This example shows how to create and use "Async UDFs" in DataFusion.
+//!
+//! Async UDFs allow you to perform asynchronous operations, such as
+//! making network requests. This can be used for tasks like fetching
+//! data from an external API such as a LLM service or an external database.
+
+use arrow::array::{ArrayRef, BooleanArray, Int64Array, RecordBatch,
StringArray};
use arrow_schema::{DataType, Field, Schema};
use async_trait::async_trait;
+use datafusion::assert_batches_eq;
+use datafusion::common::cast::as_string_view_array;
use datafusion::common::error::Result;
-use datafusion::common::types::{logical_int64, logical_string};
+use datafusion::common::not_impl_err;
use datafusion::common::utils::take_function_args;
-use datafusion::common::{internal_err, not_impl_err};
use datafusion::config::ConfigOptions;
+use datafusion::execution::SessionStateBuilder;
use datafusion::logical_expr::async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl};
use datafusion::logical_expr::{
- ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
- TypeSignatureClass, Volatility,
+ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
};
-use datafusion::logical_expr_common::signature::Coercion;
-use datafusion::physical_expr_common::datum::apply_cmp;
-use datafusion::prelude::SessionContext;
-use log::trace;
+use datafusion::prelude::{SessionConfig, SessionContext};
use std::any::Any;
use std::sync::Arc;
#[tokio::main]
async fn main() -> Result<()> {
- let ctx: SessionContext = SessionContext::new();
-
- let async_upper = AsyncUpper::new();
- let udf = AsyncScalarUDF::new(Arc::new(async_upper));
- ctx.register_udf(udf.into_scalar_udf());
- let async_equal = AsyncEqual::new();
+ // Use a hard coded parallelism level of 4 so the explain plan
+ // is consistent across machines.
+ let config = SessionConfig::new().with_target_partitions(4);
+ let ctx =
+
SessionContext::from(SessionStateBuilder::new().with_config(config).build());
+
+ // Similarly to regular UDFs, you create an AsyncScalarUDF by implementing
+ // `AsyncScalarUDFImpl` and creating an instance of `AsyncScalarUDF`.
+ let async_equal = AskLLM::new();
let udf = AsyncScalarUDF::new(Arc::new(async_equal));
+
+ // Async UDFs are registered with the SessionContext, using the same
+ // `register_udf` method as regular UDFs.
ctx.register_udf(udf.into_scalar_udf());
+
+ // Create a table named 'animal' with some sample data
ctx.register_batch("animal", animal()?)?;
- // use Async UDF in the projection
- //
+---------------+----------------------------------------------------------------------------------------+
- // | plan_type | plan
|
- //
+---------------+----------------------------------------------------------------------------------------+
- // | logical_plan | Projection: async_equal(a.id, Int64(1))
|
- // | | SubqueryAlias: a
|
- // | | TableScan: animal projection=[id]
|
- // | physical_plan | ProjectionExec: expr=[__async_fn_0@1 as
async_equal(a.id,Int64(1))] |
- // | | AsyncFuncExec:
async_expr=[async_expr(name=__async_fn_0, expr=async_equal(id@0, 1))] |
- // | | CoalesceBatchesExec: target_batch_size=8192
|
- // | | DataSourceExec: partitions=1,
partition_sizes=[1] |
- // | |
|
- //
+---------------+----------------------------------------------------------------------------------------+
- ctx.sql("explain select async_equal(a.id, 1) from animal a")
+ // You can use the async UDF as normal in SQL queries
+ //
+ // Note: Async UDFs can currently be used in the select list and filter
conditions.
+ let results = ctx
+ .sql("select * from animal a where ask_llm(a.name, 'Is this animal
furry?')")
.await?
- .show()
+ .collect()
.await?;
- // +----------------------------+
- // | async_equal(a.id,Int64(1)) |
- // +----------------------------+
- // | true |
- // | false |
- // | false |
- // | false |
- // | false |
- // +----------------------------+
- ctx.sql("select async_equal(a.id, 1) from animal a")
+ assert_batches_eq!(
+ [
+ "+----+------+",
+ "| id | name |",
+ "+----+------+",
+ "| 1 | cat |",
+ "| 2 | dog |",
+ "+----+------+",
+ ],
+ &results
+ );
+
+ // While the interface is the same for both normal and async UDFs, you can
+ // use `EXPLAIN` output to see that the async UDF uses a special
+ // `AsyncFuncExec` node in the physical plan:
+ let results = ctx
+ .sql("explain select * from animal a where ask_llm(a.name, 'Is this
animal furry?')")
.await?
- .show()
+ .collect()
.await?;
- // use Async UDF in the filter
- //
+---------------+--------------------------------------------------------------------------------------------+
- // | plan_type | plan
|
- //
+---------------+--------------------------------------------------------------------------------------------+
- // | logical_plan | SubqueryAlias: a
|
- // | | Filter: async_equal(animal.id, Int64(1))
|
- // | | TableScan: animal projection=[id, name]
|
- // | physical_plan | CoalesceBatchesExec: target_batch_size=8192
|
- // | | FilterExec: __async_fn_0@2, projection=[id@0,
name@1] |
- // | | RepartitionExec:
partitioning=RoundRobinBatch(12), input_partitions=1 |
- // | | AsyncFuncExec:
async_expr=[async_expr(name=__async_fn_0, expr=async_equal(id@0, 1))] |
- // | | CoalesceBatchesExec: target_batch_size=8192
|
- // | | DataSourceExec: partitions=1,
partition_sizes=[1] |
- // | |
|
- //
+---------------+--------------------------------------------------------------------------------------------+
- ctx.sql("explain select * from animal a where async_equal(a.id, 1)")
- .await?
- .show()
- .await?;
-
- // +----+------+
- // | id | name |
- // +----+------+
- // | 1 | cat |
- // +----+------+
- ctx.sql("select * from animal a where async_equal(a.id, 1)")
- .await?
- .show()
- .await?;
+ assert_batches_eq!(
+ [
+
"+---------------+--------------------------------------------------------------------------------------------------------------------------------+",
+ "| plan_type | plan
|",
+
"+---------------+--------------------------------------------------------------------------------------------------------------------------------+",
+ "| logical_plan | SubqueryAlias: a
|",
+ "| | Filter: ask_llm(CAST(animal.name AS Utf8View),
Utf8View(\"Is this animal furry?\"))
|",
+ "| | TableScan: animal projection=[id, name]
|",
+ "| physical_plan | CoalesceBatchesExec: target_batch_size=8192
|",
+ "| | FilterExec: __async_fn_0@2, projection=[id@0, name@1]
|",
+ "| | RepartitionExec: partitioning=RoundRobinBatch(4),
input_partitions=1 |",
+ "| | AsyncFuncExec:
async_expr=[async_expr(name=__async_fn_0, expr=ask_llm(CAST(name@1 AS
Utf8View), Is this animal furry?))] |",
+ "| | CoalesceBatchesExec: target_batch_size=8192
|",
+ "| | DataSourceExec: partitions=1,
partition_sizes=[1]
|",
+ "| |
|",
+
"+---------------+--------------------------------------------------------------------------------------------------------------------------------+",
+ ],
+ &results
+ );
Ok(())
}
+/// Returns a sample `RecordBatch` representing an "animal" table with two
columns:
fn animal() -> Result<RecordBatch> {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
@@ -127,118 +127,45 @@ fn animal() -> Result<RecordBatch> {
Ok(RecordBatch::try_new(schema, vec![id_array, name_array])?)
}
+/// An async UDF that simulates asking a large language model (LLM) service a
+/// question based on the content of two columns. The UDF will return a boolean
+/// indicating whether the LLM thinks the first argument matches the question
in
+/// the second argument.
+///
+/// Since this is a simplified example, it does not call an LLM service, but
+/// could be extended to do so in a real-world scenario.
#[derive(Debug)]
-pub struct AsyncUpper {
- signature: Signature,
-}
-
-impl Default for AsyncUpper {
- fn default() -> Self {
- Self::new()
- }
-}
-
-impl AsyncUpper {
- pub fn new() -> Self {
- Self {
- signature: Signature::new(
- TypeSignature::Coercible(vec![Coercion::Exact {
- desired_type: TypeSignatureClass::Native(logical_string()),
- }]),
- Volatility::Volatile,
- ),
- }
- }
-}
-
-#[async_trait]
-impl ScalarUDFImpl for AsyncUpper {
- fn as_any(&self) -> &dyn Any {
- self
- }
-
- fn name(&self) -> &str {
- "async_upper"
- }
-
- fn signature(&self) -> &Signature {
- &self.signature
- }
-
- fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
- Ok(DataType::Utf8)
- }
-
- fn invoke_with_args(&self, _args: ScalarFunctionArgs) ->
Result<ColumnarValue> {
- not_impl_err!("AsyncUpper can only be called from async contexts")
- }
-}
-
-#[async_trait]
-impl AsyncScalarUDFImpl for AsyncUpper {
- fn ideal_batch_size(&self) -> Option<usize> {
- Some(10)
- }
-
- async fn invoke_async_with_args(
- &self,
- args: ScalarFunctionArgs,
- _option: &ConfigOptions,
- ) -> Result<ArrayRef> {
- trace!("Invoking async_upper with args: {:?}", args);
- let value = &args.args[0];
- let result = match value {
- ColumnarValue::Array(array) => {
- let string_array = array.as_string::<i32>();
- let iter = ArrayIter::new(string_array);
- let result = iter
- .map(|string| string.map(|s| s.to_uppercase()))
- .collect::<StringArray>();
- Arc::new(result) as ArrayRef
- }
- _ => return internal_err!("Expected a string argument, got {:?}",
value),
- };
- Ok(result)
- }
-}
-
-#[derive(Debug)]
-struct AsyncEqual {
+struct AskLLM {
signature: Signature,
}
-impl Default for AsyncEqual {
+impl Default for AskLLM {
fn default() -> Self {
Self::new()
}
}
-impl AsyncEqual {
+impl AskLLM {
pub fn new() -> Self {
Self {
- signature: Signature::new(
- TypeSignature::Coercible(vec![
- Coercion::Exact {
- desired_type:
TypeSignatureClass::Native(logical_int64()),
- },
- Coercion::Exact {
- desired_type:
TypeSignatureClass::Native(logical_int64()),
- },
- ]),
+ signature: Signature::exact(
+ vec![DataType::Utf8View, DataType::Utf8View],
Volatility::Volatile,
),
}
}
}
-#[async_trait]
-impl ScalarUDFImpl for AsyncEqual {
+/// All async UDFs implement the `ScalarUDFImpl` trait, which provides the
basic
+/// information for the function, such as its name, signature, and return type.
+/// [async_trait]
+impl ScalarUDFImpl for AskLLM {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
- "async_equal"
+ "ask_llm"
}
fn signature(&self) -> &Signature {
@@ -249,19 +176,64 @@ impl ScalarUDFImpl for AsyncEqual {
Ok(DataType::Boolean)
}
+ /// Since this is an async UDF, the `invoke_with_args` method will not be
+ /// called directly.
fn invoke_with_args(&self, _args: ScalarFunctionArgs) ->
Result<ColumnarValue> {
- not_impl_err!("AsyncEqual can only be called from async contexts")
+ not_impl_err!("AskLLM can only be called from async contexts")
}
}
+/// In addition to [`ScalarUDFImpl`], we also need to implement the
+/// [`AsyncScalarUDFImpl`] trait.
#[async_trait]
-impl AsyncScalarUDFImpl for AsyncEqual {
+impl AsyncScalarUDFImpl for AskLLM {
+ /// The `invoke_async_with_args` method is similar to `invoke_with_args`,
+ /// but it returns a `Future` that resolves to the result.
+ ///
+ /// Since this signature is `async`, it can do any `async` operations, such
+ /// as network requests. This method is run on the same tokio `Runtime`
that
+ /// is processing the query, so you may wish to make actual network
requests
+ /// on a different `Runtime`, as explained in the `thread_pools.rs` example
+ /// in this directory.
async fn invoke_async_with_args(
&self,
args: ScalarFunctionArgs,
_option: &ConfigOptions,
) -> Result<ArrayRef> {
- let [arg1, arg2] = take_function_args(self.name(), &args.args)?;
- apply_cmp(arg1, arg2, eq)?.to_array(args.number_rows)
+ // in a real UDF you would likely want to special case constant
+ // arguments to improve performance, but this example converts the
+ // arguments to arrays for simplicity.
+ let args = ColumnarValue::values_to_arrays(&args.args)?;
+ let [content_column, question_column] =
take_function_args(self.name(), args)?;
+
+ // In a real function, you would use a library such as `reqwest` here
to
+ // make an async HTTP request. Credentials and other configurations can
+ // be supplied via the `ConfigOptions` parameter.
+
+ // In this example, we will simulate the LLM response by comparing the
two
+ // input arguments using some static strings
+ let content_column = as_string_view_array(&content_column)?;
+ let question_column = as_string_view_array(&question_column)?;
+
+ let result_array: BooleanArray = content_column
+ .iter()
+ .zip(question_column.iter())
+ .map(|(a, b)| {
+ // If either value is null, return None
+ let a = a?;
+ let b = b?;
+ // Simulate an LLM response by checking the arguments to some
+ // hardcoded conditions.
+ if a.contains("cat") && b.contains("furry")
+ || a.contains("dog") && b.contains("furry")
+ {
+ Some(true)
+ } else {
+ Some(false)
+ }
+ })
+ .collect();
+
+ Ok(Arc::new(result_array))
}
}
diff --git a/datafusion/core/src/execution/context/mod.rs
b/datafusion/core/src/execution/context/mod.rs
index dbe5c2c00f..ea8850d3b6 100644
--- a/datafusion/core/src/execution/context/mod.rs
+++ b/datafusion/core/src/execution/context/mod.rs
@@ -226,7 +226,7 @@ where
/// # use datafusion::execution::SessionStateBuilder;
/// # use datafusion_execution::runtime_env::RuntimeEnvBuilder;
/// // Configure a 4k batch size
-/// let config = SessionConfig::new() .with_batch_size(4 * 1024);
+/// let config = SessionConfig::new().with_batch_size(4 * 1024);
///
/// // configure a memory limit of 1GB with 20% slop
/// let runtime_env = RuntimeEnvBuilder::new()
diff --git a/docs/source/library-user-guide/functions/adding-udfs.md
b/docs/source/library-user-guide/functions/adding-udfs.md
index cf5624f68d..5c95cb3301 100644
--- a/docs/source/library-user-guide/functions/adding-udfs.md
+++ b/docs/source/library-user-guide/functions/adding-udfs.md
@@ -23,13 +23,22 @@ User Defined Functions (UDFs) are functions that can be
used in the context of D
This page covers how to add UDFs to DataFusion. In particular, it covers how
to add Scalar, Window, and Aggregate UDFs.
-| UDF Type | Description
| Example |
-| ------------ |
--------------------------------------------------------------------------------------------------------------------------------------------------------
| ------------------- |
-| Scalar | A function that takes a row of data and returns a single
value.
| [simple_udf.rs][1] |
-| Window | A function that takes a row of data and returns a single
value, but also has access to the rows around it.
| [simple_udwf.rs][2] |
-| Aggregate | A function that takes a group of rows and returns a single
value.
| [simple_udaf.rs][3] |
-| Table | A function that takes parameters and returns a
`TableProvider` to be used in an query plan.
| [simple_udtf.rs][4] |
-| Async Scalar | A scalar function that natively supports asynchronous
execution, allowing you to perform async operations (such as network or I/O
calls) within the UDF. | [async_udf.rs][5] |
+| UDF Type | Description
| Example(s)
|
+| -------------- |
----------------------------------------------------------------------------------------------------------
| ------------------------------------- |
+| Scalar | A function that takes a row of data and returns a single
value. | [simple_udf.rs] /
[advanced_udf.rs] |
+| Window | A function that takes a row of data and returns a single
value, but also has access to the rows around it. | [simple_udwf.rs] /
[advanced_udwf.rs] |
+| Aggregate | A function that takes a group of rows and returns a single
value. | [simple_udaf.rs] /
[advanced_udaf.rs] |
+| Table | A function that takes parameters and returns a
`TableProvider` to be used in an query plan. | [simple_udtf.rs]
|
+| Scalar (async) | A scalar function for performing `async` operations (such
as network or I/O calls) within the UDF. | [async_udf.rs]
|
+
+[simple_udf.rs]:
https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udf.rs
+[advanced_udf.rs]:
https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs
+[simple_udwf.rs]:
https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udwf.rs
+[advanced_udwf.rs]:
https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udwf.rs
+[simple_udaf.rs]:
https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udaf.rs
+[advanced_udaf.rs]:
https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs
+[simple_udtf.rs]:
https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udtf.rs
+[async_udf.rs]:
https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/async_udf.rs
First we'll talk about adding an Scalar UDF end-to-end, then we'll talk about
the differences between the different
types of UDFs.
@@ -345,9 +354,9 @@ async fn main() {
}
```
-## Adding a Scalar Async UDF
+## Adding a Async Scalar UDF
-A Scalar Async UDF allows you to implement user-defined functions that support
+An Async Scalar UDF allows you to implement user-defined functions that support
asynchronous execution, such as performing network or I/O operations within the
UDF.
@@ -359,22 +368,21 @@ To add a Scalar Async UDF, you need to:
### Adding by `impl AsyncScalarUDFImpl`
```rust
-use arrow::array::{ArrayIter, ArrayRef, AsArray, StringArray};
-use arrow_schema::DataType;
-use async_trait::async_trait;
-use datafusion::common::error::Result;
-use datafusion::common::{internal_err, not_impl_err};
-use datafusion::common::types::logical_string;
-use datafusion::config::ConfigOptions;
-use datafusion_expr::ScalarUDFImpl;
-use datafusion::logical_expr::async_udf::AsyncScalarUDFImpl;
-use datafusion::logical_expr::{
- ColumnarValue, Signature, TypeSignature, TypeSignatureClass, Volatility,
ScalarFunctionArgs
-};
-use datafusion::logical_expr_common::signature::Coercion;
-use log::trace;
-use std::any::Any;
-use std::sync::Arc;
+# use arrow::array::{ArrayIter, ArrayRef, AsArray, StringArray};
+# use arrow_schema::DataType;
+# use async_trait::async_trait;
+# use datafusion::common::error::Result;
+# use datafusion::common::{internal_err, not_impl_err};
+# use datafusion::common::types::logical_string;
+# use datafusion::config::ConfigOptions;
+# use datafusion_expr::ScalarUDFImpl;
+# use datafusion::logical_expr::async_udf::AsyncScalarUDFImpl;
+# use datafusion::logical_expr::{
+# ColumnarValue, Signature, TypeSignature, TypeSignatureClass, Volatility,
ScalarFunctionArgs
+# };
+# use datafusion::logical_expr_common::signature::Coercion;
+# use std::any::Any;
+# use std::sync::Arc;
#[derive(Debug)]
pub struct AsyncUpper {
@@ -419,6 +427,7 @@ impl ScalarUDFImpl for AsyncUpper {
Ok(DataType::Utf8)
}
+ // Note the normal invoke_with_args method is not called for Async UDFs
fn invoke_with_args(
&self,
_args: ScalarFunctionArgs,
@@ -434,13 +443,17 @@ impl AsyncScalarUDFImpl for AsyncUpper {
Some(10)
}
+ /// This method is called to execute the async UDF and is similar
+ /// to the normal `invoke_with_args` except it returns an `ArrayRef`
+ /// instead of `ColumnarValue` and is `async`.
async fn invoke_async_with_args(
&self,
args: ScalarFunctionArgs,
_option: &ConfigOptions,
) -> Result<ArrayRef> {
- trace!("Invoking async_upper with args: {:?}", args);
let value = &args.args[0];
+ // This function simply implements a simple string to uppercase
conversion
+ // but can be used for any async operation such as network calls.
let result = match value {
ColumnarValue::Array(array) => {
let string_array = array.as_string::<i32>();
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]