This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 36f6e0facc Use PhysicalExtensionCodec consistently (#10075)
36f6e0facc is described below
commit 36f6e0facc8cdbded3e595cbcb441190d055e56d
Author: Georgi Krastev <[email protected]>
AuthorDate: Mon Apr 15 13:27:03 2024 +0300
Use PhysicalExtensionCodec consistently (#10075)
* Use PhysicalExtensionCodec consistently
* Use PhysicalExtensionCodec consisdently also when serializing
* Add a test for window aggregation with UDF codec
* Commit binary incompatible changes
---
datafusion/proto/src/physical_plan/from_proto.rs | 71 +--
datafusion/proto/src/physical_plan/mod.rs | 93 ++--
datafusion/proto/src/physical_plan/to_proto.rs | 504 +++++++++------------
.../proto/tests/cases/roundtrip_physical_plan.rs | 128 +++---
4 files changed, 402 insertions(+), 394 deletions(-)
diff --git a/datafusion/proto/src/physical_plan/from_proto.rs
b/datafusion/proto/src/physical_plan/from_proto.rs
index aaca4dc482..81e4c92ffc 100644
--- a/datafusion/proto/src/physical_plan/from_proto.rs
+++ b/datafusion/proto/src/physical_plan/from_proto.rs
@@ -21,13 +21,11 @@ use std::collections::HashMap;
use std::convert::{TryFrom, TryInto};
use std::sync::Arc;
-use crate::common::proto_error;
-use crate::convert_required;
-use crate::logical_plan::{self, csv_writer_options_from_proto};
-use crate::protobuf::physical_expr_node::ExprType;
-use crate::protobuf::{self, copy_to_node};
-
use arrow::compute::SortOptions;
+use chrono::{TimeZone, Utc};
+use object_store::path::Path;
+use object_store::ObjectMeta;
+
use datafusion::arrow::datatypes::Schema;
use datafusion::datasource::file_format::csv::CsvSink;
use datafusion::datasource::file_format::json::JsonSink;
@@ -57,13 +55,15 @@ use
datafusion_common::file_options::json_writer::JsonWriterOptions;
use datafusion_common::parsers::CompressionTypeVariant;
use datafusion_common::stats::Precision;
use datafusion_common::{not_impl_err, DataFusionError, JoinSide, Result,
ScalarValue};
-
-use chrono::{TimeZone, Utc};
use datafusion_expr::ScalarFunctionDefinition;
-use object_store::path::Path;
-use object_store::ObjectMeta;
-use super::{DefaultPhysicalExtensionCodec, PhysicalExtensionCodec};
+use crate::common::proto_error;
+use crate::convert_required;
+use crate::logical_plan::{self, csv_writer_options_from_proto};
+use crate::protobuf::physical_expr_node::ExprType;
+use crate::protobuf::{self, copy_to_node};
+
+use super::PhysicalExtensionCodec;
impl From<&protobuf::PhysicalColumn> for Column {
fn from(c: &protobuf::PhysicalColumn) -> Column {
@@ -76,9 +76,10 @@ impl From<&protobuf::PhysicalColumn> for Column {
/// # Arguments
///
/// * `proto` - Input proto with physical sort expression node
-/// * `registry` - A registry knows how to build logical expressions out of
user-defined function' names
+/// * `registry` - A registry knows how to build logical expressions out of
user-defined function names
/// * `input_schema` - The Arrow schema for the input, used for determining
expression data types
/// when performing type coercion.
+/// * `codec` - An extension codec used to decode custom UDFs.
pub fn parse_physical_sort_expr(
proto: &protobuf::PhysicalSortExprNode,
registry: &dyn FunctionRegistry,
@@ -102,9 +103,10 @@ pub fn parse_physical_sort_expr(
/// # Arguments
///
/// * `proto` - Input proto with vector of physical sort expression node
-/// * `registry` - A registry knows how to build logical expressions out of
user-defined function' names
+/// * `registry` - A registry knows how to build logical expressions out of
user-defined function names
/// * `input_schema` - The Arrow schema for the input, used for determining
expression data types
/// when performing type coercion.
+/// * `codec` - An extension codec used to decode custom UDFs.
pub fn parse_physical_sort_exprs(
proto: &[protobuf::PhysicalSortExprNode],
registry: &dyn FunctionRegistry,
@@ -123,25 +125,26 @@ pub fn parse_physical_sort_exprs(
///
/// # Arguments
///
-/// * `proto` - Input proto with physical window exprression node.
+/// * `proto` - Input proto with physical window expression node.
/// * `name` - Name of the window expression.
-/// * `registry` - A registry knows how to build logical expressions out of
user-defined function' names
+/// * `registry` - A registry knows how to build logical expressions out of
user-defined function names
/// * `input_schema` - The Arrow schema for the input, used for determining
expression data types
/// when performing type coercion.
+/// * `codec` - An extension codec used to decode custom UDFs.
pub fn parse_physical_window_expr(
proto: &protobuf::PhysicalWindowExprNode,
registry: &dyn FunctionRegistry,
input_schema: &Schema,
+ codec: &dyn PhysicalExtensionCodec,
) -> Result<Arc<dyn WindowExpr>> {
- let codec = DefaultPhysicalExtensionCodec {};
let window_node_expr =
- parse_physical_exprs(&proto.args, registry, input_schema, &codec)?;
+ parse_physical_exprs(&proto.args, registry, input_schema, codec)?;
let partition_by =
- parse_physical_exprs(&proto.partition_by, registry, input_schema,
&codec)?;
+ parse_physical_exprs(&proto.partition_by, registry, input_schema,
codec)?;
let order_by =
- parse_physical_sort_exprs(&proto.order_by, registry, input_schema,
&codec)?;
+ parse_physical_sort_exprs(&proto.order_by, registry, input_schema,
codec)?;
let window_frame = proto
.window_frame
@@ -187,9 +190,10 @@ where
/// # Arguments
///
/// * `proto` - Input proto with physical expression node
-/// * `registry` - A registry knows how to build logical expressions out of
user-defined function' names
+/// * `registry` - A registry knows how to build logical expressions out of
user-defined function names
/// * `input_schema` - The Arrow schema for the input, used for determining
expression data types
/// when performing type coercion.
+/// * `codec` - An extension codec used to decode custom UDFs.
pub fn parse_physical_expr(
proto: &protobuf::PhysicalExprNode,
registry: &dyn FunctionRegistry,
@@ -213,6 +217,7 @@ pub fn parse_physical_expr(
registry,
"left",
input_schema,
+ codec,
)?,
logical_plan::from_proto::from_proto_binary_op(&binary_expr.op)?,
parse_required_physical_expr(
@@ -220,6 +225,7 @@ pub fn parse_physical_expr(
registry,
"right",
input_schema,
+ codec,
)?,
)),
ExprType::AggregateExpr(_) => {
@@ -241,6 +247,7 @@ pub fn parse_physical_expr(
registry,
"expr",
input_schema,
+ codec,
)?))
}
ExprType::IsNotNullExpr(e) => {
@@ -249,6 +256,7 @@ pub fn parse_physical_expr(
registry,
"expr",
input_schema,
+ codec,
)?))
}
ExprType::NotExpr(e) =>
Arc::new(NotExpr::new(parse_required_physical_expr(
@@ -256,6 +264,7 @@ pub fn parse_physical_expr(
registry,
"expr",
input_schema,
+ codec,
)?)),
ExprType::Negative(e) => {
Arc::new(NegativeExpr::new(parse_required_physical_expr(
@@ -263,6 +272,7 @@ pub fn parse_physical_expr(
registry,
"expr",
input_schema,
+ codec,
)?))
}
ExprType::InList(e) => in_list(
@@ -271,6 +281,7 @@ pub fn parse_physical_expr(
registry,
"expr",
input_schema,
+ codec,
)?,
parse_physical_exprs(&e.list, registry, input_schema, codec)?,
&e.negated,
@@ -290,12 +301,14 @@ pub fn parse_physical_expr(
registry,
"when_expr",
input_schema,
+ codec,
)?,
parse_required_physical_expr(
e.then_expr.as_ref(),
registry,
"then_expr",
input_schema,
+ codec,
)?,
))
})
@@ -311,6 +324,7 @@ pub fn parse_physical_expr(
registry,
"expr",
input_schema,
+ codec,
)?,
convert_required!(e.arrow_type)?,
None,
@@ -321,6 +335,7 @@ pub fn parse_physical_expr(
registry,
"expr",
input_schema,
+ codec,
)?,
convert_required!(e.arrow_type)?,
)),
@@ -371,12 +386,14 @@ pub fn parse_physical_expr(
registry,
"expr",
input_schema,
+ codec,
)?,
parse_required_physical_expr(
like_expr.pattern.as_deref(),
registry,
"pattern",
input_schema,
+ codec,
)?,
)),
};
@@ -389,9 +406,9 @@ fn parse_required_physical_expr(
registry: &dyn FunctionRegistry,
field: &str,
input_schema: &Schema,
+ codec: &dyn PhysicalExtensionCodec,
) -> Result<Arc<dyn PhysicalExpr>> {
- let codec = DefaultPhysicalExtensionCodec {};
- expr.map(|e| parse_physical_expr(e, registry, input_schema, &codec))
+ expr.map(|e| parse_physical_expr(e, registry, input_schema, codec))
.transpose()?
.ok_or_else(|| {
DataFusionError::Internal(format!("Missing required field
{field:?}"))
@@ -433,15 +450,15 @@ pub fn parse_protobuf_hash_partitioning(
partitioning: Option<&protobuf::PhysicalHashRepartition>,
registry: &dyn FunctionRegistry,
input_schema: &Schema,
+ codec: &dyn PhysicalExtensionCodec,
) -> Result<Option<Partitioning>> {
match partitioning {
Some(hash_part) => {
- let codec = DefaultPhysicalExtensionCodec {};
let expr = parse_physical_exprs(
&hash_part.hash_expr,
registry,
input_schema,
- &codec,
+ codec,
)?;
Ok(Some(Partitioning::Hash(
@@ -456,6 +473,7 @@ pub fn parse_protobuf_hash_partitioning(
pub fn parse_protobuf_file_scan_config(
proto: &protobuf::FileScanExecConf,
registry: &dyn FunctionRegistry,
+ codec: &dyn PhysicalExtensionCodec,
) -> Result<FileScanConfig> {
let schema: Arc<Schema> = Arc::new(convert_required!(proto.schema)?);
let projection = proto
@@ -489,7 +507,7 @@ pub fn parse_protobuf_file_scan_config(
.collect::<Result<Vec<_>>>()?;
// Remove partition columns from the schema after recreating
table_partition_cols
- // because the partition columns are not in the file. They are present to
allow the
+ // because the partition columns are not in the file. They are present to
allow
// the partition column types to be reconstructed after serde.
let file_schema = Arc::new(Schema::new(
schema
@@ -502,12 +520,11 @@ pub fn parse_protobuf_file_scan_config(
let mut output_ordering = vec![];
for node_collection in &proto.output_ordering {
- let codec = DefaultPhysicalExtensionCodec {};
let sort_expr = parse_physical_sort_exprs(
&node_collection.physical_sort_expr_nodes,
registry,
&schema,
- &codec,
+ codec,
)?;
output_ordering.push(sort_expr);
}
diff --git a/datafusion/proto/src/physical_plan/mod.rs
b/datafusion/proto/src/physical_plan/mod.rs
index 4d95c847bf..a481e7090f 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -19,22 +19,8 @@ use std::convert::TryInto;
use std::fmt::Debug;
use std::sync::Arc;
-use self::from_proto::parse_physical_window_expr;
-use self::to_proto::serialize_physical_expr;
-
-use crate::common::{byte_to_string, proto_error, str_to_byte};
-use crate::convert_required;
-use crate::physical_plan::from_proto::{
- parse_physical_expr, parse_physical_sort_expr, parse_physical_sort_exprs,
- parse_protobuf_file_scan_config,
-};
-use crate::protobuf::physical_aggregate_expr_node::AggregateFunction;
-use crate::protobuf::physical_expr_node::ExprType;
-use crate::protobuf::physical_plan_node::PhysicalPlanType;
-use crate::protobuf::repartition_exec_node::PartitionMethod;
-use crate::protobuf::{
- self, window_agg_exec_node, PhysicalPlanNode,
PhysicalSortExprNodeCollection,
-};
+use prost::bytes::BufMut;
+use prost::Message;
use datafusion::arrow::compute::SortOptions;
use datafusion::arrow::datatypes::SchemaRef;
@@ -79,13 +65,28 @@ use datafusion::physical_plan::{
use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result};
use datafusion_expr::ScalarUDF;
-use prost::bytes::BufMut;
-use prost::Message;
+use crate::common::{byte_to_string, proto_error, str_to_byte};
+use crate::convert_required;
+use crate::physical_plan::from_proto::{
+ parse_physical_expr, parse_physical_sort_expr, parse_physical_sort_exprs,
+ parse_physical_window_expr, parse_protobuf_file_scan_config,
+};
+use crate::physical_plan::to_proto::{
+ serialize_file_scan_config, serialize_maybe_filter,
serialize_physical_aggr_expr,
+ serialize_physical_window_expr,
+};
+use crate::protobuf::physical_aggregate_expr_node::AggregateFunction;
+use crate::protobuf::physical_expr_node::ExprType;
+use crate::protobuf::physical_plan_node::PhysicalPlanType;
+use crate::protobuf::repartition_exec_node::PartitionMethod;
+use crate::protobuf::{self, window_agg_exec_node};
+
+use self::to_proto::serialize_physical_expr;
pub mod from_proto;
pub mod to_proto;
-impl AsExecutionPlan for PhysicalPlanNode {
+impl AsExecutionPlan for protobuf::PhysicalPlanNode {
fn try_decode(buf: &[u8]) -> Result<Self>
where
Self: Sized,
@@ -191,6 +192,7 @@ impl AsExecutionPlan for PhysicalPlanNode {
parse_protobuf_file_scan_config(
scan.base_conf.as_ref().unwrap(),
registry,
+ extension_codec,
)?,
scan.has_header,
str_to_byte(&scan.delimiter, "delimiter")?,
@@ -210,6 +212,7 @@ impl AsExecutionPlan for PhysicalPlanNode {
let base_config = parse_protobuf_file_scan_config(
scan.base_conf.as_ref().unwrap(),
registry,
+ extension_codec,
)?;
let predicate = scan
.predicate
@@ -234,6 +237,7 @@ impl AsExecutionPlan for PhysicalPlanNode {
Ok(Arc::new(AvroExec::new(parse_protobuf_file_scan_config(
scan.base_conf.as_ref().unwrap(),
registry,
+ extension_codec,
)?)))
}
PhysicalPlanType::CoalesceBatches(coalesce_batches) => {
@@ -338,6 +342,7 @@ impl AsExecutionPlan for PhysicalPlanNode {
window_expr,
registry,
input_schema.as_ref(),
+ extension_codec,
)
})
.collect::<Result<Vec<_>, _>>()?;
@@ -1449,14 +1454,17 @@ impl AsExecutionPlan for PhysicalPlanNode {
let filter = exec
.filter_expr()
.iter()
- .map(|expr| expr.to_owned().try_into())
+ .map(|expr| serialize_maybe_filter(expr.to_owned(),
extension_codec))
.collect::<Result<Vec<_>>>()?;
let agg = exec
.aggr_expr()
.iter()
- .map(|expr| expr.to_owned().try_into())
+ .map(|expr| {
+ serialize_physical_aggr_expr(expr.to_owned(),
extension_codec)
+ })
.collect::<Result<Vec<_>>>()?;
+
let agg_names = exec
.aggr_expr()
.iter()
@@ -1556,7 +1564,10 @@ impl AsExecutionPlan for PhysicalPlanNode {
return Ok(protobuf::PhysicalPlanNode {
physical_plan_type: Some(PhysicalPlanType::CsvScan(
protobuf::CsvScanExecNode {
- base_conf: Some(exec.base_config().try_into()?),
+ base_conf: Some(serialize_file_scan_config(
+ exec.base_config(),
+ extension_codec,
+ )?),
has_header: exec.has_header(),
delimiter: byte_to_string(exec.delimiter(),
"delimiter")?,
quote: byte_to_string(exec.quote(), "quote")?,
@@ -1581,7 +1592,10 @@ impl AsExecutionPlan for PhysicalPlanNode {
return Ok(protobuf::PhysicalPlanNode {
physical_plan_type: Some(PhysicalPlanType::ParquetScan(
protobuf::ParquetScanExecNode {
- base_conf: Some(exec.base_config().try_into()?),
+ base_conf: Some(serialize_file_scan_config(
+ exec.base_config(),
+ extension_codec,
+ )?),
predicate,
},
)),
@@ -1592,7 +1606,10 @@ impl AsExecutionPlan for PhysicalPlanNode {
return Ok(protobuf::PhysicalPlanNode {
physical_plan_type: Some(PhysicalPlanType::AvroScan(
protobuf::AvroScanExecNode {
- base_conf: Some(exec.base_config().try_into()?),
+ base_conf: Some(serialize_file_scan_config(
+ exec.base_config(),
+ extension_codec,
+ )?),
},
)),
});
@@ -1688,7 +1705,7 @@ impl AsExecutionPlan for PhysicalPlanNode {
}
if let Some(union) = plan.downcast_ref::<UnionExec>() {
- let mut inputs: Vec<PhysicalPlanNode> = vec![];
+ let mut inputs: Vec<protobuf::PhysicalPlanNode> = vec![];
for input in union.inputs() {
inputs.push(protobuf::PhysicalPlanNode::try_from_physical_plan(
input.to_owned(),
@@ -1703,7 +1720,7 @@ impl AsExecutionPlan for PhysicalPlanNode {
}
if let Some(interleave) = plan.downcast_ref::<InterleaveExec>() {
- let mut inputs: Vec<PhysicalPlanNode> = vec![];
+ let mut inputs: Vec<protobuf::PhysicalPlanNode> = vec![];
for input in interleave.inputs() {
inputs.push(protobuf::PhysicalPlanNode::try_from_physical_plan(
input.to_owned(),
@@ -1809,11 +1826,11 @@ impl AsExecutionPlan for PhysicalPlanNode {
extension_codec,
)?;
- let window_expr =
- exec.window_expr()
- .iter()
- .map(|e| e.clone().try_into())
-
.collect::<Result<Vec<protobuf::PhysicalWindowExprNode>>>()?;
+ let window_expr = exec
+ .window_expr()
+ .iter()
+ .map(|e| serialize_physical_window_expr(e.clone(),
extension_codec))
+ .collect::<Result<Vec<protobuf::PhysicalWindowExprNode>>>()?;
let partition_keys = exec
.partition_keys
@@ -1839,11 +1856,11 @@ impl AsExecutionPlan for PhysicalPlanNode {
extension_codec,
)?;
- let window_expr =
- exec.window_expr()
- .iter()
- .map(|e| e.clone().try_into())
-
.collect::<Result<Vec<protobuf::PhysicalWindowExprNode>>>()?;
+ let window_expr = exec
+ .window_expr()
+ .iter()
+ .map(|e| serialize_physical_window_expr(e.clone(),
extension_codec))
+ .collect::<Result<Vec<protobuf::PhysicalWindowExprNode>>>()?;
let partition_keys = exec
.partition_keys
@@ -1901,7 +1918,7 @@ impl AsExecutionPlan for PhysicalPlanNode {
Ok(sort_expr)
})
.collect::<Result<Vec<_>>>()?;
- Some(PhysicalSortExprNodeCollection {
+ Some(protobuf::PhysicalSortExprNodeCollection {
physical_sort_expr_nodes: expr,
})
}
@@ -2044,7 +2061,7 @@ impl PhysicalExtensionCodec for
DefaultPhysicalExtensionCodec {
}
fn into_physical_plan(
- node: &Option<Box<PhysicalPlanNode>>,
+ node: &Option<Box<protobuf::PhysicalPlanNode>>,
registry: &dyn FunctionRegistry,
runtime: &RuntimeEnv,
extension_codec: &dyn PhysicalExtensionCodec,
diff --git a/datafusion/proto/src/physical_plan/to_proto.rs
b/datafusion/proto/src/physical_plan/to_proto.rs
index e1574f48fb..b4c23e4d0c 100644
--- a/datafusion/proto/src/physical_plan/to_proto.rs
+++ b/datafusion/proto/src/physical_plan/to_proto.rs
@@ -22,18 +22,8 @@ use std::{
sync::Arc,
};
-use crate::protobuf::{
- self, copy_to_node, physical_aggregate_expr_node,
physical_window_expr_node,
- scalar_value::Value, ArrowOptions, AvroOptions, PhysicalSortExprNode,
- PhysicalSortExprNodeCollection, ScalarValue,
-};
-
#[cfg(feature = "parquet")]
use datafusion::datasource::file_format::parquet::ParquetSink;
-
-use datafusion_expr::ScalarFunctionDefinition;
-
-use crate::logical_plan::csv_writer_options_to_proto;
use datafusion::logical_expr::BuiltinScalarFunction;
use datafusion::physical_expr::window::{NthValueKind,
SlidingAggregateWindowExpr};
use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr};
@@ -71,195 +61,187 @@ use datafusion_common::{
stats::Precision,
DataFusionError, JoinSide, Result,
};
+use datafusion_expr::ScalarFunctionDefinition;
-use super::{DefaultPhysicalExtensionCodec, PhysicalExtensionCodec};
-
-impl TryFrom<Arc<dyn AggregateExpr>> for protobuf::PhysicalExprNode {
- type Error = DataFusionError;
-
- fn try_from(a: Arc<dyn AggregateExpr>) -> Result<Self, Self::Error> {
- let codec = DefaultPhysicalExtensionCodec {};
- let expressions = serialize_physical_exprs(a.expressions(), &codec)?;
-
- let ordering_req = a.order_bys().unwrap_or(&[]).to_vec();
- let ordering_req = serialize_physical_sort_exprs(ordering_req,
&codec)?;
-
- if let Some(a) = a.as_any().downcast_ref::<AggregateFunctionExpr>() {
- let name = a.fun().name().to_string();
- return Ok(protobuf::PhysicalExprNode {
- expr_type:
Some(protobuf::physical_expr_node::ExprType::AggregateExpr(
- protobuf::PhysicalAggregateExprNode {
- aggregate_function:
Some(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(name)),
- expr: expressions,
- ordering_req,
- distinct: false,
- },
- )),
- });
- }
+use crate::logical_plan::csv_writer_options_to_proto;
+use crate::protobuf::{
+ self, copy_to_node, physical_aggregate_expr_node,
physical_window_expr_node,
+ scalar_value::Value, ArrowOptions, AvroOptions, PhysicalSortExprNode,
+ PhysicalSortExprNodeCollection, ScalarValue,
+};
- let AggrFn {
- inner: aggr_function,
- distinct,
- } = aggr_expr_to_aggr_fn(a.as_ref())?;
+use super::PhysicalExtensionCodec;
- Ok(protobuf::PhysicalExprNode {
+pub fn serialize_physical_aggr_expr(
+ aggr_expr: Arc<dyn AggregateExpr>,
+ codec: &dyn PhysicalExtensionCodec,
+) -> Result<protobuf::PhysicalExprNode> {
+ let expressions = serialize_physical_exprs(aggr_expr.expressions(),
codec)?;
+ let ordering_req = aggr_expr.order_bys().unwrap_or(&[]).to_vec();
+ let ordering_req = serialize_physical_sort_exprs(ordering_req, codec)?;
+
+ if let Some(a) =
aggr_expr.as_any().downcast_ref::<AggregateFunctionExpr>() {
+ let name = a.fun().name().to_string();
+ return Ok(protobuf::PhysicalExprNode {
expr_type:
Some(protobuf::physical_expr_node::ExprType::AggregateExpr(
protobuf::PhysicalAggregateExprNode {
- aggregate_function: Some(
-
physical_aggregate_expr_node::AggregateFunction::AggrFunction(
- aggr_function as i32,
- ),
- ),
+ aggregate_function:
Some(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(name)),
expr: expressions,
ordering_req,
- distinct,
+ distinct: false,
},
)),
- })
+ });
}
-}
-
-impl TryFrom<Arc<dyn WindowExpr>> for protobuf::PhysicalWindowExprNode {
- type Error = DataFusionError;
- fn try_from(
- window_expr: Arc<dyn WindowExpr>,
- ) -> std::result::Result<Self, Self::Error> {
- let expr = window_expr.as_any();
+ let AggrFn {
+ inner: aggr_function,
+ distinct,
+ } = aggr_expr_to_aggr_fn(aggr_expr.as_ref())?;
+
+ Ok(protobuf::PhysicalExprNode {
+ expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateExpr(
+ protobuf::PhysicalAggregateExprNode {
+ aggregate_function: Some(
+
physical_aggregate_expr_node::AggregateFunction::AggrFunction(
+ aggr_function as i32,
+ ),
+ ),
+ expr: expressions,
+ ordering_req,
+ distinct,
+ },
+ )),
+ })
+}
- let mut args = window_expr.expressions().to_vec();
- let window_frame = window_expr.get_window_frame();
+pub fn serialize_physical_window_expr(
+ window_expr: Arc<dyn WindowExpr>,
+ codec: &dyn PhysicalExtensionCodec,
+) -> Result<protobuf::PhysicalWindowExprNode> {
+ let expr = window_expr.as_any();
+ let mut args = window_expr.expressions().to_vec();
+ let window_frame = window_expr.get_window_frame();
- let window_function = if let Some(built_in_window_expr) =
- expr.downcast_ref::<BuiltInWindowExpr>()
+ let window_function = if let Some(built_in_window_expr) =
+ expr.downcast_ref::<BuiltInWindowExpr>()
+ {
+ let expr = built_in_window_expr.get_built_in_func_expr();
+ let built_in_fn_expr = expr.as_any();
+
+ let builtin_fn = if
built_in_fn_expr.downcast_ref::<RowNumber>().is_some() {
+ protobuf::BuiltInWindowFunction::RowNumber
+ } else if let Some(rank_expr) =
built_in_fn_expr.downcast_ref::<Rank>() {
+ match rank_expr.get_type() {
+ RankType::Basic => protobuf::BuiltInWindowFunction::Rank,
+ RankType::Dense => protobuf::BuiltInWindowFunction::DenseRank,
+ RankType::Percent =>
protobuf::BuiltInWindowFunction::PercentRank,
+ }
+ } else if built_in_fn_expr.downcast_ref::<CumeDist>().is_some() {
+ protobuf::BuiltInWindowFunction::CumeDist
+ } else if let Some(ntile_expr) =
built_in_fn_expr.downcast_ref::<Ntile>() {
+ args.insert(
+ 0,
+
Arc::new(Literal::new(datafusion_common::ScalarValue::Int64(Some(
+ ntile_expr.get_n() as i64,
+ )))),
+ );
+ protobuf::BuiltInWindowFunction::Ntile
+ } else if let Some(window_shift_expr) =
+ built_in_fn_expr.downcast_ref::<WindowShift>()
{
- let expr = built_in_window_expr.get_built_in_func_expr();
- let built_in_fn_expr = expr.as_any();
-
- let builtin_fn = if
built_in_fn_expr.downcast_ref::<RowNumber>().is_some() {
- protobuf::BuiltInWindowFunction::RowNumber
- } else if let Some(rank_expr) =
built_in_fn_expr.downcast_ref::<Rank>() {
- match rank_expr.get_type() {
- RankType::Basic => protobuf::BuiltInWindowFunction::Rank,
- RankType::Dense =>
protobuf::BuiltInWindowFunction::DenseRank,
- RankType::Percent =>
protobuf::BuiltInWindowFunction::PercentRank,
- }
- } else if built_in_fn_expr.downcast_ref::<CumeDist>().is_some() {
- protobuf::BuiltInWindowFunction::CumeDist
- } else if let Some(ntile_expr) =
built_in_fn_expr.downcast_ref::<Ntile>() {
- args.insert(
- 0,
-
Arc::new(Literal::new(datafusion_common::ScalarValue::Int64(Some(
- ntile_expr.get_n() as i64,
- )))),
- );
- protobuf::BuiltInWindowFunction::Ntile
- } else if let Some(window_shift_expr) =
- built_in_fn_expr.downcast_ref::<WindowShift>()
- {
- args.insert(
- 1,
-
Arc::new(Literal::new(datafusion_common::ScalarValue::Int64(Some(
- window_shift_expr.get_shift_offset(),
- )))),
- );
- args.insert(
- 2,
-
Arc::new(Literal::new(window_shift_expr.get_default_value())),
- );
-
- if window_shift_expr.get_shift_offset() >= 0 {
- protobuf::BuiltInWindowFunction::Lag
- } else {
- protobuf::BuiltInWindowFunction::Lead
- }
- } else if let Some(nth_value_expr) =
- built_in_fn_expr.downcast_ref::<NthValue>()
- {
- match nth_value_expr.get_kind() {
- NthValueKind::First =>
protobuf::BuiltInWindowFunction::FirstValue,
- NthValueKind::Last =>
protobuf::BuiltInWindowFunction::LastValue,
- NthValueKind::Nth(n) => {
- args.insert(
- 1,
- Arc::new(Literal::new(
- datafusion_common::ScalarValue::Int64(Some(n)),
- )),
- );
- protobuf::BuiltInWindowFunction::NthValue
- }
- }
+ args.insert(
+ 1,
+
Arc::new(Literal::new(datafusion_common::ScalarValue::Int64(Some(
+ window_shift_expr.get_shift_offset(),
+ )))),
+ );
+ args.insert(
+ 2,
+ Arc::new(Literal::new(window_shift_expr.get_default_value())),
+ );
+
+ if window_shift_expr.get_shift_offset() >= 0 {
+ protobuf::BuiltInWindowFunction::Lag
} else {
- return not_impl_err!("BuiltIn function not supported:
{expr:?}");
- };
-
-
physical_window_expr_node::WindowFunction::BuiltInFunction(builtin_fn as i32)
- } else if let Some(plain_aggr_window_expr) =
- expr.downcast_ref::<PlainAggregateWindowExpr>()
- {
- let AggrFn { inner, distinct } = aggr_expr_to_aggr_fn(
- plain_aggr_window_expr.get_aggregate_expr().as_ref(),
- )?;
-
- if distinct {
- // TODO
- return not_impl_err!(
- "Distinct aggregate functions not supported in window
expressions"
- );
+ protobuf::BuiltInWindowFunction::Lead
}
-
- if !window_frame.start_bound.is_unbounded() {
- return Err(DataFusionError::Internal(format!("Invalid
PlainAggregateWindowExpr = {window_expr:?} with WindowFrame =
{window_frame:?}")));
+ } else if let Some(nth_value_expr) =
built_in_fn_expr.downcast_ref::<NthValue>() {
+ match nth_value_expr.get_kind() {
+ NthValueKind::First =>
protobuf::BuiltInWindowFunction::FirstValue,
+ NthValueKind::Last =>
protobuf::BuiltInWindowFunction::LastValue,
+ NthValueKind::Nth(n) => {
+ args.insert(
+ 1,
+
Arc::new(Literal::new(datafusion_common::ScalarValue::Int64(
+ Some(n),
+ ))),
+ );
+ protobuf::BuiltInWindowFunction::NthValue
+ }
}
+ } else {
+ return not_impl_err!("BuiltIn function not supported: {expr:?}");
+ };
- physical_window_expr_node::WindowFunction::AggrFunction(inner as
i32)
- } else if let Some(sliding_aggr_window_expr) =
- expr.downcast_ref::<SlidingAggregateWindowExpr>()
- {
- let AggrFn { inner, distinct } = aggr_expr_to_aggr_fn(
- sliding_aggr_window_expr.get_aggregate_expr().as_ref(),
- )?;
-
- if distinct {
- // TODO
- return not_impl_err!(
- "Distinct aggregate functions not supported in window
expressions"
- );
- }
+ physical_window_expr_node::WindowFunction::BuiltInFunction(builtin_fn
as i32)
+ } else if let Some(plain_aggr_window_expr) =
+ expr.downcast_ref::<PlainAggregateWindowExpr>()
+ {
+ let AggrFn { inner, distinct } =
+
aggr_expr_to_aggr_fn(plain_aggr_window_expr.get_aggregate_expr().as_ref())?;
+
+ if distinct {
+ // TODO
+ return not_impl_err!(
+ "Distinct aggregate functions not supported in window
expressions"
+ );
+ }
- if window_frame.start_bound.is_unbounded() {
- return Err(DataFusionError::Internal(format!("Invalid
SlidingAggregateWindowExpr = {window_expr:?} with WindowFrame =
{window_frame:?}")));
- }
+ if !window_frame.start_bound.is_unbounded() {
+ return Err(DataFusionError::Internal(format!("Invalid
PlainAggregateWindowExpr = {window_expr:?} with WindowFrame =
{window_frame:?}")));
+ }
- physical_window_expr_node::WindowFunction::AggrFunction(inner as
i32)
- } else {
- return not_impl_err!("WindowExpr not supported: {window_expr:?}");
- };
- let codec = DefaultPhysicalExtensionCodec {};
- let args = serialize_physical_exprs(args, &codec)?;
- let partition_by =
- serialize_physical_exprs(window_expr.partition_by().to_vec(),
&codec)?;
+ physical_window_expr_node::WindowFunction::AggrFunction(inner as i32)
+ } else if let Some(sliding_aggr_window_expr) =
+ expr.downcast_ref::<SlidingAggregateWindowExpr>()
+ {
+ let AggrFn { inner, distinct } =
+
aggr_expr_to_aggr_fn(sliding_aggr_window_expr.get_aggregate_expr().as_ref())?;
+
+ if distinct {
+ // TODO
+ return not_impl_err!(
+ "Distinct aggregate functions not supported in window
expressions"
+ );
+ }
- let order_by =
- serialize_physical_sort_exprs(window_expr.order_by().to_vec(),
&codec)?;
+ if window_frame.start_bound.is_unbounded() {
+ return Err(DataFusionError::Internal(format!("Invalid
SlidingAggregateWindowExpr = {window_expr:?} with WindowFrame =
{window_frame:?}")));
+ }
- let window_frame: protobuf::WindowFrame = window_frame
- .as_ref()
- .try_into()
- .map_err(|e| DataFusionError::Internal(format!("{e}")))?;
-
- let name = window_expr.name().to_string();
-
- Ok(protobuf::PhysicalWindowExprNode {
- args,
- partition_by,
- order_by,
- window_frame: Some(window_frame),
- window_function: Some(window_function),
- name,
- })
- }
+ physical_window_expr_node::WindowFunction::AggrFunction(inner as i32)
+ } else {
+ return not_impl_err!("WindowExpr not supported: {window_expr:?}");
+ };
+
+ let args = serialize_physical_exprs(args, codec)?;
+ let partition_by =
+ serialize_physical_exprs(window_expr.partition_by().to_vec(), codec)?;
+ let order_by =
serialize_physical_sort_exprs(window_expr.order_by().to_vec(), codec)?;
+ let window_frame: protobuf::WindowFrame = window_frame
+ .as_ref()
+ .try_into()
+ .map_err(|e| DataFusionError::Internal(format!("{e}")))?;
+
+ Ok(protobuf::PhysicalWindowExprNode {
+ args,
+ partition_by,
+ order_by,
+ window_frame: Some(window_frame),
+ window_function: Some(window_function),
+ name: window_expr.name().to_string(),
+ })
}
struct AggrFn {
@@ -366,7 +348,7 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) ->
Result<AggrFn> {
pub fn serialize_physical_sort_exprs<I>(
sort_exprs: I,
codec: &dyn PhysicalExtensionCodec,
-) -> Result<Vec<protobuf::PhysicalSortExprNode>, DataFusionError>
+) -> Result<Vec<PhysicalSortExprNode>>
where
I: IntoIterator<Item = PhysicalSortExpr>,
{
@@ -379,7 +361,7 @@ where
pub fn serialize_physical_sort_expr(
sort_expr: PhysicalSortExpr,
codec: &dyn PhysicalExtensionCodec,
-) -> Result<protobuf::PhysicalSortExprNode, DataFusionError> {
+) -> Result<PhysicalSortExprNode> {
let PhysicalSortExpr { expr, options } = sort_expr;
let expr = serialize_physical_expr(expr, codec)?;
Ok(PhysicalSortExprNode {
@@ -392,7 +374,7 @@ pub fn serialize_physical_sort_expr(
pub fn serialize_physical_exprs<I>(
values: I,
codec: &dyn PhysicalExtensionCodec,
-) -> Result<Vec<protobuf::PhysicalExprNode>, DataFusionError>
+) -> Result<Vec<protobuf::PhysicalExprNode>>
where
I: IntoIterator<Item = Arc<dyn PhysicalExpr>>,
{
@@ -409,7 +391,7 @@ where
pub fn serialize_physical_expr(
value: Arc<dyn PhysicalExpr>,
codec: &dyn PhysicalExtensionCodec,
-) -> Result<protobuf::PhysicalExprNode, DataFusionError> {
+) -> Result<protobuf::PhysicalExprNode> {
let expr = value.as_any();
if let Some(expr) = expr.downcast_ref::<Column>() {
@@ -456,7 +438,7 @@ pub fn serialize_physical_expr(
.when_then_expr()
.iter()
.map(|(when_expr, then_expr)| {
- try_parse_when_then_expr(when_expr,
then_expr, codec)
+ serialize_when_then_expr(when_expr,
then_expr, codec)
})
.collect::<Result<
Vec<protobuf::PhysicalWhenThen>,
@@ -623,7 +605,7 @@ pub fn serialize_physical_expr(
}
}
-fn try_parse_when_then_expr(
+fn serialize_when_then_expr(
when_expr: &Arc<dyn PhysicalExpr>,
then_expr: &Arc<dyn PhysicalExpr>,
codec: &dyn PhysicalExtensionCodec,
@@ -637,7 +619,7 @@ fn try_parse_when_then_expr(
impl TryFrom<&PartitionedFile> for protobuf::PartitionedFile {
type Error = DataFusionError;
- fn try_from(pf: &PartitionedFile) -> Result<Self, Self::Error> {
+ fn try_from(pf: &PartitionedFile) -> Result<Self> {
let last_modified = pf.object_meta.last_modified;
let last_modified_ns =
last_modified.timestamp_nanos_opt().ok_or_else(|| {
DataFusionError::Plan(format!(
@@ -661,7 +643,7 @@ impl TryFrom<&PartitionedFile> for
protobuf::PartitionedFile {
impl TryFrom<&FileRange> for protobuf::FileRange {
type Error = DataFusionError;
- fn try_from(value: &FileRange) -> Result<Self, Self::Error> {
+ fn try_from(value: &FileRange) -> Result<Self> {
Ok(protobuf::FileRange {
start: value.start,
end: value.end,
@@ -746,61 +728,58 @@ impl From<&ColumnStatistics> for protobuf::ColumnStats {
}
}
-impl TryFrom<&FileScanConfig> for protobuf::FileScanExecConf {
- type Error = DataFusionError;
- fn try_from(
- conf: &FileScanConfig,
- ) -> Result<protobuf::FileScanExecConf, Self::Error> {
- let codec = DefaultPhysicalExtensionCodec {};
- let file_groups = conf
- .file_groups
- .iter()
- .map(|p| p.as_slice().try_into())
- .collect::<Result<Vec<_>, _>>()?;
-
- let mut output_orderings = vec![];
- for order in &conf.output_ordering {
- let ordering = serialize_physical_sort_exprs(order.to_vec(),
&codec)?;
- output_orderings.push(ordering)
- }
+pub fn serialize_file_scan_config(
+ conf: &FileScanConfig,
+ codec: &dyn PhysicalExtensionCodec,
+) -> Result<protobuf::FileScanExecConf> {
+ let file_groups = conf
+ .file_groups
+ .iter()
+ .map(|p| p.as_slice().try_into())
+ .collect::<Result<Vec<_>, _>>()?;
+
+ let mut output_orderings = vec![];
+ for order in &conf.output_ordering {
+ let ordering = serialize_physical_sort_exprs(order.to_vec(), codec)?;
+ output_orderings.push(ordering)
+ }
- // Fields must be added to the schema so that they can persist in the
protobuf
- // and then they are to be removed from the schema in
`parse_protobuf_file_scan_config`
- let mut fields = conf
- .file_schema
- .fields()
+ // Fields must be added to the schema so that they can persist in the
protobuf,
+ // and then they are to be removed from the schema in
`parse_protobuf_file_scan_config`
+ let mut fields = conf
+ .file_schema
+ .fields()
+ .iter()
+ .cloned()
+ .collect::<Vec<_>>();
+ fields.extend(conf.table_partition_cols.iter().cloned().map(Arc::new));
+ let schema = Arc::new(arrow::datatypes::Schema::new(fields.clone()));
+
+ Ok(protobuf::FileScanExecConf {
+ file_groups,
+ statistics: Some((&conf.statistics).into()),
+ limit: conf.limit.map(|l| protobuf::ScanLimit { limit: l as u32 }),
+ projection: conf
+ .projection
+ .as_ref()
+ .unwrap_or(&vec![])
.iter()
- .cloned()
- .collect::<Vec<_>>();
- fields.extend(conf.table_partition_cols.iter().cloned().map(Arc::new));
- let schema =
Arc::new(datafusion::arrow::datatypes::Schema::new(fields.clone()));
-
- Ok(protobuf::FileScanExecConf {
- file_groups,
- statistics: Some((&conf.statistics).into()),
- limit: conf.limit.map(|l| protobuf::ScanLimit { limit: l as u32 }),
- projection: conf
- .projection
- .as_ref()
- .unwrap_or(&vec![])
- .iter()
- .map(|n| *n as u32)
- .collect(),
- schema: Some(schema.as_ref().try_into()?),
- table_partition_cols: conf
- .table_partition_cols
- .iter()
- .map(|x| x.name().clone())
- .collect::<Vec<_>>(),
- object_store_url: conf.object_store_url.to_string(),
- output_ordering: output_orderings
- .into_iter()
- .map(|e| PhysicalSortExprNodeCollection {
- physical_sort_expr_nodes: e,
- })
- .collect::<Vec<_>>(),
- })
- }
+ .map(|n| *n as u32)
+ .collect(),
+ schema: Some(schema.as_ref().try_into()?),
+ table_partition_cols: conf
+ .table_partition_cols
+ .iter()
+ .map(|x| x.name().clone())
+ .collect::<Vec<_>>(),
+ object_store_url: conf.object_store_url.to_string(),
+ output_ordering: output_orderings
+ .into_iter()
+ .map(|e| PhysicalSortExprNodeCollection {
+ physical_sort_expr_nodes: e,
+ })
+ .collect::<Vec<_>>(),
+ })
}
impl From<JoinSide> for protobuf::JoinSide {
@@ -812,46 +791,15 @@ impl From<JoinSide> for protobuf::JoinSide {
}
}
-impl TryFrom<Option<Arc<dyn PhysicalExpr>>> for protobuf::MaybeFilter {
- type Error = DataFusionError;
-
- fn try_from(expr: Option<Arc<dyn PhysicalExpr>>) -> Result<Self,
Self::Error> {
- let codec = DefaultPhysicalExtensionCodec {};
- match expr {
- None => Ok(protobuf::MaybeFilter { expr: None }),
- Some(expr) => Ok(protobuf::MaybeFilter {
- expr: Some(serialize_physical_expr(expr, &codec)?),
- }),
- }
- }
-}
-
-impl TryFrom<Option<Vec<PhysicalSortExpr>>> for
protobuf::MaybePhysicalSortExprs {
- type Error = DataFusionError;
-
- fn try_from(sort_exprs: Option<Vec<PhysicalSortExpr>>) -> Result<Self,
Self::Error> {
- match sort_exprs {
- None => Ok(protobuf::MaybePhysicalSortExprs { sort_expr: vec![] }),
- Some(sort_exprs) => Ok(protobuf::MaybePhysicalSortExprs {
- sort_expr: sort_exprs
- .into_iter()
- .map(|sort_expr| sort_expr.try_into())
- .collect::<Result<Vec<_>>>()?,
- }),
- }
- }
-}
-
-impl TryFrom<PhysicalSortExpr> for protobuf::PhysicalSortExprNode {
- type Error = DataFusionError;
-
- fn try_from(sort_expr: PhysicalSortExpr) -> std::result::Result<Self,
Self::Error> {
- let codec = DefaultPhysicalExtensionCodec {};
- Ok(PhysicalSortExprNode {
- expr: Some(Box::new(serialize_physical_expr(sort_expr.expr,
&codec)?)),
- asc: !sort_expr.options.descending,
- nulls_first: sort_expr.options.nulls_first,
- })
+pub fn serialize_maybe_filter(
+ expr: Option<Arc<dyn PhysicalExpr>>,
+ codec: &dyn PhysicalExtensionCodec,
+) -> Result<protobuf::MaybeFilter> {
+ match expr {
+ None => Ok(protobuf::MaybeFilter { expr: None }),
+ Some(expr) => Ok(protobuf::MaybeFilter {
+ expr: Some(serialize_physical_expr(expr, codec)?),
+ }),
}
}
diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
index f97cfea765..642860d639 100644
--- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
@@ -21,6 +21,8 @@ use std::sync::Arc;
use std::vec;
use arrow::csv::WriterBuilder;
+use prost::Message;
+
use datafusion::arrow::array::ArrayRef;
use datafusion::arrow::compute::kernels::sort::SortOptions;
use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema};
@@ -35,7 +37,7 @@ use datafusion::datasource::physical_plan::{
};
use datafusion::execution::FunctionRegistry;
use datafusion::logical_expr::{create_udf, JoinType, Operator, Volatility};
-use datafusion::physical_expr::expressions::NthValueAgg;
+use datafusion::physical_expr::expressions::{Count, Max, NthValueAgg};
use datafusion::physical_expr::window::SlidingAggregateWindowExpr;
use datafusion::physical_expr::{PhysicalSortRequirement, ScalarFunctionExpr};
use datafusion::physical_plan::aggregates::{
@@ -77,19 +79,18 @@ use datafusion_expr::{
ScalarFunctionDefinition, ScalarUDF, ScalarUDFImpl, Signature,
SimpleAggregateUDF,
WindowFrame, WindowFrameBound,
};
-use datafusion_proto::physical_plan::from_proto::parse_physical_expr;
-use datafusion_proto::physical_plan::to_proto::serialize_physical_expr;
use datafusion_proto::physical_plan::{
AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec,
};
use datafusion_proto::protobuf;
-use prost::Message;
/// Perform a serde roundtrip and assert that the string representation of the
before and after plans
/// are identical. Note that this often isn't sufficient to guarantee that no
information is
/// lost during serde because the string representation of a plan often only
shows a subset of state.
fn roundtrip_test(exec_plan: Arc<dyn ExecutionPlan>) -> Result<()> {
- let _ = roundtrip_test_and_return(exec_plan);
+ let ctx = SessionContext::new();
+ let codec = DefaultPhysicalExtensionCodec {};
+ roundtrip_test_and_return(exec_plan, &ctx, &codec)?;
Ok(())
}
@@ -101,15 +102,15 @@ fn roundtrip_test(exec_plan: Arc<dyn ExecutionPlan>) ->
Result<()> {
/// farther in tests.
fn roundtrip_test_and_return(
exec_plan: Arc<dyn ExecutionPlan>,
+ ctx: &SessionContext,
+ codec: &dyn PhysicalExtensionCodec,
) -> Result<Arc<dyn ExecutionPlan>> {
- let ctx = SessionContext::new();
- let codec = DefaultPhysicalExtensionCodec {};
let proto: protobuf::PhysicalPlanNode =
- protobuf::PhysicalPlanNode::try_from_physical_plan(exec_plan.clone(),
&codec)
+ protobuf::PhysicalPlanNode::try_from_physical_plan(exec_plan.clone(),
codec)
.expect("to proto");
let runtime = ctx.runtime_env();
let result_exec_plan: Arc<dyn ExecutionPlan> = proto
- .try_into_physical_plan(&ctx, runtime.deref(), &codec)
+ .try_into_physical_plan(ctx, runtime.deref(), codec)
.expect("from proto");
assert_eq!(format!("{exec_plan:?}"), format!("{result_exec_plan:?}"));
Ok(result_exec_plan)
@@ -123,17 +124,10 @@ fn roundtrip_test_and_return(
/// performing serde on some plans.
fn roundtrip_test_with_context(
exec_plan: Arc<dyn ExecutionPlan>,
- ctx: SessionContext,
+ ctx: &SessionContext,
) -> Result<()> {
let codec = DefaultPhysicalExtensionCodec {};
- let proto: protobuf::PhysicalPlanNode =
- protobuf::PhysicalPlanNode::try_from_physical_plan(exec_plan.clone(),
&codec)
- .expect("to proto");
- let runtime = ctx.runtime_env();
- let result_exec_plan: Arc<dyn ExecutionPlan> = proto
- .try_into_physical_plan(&ctx, runtime.deref(), &codec)
- .expect("from proto");
- assert_eq!(format!("{exec_plan:?}"), format!("{result_exec_plan:?}"));
+ roundtrip_test_and_return(exec_plan, ctx, &codec)?;
Ok(())
}
@@ -444,7 +438,7 @@ fn roundtrip_aggregate_udaf() -> Result<()> {
Arc::new(EmptyExec::new(schema.clone())),
schema,
)?),
- ctx,
+ &ctx,
)
}
@@ -642,11 +636,11 @@ fn roundtrip_scalar_udf() -> Result<()> {
ctx.register_udf(udf);
- roundtrip_test_with_context(Arc::new(project), ctx)
+ roundtrip_test_with_context(Arc::new(project), &ctx)
}
#[test]
-fn roundtrip_scalar_udf_extension_codec() {
+fn roundtrip_scalar_udf_extension_codec() -> Result<()> {
#[derive(Debug)]
struct MyRegexUdf {
signature: Signature,
@@ -657,11 +651,7 @@ fn roundtrip_scalar_udf_extension_codec() {
impl MyRegexUdf {
fn new(pattern: String) -> Self {
Self {
- signature: Signature::uniform(
- 1,
- vec![DataType::Int32],
- Volatility::Immutable,
- ),
+ signature: Signature::exact(vec![DataType::Utf8],
Volatility::Immutable),
pattern,
}
}
@@ -672,18 +662,22 @@ fn roundtrip_scalar_udf_extension_codec() {
fn as_any(&self) -> &dyn Any {
self
}
+
fn name(&self) -> &str {
"regex_udf"
}
+
fn signature(&self) -> &Signature {
&self.signature
}
+
fn return_type(&self, args: &[DataType]) -> Result<DataType> {
if !matches!(args.first(), Some(&DataType::Utf8)) {
return plan_err!("regex_udf only accepts Utf8 arguments");
}
- Ok(DataType::Int32)
+ Ok(DataType::Int64)
}
+
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
unimplemented!()
}
@@ -747,32 +741,58 @@ fn roundtrip_scalar_udf_extension_codec() {
}
}
+ let field_text = Field::new("text", DataType::Utf8, true);
+ let field_published = Field::new("published", DataType::Boolean, false);
+ let field_author = Field::new("author", DataType::Utf8, false);
+ let schema = Arc::new(Schema::new(vec![field_text, field_published,
field_author]));
+ let input = Arc::new(EmptyExec::new(schema.clone()));
+
let pattern = ".*";
let udf = ScalarUDF::from(MyRegexUdf::new(pattern.to_string()));
- let test_expr = ScalarFunctionExpr::new(
+ let udf_expr = Arc::new(ScalarFunctionExpr::new(
udf.name(),
ScalarFunctionDefinition::UDF(Arc::new(udf.clone())),
- vec![],
- DataType::Int32,
+ vec![col("text", &schema)?],
+ DataType::Int64,
None,
false,
- );
- let fmt_expr = format!("{test_expr:?}");
- let ctx = SessionContext::new();
+ ));
- ctx.register_udf(udf.clone());
- let extension_codec = ScalarUDFExtensionCodec {};
- let proto: protobuf::PhysicalExprNode =
- match serialize_physical_expr(Arc::new(test_expr), &extension_codec) {
- Ok(proto) => proto,
- Err(e) => panic!("failed to serialize expr: {e:?}"),
- };
- let field_a = Field::new("a", DataType::Int32, false);
- let schema = Arc::new(Schema::new(vec![field_a]));
- let round_trip =
- parse_physical_expr(&proto, &ctx, &schema, &extension_codec).unwrap();
- assert_eq!(fmt_expr, format!("{round_trip:?}"));
+ let filter = Arc::new(FilterExec::try_new(
+ Arc::new(BinaryExpr::new(
+ col("published", &schema)?,
+ Operator::And,
+ Arc::new(BinaryExpr::new(udf_expr.clone(), Operator::Gt, lit(0))),
+ )),
+ input,
+ )?);
+
+ let window = Arc::new(WindowAggExec::try_new(
+ vec![Arc::new(PlainAggregateWindowExpr::new(
+ Arc::new(Max::new(udf_expr.clone(), "max", DataType::Int64)),
+ &[col("author", &schema)?],
+ &[],
+ Arc::new(WindowFrame::new(None)),
+ ))],
+ filter,
+ vec![col("author", &schema)?],
+ )?);
+
+ let aggregate = Arc::new(AggregateExec::try_new(
+ AggregateMode::Final,
+ PhysicalGroupBy::new(vec![], vec![], vec![]),
+ vec![Arc::new(Count::new(udf_expr, "count", DataType::Int64))],
+ vec![None],
+ window,
+ schema.clone(),
+ )?);
+
+ let ctx = SessionContext::new();
+ let codec = ScalarUDFExtensionCodec {};
+ roundtrip_test_and_return(aggregate, &ctx, &codec)?;
+ Ok(())
}
+
#[test]
fn roundtrip_distinct_count() -> Result<()> {
let field_a = Field::new("a", DataType::Int64, false);
@@ -896,12 +916,18 @@ fn roundtrip_csv_sink() -> Result<()> {
}),
)];
- let roundtrip_plan = roundtrip_test_and_return(Arc::new(DataSinkExec::new(
- input,
- data_sink,
- schema.clone(),
- Some(sort_order),
- )))
+ let ctx = SessionContext::new();
+ let codec = DefaultPhysicalExtensionCodec {};
+ let roundtrip_plan = roundtrip_test_and_return(
+ Arc::new(DataSinkExec::new(
+ input,
+ data_sink,
+ schema.clone(),
+ Some(sort_order),
+ )),
+ &ctx,
+ &codec,
+ )
.unwrap();
let roundtrip_plan = roundtrip_plan