LiaCastaneda commented on code in PR #22362:
URL: https://github.com/apache/datafusion/pull/22362#discussion_r3280332102
##########
datafusion/proto/src/logical_plan/to_proto.rs:
##########
@@ -634,11 +635,35 @@ pub fn serialize_expr(
.unwrap_or(HashMap::new()),
})),
},
- Expr::HigherOrderFunction(_) | Expr::Lambda(_) |
Expr::LambdaVariable(_) => {
- return Err(Error::General(
- "Proto serialization error: Lambda not
implemented".to_string(),
- ));
+ Expr::HigherOrderFunction(HigherOrderFunction { func, args }) => {
Review Comment:
super nit: instead of appending this, can we move this next to the rest of
functions? (Scalar, Aggregate etc)
##########
datafusion/proto/proto/datafusion.proto:
##########
@@ -430,6 +430,10 @@ message LogicalExprNode {
// Subquery expressions
ScalarSubqueryExprNode scalar_subquery_expr = 36;
+
+ HigherOrderUDFExprNode higher_order_udf_expr = 37;
+ Lambda lambda = 38;
+ LambdaVariable lambda_variable = 39;
Review Comment:
super nit as well: Can we move this next to the other UDFs?
##########
datafusion/proto/src/physical_plan/to_proto.rs:
##########
@@ -575,6 +577,41 @@ pub fn serialize_physical_expr_with_converter(
}),
)),
})
+ } else if let Some(expr) = expr.downcast_ref::<HigherOrderFunctionExpr>() {
+ let mut buf = Vec::new();
+ codec.try_encode_higher_order_function(expr.fun(), &mut buf)?;
+ Ok(protobuf::PhysicalExprNode {
+ expr_id,
+ expr_type:
Some(protobuf::physical_expr_node::ExprType::HigherOrderUdf(
+ protobuf::PhysicalHigherOrderUdfNode {
+ name: expr.name().to_string(),
+ args: serialize_physical_exprs(expr.args(), codec,
proto_converter)?,
+ fun_definition: (!buf.is_empty()).then_some(buf),
+ },
+ )),
+ })
+ } else if let Some(lambda) = expr.downcast_ref::<LambdaExpr>() {
+ Ok(protobuf::PhysicalExprNode {
+ expr_id,
+ expr_type:
Some(protobuf::physical_expr_node::ExprType::Lambda(Box::new(
+ protobuf::PhysicalLambdaExprNode {
+ params: lambda.params().to_vec(),
+ body: Some(Box::new(
+ proto_converter.physical_expr_to_proto(lambda.body(),
codec)?,
+ )),
+ },
+ ))),
+ })
+ } else if let Some(var) = expr.downcast_ref::<LambdaVariable>() {
+ Ok(protobuf::PhysicalExprNode {
+ expr_id,
+ expr_type:
Some(protobuf::physical_expr_node::ExprType::LambdaVariable(
+ PhysicalLambdaVariableExprNode {
+ index: var.index() as u32,
Review Comment:
for now we don't use this so it is always 0 no?
##########
datafusion/proto/tests/cases/serialize.rs:
##########
@@ -290,11 +294,101 @@ fn test_expression_serialization_roundtrip() {
assert_eq!(serialize_name, deserialize_name);
}
+}
+
+/// Extracts the first part of a function name
+/// 'foo(bar)' -> 'foo'
+fn extract_function_name(expr: &Expr) -> String {
+ let name = expr.schema_name().to_string();
+ name.split('(').next().unwrap().to_string()
+}
+
+/// return a `SessionContext` with `MyHigherOrderUDF` registered as a
higher-order UDF
+fn context_with_higher_order_function() -> SessionContext {
+ let mut ctx = SessionContext::new();
+ let hof = Arc::new(MyHigherOrderUDF::new("payload".to_string()));
+ ctx.register_higher_order_function(hof).unwrap();
+ ctx
+}
+
+fn dummy_higher_order_function_call(hof: Arc<dyn HigherOrderUDF>) -> Expr {
+ let list = ScalarValue::List(ScalarValue::new_list_nullable(
+ &[ScalarValue::Int32(Some(1))],
+ &DataType::Int32,
+ ));
+ let lambda_var_with_field = Expr::LambdaVariable(LambdaVariable::new(
+ "x".to_string(),
+ Some(Arc::new(Field::new("x", DataType::Int32, true))),
+ ));
+ let lambda_var_without_field =
+ Expr::LambdaVariable(LambdaVariable::new("x".into(), None));
+ let lambda = lambda(["x"], lambda_var_with_field +
lambda_var_without_field);
+ Expr::HigherOrderFunction(HigherOrderFunction::new(
+ hof,
+ vec![Expr::Literal(list, None), lambda],
+ ))
+}
- /// Extracts the first part of a function name
- /// 'foo(bar)' -> 'foo'
- fn extract_function_name(expr: &Expr) -> String {
- let name = expr.schema_name().to_string();
- name.split('(').next().unwrap().to_string()
+#[test]
+fn hof_roundtrip_with_registry() {
+ let ctx = context_with_higher_order_function();
+ let hof = ctx
+ .higher_order_function("higher_order_udf")
+ .expect("could not find higher order udf");
+
+ let expr = dummy_higher_order_function_call(hof);
+
+ let bytes = expr.to_bytes().unwrap();
+ let deserialized_expr =
+ Expr::from_bytes_with_ctx(&bytes, ctx.task_ctx().as_ref()).unwrap();
+
+ assert_eq!(expr, deserialized_expr);
+}
+
+#[test]
+#[should_panic(
+ expected = "LogicalExtensionCodec is not provided for higher order
function higher_order_udf"
+)]
+fn hof_roundtrip_without_registry() {
+ let ctx = context_with_higher_order_function();
+ let hof = ctx
+ .higher_order_function("higher_order_udf")
+ .expect("could not find higher order udf");
+
+ let expr = dummy_higher_order_function_call(hof);
+
+ let bytes = expr.to_bytes().unwrap();
+ Expr::from_bytes(&bytes).unwrap();
+}
+
+#[test]
+fn test_higher_order_serialization_roundtrip() {
+ let ctx = SessionContext::new();
+ let list = ScalarValue::List(ScalarValue::new_list_nullable(
+ &[ScalarValue::Int32(Some(1))],
+ &DataType::Int32,
+ ));
+ let lambda_var_with_field = Expr::LambdaVariable(LambdaVariable::new(
+ "x".to_string(),
+ Some(Arc::new(Field::new("x", DataType::Int32, true))),
+ ));
+ let lambda_var_without_field =
+ Expr::LambdaVariable(LambdaVariable::new("x".into(), None));
+ let lambda = lambda(["x"], lambda_var_with_field +
lambda_var_without_field);
+ let args = vec![Expr::Literal(list, None), lambda];
+
+ for function in
datafusion::functions_nested::all_default_higher_order_functions() {
+ let expr =
+ Expr::HigherOrderFunction(HigherOrderFunction::new(function,
args.clone()));
+
+ let extension_codec = DefaultLogicalExtensionCodec {};
+ let proto = serialize_expr(&expr, &extension_codec).unwrap();
+ let deserialize =
+ parse_expr(&proto, ctx.task_ctx().as_ref(),
&extension_codec).unwrap();
+
+ let serialize_name = extract_function_name(&expr);
+ let deserialize_name = extract_function_name(&deserialize);
+
+ assert_eq!(serialize_name, deserialize_name);
Review Comment:
should we assert directly on `expr` and `deserialize` instead? since the
Hofs implement `PartialEq` I think it should work
`assert_eq!(expr, deserialized_expr);`
##########
datafusion/proto/src/physical_plan/to_proto.rs:
##########
@@ -575,6 +577,41 @@ pub fn serialize_physical_expr_with_converter(
}),
)),
})
+ } else if let Some(expr) = expr.downcast_ref::<HigherOrderFunctionExpr>() {
+ let mut buf = Vec::new();
+ codec.try_encode_higher_order_function(expr.fun(), &mut buf)?;
+ Ok(protobuf::PhysicalExprNode {
+ expr_id,
+ expr_type:
Some(protobuf::physical_expr_node::ExprType::HigherOrderUdf(
+ protobuf::PhysicalHigherOrderUdfNode {
+ name: expr.name().to_string(),
+ args: serialize_physical_exprs(expr.args(), codec,
proto_converter)?,
+ fun_definition: (!buf.is_empty()).then_some(buf),
+ },
Review Comment:
Shouldn't we encode the return_field instead of resolving it from the schema
when decoding (like ScalarUDF)?
Not really sure what would be the exact disadvantage of keeping it at is,
but wondering if it would be cleaner to make the return type explicit.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]