gstvg commented on code in PR #22362:
URL: https://github.com/apache/datafusion/pull/22362#discussion_r3291606705
##########
datafusion/proto/src/physical_plan/to_proto.rs:
##########
@@ -575,6 +577,41 @@ pub fn serialize_physical_expr_with_converter(
}),
)),
})
+ } else if let Some(expr) = expr.downcast_ref::<HigherOrderFunctionExpr>() {
+ let mut buf = Vec::new();
+ codec.try_encode_higher_order_function(expr.fun(), &mut buf)?;
+ Ok(protobuf::PhysicalExprNode {
+ expr_id,
+ expr_type:
Some(protobuf::physical_expr_node::ExprType::HigherOrderUdf(
+ protobuf::PhysicalHigherOrderUdfNode {
+ name: expr.name().to_string(),
+ args: serialize_physical_exprs(expr.args(), codec,
proto_converter)?,
+ fun_definition: (!buf.is_empty()).then_some(buf),
+ },
Review Comment:
I removed the `HigherOrderFunctionExpr::new`, the only constructor that
receives `return_field` as arg, from the lambda PR, due to possible type
mismatch, so it could be discussed on it's own PR when needed. See
https://github.com/apache/datafusion/pull/21679#pullrequestreview-4135675354
and
> FYI, the return field might not match the function output after the new
arguments, but you don't have the schema here so you cant check that and I see
ScalarFunctionExpr have the same problem
from https://github.com/apache/datafusion/pull/21679/changes#r3109735988
(quoting directly since the github link anchor doesn't work well)
##########
datafusion/proto/tests/cases/serialize.rs:
##########
@@ -290,11 +294,101 @@ fn test_expression_serialization_roundtrip() {
assert_eq!(serialize_name, deserialize_name);
}
+}
+
+/// Extracts the first part of a function name
+/// 'foo(bar)' -> 'foo'
+fn extract_function_name(expr: &Expr) -> String {
+ let name = expr.schema_name().to_string();
+ name.split('(').next().unwrap().to_string()
+}
+
+/// return a `SessionContext` with `MyHigherOrderUDF` registered as a
higher-order UDF
+fn context_with_higher_order_function() -> SessionContext {
+ let mut ctx = SessionContext::new();
+ let hof = Arc::new(MyHigherOrderUDF::new("payload".to_string()));
+ ctx.register_higher_order_function(hof).unwrap();
+ ctx
+}
+
+fn dummy_higher_order_function_call(hof: Arc<dyn HigherOrderUDF>) -> Expr {
+ let list = ScalarValue::List(ScalarValue::new_list_nullable(
+ &[ScalarValue::Int32(Some(1))],
+ &DataType::Int32,
+ ));
+ let lambda_var_with_field = Expr::LambdaVariable(LambdaVariable::new(
+ "x".to_string(),
+ Some(Arc::new(Field::new("x", DataType::Int32, true))),
+ ));
+ let lambda_var_without_field =
+ Expr::LambdaVariable(LambdaVariable::new("x".into(), None));
+ let lambda = lambda(["x"], lambda_var_with_field +
lambda_var_without_field);
+ Expr::HigherOrderFunction(HigherOrderFunction::new(
+ hof,
+ vec![Expr::Literal(list, None), lambda],
+ ))
+}
- /// Extracts the first part of a function name
- /// 'foo(bar)' -> 'foo'
- fn extract_function_name(expr: &Expr) -> String {
- let name = expr.schema_name().to_string();
- name.split('(').next().unwrap().to_string()
+#[test]
+fn hof_roundtrip_with_registry() {
+ let ctx = context_with_higher_order_function();
+ let hof = ctx
+ .higher_order_function("higher_order_udf")
+ .expect("could not find higher order udf");
+
+ let expr = dummy_higher_order_function_call(hof);
+
+ let bytes = expr.to_bytes().unwrap();
+ let deserialized_expr =
+ Expr::from_bytes_with_ctx(&bytes, ctx.task_ctx().as_ref()).unwrap();
+
+ assert_eq!(expr, deserialized_expr);
+}
+
+#[test]
+#[should_panic(
+ expected = "LogicalExtensionCodec is not provided for higher order
function higher_order_udf"
+)]
+fn hof_roundtrip_without_registry() {
+ let ctx = context_with_higher_order_function();
+ let hof = ctx
+ .higher_order_function("higher_order_udf")
+ .expect("could not find higher order udf");
+
+ let expr = dummy_higher_order_function_call(hof);
+
+ let bytes = expr.to_bytes().unwrap();
+ Expr::from_bytes(&bytes).unwrap();
+}
+
+#[test]
+fn test_higher_order_serialization_roundtrip() {
+ let ctx = SessionContext::new();
+ let list = ScalarValue::List(ScalarValue::new_list_nullable(
+ &[ScalarValue::Int32(Some(1))],
+ &DataType::Int32,
+ ));
+ let lambda_var_with_field = Expr::LambdaVariable(LambdaVariable::new(
+ "x".to_string(),
+ Some(Arc::new(Field::new("x", DataType::Int32, true))),
+ ));
+ let lambda_var_without_field =
+ Expr::LambdaVariable(LambdaVariable::new("x".into(), None));
+ let lambda = lambda(["x"], lambda_var_with_field +
lambda_var_without_field);
+ let args = vec![Expr::Literal(list, None), lambda];
+
+ for function in
datafusion::functions_nested::all_default_higher_order_functions() {
+ let expr =
+ Expr::HigherOrderFunction(HigherOrderFunction::new(function,
args.clone()));
+
+ let extension_codec = DefaultLogicalExtensionCodec {};
+ let proto = serialize_expr(&expr, &extension_codec).unwrap();
+ let deserialize =
+ parse_expr(&proto, ctx.task_ctx().as_ref(),
&extension_codec).unwrap();
+
+ let serialize_name = extract_function_name(&expr);
+ let deserialize_name = extract_function_name(&deserialize);
+
+ assert_eq!(serialize_name, deserialize_name);
Review Comment:
Yes,
https://github.com/apache/datafusion/pull/22362/changes/9a3f30c28c61ec23b1061c95d6dc18b70610cf25
thanks
This was based on `test_expression_serialization_roundtrip` which checks
only names but asserting directly on `expr` is definitively better
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]