This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new a6d4798c45 Fixes 3 bugs during serialization and deserialization of 
physical plans (#16858)
a6d4798c45 is described below

commit a6d4798c45a4e6ca95049a2fcdf61b9490aa8e17
Author: Nga Tran <[email protected]>
AuthorDate: Fri Jul 25 10:32:20 2025 -0400

    Fixes 3 bugs during serialization and deserialization of physical plans 
(#16858)
---
 Cargo.lock                                         | 23 +++++++
 datafusion/core/src/physical_planner.rs            |  3 +
 datafusion/ffi/src/udaf/accumulator_args.rs        |  1 +
 datafusion/physical-plan/src/aggregates/mod.rs     |  9 ++-
 datafusion/proto/Cargo.toml                        |  1 +
 datafusion/proto/proto/datafusion.proto            |  2 +
 datafusion/proto/src/generated/pbjson.rs           | 36 ++++++++++
 datafusion/proto/src/generated/prost.rs            |  4 ++
 datafusion/proto/src/physical_plan/from_proto.rs   |  7 +-
 datafusion/proto/src/physical_plan/mod.rs          |  1 +
 datafusion/proto/src/physical_plan/to_proto.rs     |  6 ++
 .../proto/tests/cases/roundtrip_physical_plan.rs   | 77 ++++++++++++++++++++--
 12 files changed, 163 insertions(+), 7 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index 45a8333a1e..7fcda2957e 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -2542,6 +2542,7 @@ dependencies = [
  "doc-comment",
  "object_store",
  "pbjson",
+ "pretty_assertions",
  "prost",
  "serde",
  "serde_json",
@@ -2727,6 +2728,12 @@ dependencies = [
  "serde",
 ]
 
+[[package]]
+name = "diff"
+version = "0.1.13"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8"
+
 [[package]]
 name = "difflib"
 version = "0.4.0"
@@ -4816,6 +4823,16 @@ dependencies = [
  "termtree",
 ]
 
+[[package]]
+name = "pretty_assertions"
+version = "1.4.1"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "3ae130e2f271fbc2ac3a40fb1d07180839cdbbe443c7a27e1e3c13c5cac0116d"
+dependencies = [
+ "diff",
+ "yansi",
+]
+
 [[package]]
 name = "prettyplease"
 version = "0.2.32"
@@ -7492,6 +7509,12 @@ dependencies = [
  "lzma-sys",
 ]
 
+[[package]]
+name = "yansi"
+version = "1.0.1"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049"
+
 [[package]]
 name = "yoke"
 version = "0.8.0"
diff --git a/datafusion/core/src/physical_planner.rs 
b/datafusion/core/src/physical_planner.rs
index 5a0ee327cb..e1f4154324 100644
--- a/datafusion/core/src/physical_planner.rs
+++ b/datafusion/core/src/physical_planner.rs
@@ -1358,6 +1358,9 @@ impl DefaultPhysicalPlanner {
                     physical_name(expr),
                 ))?])),
             }
+        } else if group_expr.is_empty() {
+            // No GROUP BY clause - create empty PhysicalGroupBy
+            Ok(PhysicalGroupBy::new(vec![], vec![], vec![]))
         } else {
             Ok(PhysicalGroupBy::new_single(
                 group_expr
diff --git a/datafusion/ffi/src/udaf/accumulator_args.rs 
b/datafusion/ffi/src/udaf/accumulator_args.rs
index 874a2ac8b8..2cd2fa5f51 100644
--- a/datafusion/ffi/src/udaf/accumulator_args.rs
+++ b/datafusion/ffi/src/udaf/accumulator_args.rs
@@ -75,6 +75,7 @@ impl TryFrom<AccumulatorArgs<'_>> for FFI_AccumulatorArgs {
             ignore_nulls: args.ignore_nulls,
             fun_definition: None,
             aggregate_function: None,
+            human_display: args.name.to_string(),
         };
         let physical_expr_def = physical_expr_def.encode_to_vec().into();
 
diff --git a/datafusion/physical-plan/src/aggregates/mod.rs 
b/datafusion/physical-plan/src/aggregates/mod.rs
index 66d721fab0..784b7db893 100644
--- a/datafusion/physical-plan/src/aggregates/mod.rs
+++ b/datafusion/physical-plan/src/aggregates/mod.rs
@@ -332,10 +332,17 @@ impl PhysicalGroupBy {
                 )
                 .collect();
         let num_exprs = expr.len();
+        let groups = if self.expr.is_empty() {
+            // No GROUP BY expressions - should have no groups
+            vec![]
+        } else {
+            // Has GROUP BY expressions - create a single group
+            vec![vec![false; num_exprs]]
+        };
         Self {
             expr,
             null_expr: vec![],
-            groups: vec![vec![false; num_exprs]],
+            groups,
         }
     }
 }
diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml
index a1eeabdf87..c95f392a05 100644
--- a/datafusion/proto/Cargo.toml
+++ b/datafusion/proto/Cargo.toml
@@ -60,4 +60,5 @@ datafusion-functions = { workspace = true, default-features = 
true }
 datafusion-functions-aggregate = { workspace = true }
 datafusion-functions-window-common = { workspace = true }
 doc-comment = { workspace = true }
+pretty_assertions = "1.4"
 tokio = { workspace = true, features = ["rt-multi-thread"] }
diff --git a/datafusion/proto/proto/datafusion.proto 
b/datafusion/proto/proto/datafusion.proto
index 64789f5de0..85a565c0b2 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -859,6 +859,7 @@ message PhysicalScalarUdfNode {
   optional bytes fun_definition = 3;
   datafusion_common.ArrowType return_type = 4;
   bool nullable = 5;
+  string return_field_name = 6;
 }
 
 message PhysicalAggregateExprNode {
@@ -870,6 +871,7 @@ message PhysicalAggregateExprNode {
   bool distinct = 3;
   bool ignore_nulls = 6;
   optional bytes fun_definition = 7;
+  string human_display = 8;
 }
 
 message PhysicalWindowExprNode {
diff --git a/datafusion/proto/src/generated/pbjson.rs 
b/datafusion/proto/src/generated/pbjson.rs
index 92309ea6a5..83f66ec22c 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -13619,6 +13619,9 @@ impl serde::Serialize for PhysicalAggregateExprNode {
         if self.fun_definition.is_some() {
             len += 1;
         }
+        if !self.human_display.is_empty() {
+            len += 1;
+        }
         if self.aggregate_function.is_some() {
             len += 1;
         }
@@ -13640,6 +13643,9 @@ impl serde::Serialize for PhysicalAggregateExprNode {
             #[allow(clippy::needless_borrows_for_generic_args)]
             struct_ser.serialize_field("funDefinition", 
pbjson::private::base64::encode(&v).as_str())?;
         }
+        if !self.human_display.is_empty() {
+            struct_ser.serialize_field("humanDisplay", &self.human_display)?;
+        }
         if let Some(v) = self.aggregate_function.as_ref() {
             match v {
                 
physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(v) => {
@@ -13665,6 +13671,8 @@ impl<'de> serde::Deserialize<'de> for 
PhysicalAggregateExprNode {
             "ignoreNulls",
             "fun_definition",
             "funDefinition",
+            "human_display",
+            "humanDisplay",
             "user_defined_aggr_function",
             "userDefinedAggrFunction",
         ];
@@ -13676,6 +13684,7 @@ impl<'de> serde::Deserialize<'de> for 
PhysicalAggregateExprNode {
             Distinct,
             IgnoreNulls,
             FunDefinition,
+            HumanDisplay,
             UserDefinedAggrFunction,
         }
         impl<'de> serde::Deserialize<'de> for GeneratedField {
@@ -13703,6 +13712,7 @@ impl<'de> serde::Deserialize<'de> for 
PhysicalAggregateExprNode {
                             "distinct" => Ok(GeneratedField::Distinct),
                             "ignoreNulls" | "ignore_nulls" => 
Ok(GeneratedField::IgnoreNulls),
                             "funDefinition" | "fun_definition" => 
Ok(GeneratedField::FunDefinition),
+                            "humanDisplay" | "human_display" => 
Ok(GeneratedField::HumanDisplay),
                             "userDefinedAggrFunction" | 
"user_defined_aggr_function" => Ok(GeneratedField::UserDefinedAggrFunction),
                             _ => Err(serde::de::Error::unknown_field(value, 
FIELDS)),
                         }
@@ -13728,6 +13738,7 @@ impl<'de> serde::Deserialize<'de> for 
PhysicalAggregateExprNode {
                 let mut distinct__ = None;
                 let mut ignore_nulls__ = None;
                 let mut fun_definition__ = None;
+                let mut human_display__ = None;
                 let mut aggregate_function__ = None;
                 while let Some(k) = map_.next_key()? {
                     match k {
@@ -13763,6 +13774,12 @@ impl<'de> serde::Deserialize<'de> for 
PhysicalAggregateExprNode {
                                 
map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x|
 x.0)
                             ;
                         }
+                        GeneratedField::HumanDisplay => {
+                            if human_display__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("humanDisplay"));
+                            }
+                            human_display__ = Some(map_.next_value()?);
+                        }
                         GeneratedField::UserDefinedAggrFunction => {
                             if aggregate_function__.is_some() {
                                 return 
Err(serde::de::Error::duplicate_field("userDefinedAggrFunction"));
@@ -13777,6 +13794,7 @@ impl<'de> serde::Deserialize<'de> for 
PhysicalAggregateExprNode {
                     distinct: distinct__.unwrap_or_default(),
                     ignore_nulls: ignore_nulls__.unwrap_or_default(),
                     fun_definition: fun_definition__,
+                    human_display: human_display__.unwrap_or_default(),
                     aggregate_function: aggregate_function__,
                 })
             }
@@ -16312,6 +16330,9 @@ impl serde::Serialize for PhysicalScalarUdfNode {
         if self.nullable {
             len += 1;
         }
+        if !self.return_field_name.is_empty() {
+            len += 1;
+        }
         let mut struct_ser = 
serializer.serialize_struct("datafusion.PhysicalScalarUdfNode", len)?;
         if !self.name.is_empty() {
             struct_ser.serialize_field("name", &self.name)?;
@@ -16330,6 +16351,9 @@ impl serde::Serialize for PhysicalScalarUdfNode {
         if self.nullable {
             struct_ser.serialize_field("nullable", &self.nullable)?;
         }
+        if !self.return_field_name.is_empty() {
+            struct_ser.serialize_field("returnFieldName", 
&self.return_field_name)?;
+        }
         struct_ser.end()
     }
 }
@@ -16347,6 +16371,8 @@ impl<'de> serde::Deserialize<'de> for 
PhysicalScalarUdfNode {
             "return_type",
             "returnType",
             "nullable",
+            "return_field_name",
+            "returnFieldName",
         ];
 
         #[allow(clippy::enum_variant_names)]
@@ -16356,6 +16382,7 @@ impl<'de> serde::Deserialize<'de> for 
PhysicalScalarUdfNode {
             FunDefinition,
             ReturnType,
             Nullable,
+            ReturnFieldName,
         }
         impl<'de> serde::Deserialize<'de> for GeneratedField {
             fn deserialize<D>(deserializer: D) -> 
std::result::Result<GeneratedField, D::Error>
@@ -16382,6 +16409,7 @@ impl<'de> serde::Deserialize<'de> for 
PhysicalScalarUdfNode {
                             "funDefinition" | "fun_definition" => 
Ok(GeneratedField::FunDefinition),
                             "returnType" | "return_type" => 
Ok(GeneratedField::ReturnType),
                             "nullable" => Ok(GeneratedField::Nullable),
+                            "returnFieldName" | "return_field_name" => 
Ok(GeneratedField::ReturnFieldName),
                             _ => Err(serde::de::Error::unknown_field(value, 
FIELDS)),
                         }
                     }
@@ -16406,6 +16434,7 @@ impl<'de> serde::Deserialize<'de> for 
PhysicalScalarUdfNode {
                 let mut fun_definition__ = None;
                 let mut return_type__ = None;
                 let mut nullable__ = None;
+                let mut return_field_name__ = None;
                 while let Some(k) = map_.next_key()? {
                     match k {
                         GeneratedField::Name => {
@@ -16440,6 +16469,12 @@ impl<'de> serde::Deserialize<'de> for 
PhysicalScalarUdfNode {
                             }
                             nullable__ = Some(map_.next_value()?);
                         }
+                        GeneratedField::ReturnFieldName => {
+                            if return_field_name__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("returnFieldName"));
+                            }
+                            return_field_name__ = Some(map_.next_value()?);
+                        }
                     }
                 }
                 Ok(PhysicalScalarUdfNode {
@@ -16448,6 +16483,7 @@ impl<'de> serde::Deserialize<'de> for 
PhysicalScalarUdfNode {
                     fun_definition: fun_definition__,
                     return_type: return_type__,
                     nullable: nullable__.unwrap_or_default(),
+                    return_field_name: return_field_name__.unwrap_or_default(),
                 })
             }
         }
diff --git a/datafusion/proto/src/generated/prost.rs 
b/datafusion/proto/src/generated/prost.rs
index b0fc0ce604..cbf6b3b2d4 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -1305,6 +1305,8 @@ pub struct PhysicalScalarUdfNode {
     pub return_type: 
::core::option::Option<super::datafusion_common::ArrowType>,
     #[prost(bool, tag = "5")]
     pub nullable: bool,
+    #[prost(string, tag = "6")]
+    pub return_field_name: ::prost::alloc::string::String,
 }
 #[derive(Clone, PartialEq, ::prost::Message)]
 pub struct PhysicalAggregateExprNode {
@@ -1318,6 +1320,8 @@ pub struct PhysicalAggregateExprNode {
     pub ignore_nulls: bool,
     #[prost(bytes = "vec", optional, tag = "7")]
     pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec<u8>>,
+    #[prost(string, tag = "8")]
+    pub human_display: ::prost::alloc::string::String,
     #[prost(oneof = "physical_aggregate_expr_node::AggregateFunction", tags = 
"4")]
     pub aggregate_function: ::core::option::Option<
         physical_aggregate_expr_node::AggregateFunction,
diff --git a/datafusion/proto/src/physical_plan/from_proto.rs 
b/datafusion/proto/src/physical_plan/from_proto.rs
index 1c60470b22..a01b121af6 100644
--- a/datafusion/proto/src/physical_plan/from_proto.rs
+++ b/datafusion/proto/src/physical_plan/from_proto.rs
@@ -368,7 +368,12 @@ pub fn parse_physical_expr(
                     e.name.as_str(),
                     scalar_fun_def,
                     args,
-                    Field::new("f", convert_required!(e.return_type)?, 
true).into(),
+                    Field::new(
+                        &e.return_field_name,
+                        convert_required!(e.return_type)?,
+                        true,
+                    )
+                    .into(),
                 )
                 .with_nullable(e.nullable),
             )
diff --git a/datafusion/proto/src/physical_plan/mod.rs 
b/datafusion/proto/src/physical_plan/mod.rs
index 52e0b20db2..1d34ab8d14 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -1108,6 +1108,7 @@ impl protobuf::PhysicalPlanNode {
                                     AggregateExprBuilder::new(agg_udf, 
input_phy_expr)
                                         .schema(Arc::clone(&physical_schema))
                                         .alias(name)
+                                        
.human_display(agg_node.human_display.clone())
                                         
.with_ignore_nulls(agg_node.ignore_nulls)
                                         .with_distinct(agg_node.distinct)
                                         .order_by(order_bys)
diff --git a/datafusion/proto/src/physical_plan/to_proto.rs 
b/datafusion/proto/src/physical_plan/to_proto.rs
index d22a0b5451..85ced4933a 100644
--- a/datafusion/proto/src/physical_plan/to_proto.rs
+++ b/datafusion/proto/src/physical_plan/to_proto.rs
@@ -17,6 +17,7 @@
 
 use std::sync::Arc;
 
+use arrow::datatypes::Schema;
 #[cfg(feature = "parquet")]
 use datafusion::datasource::file_format::parquet::ParquetSink;
 use datafusion::datasource::physical_plan::FileSink;
@@ -69,6 +70,7 @@ pub fn serialize_physical_aggr_expr(
                 distinct: aggr_expr.is_distinct(),
                 ignore_nulls: aggr_expr.ignore_nulls(),
                 fun_definition: (!buf.is_empty()).then_some(buf),
+                human_display: aggr_expr.human_display().to_string(),
             },
         )),
     })
@@ -351,6 +353,10 @@ pub fn serialize_physical_expr(
                     fun_definition: (!buf.is_empty()).then_some(buf),
                     return_type: Some(expr.return_type().try_into()?),
                     nullable: expr.nullable(),
+                    return_field_name: expr
+                        .return_field(&Schema::empty())?
+                        .name()
+                        .to_string(),
                 },
             )),
         })
diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs 
b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
index 2d27a21447..f8fa1020bc 100644
--- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
@@ -141,7 +141,11 @@ fn roundtrip_test_and_return(
     let result_exec_plan: Arc<dyn ExecutionPlan> = proto
         .try_into_physical_plan(ctx, runtime.deref(), codec)
         .expect("from proto");
-    assert_eq!(format!("{exec_plan:?}"), format!("{result_exec_plan:?}"));
+
+    pretty_assertions::assert_eq!(
+        format!("{exec_plan:?}"),
+        format!("{result_exec_plan:?}")
+    );
     Ok(result_exec_plan)
 }
 
@@ -1819,10 +1823,7 @@ async fn test_serialize_deserialize_tpch_queries() -> 
Result<()> {
     Ok(())
 }
 
-// bug: https://github.com/apache/datafusion/issues/16772
-// Only 4 queries pass: q3, q5, q10, q12
-// Ignore the test until the bug is fixed
-#[ignore]
+// Bugs: https://github.com/apache/datafusion/issues/16772
 #[tokio::test]
 async fn test_round_trip_tpch_queries() -> Result<()> {
     // Create context with TPC-H tables
@@ -1839,3 +1840,69 @@ async fn test_round_trip_tpch_queries() -> Result<()> {
 
     Ok(())
 }
+
+// Bug 1 of https://github.com/apache/datafusion/issues/16772
+/// Test that AggregateFunctionExpr human_display field is correctly preserved
+/// during serialization/deserialization roundtrip.
+///
+/// Test for issue where the human_display field (used for EXPLAIN output)
+/// was not being serialized to protobuf, causing it to be lost during 
roundtrip
+/// and resulting in empty or incorrect display strings in query plans.
+#[tokio::test]
+async fn test_round_trip_human_display() -> Result<()> {
+    // Create context with TPC-H tables
+    let ctx = tpch_context().await?;
+
+    let sql = "select r_name, count(1) from region group by r_name";
+    roundtrip_test_sql_with_context(sql, &ctx).await?;
+
+    let sql = "select r_name, count(*) from region group by r_name";
+    roundtrip_test_sql_with_context(sql, &ctx).await?;
+
+    let sql = "select r_name, count(r_name) from region group by r_name";
+    roundtrip_test_sql_with_context(sql, &ctx).await?;
+
+    Ok(())
+}
+
+// Bug 2 of https://github.com/apache/datafusion/issues/16772
+/// Test that PhysicalGroupBy groups field is correctly serialized/deserialized
+/// for simple aggregates (no GROUP BY clause).
+///
+/// Test for issue where simple aggregates like "SELECT SUM(col1 * col2) FROM 
table"
+/// would incorrectly serialize groups as [[]] instead of [] during roundtrip 
serialization.
+/// The groups field should be empty ([]) when there are no GROUP BY 
expressions.
+#[tokio::test]
+async fn test_round_trip_groups_display() -> Result<()> {
+    // Create context with TPC-H tables
+    let ctx = tpch_context().await?;
+
+    let sql = "select sum(l_extendedprice * l_discount) as revenue from 
lineitem;";
+    roundtrip_test_sql_with_context(sql, &ctx).await?;
+
+    let sql = "select sum(l_extendedprice) as revenue from lineitem;";
+    roundtrip_test_sql_with_context(sql, &ctx).await?;
+
+    Ok(())
+}
+
+// Bug 3 of https://github.com/apache/datafusion/issues/16772
+/// Test that ScalarFunctionExpr return_field name is correctly preserved
+/// during serialization/deserialization roundtrip.
+///
+/// Test for issue where the return_field.name for scalar functions
+/// was not being serialized to protobuf, causing it to be lost during 
roundtrip
+/// and defaulting to a generic name like "f" instead of the proper function 
name.
+#[tokio::test]
+async fn test_round_trip_date_part_display() -> Result<()> {
+    // Create context with TPC-H tables
+    let ctx = tpch_context().await?;
+
+    let sql = "select extract(year from l_shipdate) as l_year from lineitem ";
+    roundtrip_test_sql_with_context(sql, &ctx).await?;
+
+    let sql = "select extract(month from l_shipdate) as l_year from lineitem ";
+    roundtrip_test_sql_with_context(sql, &ctx).await?;
+
+    Ok(())
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to