This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 36f6e0facc Use PhysicalExtensionCodec consistently (#10075)
36f6e0facc is described below

commit 36f6e0facc8cdbded3e595cbcb441190d055e56d
Author: Georgi Krastev <[email protected]>
AuthorDate: Mon Apr 15 13:27:03 2024 +0300

    Use PhysicalExtensionCodec consistently (#10075)
    
    * Use PhysicalExtensionCodec consistently
    
    * Use PhysicalExtensionCodec consisdently also when serializing
    
    * Add a test for window aggregation with UDF codec
    
    * Commit binary incompatible changes
---
 datafusion/proto/src/physical_plan/from_proto.rs   |  71 +--
 datafusion/proto/src/physical_plan/mod.rs          |  93 ++--
 datafusion/proto/src/physical_plan/to_proto.rs     | 504 +++++++++------------
 .../proto/tests/cases/roundtrip_physical_plan.rs   | 128 +++---
 4 files changed, 402 insertions(+), 394 deletions(-)

diff --git a/datafusion/proto/src/physical_plan/from_proto.rs 
b/datafusion/proto/src/physical_plan/from_proto.rs
index aaca4dc482..81e4c92ffc 100644
--- a/datafusion/proto/src/physical_plan/from_proto.rs
+++ b/datafusion/proto/src/physical_plan/from_proto.rs
@@ -21,13 +21,11 @@ use std::collections::HashMap;
 use std::convert::{TryFrom, TryInto};
 use std::sync::Arc;
 
-use crate::common::proto_error;
-use crate::convert_required;
-use crate::logical_plan::{self, csv_writer_options_from_proto};
-use crate::protobuf::physical_expr_node::ExprType;
-use crate::protobuf::{self, copy_to_node};
-
 use arrow::compute::SortOptions;
+use chrono::{TimeZone, Utc};
+use object_store::path::Path;
+use object_store::ObjectMeta;
+
 use datafusion::arrow::datatypes::Schema;
 use datafusion::datasource::file_format::csv::CsvSink;
 use datafusion::datasource::file_format::json::JsonSink;
@@ -57,13 +55,15 @@ use 
datafusion_common::file_options::json_writer::JsonWriterOptions;
 use datafusion_common::parsers::CompressionTypeVariant;
 use datafusion_common::stats::Precision;
 use datafusion_common::{not_impl_err, DataFusionError, JoinSide, Result, 
ScalarValue};
-
-use chrono::{TimeZone, Utc};
 use datafusion_expr::ScalarFunctionDefinition;
-use object_store::path::Path;
-use object_store::ObjectMeta;
 
-use super::{DefaultPhysicalExtensionCodec, PhysicalExtensionCodec};
+use crate::common::proto_error;
+use crate::convert_required;
+use crate::logical_plan::{self, csv_writer_options_from_proto};
+use crate::protobuf::physical_expr_node::ExprType;
+use crate::protobuf::{self, copy_to_node};
+
+use super::PhysicalExtensionCodec;
 
 impl From<&protobuf::PhysicalColumn> for Column {
     fn from(c: &protobuf::PhysicalColumn) -> Column {
@@ -76,9 +76,10 @@ impl From<&protobuf::PhysicalColumn> for Column {
 /// # Arguments
 ///
 /// * `proto` - Input proto with physical sort expression node
-/// * `registry` - A registry knows how to build logical expressions out of 
user-defined function' names
+/// * `registry` - A registry knows how to build logical expressions out of 
user-defined function names
 /// * `input_schema` - The Arrow schema for the input, used for determining 
expression data types
 ///                    when performing type coercion.
+/// * `codec` - An extension codec used to decode custom UDFs.
 pub fn parse_physical_sort_expr(
     proto: &protobuf::PhysicalSortExprNode,
     registry: &dyn FunctionRegistry,
@@ -102,9 +103,10 @@ pub fn parse_physical_sort_expr(
 /// # Arguments
 ///
 /// * `proto` - Input proto with vector of physical sort expression node
-/// * `registry` - A registry knows how to build logical expressions out of 
user-defined function' names
+/// * `registry` - A registry knows how to build logical expressions out of 
user-defined function names
 /// * `input_schema` - The Arrow schema for the input, used for determining 
expression data types
 ///                    when performing type coercion.
+/// * `codec` - An extension codec used to decode custom UDFs.
 pub fn parse_physical_sort_exprs(
     proto: &[protobuf::PhysicalSortExprNode],
     registry: &dyn FunctionRegistry,
@@ -123,25 +125,26 @@ pub fn parse_physical_sort_exprs(
 ///
 /// # Arguments
 ///
-/// * `proto` - Input proto with physical window exprression node.
+/// * `proto` - Input proto with physical window expression node.
 /// * `name` - Name of the window expression.
-/// * `registry` - A registry knows how to build logical expressions out of 
user-defined function' names
+/// * `registry` - A registry knows how to build logical expressions out of 
user-defined function names
 /// * `input_schema` - The Arrow schema for the input, used for determining 
expression data types
 ///                    when performing type coercion.
+/// * `codec` - An extension codec used to decode custom UDFs.
 pub fn parse_physical_window_expr(
     proto: &protobuf::PhysicalWindowExprNode,
     registry: &dyn FunctionRegistry,
     input_schema: &Schema,
+    codec: &dyn PhysicalExtensionCodec,
 ) -> Result<Arc<dyn WindowExpr>> {
-    let codec = DefaultPhysicalExtensionCodec {};
     let window_node_expr =
-        parse_physical_exprs(&proto.args, registry, input_schema, &codec)?;
+        parse_physical_exprs(&proto.args, registry, input_schema, codec)?;
 
     let partition_by =
-        parse_physical_exprs(&proto.partition_by, registry, input_schema, 
&codec)?;
+        parse_physical_exprs(&proto.partition_by, registry, input_schema, 
codec)?;
 
     let order_by =
-        parse_physical_sort_exprs(&proto.order_by, registry, input_schema, 
&codec)?;
+        parse_physical_sort_exprs(&proto.order_by, registry, input_schema, 
codec)?;
 
     let window_frame = proto
         .window_frame
@@ -187,9 +190,10 @@ where
 /// # Arguments
 ///
 /// * `proto` - Input proto with physical expression node
-/// * `registry` - A registry knows how to build logical expressions out of 
user-defined function' names
+/// * `registry` - A registry knows how to build logical expressions out of 
user-defined function names
 /// * `input_schema` - The Arrow schema for the input, used for determining 
expression data types
 ///                    when performing type coercion.
+/// * `codec` - An extension codec used to decode custom UDFs.
 pub fn parse_physical_expr(
     proto: &protobuf::PhysicalExprNode,
     registry: &dyn FunctionRegistry,
@@ -213,6 +217,7 @@ pub fn parse_physical_expr(
                 registry,
                 "left",
                 input_schema,
+                codec,
             )?,
             logical_plan::from_proto::from_proto_binary_op(&binary_expr.op)?,
             parse_required_physical_expr(
@@ -220,6 +225,7 @@ pub fn parse_physical_expr(
                 registry,
                 "right",
                 input_schema,
+                codec,
             )?,
         )),
         ExprType::AggregateExpr(_) => {
@@ -241,6 +247,7 @@ pub fn parse_physical_expr(
                 registry,
                 "expr",
                 input_schema,
+                codec,
             )?))
         }
         ExprType::IsNotNullExpr(e) => {
@@ -249,6 +256,7 @@ pub fn parse_physical_expr(
                 registry,
                 "expr",
                 input_schema,
+                codec,
             )?))
         }
         ExprType::NotExpr(e) => 
Arc::new(NotExpr::new(parse_required_physical_expr(
@@ -256,6 +264,7 @@ pub fn parse_physical_expr(
             registry,
             "expr",
             input_schema,
+            codec,
         )?)),
         ExprType::Negative(e) => {
             Arc::new(NegativeExpr::new(parse_required_physical_expr(
@@ -263,6 +272,7 @@ pub fn parse_physical_expr(
                 registry,
                 "expr",
                 input_schema,
+                codec,
             )?))
         }
         ExprType::InList(e) => in_list(
@@ -271,6 +281,7 @@ pub fn parse_physical_expr(
                 registry,
                 "expr",
                 input_schema,
+                codec,
             )?,
             parse_physical_exprs(&e.list, registry, input_schema, codec)?,
             &e.negated,
@@ -290,12 +301,14 @@ pub fn parse_physical_expr(
                             registry,
                             "when_expr",
                             input_schema,
+                            codec,
                         )?,
                         parse_required_physical_expr(
                             e.then_expr.as_ref(),
                             registry,
                             "then_expr",
                             input_schema,
+                            codec,
                         )?,
                     ))
                 })
@@ -311,6 +324,7 @@ pub fn parse_physical_expr(
                 registry,
                 "expr",
                 input_schema,
+                codec,
             )?,
             convert_required!(e.arrow_type)?,
             None,
@@ -321,6 +335,7 @@ pub fn parse_physical_expr(
                 registry,
                 "expr",
                 input_schema,
+                codec,
             )?,
             convert_required!(e.arrow_type)?,
         )),
@@ -371,12 +386,14 @@ pub fn parse_physical_expr(
                 registry,
                 "expr",
                 input_schema,
+                codec,
             )?,
             parse_required_physical_expr(
                 like_expr.pattern.as_deref(),
                 registry,
                 "pattern",
                 input_schema,
+                codec,
             )?,
         )),
     };
@@ -389,9 +406,9 @@ fn parse_required_physical_expr(
     registry: &dyn FunctionRegistry,
     field: &str,
     input_schema: &Schema,
+    codec: &dyn PhysicalExtensionCodec,
 ) -> Result<Arc<dyn PhysicalExpr>> {
-    let codec = DefaultPhysicalExtensionCodec {};
-    expr.map(|e| parse_physical_expr(e, registry, input_schema, &codec))
+    expr.map(|e| parse_physical_expr(e, registry, input_schema, codec))
         .transpose()?
         .ok_or_else(|| {
             DataFusionError::Internal(format!("Missing required field 
{field:?}"))
@@ -433,15 +450,15 @@ pub fn parse_protobuf_hash_partitioning(
     partitioning: Option<&protobuf::PhysicalHashRepartition>,
     registry: &dyn FunctionRegistry,
     input_schema: &Schema,
+    codec: &dyn PhysicalExtensionCodec,
 ) -> Result<Option<Partitioning>> {
     match partitioning {
         Some(hash_part) => {
-            let codec = DefaultPhysicalExtensionCodec {};
             let expr = parse_physical_exprs(
                 &hash_part.hash_expr,
                 registry,
                 input_schema,
-                &codec,
+                codec,
             )?;
 
             Ok(Some(Partitioning::Hash(
@@ -456,6 +473,7 @@ pub fn parse_protobuf_hash_partitioning(
 pub fn parse_protobuf_file_scan_config(
     proto: &protobuf::FileScanExecConf,
     registry: &dyn FunctionRegistry,
+    codec: &dyn PhysicalExtensionCodec,
 ) -> Result<FileScanConfig> {
     let schema: Arc<Schema> = Arc::new(convert_required!(proto.schema)?);
     let projection = proto
@@ -489,7 +507,7 @@ pub fn parse_protobuf_file_scan_config(
         .collect::<Result<Vec<_>>>()?;
 
     // Remove partition columns from the schema after recreating 
table_partition_cols
-    // because the partition columns are not in the file. They are present to 
allow the
+    // because the partition columns are not in the file. They are present to 
allow
     // the partition column types to be reconstructed after serde.
     let file_schema = Arc::new(Schema::new(
         schema
@@ -502,12 +520,11 @@ pub fn parse_protobuf_file_scan_config(
 
     let mut output_ordering = vec![];
     for node_collection in &proto.output_ordering {
-        let codec = DefaultPhysicalExtensionCodec {};
         let sort_expr = parse_physical_sort_exprs(
             &node_collection.physical_sort_expr_nodes,
             registry,
             &schema,
-            &codec,
+            codec,
         )?;
         output_ordering.push(sort_expr);
     }
diff --git a/datafusion/proto/src/physical_plan/mod.rs 
b/datafusion/proto/src/physical_plan/mod.rs
index 4d95c847bf..a481e7090f 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -19,22 +19,8 @@ use std::convert::TryInto;
 use std::fmt::Debug;
 use std::sync::Arc;
 
-use self::from_proto::parse_physical_window_expr;
-use self::to_proto::serialize_physical_expr;
-
-use crate::common::{byte_to_string, proto_error, str_to_byte};
-use crate::convert_required;
-use crate::physical_plan::from_proto::{
-    parse_physical_expr, parse_physical_sort_expr, parse_physical_sort_exprs,
-    parse_protobuf_file_scan_config,
-};
-use crate::protobuf::physical_aggregate_expr_node::AggregateFunction;
-use crate::protobuf::physical_expr_node::ExprType;
-use crate::protobuf::physical_plan_node::PhysicalPlanType;
-use crate::protobuf::repartition_exec_node::PartitionMethod;
-use crate::protobuf::{
-    self, window_agg_exec_node, PhysicalPlanNode, 
PhysicalSortExprNodeCollection,
-};
+use prost::bytes::BufMut;
+use prost::Message;
 
 use datafusion::arrow::compute::SortOptions;
 use datafusion::arrow::datatypes::SchemaRef;
@@ -79,13 +65,28 @@ use datafusion::physical_plan::{
 use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result};
 use datafusion_expr::ScalarUDF;
 
-use prost::bytes::BufMut;
-use prost::Message;
+use crate::common::{byte_to_string, proto_error, str_to_byte};
+use crate::convert_required;
+use crate::physical_plan::from_proto::{
+    parse_physical_expr, parse_physical_sort_expr, parse_physical_sort_exprs,
+    parse_physical_window_expr, parse_protobuf_file_scan_config,
+};
+use crate::physical_plan::to_proto::{
+    serialize_file_scan_config, serialize_maybe_filter, 
serialize_physical_aggr_expr,
+    serialize_physical_window_expr,
+};
+use crate::protobuf::physical_aggregate_expr_node::AggregateFunction;
+use crate::protobuf::physical_expr_node::ExprType;
+use crate::protobuf::physical_plan_node::PhysicalPlanType;
+use crate::protobuf::repartition_exec_node::PartitionMethod;
+use crate::protobuf::{self, window_agg_exec_node};
+
+use self::to_proto::serialize_physical_expr;
 
 pub mod from_proto;
 pub mod to_proto;
 
-impl AsExecutionPlan for PhysicalPlanNode {
+impl AsExecutionPlan for protobuf::PhysicalPlanNode {
     fn try_decode(buf: &[u8]) -> Result<Self>
     where
         Self: Sized,
@@ -191,6 +192,7 @@ impl AsExecutionPlan for PhysicalPlanNode {
                 parse_protobuf_file_scan_config(
                     scan.base_conf.as_ref().unwrap(),
                     registry,
+                    extension_codec,
                 )?,
                 scan.has_header,
                 str_to_byte(&scan.delimiter, "delimiter")?,
@@ -210,6 +212,7 @@ impl AsExecutionPlan for PhysicalPlanNode {
                 let base_config = parse_protobuf_file_scan_config(
                     scan.base_conf.as_ref().unwrap(),
                     registry,
+                    extension_codec,
                 )?;
                 let predicate = scan
                     .predicate
@@ -234,6 +237,7 @@ impl AsExecutionPlan for PhysicalPlanNode {
                 Ok(Arc::new(AvroExec::new(parse_protobuf_file_scan_config(
                     scan.base_conf.as_ref().unwrap(),
                     registry,
+                    extension_codec,
                 )?)))
             }
             PhysicalPlanType::CoalesceBatches(coalesce_batches) => {
@@ -338,6 +342,7 @@ impl AsExecutionPlan for PhysicalPlanNode {
                             window_expr,
                             registry,
                             input_schema.as_ref(),
+                            extension_codec,
                         )
                     })
                     .collect::<Result<Vec<_>, _>>()?;
@@ -1449,14 +1454,17 @@ impl AsExecutionPlan for PhysicalPlanNode {
             let filter = exec
                 .filter_expr()
                 .iter()
-                .map(|expr| expr.to_owned().try_into())
+                .map(|expr| serialize_maybe_filter(expr.to_owned(), 
extension_codec))
                 .collect::<Result<Vec<_>>>()?;
 
             let agg = exec
                 .aggr_expr()
                 .iter()
-                .map(|expr| expr.to_owned().try_into())
+                .map(|expr| {
+                    serialize_physical_aggr_expr(expr.to_owned(), 
extension_codec)
+                })
                 .collect::<Result<Vec<_>>>()?;
+
             let agg_names = exec
                 .aggr_expr()
                 .iter()
@@ -1556,7 +1564,10 @@ impl AsExecutionPlan for PhysicalPlanNode {
             return Ok(protobuf::PhysicalPlanNode {
                 physical_plan_type: Some(PhysicalPlanType::CsvScan(
                     protobuf::CsvScanExecNode {
-                        base_conf: Some(exec.base_config().try_into()?),
+                        base_conf: Some(serialize_file_scan_config(
+                            exec.base_config(),
+                            extension_codec,
+                        )?),
                         has_header: exec.has_header(),
                         delimiter: byte_to_string(exec.delimiter(), 
"delimiter")?,
                         quote: byte_to_string(exec.quote(), "quote")?,
@@ -1581,7 +1592,10 @@ impl AsExecutionPlan for PhysicalPlanNode {
             return Ok(protobuf::PhysicalPlanNode {
                 physical_plan_type: Some(PhysicalPlanType::ParquetScan(
                     protobuf::ParquetScanExecNode {
-                        base_conf: Some(exec.base_config().try_into()?),
+                        base_conf: Some(serialize_file_scan_config(
+                            exec.base_config(),
+                            extension_codec,
+                        )?),
                         predicate,
                     },
                 )),
@@ -1592,7 +1606,10 @@ impl AsExecutionPlan for PhysicalPlanNode {
             return Ok(protobuf::PhysicalPlanNode {
                 physical_plan_type: Some(PhysicalPlanType::AvroScan(
                     protobuf::AvroScanExecNode {
-                        base_conf: Some(exec.base_config().try_into()?),
+                        base_conf: Some(serialize_file_scan_config(
+                            exec.base_config(),
+                            extension_codec,
+                        )?),
                     },
                 )),
             });
@@ -1688,7 +1705,7 @@ impl AsExecutionPlan for PhysicalPlanNode {
         }
 
         if let Some(union) = plan.downcast_ref::<UnionExec>() {
-            let mut inputs: Vec<PhysicalPlanNode> = vec![];
+            let mut inputs: Vec<protobuf::PhysicalPlanNode> = vec![];
             for input in union.inputs() {
                 inputs.push(protobuf::PhysicalPlanNode::try_from_physical_plan(
                     input.to_owned(),
@@ -1703,7 +1720,7 @@ impl AsExecutionPlan for PhysicalPlanNode {
         }
 
         if let Some(interleave) = plan.downcast_ref::<InterleaveExec>() {
-            let mut inputs: Vec<PhysicalPlanNode> = vec![];
+            let mut inputs: Vec<protobuf::PhysicalPlanNode> = vec![];
             for input in interleave.inputs() {
                 inputs.push(protobuf::PhysicalPlanNode::try_from_physical_plan(
                     input.to_owned(),
@@ -1809,11 +1826,11 @@ impl AsExecutionPlan for PhysicalPlanNode {
                 extension_codec,
             )?;
 
-            let window_expr =
-                exec.window_expr()
-                    .iter()
-                    .map(|e| e.clone().try_into())
-                    
.collect::<Result<Vec<protobuf::PhysicalWindowExprNode>>>()?;
+            let window_expr = exec
+                .window_expr()
+                .iter()
+                .map(|e| serialize_physical_window_expr(e.clone(), 
extension_codec))
+                .collect::<Result<Vec<protobuf::PhysicalWindowExprNode>>>()?;
 
             let partition_keys = exec
                 .partition_keys
@@ -1839,11 +1856,11 @@ impl AsExecutionPlan for PhysicalPlanNode {
                 extension_codec,
             )?;
 
-            let window_expr =
-                exec.window_expr()
-                    .iter()
-                    .map(|e| e.clone().try_into())
-                    
.collect::<Result<Vec<protobuf::PhysicalWindowExprNode>>>()?;
+            let window_expr = exec
+                .window_expr()
+                .iter()
+                .map(|e| serialize_physical_window_expr(e.clone(), 
extension_codec))
+                .collect::<Result<Vec<protobuf::PhysicalWindowExprNode>>>()?;
 
             let partition_keys = exec
                 .partition_keys
@@ -1901,7 +1918,7 @@ impl AsExecutionPlan for PhysicalPlanNode {
                             Ok(sort_expr)
                         })
                         .collect::<Result<Vec<_>>>()?;
-                    Some(PhysicalSortExprNodeCollection {
+                    Some(protobuf::PhysicalSortExprNodeCollection {
                         physical_sort_expr_nodes: expr,
                     })
                 }
@@ -2044,7 +2061,7 @@ impl PhysicalExtensionCodec for 
DefaultPhysicalExtensionCodec {
 }
 
 fn into_physical_plan(
-    node: &Option<Box<PhysicalPlanNode>>,
+    node: &Option<Box<protobuf::PhysicalPlanNode>>,
     registry: &dyn FunctionRegistry,
     runtime: &RuntimeEnv,
     extension_codec: &dyn PhysicalExtensionCodec,
diff --git a/datafusion/proto/src/physical_plan/to_proto.rs 
b/datafusion/proto/src/physical_plan/to_proto.rs
index e1574f48fb..b4c23e4d0c 100644
--- a/datafusion/proto/src/physical_plan/to_proto.rs
+++ b/datafusion/proto/src/physical_plan/to_proto.rs
@@ -22,18 +22,8 @@ use std::{
     sync::Arc,
 };
 
-use crate::protobuf::{
-    self, copy_to_node, physical_aggregate_expr_node, 
physical_window_expr_node,
-    scalar_value::Value, ArrowOptions, AvroOptions, PhysicalSortExprNode,
-    PhysicalSortExprNodeCollection, ScalarValue,
-};
-
 #[cfg(feature = "parquet")]
 use datafusion::datasource::file_format::parquet::ParquetSink;
-
-use datafusion_expr::ScalarFunctionDefinition;
-
-use crate::logical_plan::csv_writer_options_to_proto;
 use datafusion::logical_expr::BuiltinScalarFunction;
 use datafusion::physical_expr::window::{NthValueKind, 
SlidingAggregateWindowExpr};
 use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr};
@@ -71,195 +61,187 @@ use datafusion_common::{
     stats::Precision,
     DataFusionError, JoinSide, Result,
 };
+use datafusion_expr::ScalarFunctionDefinition;
 
-use super::{DefaultPhysicalExtensionCodec, PhysicalExtensionCodec};
-
-impl TryFrom<Arc<dyn AggregateExpr>> for protobuf::PhysicalExprNode {
-    type Error = DataFusionError;
-
-    fn try_from(a: Arc<dyn AggregateExpr>) -> Result<Self, Self::Error> {
-        let codec = DefaultPhysicalExtensionCodec {};
-        let expressions = serialize_physical_exprs(a.expressions(), &codec)?;
-
-        let ordering_req = a.order_bys().unwrap_or(&[]).to_vec();
-        let ordering_req = serialize_physical_sort_exprs(ordering_req, 
&codec)?;
-
-        if let Some(a) = a.as_any().downcast_ref::<AggregateFunctionExpr>() {
-            let name = a.fun().name().to_string();
-            return Ok(protobuf::PhysicalExprNode {
-                    expr_type: 
Some(protobuf::physical_expr_node::ExprType::AggregateExpr(
-                        protobuf::PhysicalAggregateExprNode {
-                            aggregate_function: 
Some(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(name)),
-                            expr: expressions,
-                            ordering_req,
-                            distinct: false,
-                        },
-                    )),
-                });
-        }
+use crate::logical_plan::csv_writer_options_to_proto;
+use crate::protobuf::{
+    self, copy_to_node, physical_aggregate_expr_node, 
physical_window_expr_node,
+    scalar_value::Value, ArrowOptions, AvroOptions, PhysicalSortExprNode,
+    PhysicalSortExprNodeCollection, ScalarValue,
+};
 
-        let AggrFn {
-            inner: aggr_function,
-            distinct,
-        } = aggr_expr_to_aggr_fn(a.as_ref())?;
+use super::PhysicalExtensionCodec;
 
-        Ok(protobuf::PhysicalExprNode {
+pub fn serialize_physical_aggr_expr(
+    aggr_expr: Arc<dyn AggregateExpr>,
+    codec: &dyn PhysicalExtensionCodec,
+) -> Result<protobuf::PhysicalExprNode> {
+    let expressions = serialize_physical_exprs(aggr_expr.expressions(), 
codec)?;
+    let ordering_req = aggr_expr.order_bys().unwrap_or(&[]).to_vec();
+    let ordering_req = serialize_physical_sort_exprs(ordering_req, codec)?;
+
+    if let Some(a) = 
aggr_expr.as_any().downcast_ref::<AggregateFunctionExpr>() {
+        let name = a.fun().name().to_string();
+        return Ok(protobuf::PhysicalExprNode {
             expr_type: 
Some(protobuf::physical_expr_node::ExprType::AggregateExpr(
                 protobuf::PhysicalAggregateExprNode {
-                    aggregate_function: Some(
-                        
physical_aggregate_expr_node::AggregateFunction::AggrFunction(
-                            aggr_function as i32,
-                        ),
-                    ),
+                    aggregate_function: 
Some(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(name)),
                     expr: expressions,
                     ordering_req,
-                    distinct,
+                    distinct: false,
                 },
             )),
-        })
+        });
     }
-}
-
-impl TryFrom<Arc<dyn WindowExpr>> for protobuf::PhysicalWindowExprNode {
-    type Error = DataFusionError;
 
-    fn try_from(
-        window_expr: Arc<dyn WindowExpr>,
-    ) -> std::result::Result<Self, Self::Error> {
-        let expr = window_expr.as_any();
+    let AggrFn {
+        inner: aggr_function,
+        distinct,
+    } = aggr_expr_to_aggr_fn(aggr_expr.as_ref())?;
+
+    Ok(protobuf::PhysicalExprNode {
+        expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateExpr(
+            protobuf::PhysicalAggregateExprNode {
+                aggregate_function: Some(
+                    
physical_aggregate_expr_node::AggregateFunction::AggrFunction(
+                        aggr_function as i32,
+                    ),
+                ),
+                expr: expressions,
+                ordering_req,
+                distinct,
+            },
+        )),
+    })
+}
 
-        let mut args = window_expr.expressions().to_vec();
-        let window_frame = window_expr.get_window_frame();
+pub fn serialize_physical_window_expr(
+    window_expr: Arc<dyn WindowExpr>,
+    codec: &dyn PhysicalExtensionCodec,
+) -> Result<protobuf::PhysicalWindowExprNode> {
+    let expr = window_expr.as_any();
+    let mut args = window_expr.expressions().to_vec();
+    let window_frame = window_expr.get_window_frame();
 
-        let window_function = if let Some(built_in_window_expr) =
-            expr.downcast_ref::<BuiltInWindowExpr>()
+    let window_function = if let Some(built_in_window_expr) =
+        expr.downcast_ref::<BuiltInWindowExpr>()
+    {
+        let expr = built_in_window_expr.get_built_in_func_expr();
+        let built_in_fn_expr = expr.as_any();
+
+        let builtin_fn = if 
built_in_fn_expr.downcast_ref::<RowNumber>().is_some() {
+            protobuf::BuiltInWindowFunction::RowNumber
+        } else if let Some(rank_expr) = 
built_in_fn_expr.downcast_ref::<Rank>() {
+            match rank_expr.get_type() {
+                RankType::Basic => protobuf::BuiltInWindowFunction::Rank,
+                RankType::Dense => protobuf::BuiltInWindowFunction::DenseRank,
+                RankType::Percent => 
protobuf::BuiltInWindowFunction::PercentRank,
+            }
+        } else if built_in_fn_expr.downcast_ref::<CumeDist>().is_some() {
+            protobuf::BuiltInWindowFunction::CumeDist
+        } else if let Some(ntile_expr) = 
built_in_fn_expr.downcast_ref::<Ntile>() {
+            args.insert(
+                0,
+                
Arc::new(Literal::new(datafusion_common::ScalarValue::Int64(Some(
+                    ntile_expr.get_n() as i64,
+                )))),
+            );
+            protobuf::BuiltInWindowFunction::Ntile
+        } else if let Some(window_shift_expr) =
+            built_in_fn_expr.downcast_ref::<WindowShift>()
         {
-            let expr = built_in_window_expr.get_built_in_func_expr();
-            let built_in_fn_expr = expr.as_any();
-
-            let builtin_fn = if 
built_in_fn_expr.downcast_ref::<RowNumber>().is_some() {
-                protobuf::BuiltInWindowFunction::RowNumber
-            } else if let Some(rank_expr) = 
built_in_fn_expr.downcast_ref::<Rank>() {
-                match rank_expr.get_type() {
-                    RankType::Basic => protobuf::BuiltInWindowFunction::Rank,
-                    RankType::Dense => 
protobuf::BuiltInWindowFunction::DenseRank,
-                    RankType::Percent => 
protobuf::BuiltInWindowFunction::PercentRank,
-                }
-            } else if built_in_fn_expr.downcast_ref::<CumeDist>().is_some() {
-                protobuf::BuiltInWindowFunction::CumeDist
-            } else if let Some(ntile_expr) = 
built_in_fn_expr.downcast_ref::<Ntile>() {
-                args.insert(
-                    0,
-                    
Arc::new(Literal::new(datafusion_common::ScalarValue::Int64(Some(
-                        ntile_expr.get_n() as i64,
-                    )))),
-                );
-                protobuf::BuiltInWindowFunction::Ntile
-            } else if let Some(window_shift_expr) =
-                built_in_fn_expr.downcast_ref::<WindowShift>()
-            {
-                args.insert(
-                    1,
-                    
Arc::new(Literal::new(datafusion_common::ScalarValue::Int64(Some(
-                        window_shift_expr.get_shift_offset(),
-                    )))),
-                );
-                args.insert(
-                    2,
-                    
Arc::new(Literal::new(window_shift_expr.get_default_value())),
-                );
-
-                if window_shift_expr.get_shift_offset() >= 0 {
-                    protobuf::BuiltInWindowFunction::Lag
-                } else {
-                    protobuf::BuiltInWindowFunction::Lead
-                }
-            } else if let Some(nth_value_expr) =
-                built_in_fn_expr.downcast_ref::<NthValue>()
-            {
-                match nth_value_expr.get_kind() {
-                    NthValueKind::First => 
protobuf::BuiltInWindowFunction::FirstValue,
-                    NthValueKind::Last => 
protobuf::BuiltInWindowFunction::LastValue,
-                    NthValueKind::Nth(n) => {
-                        args.insert(
-                            1,
-                            Arc::new(Literal::new(
-                                datafusion_common::ScalarValue::Int64(Some(n)),
-                            )),
-                        );
-                        protobuf::BuiltInWindowFunction::NthValue
-                    }
-                }
+            args.insert(
+                1,
+                
Arc::new(Literal::new(datafusion_common::ScalarValue::Int64(Some(
+                    window_shift_expr.get_shift_offset(),
+                )))),
+            );
+            args.insert(
+                2,
+                Arc::new(Literal::new(window_shift_expr.get_default_value())),
+            );
+
+            if window_shift_expr.get_shift_offset() >= 0 {
+                protobuf::BuiltInWindowFunction::Lag
             } else {
-                return not_impl_err!("BuiltIn function not supported: 
{expr:?}");
-            };
-
-            
physical_window_expr_node::WindowFunction::BuiltInFunction(builtin_fn as i32)
-        } else if let Some(plain_aggr_window_expr) =
-            expr.downcast_ref::<PlainAggregateWindowExpr>()
-        {
-            let AggrFn { inner, distinct } = aggr_expr_to_aggr_fn(
-                plain_aggr_window_expr.get_aggregate_expr().as_ref(),
-            )?;
-
-            if distinct {
-                // TODO
-                return not_impl_err!(
-                    "Distinct aggregate functions not supported in window 
expressions"
-                );
+                protobuf::BuiltInWindowFunction::Lead
             }
-
-            if !window_frame.start_bound.is_unbounded() {
-                return Err(DataFusionError::Internal(format!("Invalid 
PlainAggregateWindowExpr = {window_expr:?} with WindowFrame = 
{window_frame:?}")));
+        } else if let Some(nth_value_expr) = 
built_in_fn_expr.downcast_ref::<NthValue>() {
+            match nth_value_expr.get_kind() {
+                NthValueKind::First => 
protobuf::BuiltInWindowFunction::FirstValue,
+                NthValueKind::Last => 
protobuf::BuiltInWindowFunction::LastValue,
+                NthValueKind::Nth(n) => {
+                    args.insert(
+                        1,
+                        
Arc::new(Literal::new(datafusion_common::ScalarValue::Int64(
+                            Some(n),
+                        ))),
+                    );
+                    protobuf::BuiltInWindowFunction::NthValue
+                }
             }
+        } else {
+            return not_impl_err!("BuiltIn function not supported: {expr:?}");
+        };
 
-            physical_window_expr_node::WindowFunction::AggrFunction(inner as 
i32)
-        } else if let Some(sliding_aggr_window_expr) =
-            expr.downcast_ref::<SlidingAggregateWindowExpr>()
-        {
-            let AggrFn { inner, distinct } = aggr_expr_to_aggr_fn(
-                sliding_aggr_window_expr.get_aggregate_expr().as_ref(),
-            )?;
-
-            if distinct {
-                // TODO
-                return not_impl_err!(
-                    "Distinct aggregate functions not supported in window 
expressions"
-                );
-            }
+        physical_window_expr_node::WindowFunction::BuiltInFunction(builtin_fn 
as i32)
+    } else if let Some(plain_aggr_window_expr) =
+        expr.downcast_ref::<PlainAggregateWindowExpr>()
+    {
+        let AggrFn { inner, distinct } =
+            
aggr_expr_to_aggr_fn(plain_aggr_window_expr.get_aggregate_expr().as_ref())?;
+
+        if distinct {
+            // TODO
+            return not_impl_err!(
+                "Distinct aggregate functions not supported in window 
expressions"
+            );
+        }
 
-            if window_frame.start_bound.is_unbounded() {
-                return Err(DataFusionError::Internal(format!("Invalid 
SlidingAggregateWindowExpr = {window_expr:?} with WindowFrame = 
{window_frame:?}")));
-            }
+        if !window_frame.start_bound.is_unbounded() {
+            return Err(DataFusionError::Internal(format!("Invalid 
PlainAggregateWindowExpr = {window_expr:?} with WindowFrame = 
{window_frame:?}")));
+        }
 
-            physical_window_expr_node::WindowFunction::AggrFunction(inner as 
i32)
-        } else {
-            return not_impl_err!("WindowExpr not supported: {window_expr:?}");
-        };
-        let codec = DefaultPhysicalExtensionCodec {};
-        let args = serialize_physical_exprs(args, &codec)?;
-        let partition_by =
-            serialize_physical_exprs(window_expr.partition_by().to_vec(), 
&codec)?;
+        physical_window_expr_node::WindowFunction::AggrFunction(inner as i32)
+    } else if let Some(sliding_aggr_window_expr) =
+        expr.downcast_ref::<SlidingAggregateWindowExpr>()
+    {
+        let AggrFn { inner, distinct } =
+            
aggr_expr_to_aggr_fn(sliding_aggr_window_expr.get_aggregate_expr().as_ref())?;
+
+        if distinct {
+            // TODO
+            return not_impl_err!(
+                "Distinct aggregate functions not supported in window 
expressions"
+            );
+        }
 
-        let order_by =
-            serialize_physical_sort_exprs(window_expr.order_by().to_vec(), 
&codec)?;
+        if window_frame.start_bound.is_unbounded() {
+            return Err(DataFusionError::Internal(format!("Invalid 
SlidingAggregateWindowExpr = {window_expr:?} with WindowFrame = 
{window_frame:?}")));
+        }
 
-        let window_frame: protobuf::WindowFrame = window_frame
-            .as_ref()
-            .try_into()
-            .map_err(|e| DataFusionError::Internal(format!("{e}")))?;
-
-        let name = window_expr.name().to_string();
-
-        Ok(protobuf::PhysicalWindowExprNode {
-            args,
-            partition_by,
-            order_by,
-            window_frame: Some(window_frame),
-            window_function: Some(window_function),
-            name,
-        })
-    }
+        physical_window_expr_node::WindowFunction::AggrFunction(inner as i32)
+    } else {
+        return not_impl_err!("WindowExpr not supported: {window_expr:?}");
+    };
+
+    let args = serialize_physical_exprs(args, codec)?;
+    let partition_by =
+        serialize_physical_exprs(window_expr.partition_by().to_vec(), codec)?;
+    let order_by = 
serialize_physical_sort_exprs(window_expr.order_by().to_vec(), codec)?;
+    let window_frame: protobuf::WindowFrame = window_frame
+        .as_ref()
+        .try_into()
+        .map_err(|e| DataFusionError::Internal(format!("{e}")))?;
+
+    Ok(protobuf::PhysicalWindowExprNode {
+        args,
+        partition_by,
+        order_by,
+        window_frame: Some(window_frame),
+        window_function: Some(window_function),
+        name: window_expr.name().to_string(),
+    })
 }
 
 struct AggrFn {
@@ -366,7 +348,7 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> 
Result<AggrFn> {
 pub fn serialize_physical_sort_exprs<I>(
     sort_exprs: I,
     codec: &dyn PhysicalExtensionCodec,
-) -> Result<Vec<protobuf::PhysicalSortExprNode>, DataFusionError>
+) -> Result<Vec<PhysicalSortExprNode>>
 where
     I: IntoIterator<Item = PhysicalSortExpr>,
 {
@@ -379,7 +361,7 @@ where
 pub fn serialize_physical_sort_expr(
     sort_expr: PhysicalSortExpr,
     codec: &dyn PhysicalExtensionCodec,
-) -> Result<protobuf::PhysicalSortExprNode, DataFusionError> {
+) -> Result<PhysicalSortExprNode> {
     let PhysicalSortExpr { expr, options } = sort_expr;
     let expr = serialize_physical_expr(expr, codec)?;
     Ok(PhysicalSortExprNode {
@@ -392,7 +374,7 @@ pub fn serialize_physical_sort_expr(
 pub fn serialize_physical_exprs<I>(
     values: I,
     codec: &dyn PhysicalExtensionCodec,
-) -> Result<Vec<protobuf::PhysicalExprNode>, DataFusionError>
+) -> Result<Vec<protobuf::PhysicalExprNode>>
 where
     I: IntoIterator<Item = Arc<dyn PhysicalExpr>>,
 {
@@ -409,7 +391,7 @@ where
 pub fn serialize_physical_expr(
     value: Arc<dyn PhysicalExpr>,
     codec: &dyn PhysicalExtensionCodec,
-) -> Result<protobuf::PhysicalExprNode, DataFusionError> {
+) -> Result<protobuf::PhysicalExprNode> {
     let expr = value.as_any();
 
     if let Some(expr) = expr.downcast_ref::<Column>() {
@@ -456,7 +438,7 @@ pub fn serialize_physical_expr(
                                 .when_then_expr()
                                 .iter()
                                 .map(|(when_expr, then_expr)| {
-                                    try_parse_when_then_expr(when_expr, 
then_expr, codec)
+                                    serialize_when_then_expr(when_expr, 
then_expr, codec)
                                 })
                                 .collect::<Result<
                                     Vec<protobuf::PhysicalWhenThen>,
@@ -623,7 +605,7 @@ pub fn serialize_physical_expr(
     }
 }
 
-fn try_parse_when_then_expr(
+fn serialize_when_then_expr(
     when_expr: &Arc<dyn PhysicalExpr>,
     then_expr: &Arc<dyn PhysicalExpr>,
     codec: &dyn PhysicalExtensionCodec,
@@ -637,7 +619,7 @@ fn try_parse_when_then_expr(
 impl TryFrom<&PartitionedFile> for protobuf::PartitionedFile {
     type Error = DataFusionError;
 
-    fn try_from(pf: &PartitionedFile) -> Result<Self, Self::Error> {
+    fn try_from(pf: &PartitionedFile) -> Result<Self> {
         let last_modified = pf.object_meta.last_modified;
         let last_modified_ns = 
last_modified.timestamp_nanos_opt().ok_or_else(|| {
             DataFusionError::Plan(format!(
@@ -661,7 +643,7 @@ impl TryFrom<&PartitionedFile> for 
protobuf::PartitionedFile {
 impl TryFrom<&FileRange> for protobuf::FileRange {
     type Error = DataFusionError;
 
-    fn try_from(value: &FileRange) -> Result<Self, Self::Error> {
+    fn try_from(value: &FileRange) -> Result<Self> {
         Ok(protobuf::FileRange {
             start: value.start,
             end: value.end,
@@ -746,61 +728,58 @@ impl From<&ColumnStatistics> for protobuf::ColumnStats {
     }
 }
 
-impl TryFrom<&FileScanConfig> for protobuf::FileScanExecConf {
-    type Error = DataFusionError;
-    fn try_from(
-        conf: &FileScanConfig,
-    ) -> Result<protobuf::FileScanExecConf, Self::Error> {
-        let codec = DefaultPhysicalExtensionCodec {};
-        let file_groups = conf
-            .file_groups
-            .iter()
-            .map(|p| p.as_slice().try_into())
-            .collect::<Result<Vec<_>, _>>()?;
-
-        let mut output_orderings = vec![];
-        for order in &conf.output_ordering {
-            let ordering = serialize_physical_sort_exprs(order.to_vec(), 
&codec)?;
-            output_orderings.push(ordering)
-        }
+pub fn serialize_file_scan_config(
+    conf: &FileScanConfig,
+    codec: &dyn PhysicalExtensionCodec,
+) -> Result<protobuf::FileScanExecConf> {
+    let file_groups = conf
+        .file_groups
+        .iter()
+        .map(|p| p.as_slice().try_into())
+        .collect::<Result<Vec<_>, _>>()?;
+
+    let mut output_orderings = vec![];
+    for order in &conf.output_ordering {
+        let ordering = serialize_physical_sort_exprs(order.to_vec(), codec)?;
+        output_orderings.push(ordering)
+    }
 
-        // Fields must be added to the schema so that they can persist in the 
protobuf
-        // and then they are to be removed from the schema in 
`parse_protobuf_file_scan_config`
-        let mut fields = conf
-            .file_schema
-            .fields()
+    // Fields must be added to the schema so that they can persist in the 
protobuf,
+    // and then they are to be removed from the schema in 
`parse_protobuf_file_scan_config`
+    let mut fields = conf
+        .file_schema
+        .fields()
+        .iter()
+        .cloned()
+        .collect::<Vec<_>>();
+    fields.extend(conf.table_partition_cols.iter().cloned().map(Arc::new));
+    let schema = Arc::new(arrow::datatypes::Schema::new(fields.clone()));
+
+    Ok(protobuf::FileScanExecConf {
+        file_groups,
+        statistics: Some((&conf.statistics).into()),
+        limit: conf.limit.map(|l| protobuf::ScanLimit { limit: l as u32 }),
+        projection: conf
+            .projection
+            .as_ref()
+            .unwrap_or(&vec![])
             .iter()
-            .cloned()
-            .collect::<Vec<_>>();
-        fields.extend(conf.table_partition_cols.iter().cloned().map(Arc::new));
-        let schema = 
Arc::new(datafusion::arrow::datatypes::Schema::new(fields.clone()));
-
-        Ok(protobuf::FileScanExecConf {
-            file_groups,
-            statistics: Some((&conf.statistics).into()),
-            limit: conf.limit.map(|l| protobuf::ScanLimit { limit: l as u32 }),
-            projection: conf
-                .projection
-                .as_ref()
-                .unwrap_or(&vec![])
-                .iter()
-                .map(|n| *n as u32)
-                .collect(),
-            schema: Some(schema.as_ref().try_into()?),
-            table_partition_cols: conf
-                .table_partition_cols
-                .iter()
-                .map(|x| x.name().clone())
-                .collect::<Vec<_>>(),
-            object_store_url: conf.object_store_url.to_string(),
-            output_ordering: output_orderings
-                .into_iter()
-                .map(|e| PhysicalSortExprNodeCollection {
-                    physical_sort_expr_nodes: e,
-                })
-                .collect::<Vec<_>>(),
-        })
-    }
+            .map(|n| *n as u32)
+            .collect(),
+        schema: Some(schema.as_ref().try_into()?),
+        table_partition_cols: conf
+            .table_partition_cols
+            .iter()
+            .map(|x| x.name().clone())
+            .collect::<Vec<_>>(),
+        object_store_url: conf.object_store_url.to_string(),
+        output_ordering: output_orderings
+            .into_iter()
+            .map(|e| PhysicalSortExprNodeCollection {
+                physical_sort_expr_nodes: e,
+            })
+            .collect::<Vec<_>>(),
+    })
 }
 
 impl From<JoinSide> for protobuf::JoinSide {
@@ -812,46 +791,15 @@ impl From<JoinSide> for protobuf::JoinSide {
     }
 }
 
-impl TryFrom<Option<Arc<dyn PhysicalExpr>>> for protobuf::MaybeFilter {
-    type Error = DataFusionError;
-
-    fn try_from(expr: Option<Arc<dyn PhysicalExpr>>) -> Result<Self, 
Self::Error> {
-        let codec = DefaultPhysicalExtensionCodec {};
-        match expr {
-            None => Ok(protobuf::MaybeFilter { expr: None }),
-            Some(expr) => Ok(protobuf::MaybeFilter {
-                expr: Some(serialize_physical_expr(expr, &codec)?),
-            }),
-        }
-    }
-}
-
-impl TryFrom<Option<Vec<PhysicalSortExpr>>> for 
protobuf::MaybePhysicalSortExprs {
-    type Error = DataFusionError;
-
-    fn try_from(sort_exprs: Option<Vec<PhysicalSortExpr>>) -> Result<Self, 
Self::Error> {
-        match sort_exprs {
-            None => Ok(protobuf::MaybePhysicalSortExprs { sort_expr: vec![] }),
-            Some(sort_exprs) => Ok(protobuf::MaybePhysicalSortExprs {
-                sort_expr: sort_exprs
-                    .into_iter()
-                    .map(|sort_expr| sort_expr.try_into())
-                    .collect::<Result<Vec<_>>>()?,
-            }),
-        }
-    }
-}
-
-impl TryFrom<PhysicalSortExpr> for protobuf::PhysicalSortExprNode {
-    type Error = DataFusionError;
-
-    fn try_from(sort_expr: PhysicalSortExpr) -> std::result::Result<Self, 
Self::Error> {
-        let codec = DefaultPhysicalExtensionCodec {};
-        Ok(PhysicalSortExprNode {
-            expr: Some(Box::new(serialize_physical_expr(sort_expr.expr, 
&codec)?)),
-            asc: !sort_expr.options.descending,
-            nulls_first: sort_expr.options.nulls_first,
-        })
+pub fn serialize_maybe_filter(
+    expr: Option<Arc<dyn PhysicalExpr>>,
+    codec: &dyn PhysicalExtensionCodec,
+) -> Result<protobuf::MaybeFilter> {
+    match expr {
+        None => Ok(protobuf::MaybeFilter { expr: None }),
+        Some(expr) => Ok(protobuf::MaybeFilter {
+            expr: Some(serialize_physical_expr(expr, codec)?),
+        }),
     }
 }
 
diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs 
b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
index f97cfea765..642860d639 100644
--- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
@@ -21,6 +21,8 @@ use std::sync::Arc;
 use std::vec;
 
 use arrow::csv::WriterBuilder;
+use prost::Message;
+
 use datafusion::arrow::array::ArrayRef;
 use datafusion::arrow::compute::kernels::sort::SortOptions;
 use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema};
@@ -35,7 +37,7 @@ use datafusion::datasource::physical_plan::{
 };
 use datafusion::execution::FunctionRegistry;
 use datafusion::logical_expr::{create_udf, JoinType, Operator, Volatility};
-use datafusion::physical_expr::expressions::NthValueAgg;
+use datafusion::physical_expr::expressions::{Count, Max, NthValueAgg};
 use datafusion::physical_expr::window::SlidingAggregateWindowExpr;
 use datafusion::physical_expr::{PhysicalSortRequirement, ScalarFunctionExpr};
 use datafusion::physical_plan::aggregates::{
@@ -77,19 +79,18 @@ use datafusion_expr::{
     ScalarFunctionDefinition, ScalarUDF, ScalarUDFImpl, Signature, 
SimpleAggregateUDF,
     WindowFrame, WindowFrameBound,
 };
-use datafusion_proto::physical_plan::from_proto::parse_physical_expr;
-use datafusion_proto::physical_plan::to_proto::serialize_physical_expr;
 use datafusion_proto::physical_plan::{
     AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec,
 };
 use datafusion_proto::protobuf;
-use prost::Message;
 
 /// Perform a serde roundtrip and assert that the string representation of the 
before and after plans
 /// are identical. Note that this often isn't sufficient to guarantee that no 
information is
 /// lost during serde because the string representation of a plan often only 
shows a subset of state.
 fn roundtrip_test(exec_plan: Arc<dyn ExecutionPlan>) -> Result<()> {
-    let _ = roundtrip_test_and_return(exec_plan);
+    let ctx = SessionContext::new();
+    let codec = DefaultPhysicalExtensionCodec {};
+    roundtrip_test_and_return(exec_plan, &ctx, &codec)?;
     Ok(())
 }
 
@@ -101,15 +102,15 @@ fn roundtrip_test(exec_plan: Arc<dyn ExecutionPlan>) -> 
Result<()> {
 /// farther in tests.
 fn roundtrip_test_and_return(
     exec_plan: Arc<dyn ExecutionPlan>,
+    ctx: &SessionContext,
+    codec: &dyn PhysicalExtensionCodec,
 ) -> Result<Arc<dyn ExecutionPlan>> {
-    let ctx = SessionContext::new();
-    let codec = DefaultPhysicalExtensionCodec {};
     let proto: protobuf::PhysicalPlanNode =
-        protobuf::PhysicalPlanNode::try_from_physical_plan(exec_plan.clone(), 
&codec)
+        protobuf::PhysicalPlanNode::try_from_physical_plan(exec_plan.clone(), 
codec)
             .expect("to proto");
     let runtime = ctx.runtime_env();
     let result_exec_plan: Arc<dyn ExecutionPlan> = proto
-        .try_into_physical_plan(&ctx, runtime.deref(), &codec)
+        .try_into_physical_plan(ctx, runtime.deref(), codec)
         .expect("from proto");
     assert_eq!(format!("{exec_plan:?}"), format!("{result_exec_plan:?}"));
     Ok(result_exec_plan)
@@ -123,17 +124,10 @@ fn roundtrip_test_and_return(
 /// performing serde on some plans.
 fn roundtrip_test_with_context(
     exec_plan: Arc<dyn ExecutionPlan>,
-    ctx: SessionContext,
+    ctx: &SessionContext,
 ) -> Result<()> {
     let codec = DefaultPhysicalExtensionCodec {};
-    let proto: protobuf::PhysicalPlanNode =
-        protobuf::PhysicalPlanNode::try_from_physical_plan(exec_plan.clone(), 
&codec)
-            .expect("to proto");
-    let runtime = ctx.runtime_env();
-    let result_exec_plan: Arc<dyn ExecutionPlan> = proto
-        .try_into_physical_plan(&ctx, runtime.deref(), &codec)
-        .expect("from proto");
-    assert_eq!(format!("{exec_plan:?}"), format!("{result_exec_plan:?}"));
+    roundtrip_test_and_return(exec_plan, ctx, &codec)?;
     Ok(())
 }
 
@@ -444,7 +438,7 @@ fn roundtrip_aggregate_udaf() -> Result<()> {
             Arc::new(EmptyExec::new(schema.clone())),
             schema,
         )?),
-        ctx,
+        &ctx,
     )
 }
 
@@ -642,11 +636,11 @@ fn roundtrip_scalar_udf() -> Result<()> {
 
     ctx.register_udf(udf);
 
-    roundtrip_test_with_context(Arc::new(project), ctx)
+    roundtrip_test_with_context(Arc::new(project), &ctx)
 }
 
 #[test]
-fn roundtrip_scalar_udf_extension_codec() {
+fn roundtrip_scalar_udf_extension_codec() -> Result<()> {
     #[derive(Debug)]
     struct MyRegexUdf {
         signature: Signature,
@@ -657,11 +651,7 @@ fn roundtrip_scalar_udf_extension_codec() {
     impl MyRegexUdf {
         fn new(pattern: String) -> Self {
             Self {
-                signature: Signature::uniform(
-                    1,
-                    vec![DataType::Int32],
-                    Volatility::Immutable,
-                ),
+                signature: Signature::exact(vec![DataType::Utf8], 
Volatility::Immutable),
                 pattern,
             }
         }
@@ -672,18 +662,22 @@ fn roundtrip_scalar_udf_extension_codec() {
         fn as_any(&self) -> &dyn Any {
             self
         }
+
         fn name(&self) -> &str {
             "regex_udf"
         }
+
         fn signature(&self) -> &Signature {
             &self.signature
         }
+
         fn return_type(&self, args: &[DataType]) -> Result<DataType> {
             if !matches!(args.first(), Some(&DataType::Utf8)) {
                 return plan_err!("regex_udf only accepts Utf8 arguments");
             }
-            Ok(DataType::Int32)
+            Ok(DataType::Int64)
         }
+
         fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
             unimplemented!()
         }
@@ -747,32 +741,58 @@ fn roundtrip_scalar_udf_extension_codec() {
         }
     }
 
+    let field_text = Field::new("text", DataType::Utf8, true);
+    let field_published = Field::new("published", DataType::Boolean, false);
+    let field_author = Field::new("author", DataType::Utf8, false);
+    let schema = Arc::new(Schema::new(vec![field_text, field_published, 
field_author]));
+    let input = Arc::new(EmptyExec::new(schema.clone()));
+
     let pattern = ".*";
     let udf = ScalarUDF::from(MyRegexUdf::new(pattern.to_string()));
-    let test_expr = ScalarFunctionExpr::new(
+    let udf_expr = Arc::new(ScalarFunctionExpr::new(
         udf.name(),
         ScalarFunctionDefinition::UDF(Arc::new(udf.clone())),
-        vec![],
-        DataType::Int32,
+        vec![col("text", &schema)?],
+        DataType::Int64,
         None,
         false,
-    );
-    let fmt_expr = format!("{test_expr:?}");
-    let ctx = SessionContext::new();
+    ));
 
-    ctx.register_udf(udf.clone());
-    let extension_codec = ScalarUDFExtensionCodec {};
-    let proto: protobuf::PhysicalExprNode =
-        match serialize_physical_expr(Arc::new(test_expr), &extension_codec) {
-            Ok(proto) => proto,
-            Err(e) => panic!("failed to serialize expr: {e:?}"),
-        };
-    let field_a = Field::new("a", DataType::Int32, false);
-    let schema = Arc::new(Schema::new(vec![field_a]));
-    let round_trip =
-        parse_physical_expr(&proto, &ctx, &schema, &extension_codec).unwrap();
-    assert_eq!(fmt_expr, format!("{round_trip:?}"));
+    let filter = Arc::new(FilterExec::try_new(
+        Arc::new(BinaryExpr::new(
+            col("published", &schema)?,
+            Operator::And,
+            Arc::new(BinaryExpr::new(udf_expr.clone(), Operator::Gt, lit(0))),
+        )),
+        input,
+    )?);
+
+    let window = Arc::new(WindowAggExec::try_new(
+        vec![Arc::new(PlainAggregateWindowExpr::new(
+            Arc::new(Max::new(udf_expr.clone(), "max", DataType::Int64)),
+            &[col("author", &schema)?],
+            &[],
+            Arc::new(WindowFrame::new(None)),
+        ))],
+        filter,
+        vec![col("author", &schema)?],
+    )?);
+
+    let aggregate = Arc::new(AggregateExec::try_new(
+        AggregateMode::Final,
+        PhysicalGroupBy::new(vec![], vec![], vec![]),
+        vec![Arc::new(Count::new(udf_expr, "count", DataType::Int64))],
+        vec![None],
+        window,
+        schema.clone(),
+    )?);
+
+    let ctx = SessionContext::new();
+    let codec = ScalarUDFExtensionCodec {};
+    roundtrip_test_and_return(aggregate, &ctx, &codec)?;
+    Ok(())
 }
+
 #[test]
 fn roundtrip_distinct_count() -> Result<()> {
     let field_a = Field::new("a", DataType::Int64, false);
@@ -896,12 +916,18 @@ fn roundtrip_csv_sink() -> Result<()> {
         }),
     )];
 
-    let roundtrip_plan = roundtrip_test_and_return(Arc::new(DataSinkExec::new(
-        input,
-        data_sink,
-        schema.clone(),
-        Some(sort_order),
-    )))
+    let ctx = SessionContext::new();
+    let codec = DefaultPhysicalExtensionCodec {};
+    let roundtrip_plan = roundtrip_test_and_return(
+        Arc::new(DataSinkExec::new(
+            input,
+            data_sink,
+            schema.clone(),
+            Some(sort_order),
+        )),
+        &ctx,
+        &codec,
+    )
     .unwrap();
 
     let roundtrip_plan = roundtrip_plan

Reply via email to