This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 366e6100f fix: check recursion limit in `Expr::to_bytes` (#3970)
366e6100f is described below

commit 366e6100f217893e96151032686cd015c20f9e23
Author: Marco Neumann <[email protected]>
AuthorDate: Fri Oct 28 14:18:17 2022 +0000

    fix: check recursion limit in `Expr::to_bytes` (#3970)
    
    Install a DF-specific workaround until
    https://github.com/tokio-rs/prost/issues/736 is implemented.
    
    Fixes #3968.
---
 datafusion/proto/src/bytes/mod.rs | 78 +++++++++++++++++++++++++++++++++++++--
 1 file changed, 75 insertions(+), 3 deletions(-)

diff --git a/datafusion/proto/src/bytes/mod.rs 
b/datafusion/proto/src/bytes/mod.rs
index 7c8b94e5b..3677ea8af 100644
--- a/datafusion/proto/src/bytes/mod.rs
+++ b/datafusion/proto/src/bytes/mod.rs
@@ -20,8 +20,11 @@ use crate::logical_plan::{AsLogicalPlan, 
LogicalExtensionCodec};
 use crate::{from_proto::parse_expr, protobuf};
 use arrow::datatypes::SchemaRef;
 use datafusion::datasource::TableProvider;
+use datafusion::physical_plan::functions::make_scalar_function;
 use datafusion_common::{DataFusionError, Result};
-use datafusion_expr::{Expr, Extension, LogicalPlan};
+use datafusion_expr::{
+    create_udaf, create_udf, Expr, Extension, LogicalPlan, Volatility,
+};
 use prost::{
     bytes::{Bytes, BytesMut},
     Message,
@@ -83,7 +86,45 @@ impl Serializeable for Expr {
             DataFusionError::Plan(format!("Error encoding protobuf as bytes: 
{}", e))
         })?;
 
-        Ok(buffer.into())
+        let bytes: Bytes = buffer.into();
+
+        // the produced byte stream may lead to "recursion limit" errors, see
+        // https://github.com/apache/arrow-datafusion/issues/3968
+        // Until the underlying prost issue ( 
https://github.com/tokio-rs/prost/issues/736 ) is fixed, we try to
+        // deserialize the data here and check for errors.
+        //
+        // Need to provide some placeholder registry because the stream may 
contain UDFs
+        struct PlaceHolderRegistry;
+
+        impl FunctionRegistry for PlaceHolderRegistry {
+            fn udfs(&self) -> std::collections::HashSet<String> {
+                std::collections::HashSet::default()
+            }
+
+            fn udf(&self, name: &str) -> 
Result<Arc<datafusion_expr::ScalarUDF>> {
+                Ok(Arc::new(create_udf(
+                    name,
+                    vec![],
+                    Arc::new(arrow::datatypes::DataType::Null),
+                    Volatility::Immutable,
+                    make_scalar_function(|_| unimplemented!()),
+                )))
+            }
+
+            fn udaf(&self, name: &str) -> 
Result<Arc<datafusion_expr::AggregateUDF>> {
+                Ok(Arc::new(create_udaf(
+                    name,
+                    arrow::datatypes::DataType::Null,
+                    Arc::new(arrow::datatypes::DataType::Null),
+                    Volatility::Immutable,
+                    Arc::new(|_| unimplemented!()),
+                    Arc::new(vec![]),
+                )))
+            }
+        }
+        Expr::from_bytes_with_registry(&bytes, &PlaceHolderRegistry)?;
+
+        Ok(bytes)
     }
 
     fn from_bytes_with_registry(
@@ -212,7 +253,7 @@ mod test {
     use arrow::{array::ArrayRef, datatypes::DataType};
     use datafusion::physical_plan::functions::make_scalar_function;
     use datafusion::prelude::SessionContext;
-    use datafusion_expr::{create_udf, lit, Volatility};
+    use datafusion_expr::{col, create_udf, lit, Volatility};
     use std::sync::Arc;
 
     #[test]
@@ -280,6 +321,37 @@ mod test {
         Expr::from_bytes(&bytes).unwrap();
     }
 
+    #[test]
+    fn roundtrip_deeply_nested() {
+        // we need more stack space so this doesn't overflow in dev builds
+        std::thread::Builder::new().stack_size(10_000_000).spawn(|| {
+            // don't know what "too much" is, so let's slowly try to increase 
complexity
+            let n_max = 100;
+
+            for n in 1..n_max {
+                println!("testing: {n}");
+
+                let expr_base = col("a").lt(lit(5i32));
+                let expr = (0..n).fold(expr_base.clone(), |expr, _| 
expr.and(expr_base.clone()));
+
+                // Convert it to an opaque form
+                let bytes = match expr.to_bytes() {
+                    Ok(bytes) => bytes,
+                    Err(_) => {
+                        // found expression that is too deeply nested
+                        return;
+                    }
+                };
+
+                // Decode bytes from somewhere (over network, etc.
+                let decoded_expr = 
Expr::from_bytes(&bytes).expect("serialization worked, so deserialization 
should work as well");
+                assert_eq!(expr, decoded_expr);
+            }
+
+            panic!("did not find a 'too deeply nested' expression, tested up 
to a depth of {n_max}")
+        }).expect("spawning thread").join().expect("joining thread");
+    }
+
     /// return a `SessionContext` with a `dummy` function registered as a UDF
     fn context_with_udf() -> SessionContext {
         let fn_impl = |args: &[ArrayRef]| Ok(Arc::new(args[0].clone()) as 
ArrayRef);

Reply via email to