thinkharderdev commented on code in PR #9436:
URL: https://github.com/apache/arrow-datafusion/pull/9436#discussion_r1529944851
##########
datafusion/physical-expr/src/scalar_function.rs:
##########
@@ -171,8 +171,17 @@ impl PhysicalExpr for ScalarFunctionExpr {
.collect::<Result<Vec<_>>>()?,
};
+ let fun_implementation = match self.fun {
+ ScalarFunctionDefinition::BuiltIn(ref fun) =>
create_physical_fun(fun)?,
+ ScalarFunctionDefinition::UDF(ref fun) => fun.fun(),
+ ScalarFunctionDefinition::Name(_) => {
+ return internal_err!(
+ "Name function must be resolved to one of the other
variants prior to physical planning"
+ );
+ }
+ };
// evaluate the function
- let fun = self.fun.as_ref();
+ let fun = fun_implementation.as_ref();
(fun)(&inputs)
Review Comment:
nit: In the case of `UDF` we can skip lifting into an `Arc` function pointer
here:
```
match self.fun {
ScalarFunctionDefinition::BuiltIn(ref fun) => {
let fun = create_physical_fun(fun)?;
(fun)(&inputs)
}
ScalarFunctionDefinition::UDF(ref fun) => fun.invoke(&inputs),
ScalarFunctionDefinition::Name(_) => {
internal_err!(
"Name function must be resolved to one of the other
variants prior to physical planning"
)
}
}
```
##########
datafusion/proto/src/physical_plan/to_proto.rs:
##########
@@ -374,195 +381,215 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) ->
Result<AggrFn> {
Ok(AggrFn { inner, distinct })
}
-impl TryFrom<Arc<dyn PhysicalExpr>> for protobuf::PhysicalExprNode {
- type Error = DataFusionError;
-
- fn try_from(value: Arc<dyn PhysicalExpr>) -> Result<Self, Self::Error> {
- let expr = value.as_any();
+pub fn serialize_expr(
Review Comment:
nit: suggest we make the naming more explicit here and add a doc comment
```suggestion
/// Serialize a `PhysicalExpr` to default protobuf representation.
///
/// If required, a [`PhysicalExtensionCodec`] can be provided which can
handle
/// serialization of udfs requiring specialized serialization (see
[`PhysicalExtensionCodec::try_encode_udf`])
pub fn serialize_physical_expr(
```
##########
datafusion/physical-expr/src/functions.rs:
##########
@@ -57,7 +58,7 @@ pub fn create_physical_expr(
fun: &BuiltinScalarFunction,
input_phy_exprs: &[Arc<dyn PhysicalExpr>],
input_schema: &Schema,
- execution_props: &ExecutionProps,
+ _execution_props: &ExecutionProps,
Review Comment:
weird, seems like it never needed this param
##########
datafusion/proto/tests/cases/roundtrip_physical_plan.rs:
##########
@@ -665,6 +670,133 @@ fn roundtrip_scalar_udf() -> Result<()> {
roundtrip_test_with_context(Arc::new(project), ctx)
}
+#[test]
+fn roundtrip_scalar_udf_extension_codec() {
+ #[derive(Debug)]
+ struct MyRegexUdf {
+ signature: Signature,
+ // regex as original string
+ pattern: String,
+ }
+
+ impl MyRegexUdf {
+ fn new(pattern: String) -> Self {
+ Self {
+ signature: Signature::uniform(
+ 1,
+ vec![DataType::Int32],
+ Volatility::Immutable,
+ ),
+ pattern,
+ }
+ }
+ }
+
+ /// Implement the ScalarUDFImpl trait for MyRegexUdf
+ impl ScalarUDFImpl for MyRegexUdf {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+ fn name(&self) -> &str {
+ "regex_udf"
+ }
+ fn signature(&self) -> &Signature {
+ &self.signature
+ }
+ fn return_type(&self, args: &[DataType]) -> Result<DataType> {
+ if !matches!(args.first(), Some(&DataType::Utf8)) {
+ return plan_err!("regex_udf only accepts Utf8 arguments");
+ }
+ Ok(DataType::Int32)
+ }
+ fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
+ unimplemented!()
+ }
+ }
+
+ #[derive(Clone, PartialEq, ::prost::Message)]
+ pub struct MyRegexUdfNode {
+ #[prost(string, tag = "1")]
+ pub pattern: String,
+ }
+
+ #[derive(Debug)]
+ pub struct ScalarUDFExtensionCodec {}
+
+ impl PhysicalExtensionCodec for ScalarUDFExtensionCodec {
+ fn try_decode(
+ &self,
+ _buf: &[u8],
+ _inputs: &[Arc<dyn ExecutionPlan>],
+ _registry: &dyn FunctionRegistry,
+ ) -> Result<Arc<dyn ExecutionPlan>> {
+ not_impl_err!("No extension codec provided")
+ }
+
+ fn try_encode(
+ &self,
+ _node: Arc<dyn ExecutionPlan>,
+ _buf: &mut Vec<u8>,
+ ) -> Result<()> {
+ not_impl_err!("No extension codec provided")
+ }
+
+ fn try_decode_udf(&self, name: &str, buf: &[u8]) ->
Result<Arc<ScalarUDF>> {
+ if name == "regex_udf" {
+ let proto = MyRegexUdfNode::decode(buf).map_err(|err| {
+ DataFusionError::Internal(format!(
+ "failed to decode regex_udf: {}",
+ err
+ ))
+ })?;
+
+ Ok(Arc::new(ScalarUDF::new_from_impl(MyRegexUdf::new(
+ proto.pattern,
+ ))))
+ } else {
+ not_impl_err!("unrecognized scalar UDF implementation, cannot
decode")
+ }
+ }
+
+ fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec<u8>) ->
Result<()> {
+ let binding = node.inner();
+ let udf = binding.as_any().downcast_ref::<MyRegexUdf>().unwrap();
+ let proto = MyRegexUdfNode {
+ pattern: udf.pattern.clone(),
+ };
+ proto.encode(buf).map_err(|e| {
+ DataFusionError::Internal(format!("failed to encode udf:
{e:?}"))
+ })?;
Review Comment:
nit: this works for the test but as an example I think gives a misleading
impressions of how to implement this. We don't want to fail of UDFs which don't
need special handling.
```suggestion
if let Some(udf) = binding.as_any().downcast_ref::<MyRegexUdf>()
{
let proto = MyRegexUdfNode {
pattern: udf.pattern.clone(),
};
proto.encode(buf).map_err(|e| {
DataFusionError::Internal(format!("failed to encode udf:
{e:?}"))
})?;
}
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]