This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new a6d4798c45 Fixes 3 bugs during serialization and deserialization of
physical plans (#16858)
a6d4798c45 is described below
commit a6d4798c45a4e6ca95049a2fcdf61b9490aa8e17
Author: Nga Tran <[email protected]>
AuthorDate: Fri Jul 25 10:32:20 2025 -0400
Fixes 3 bugs during serialization and deserialization of physical plans
(#16858)
---
Cargo.lock | 23 +++++++
datafusion/core/src/physical_planner.rs | 3 +
datafusion/ffi/src/udaf/accumulator_args.rs | 1 +
datafusion/physical-plan/src/aggregates/mod.rs | 9 ++-
datafusion/proto/Cargo.toml | 1 +
datafusion/proto/proto/datafusion.proto | 2 +
datafusion/proto/src/generated/pbjson.rs | 36 ++++++++++
datafusion/proto/src/generated/prost.rs | 4 ++
datafusion/proto/src/physical_plan/from_proto.rs | 7 +-
datafusion/proto/src/physical_plan/mod.rs | 1 +
datafusion/proto/src/physical_plan/to_proto.rs | 6 ++
.../proto/tests/cases/roundtrip_physical_plan.rs | 77 ++++++++++++++++++++--
12 files changed, 163 insertions(+), 7 deletions(-)
diff --git a/Cargo.lock b/Cargo.lock
index 45a8333a1e..7fcda2957e 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -2542,6 +2542,7 @@ dependencies = [
"doc-comment",
"object_store",
"pbjson",
+ "pretty_assertions",
"prost",
"serde",
"serde_json",
@@ -2727,6 +2728,12 @@ dependencies = [
"serde",
]
+[[package]]
+name = "diff"
+version = "0.1.13"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8"
+
[[package]]
name = "difflib"
version = "0.4.0"
@@ -4816,6 +4823,16 @@ dependencies = [
"termtree",
]
+[[package]]
+name = "pretty_assertions"
+version = "1.4.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "3ae130e2f271fbc2ac3a40fb1d07180839cdbbe443c7a27e1e3c13c5cac0116d"
+dependencies = [
+ "diff",
+ "yansi",
+]
+
[[package]]
name = "prettyplease"
version = "0.2.32"
@@ -7492,6 +7509,12 @@ dependencies = [
"lzma-sys",
]
+[[package]]
+name = "yansi"
+version = "1.0.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049"
+
[[package]]
name = "yoke"
version = "0.8.0"
diff --git a/datafusion/core/src/physical_planner.rs
b/datafusion/core/src/physical_planner.rs
index 5a0ee327cb..e1f4154324 100644
--- a/datafusion/core/src/physical_planner.rs
+++ b/datafusion/core/src/physical_planner.rs
@@ -1358,6 +1358,9 @@ impl DefaultPhysicalPlanner {
physical_name(expr),
))?])),
}
+ } else if group_expr.is_empty() {
+ // No GROUP BY clause - create empty PhysicalGroupBy
+ Ok(PhysicalGroupBy::new(vec![], vec![], vec![]))
} else {
Ok(PhysicalGroupBy::new_single(
group_expr
diff --git a/datafusion/ffi/src/udaf/accumulator_args.rs
b/datafusion/ffi/src/udaf/accumulator_args.rs
index 874a2ac8b8..2cd2fa5f51 100644
--- a/datafusion/ffi/src/udaf/accumulator_args.rs
+++ b/datafusion/ffi/src/udaf/accumulator_args.rs
@@ -75,6 +75,7 @@ impl TryFrom<AccumulatorArgs<'_>> for FFI_AccumulatorArgs {
ignore_nulls: args.ignore_nulls,
fun_definition: None,
aggregate_function: None,
+ human_display: args.name.to_string(),
};
let physical_expr_def = physical_expr_def.encode_to_vec().into();
diff --git a/datafusion/physical-plan/src/aggregates/mod.rs
b/datafusion/physical-plan/src/aggregates/mod.rs
index 66d721fab0..784b7db893 100644
--- a/datafusion/physical-plan/src/aggregates/mod.rs
+++ b/datafusion/physical-plan/src/aggregates/mod.rs
@@ -332,10 +332,17 @@ impl PhysicalGroupBy {
)
.collect();
let num_exprs = expr.len();
+ let groups = if self.expr.is_empty() {
+ // No GROUP BY expressions - should have no groups
+ vec![]
+ } else {
+ // Has GROUP BY expressions - create a single group
+ vec![vec![false; num_exprs]]
+ };
Self {
expr,
null_expr: vec![],
- groups: vec![vec![false; num_exprs]],
+ groups,
}
}
}
diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml
index a1eeabdf87..c95f392a05 100644
--- a/datafusion/proto/Cargo.toml
+++ b/datafusion/proto/Cargo.toml
@@ -60,4 +60,5 @@ datafusion-functions = { workspace = true, default-features =
true }
datafusion-functions-aggregate = { workspace = true }
datafusion-functions-window-common = { workspace = true }
doc-comment = { workspace = true }
+pretty_assertions = "1.4"
tokio = { workspace = true, features = ["rt-multi-thread"] }
diff --git a/datafusion/proto/proto/datafusion.proto
b/datafusion/proto/proto/datafusion.proto
index 64789f5de0..85a565c0b2 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -859,6 +859,7 @@ message PhysicalScalarUdfNode {
optional bytes fun_definition = 3;
datafusion_common.ArrowType return_type = 4;
bool nullable = 5;
+ string return_field_name = 6;
}
message PhysicalAggregateExprNode {
@@ -870,6 +871,7 @@ message PhysicalAggregateExprNode {
bool distinct = 3;
bool ignore_nulls = 6;
optional bytes fun_definition = 7;
+ string human_display = 8;
}
message PhysicalWindowExprNode {
diff --git a/datafusion/proto/src/generated/pbjson.rs
b/datafusion/proto/src/generated/pbjson.rs
index 92309ea6a5..83f66ec22c 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -13619,6 +13619,9 @@ impl serde::Serialize for PhysicalAggregateExprNode {
if self.fun_definition.is_some() {
len += 1;
}
+ if !self.human_display.is_empty() {
+ len += 1;
+ }
if self.aggregate_function.is_some() {
len += 1;
}
@@ -13640,6 +13643,9 @@ impl serde::Serialize for PhysicalAggregateExprNode {
#[allow(clippy::needless_borrows_for_generic_args)]
struct_ser.serialize_field("funDefinition",
pbjson::private::base64::encode(&v).as_str())?;
}
+ if !self.human_display.is_empty() {
+ struct_ser.serialize_field("humanDisplay", &self.human_display)?;
+ }
if let Some(v) = self.aggregate_function.as_ref() {
match v {
physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(v) => {
@@ -13665,6 +13671,8 @@ impl<'de> serde::Deserialize<'de> for
PhysicalAggregateExprNode {
"ignoreNulls",
"fun_definition",
"funDefinition",
+ "human_display",
+ "humanDisplay",
"user_defined_aggr_function",
"userDefinedAggrFunction",
];
@@ -13676,6 +13684,7 @@ impl<'de> serde::Deserialize<'de> for
PhysicalAggregateExprNode {
Distinct,
IgnoreNulls,
FunDefinition,
+ HumanDisplay,
UserDefinedAggrFunction,
}
impl<'de> serde::Deserialize<'de> for GeneratedField {
@@ -13703,6 +13712,7 @@ impl<'de> serde::Deserialize<'de> for
PhysicalAggregateExprNode {
"distinct" => Ok(GeneratedField::Distinct),
"ignoreNulls" | "ignore_nulls" =>
Ok(GeneratedField::IgnoreNulls),
"funDefinition" | "fun_definition" =>
Ok(GeneratedField::FunDefinition),
+ "humanDisplay" | "human_display" =>
Ok(GeneratedField::HumanDisplay),
"userDefinedAggrFunction" |
"user_defined_aggr_function" => Ok(GeneratedField::UserDefinedAggrFunction),
_ => Err(serde::de::Error::unknown_field(value,
FIELDS)),
}
@@ -13728,6 +13738,7 @@ impl<'de> serde::Deserialize<'de> for
PhysicalAggregateExprNode {
let mut distinct__ = None;
let mut ignore_nulls__ = None;
let mut fun_definition__ = None;
+ let mut human_display__ = None;
let mut aggregate_function__ = None;
while let Some(k) = map_.next_key()? {
match k {
@@ -13763,6 +13774,12 @@ impl<'de> serde::Deserialize<'de> for
PhysicalAggregateExprNode {
map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x|
x.0)
;
}
+ GeneratedField::HumanDisplay => {
+ if human_display__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("humanDisplay"));
+ }
+ human_display__ = Some(map_.next_value()?);
+ }
GeneratedField::UserDefinedAggrFunction => {
if aggregate_function__.is_some() {
return
Err(serde::de::Error::duplicate_field("userDefinedAggrFunction"));
@@ -13777,6 +13794,7 @@ impl<'de> serde::Deserialize<'de> for
PhysicalAggregateExprNode {
distinct: distinct__.unwrap_or_default(),
ignore_nulls: ignore_nulls__.unwrap_or_default(),
fun_definition: fun_definition__,
+ human_display: human_display__.unwrap_or_default(),
aggregate_function: aggregate_function__,
})
}
@@ -16312,6 +16330,9 @@ impl serde::Serialize for PhysicalScalarUdfNode {
if self.nullable {
len += 1;
}
+ if !self.return_field_name.is_empty() {
+ len += 1;
+ }
let mut struct_ser =
serializer.serialize_struct("datafusion.PhysicalScalarUdfNode", len)?;
if !self.name.is_empty() {
struct_ser.serialize_field("name", &self.name)?;
@@ -16330,6 +16351,9 @@ impl serde::Serialize for PhysicalScalarUdfNode {
if self.nullable {
struct_ser.serialize_field("nullable", &self.nullable)?;
}
+ if !self.return_field_name.is_empty() {
+ struct_ser.serialize_field("returnFieldName",
&self.return_field_name)?;
+ }
struct_ser.end()
}
}
@@ -16347,6 +16371,8 @@ impl<'de> serde::Deserialize<'de> for
PhysicalScalarUdfNode {
"return_type",
"returnType",
"nullable",
+ "return_field_name",
+ "returnFieldName",
];
#[allow(clippy::enum_variant_names)]
@@ -16356,6 +16382,7 @@ impl<'de> serde::Deserialize<'de> for
PhysicalScalarUdfNode {
FunDefinition,
ReturnType,
Nullable,
+ ReturnFieldName,
}
impl<'de> serde::Deserialize<'de> for GeneratedField {
fn deserialize<D>(deserializer: D) ->
std::result::Result<GeneratedField, D::Error>
@@ -16382,6 +16409,7 @@ impl<'de> serde::Deserialize<'de> for
PhysicalScalarUdfNode {
"funDefinition" | "fun_definition" =>
Ok(GeneratedField::FunDefinition),
"returnType" | "return_type" =>
Ok(GeneratedField::ReturnType),
"nullable" => Ok(GeneratedField::Nullable),
+ "returnFieldName" | "return_field_name" =>
Ok(GeneratedField::ReturnFieldName),
_ => Err(serde::de::Error::unknown_field(value,
FIELDS)),
}
}
@@ -16406,6 +16434,7 @@ impl<'de> serde::Deserialize<'de> for
PhysicalScalarUdfNode {
let mut fun_definition__ = None;
let mut return_type__ = None;
let mut nullable__ = None;
+ let mut return_field_name__ = None;
while let Some(k) = map_.next_key()? {
match k {
GeneratedField::Name => {
@@ -16440,6 +16469,12 @@ impl<'de> serde::Deserialize<'de> for
PhysicalScalarUdfNode {
}
nullable__ = Some(map_.next_value()?);
}
+ GeneratedField::ReturnFieldName => {
+ if return_field_name__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("returnFieldName"));
+ }
+ return_field_name__ = Some(map_.next_value()?);
+ }
}
}
Ok(PhysicalScalarUdfNode {
@@ -16448,6 +16483,7 @@ impl<'de> serde::Deserialize<'de> for
PhysicalScalarUdfNode {
fun_definition: fun_definition__,
return_type: return_type__,
nullable: nullable__.unwrap_or_default(),
+ return_field_name: return_field_name__.unwrap_or_default(),
})
}
}
diff --git a/datafusion/proto/src/generated/prost.rs
b/datafusion/proto/src/generated/prost.rs
index b0fc0ce604..cbf6b3b2d4 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -1305,6 +1305,8 @@ pub struct PhysicalScalarUdfNode {
pub return_type:
::core::option::Option<super::datafusion_common::ArrowType>,
#[prost(bool, tag = "5")]
pub nullable: bool,
+ #[prost(string, tag = "6")]
+ pub return_field_name: ::prost::alloc::string::String,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct PhysicalAggregateExprNode {
@@ -1318,6 +1320,8 @@ pub struct PhysicalAggregateExprNode {
pub ignore_nulls: bool,
#[prost(bytes = "vec", optional, tag = "7")]
pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec<u8>>,
+ #[prost(string, tag = "8")]
+ pub human_display: ::prost::alloc::string::String,
#[prost(oneof = "physical_aggregate_expr_node::AggregateFunction", tags =
"4")]
pub aggregate_function: ::core::option::Option<
physical_aggregate_expr_node::AggregateFunction,
diff --git a/datafusion/proto/src/physical_plan/from_proto.rs
b/datafusion/proto/src/physical_plan/from_proto.rs
index 1c60470b22..a01b121af6 100644
--- a/datafusion/proto/src/physical_plan/from_proto.rs
+++ b/datafusion/proto/src/physical_plan/from_proto.rs
@@ -368,7 +368,12 @@ pub fn parse_physical_expr(
e.name.as_str(),
scalar_fun_def,
args,
- Field::new("f", convert_required!(e.return_type)?,
true).into(),
+ Field::new(
+ &e.return_field_name,
+ convert_required!(e.return_type)?,
+ true,
+ )
+ .into(),
)
.with_nullable(e.nullable),
)
diff --git a/datafusion/proto/src/physical_plan/mod.rs
b/datafusion/proto/src/physical_plan/mod.rs
index 52e0b20db2..1d34ab8d14 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -1108,6 +1108,7 @@ impl protobuf::PhysicalPlanNode {
AggregateExprBuilder::new(agg_udf,
input_phy_expr)
.schema(Arc::clone(&physical_schema))
.alias(name)
+
.human_display(agg_node.human_display.clone())
.with_ignore_nulls(agg_node.ignore_nulls)
.with_distinct(agg_node.distinct)
.order_by(order_bys)
diff --git a/datafusion/proto/src/physical_plan/to_proto.rs
b/datafusion/proto/src/physical_plan/to_proto.rs
index d22a0b5451..85ced4933a 100644
--- a/datafusion/proto/src/physical_plan/to_proto.rs
+++ b/datafusion/proto/src/physical_plan/to_proto.rs
@@ -17,6 +17,7 @@
use std::sync::Arc;
+use arrow::datatypes::Schema;
#[cfg(feature = "parquet")]
use datafusion::datasource::file_format::parquet::ParquetSink;
use datafusion::datasource::physical_plan::FileSink;
@@ -69,6 +70,7 @@ pub fn serialize_physical_aggr_expr(
distinct: aggr_expr.is_distinct(),
ignore_nulls: aggr_expr.ignore_nulls(),
fun_definition: (!buf.is_empty()).then_some(buf),
+ human_display: aggr_expr.human_display().to_string(),
},
)),
})
@@ -351,6 +353,10 @@ pub fn serialize_physical_expr(
fun_definition: (!buf.is_empty()).then_some(buf),
return_type: Some(expr.return_type().try_into()?),
nullable: expr.nullable(),
+ return_field_name: expr
+ .return_field(&Schema::empty())?
+ .name()
+ .to_string(),
},
)),
})
diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
index 2d27a21447..f8fa1020bc 100644
--- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
@@ -141,7 +141,11 @@ fn roundtrip_test_and_return(
let result_exec_plan: Arc<dyn ExecutionPlan> = proto
.try_into_physical_plan(ctx, runtime.deref(), codec)
.expect("from proto");
- assert_eq!(format!("{exec_plan:?}"), format!("{result_exec_plan:?}"));
+
+ pretty_assertions::assert_eq!(
+ format!("{exec_plan:?}"),
+ format!("{result_exec_plan:?}")
+ );
Ok(result_exec_plan)
}
@@ -1819,10 +1823,7 @@ async fn test_serialize_deserialize_tpch_queries() ->
Result<()> {
Ok(())
}
-// bug: https://github.com/apache/datafusion/issues/16772
-// Only 4 queries pass: q3, q5, q10, q12
-// Ignore the test until the bug is fixed
-#[ignore]
+// Bugs: https://github.com/apache/datafusion/issues/16772
#[tokio::test]
async fn test_round_trip_tpch_queries() -> Result<()> {
// Create context with TPC-H tables
@@ -1839,3 +1840,69 @@ async fn test_round_trip_tpch_queries() -> Result<()> {
Ok(())
}
+
+// Bug 1 of https://github.com/apache/datafusion/issues/16772
+/// Test that AggregateFunctionExpr human_display field is correctly preserved
+/// during serialization/deserialization roundtrip.
+///
+/// Test for issue where the human_display field (used for EXPLAIN output)
+/// was not being serialized to protobuf, causing it to be lost during
roundtrip
+/// and resulting in empty or incorrect display strings in query plans.
+#[tokio::test]
+async fn test_round_trip_human_display() -> Result<()> {
+ // Create context with TPC-H tables
+ let ctx = tpch_context().await?;
+
+ let sql = "select r_name, count(1) from region group by r_name";
+ roundtrip_test_sql_with_context(sql, &ctx).await?;
+
+ let sql = "select r_name, count(*) from region group by r_name";
+ roundtrip_test_sql_with_context(sql, &ctx).await?;
+
+ let sql = "select r_name, count(r_name) from region group by r_name";
+ roundtrip_test_sql_with_context(sql, &ctx).await?;
+
+ Ok(())
+}
+
+// Bug 2 of https://github.com/apache/datafusion/issues/16772
+/// Test that PhysicalGroupBy groups field is correctly serialized/deserialized
+/// for simple aggregates (no GROUP BY clause).
+///
+/// Test for issue where simple aggregates like "SELECT SUM(col1 * col2) FROM
table"
+/// would incorrectly serialize groups as [[]] instead of [] during roundtrip
serialization.
+/// The groups field should be empty ([]) when there are no GROUP BY
expressions.
+#[tokio::test]
+async fn test_round_trip_groups_display() -> Result<()> {
+ // Create context with TPC-H tables
+ let ctx = tpch_context().await?;
+
+ let sql = "select sum(l_extendedprice * l_discount) as revenue from
lineitem;";
+ roundtrip_test_sql_with_context(sql, &ctx).await?;
+
+ let sql = "select sum(l_extendedprice) as revenue from lineitem;";
+ roundtrip_test_sql_with_context(sql, &ctx).await?;
+
+ Ok(())
+}
+
+// Bug 3 of https://github.com/apache/datafusion/issues/16772
+/// Test that ScalarFunctionExpr return_field name is correctly preserved
+/// during serialization/deserialization roundtrip.
+///
+/// Test for issue where the return_field.name for scalar functions
+/// was not being serialized to protobuf, causing it to be lost during
roundtrip
+/// and defaulting to a generic name like "f" instead of the proper function
name.
+#[tokio::test]
+async fn test_round_trip_date_part_display() -> Result<()> {
+ // Create context with TPC-H tables
+ let ctx = tpch_context().await?;
+
+ let sql = "select extract(year from l_shipdate) as l_year from lineitem ";
+ roundtrip_test_sql_with_context(sql, &ctx).await?;
+
+ let sql = "select extract(month from l_shipdate) as l_year from lineitem ";
+ roundtrip_test_sql_with_context(sql, &ctx).await?;
+
+ Ok(())
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]