This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 6d8bd2cb67 Implement physical plan serialization for csv COPY plans , 
add `as_any`, `Debug` to `FileFormatFactory` (#11588)
6d8bd2cb67 is described below

commit 6d8bd2cb670ec003929871b619aadc3967457ac1
Author: Lordworms <[email protected]>
AuthorDate: Tue Jul 23 10:57:41 2024 -0700

    Implement physical plan serialization for csv COPY plans , add `as_any`, 
`Debug` to `FileFormatFactory` (#11588)
    
    * Implement physical plan serialization for COPY plans 
CsvLogicalExtensionCodec
    
    * fix check
    
    * optimize code
    
    * optimize code
---
 datafusion-examples/examples/custom_file_format.rs |   6 +-
 .../core/src/datasource/file_format/arrow.rs       |   6 +-
 datafusion/core/src/datasource/file_format/avro.rs |  11 ++
 datafusion/core/src/datasource/file_format/csv.rs  |  15 +-
 datafusion/core/src/datasource/file_format/json.rs |  12 ++
 datafusion/core/src/datasource/file_format/mod.rs  |  16 ++-
 .../core/src/datasource/file_format/parquet.rs     |  11 ++
 datafusion/proto/src/logical_plan/file_formats.rs  | 154 +++++++++++++++++++--
 datafusion/proto/src/logical_plan/mod.rs           |   7 +-
 .../proto/tests/cases/roundtrip_logical_plan.rs    |  43 ++++--
 10 files changed, 251 insertions(+), 30 deletions(-)

diff --git a/datafusion-examples/examples/custom_file_format.rs 
b/datafusion-examples/examples/custom_file_format.rs
index bdb702375c..8612a1cc44 100644
--- a/datafusion-examples/examples/custom_file_format.rs
+++ b/datafusion-examples/examples/custom_file_format.rs
@@ -131,7 +131,7 @@ impl FileFormat for TSVFileFormat {
     }
 }
 
-#[derive(Default)]
+#[derive(Default, Debug)]
 /// Factory for creating TSV file formats
 ///
 /// This factory is a wrapper around the CSV file format factory
@@ -166,6 +166,10 @@ impl FileFormatFactory for TSVFileFactory {
     fn default(&self) -> std::sync::Arc<dyn FileFormat> {
         todo!()
     }
+
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
 }
 
 impl GetExt for TSVFileFactory {
diff --git a/datafusion/core/src/datasource/file_format/arrow.rs 
b/datafusion/core/src/datasource/file_format/arrow.rs
index 6bcbd43476..8b6a880011 100644
--- a/datafusion/core/src/datasource/file_format/arrow.rs
+++ b/datafusion/core/src/datasource/file_format/arrow.rs
@@ -66,7 +66,7 @@ const INITIAL_BUFFER_BYTES: usize = 1048576;
 /// If the buffered Arrow data exceeds this size, it is flushed to object store
 const BUFFER_FLUSH_BYTES: usize = 1024000;
 
-#[derive(Default)]
+#[derive(Default, Debug)]
 /// Factory struct used to create [ArrowFormat]
 pub struct ArrowFormatFactory;
 
@@ -89,6 +89,10 @@ impl FileFormatFactory for ArrowFormatFactory {
     fn default(&self) -> Arc<dyn FileFormat> {
         Arc::new(ArrowFormat)
     }
+
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
 }
 
 impl GetExt for ArrowFormatFactory {
diff --git a/datafusion/core/src/datasource/file_format/avro.rs 
b/datafusion/core/src/datasource/file_format/avro.rs
index f4f9adcba7..5190bdbe15 100644
--- a/datafusion/core/src/datasource/file_format/avro.rs
+++ b/datafusion/core/src/datasource/file_format/avro.rs
@@ -19,6 +19,7 @@
 
 use std::any::Any;
 use std::collections::HashMap;
+use std::fmt;
 use std::sync::Arc;
 
 use arrow::datatypes::Schema;
@@ -64,6 +65,16 @@ impl FileFormatFactory for AvroFormatFactory {
     fn default(&self) -> Arc<dyn FileFormat> {
         Arc::new(AvroFormat)
     }
+
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+}
+
+impl fmt::Debug for AvroFormatFactory {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        f.debug_struct("AvroFormatFactory").finish()
+    }
 }
 
 impl GetExt for AvroFormatFactory {
diff --git a/datafusion/core/src/datasource/file_format/csv.rs 
b/datafusion/core/src/datasource/file_format/csv.rs
index 958d2694aa..e1b6daac09 100644
--- a/datafusion/core/src/datasource/file_format/csv.rs
+++ b/datafusion/core/src/datasource/file_format/csv.rs
@@ -58,7 +58,8 @@ use object_store::{delimited::newline_delimited_stream, 
ObjectMeta, ObjectStore}
 #[derive(Default)]
 /// Factory struct used to create [CsvFormatFactory]
 pub struct CsvFormatFactory {
-    options: Option<CsvOptions>,
+    /// the options for csv file read
+    pub options: Option<CsvOptions>,
 }
 
 impl CsvFormatFactory {
@@ -75,6 +76,14 @@ impl CsvFormatFactory {
     }
 }
 
+impl fmt::Debug for CsvFormatFactory {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        f.debug_struct("CsvFormatFactory")
+            .field("options", &self.options)
+            .finish()
+    }
+}
+
 impl FileFormatFactory for CsvFormatFactory {
     fn create(
         &self,
@@ -103,6 +112,10 @@ impl FileFormatFactory for CsvFormatFactory {
     fn default(&self) -> Arc<dyn FileFormat> {
         Arc::new(CsvFormat::default())
     }
+
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
 }
 
 impl GetExt for CsvFormatFactory {
diff --git a/datafusion/core/src/datasource/file_format/json.rs 
b/datafusion/core/src/datasource/file_format/json.rs
index 007b084f50..9de9c3d7d8 100644
--- a/datafusion/core/src/datasource/file_format/json.rs
+++ b/datafusion/core/src/datasource/file_format/json.rs
@@ -102,6 +102,10 @@ impl FileFormatFactory for JsonFormatFactory {
     fn default(&self) -> Arc<dyn FileFormat> {
         Arc::new(JsonFormat::default())
     }
+
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
 }
 
 impl GetExt for JsonFormatFactory {
@@ -111,6 +115,14 @@ impl GetExt for JsonFormatFactory {
     }
 }
 
+impl fmt::Debug for JsonFormatFactory {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        f.debug_struct("JsonFormatFactory")
+            .field("options", &self.options)
+            .finish()
+    }
+}
+
 /// New line delimited JSON `FileFormat` implementation.
 #[derive(Debug, Default)]
 pub struct JsonFormat {
diff --git a/datafusion/core/src/datasource/file_format/mod.rs 
b/datafusion/core/src/datasource/file_format/mod.rs
index 1aa93a106a..500f20af47 100644
--- a/datafusion/core/src/datasource/file_format/mod.rs
+++ b/datafusion/core/src/datasource/file_format/mod.rs
@@ -49,11 +49,11 @@ use datafusion_physical_expr::{PhysicalExpr, 
PhysicalSortRequirement};
 use async_trait::async_trait;
 use file_compression_type::FileCompressionType;
 use object_store::{ObjectMeta, ObjectStore};
-
+use std::fmt::Debug;
 /// Factory for creating [`FileFormat`] instances based on session and command 
level options
 ///
 /// Users can provide their own `FileFormatFactory` to support arbitrary file 
formats
-pub trait FileFormatFactory: Sync + Send + GetExt {
+pub trait FileFormatFactory: Sync + Send + GetExt + Debug {
     /// Initialize a [FileFormat] and configure based on session and command 
level options
     fn create(
         &self,
@@ -63,6 +63,10 @@ pub trait FileFormatFactory: Sync + Send + GetExt {
 
     /// Initialize a [FileFormat] with all options set to default values
     fn default(&self) -> Arc<dyn FileFormat>;
+
+    /// Returns the table source as [`Any`] so that it can be
+    /// downcast to a specific implementation.
+    fn as_any(&self) -> &dyn Any;
 }
 
 /// This trait abstracts all the file format specific implementations
@@ -138,6 +142,7 @@ pub trait FileFormat: Send + Sync + fmt::Debug {
 /// The former trait is a superset of the latter trait, which includes 
execution time
 /// relevant methods. [FileType] is only used in logical planning and only 
implements
 /// the subset of methods required during logical planning.
+#[derive(Debug)]
 pub struct DefaultFileType {
     file_format_factory: Arc<dyn FileFormatFactory>,
 }
@@ -149,6 +154,11 @@ impl DefaultFileType {
             file_format_factory,
         }
     }
+
+    /// get a reference to the inner [FileFormatFactory] struct
+    pub fn as_format_factory(&self) -> &Arc<dyn FileFormatFactory> {
+        &self.file_format_factory
+    }
 }
 
 impl FileType for DefaultFileType {
@@ -159,7 +169,7 @@ impl FileType for DefaultFileType {
 
 impl Display for DefaultFileType {
     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
-        self.file_format_factory.default().fmt(f)
+        write!(f, "{:?}", self.file_format_factory)
     }
 }
 
diff --git a/datafusion/core/src/datasource/file_format/parquet.rs 
b/datafusion/core/src/datasource/file_format/parquet.rs
index d4e77b911c..3250b59fa1 100644
--- a/datafusion/core/src/datasource/file_format/parquet.rs
+++ b/datafusion/core/src/datasource/file_format/parquet.rs
@@ -140,6 +140,10 @@ impl FileFormatFactory for ParquetFormatFactory {
     fn default(&self) -> Arc<dyn FileFormat> {
         Arc::new(ParquetFormat::default())
     }
+
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
 }
 
 impl GetExt for ParquetFormatFactory {
@@ -149,6 +153,13 @@ impl GetExt for ParquetFormatFactory {
     }
 }
 
+impl fmt::Debug for ParquetFormatFactory {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        f.debug_struct("ParquetFormatFactory")
+            .field("ParquetFormatFactory", &self.options)
+            .finish()
+    }
+}
 /// The Apache Parquet `FileFormat` implementation
 #[derive(Debug, Default)]
 pub struct ParquetFormat {
diff --git a/datafusion/proto/src/logical_plan/file_formats.rs 
b/datafusion/proto/src/logical_plan/file_formats.rs
index 09e36a650b..2c4085b888 100644
--- a/datafusion/proto/src/logical_plan/file_formats.rs
+++ b/datafusion/proto/src/logical_plan/file_formats.rs
@@ -18,19 +18,129 @@
 use std::sync::Arc;
 
 use datafusion::{
+    config::CsvOptions,
     datasource::file_format::{
         arrow::ArrowFormatFactory, csv::CsvFormatFactory, 
json::JsonFormatFactory,
         parquet::ParquetFormatFactory, FileFormatFactory,
     },
     prelude::SessionContext,
 };
-use datafusion_common::{not_impl_err, TableReference};
+use datafusion_common::{
+    exec_err, not_impl_err, parsers::CompressionTypeVariant, DataFusionError,
+    TableReference,
+};
+use prost::Message;
+
+use crate::protobuf::CsvOptions as CsvOptionsProto;
 
 use super::LogicalExtensionCodec;
 
 #[derive(Debug)]
 pub struct CsvLogicalExtensionCodec;
 
+impl CsvOptionsProto {
+    fn from_factory(factory: &CsvFormatFactory) -> Self {
+        if let Some(options) = &factory.options {
+            CsvOptionsProto {
+                has_header: options.has_header.map_or(vec![], |v| vec![v as 
u8]),
+                delimiter: vec![options.delimiter],
+                quote: vec![options.quote],
+                escape: options.escape.map_or(vec![], |v| vec![v]),
+                double_quote: options.double_quote.map_or(vec![], |v| vec![v 
as u8]),
+                compression: options.compression as i32,
+                schema_infer_max_rec: options.schema_infer_max_rec as u64,
+                date_format: options.date_format.clone().unwrap_or_default(),
+                datetime_format: 
options.datetime_format.clone().unwrap_or_default(),
+                timestamp_format: 
options.timestamp_format.clone().unwrap_or_default(),
+                timestamp_tz_format: options
+                    .timestamp_tz_format
+                    .clone()
+                    .unwrap_or_default(),
+                time_format: options.time_format.clone().unwrap_or_default(),
+                null_value: options.null_value.clone().unwrap_or_default(),
+                comment: options.comment.map_or(vec![], |v| vec![v]),
+                newlines_in_values: options
+                    .newlines_in_values
+                    .map_or(vec![], |v| vec![v as u8]),
+            }
+        } else {
+            CsvOptionsProto::default()
+        }
+    }
+}
+
+impl From<&CsvOptionsProto> for CsvOptions {
+    fn from(proto: &CsvOptionsProto) -> Self {
+        CsvOptions {
+            has_header: if !proto.has_header.is_empty() {
+                Some(proto.has_header[0] != 0)
+            } else {
+                None
+            },
+            delimiter: proto.delimiter.first().copied().unwrap_or(b','),
+            quote: proto.quote.first().copied().unwrap_or(b'"'),
+            escape: if !proto.escape.is_empty() {
+                Some(proto.escape[0])
+            } else {
+                None
+            },
+            double_quote: if !proto.double_quote.is_empty() {
+                Some(proto.double_quote[0] != 0)
+            } else {
+                None
+            },
+            compression: match proto.compression {
+                0 => CompressionTypeVariant::GZIP,
+                1 => CompressionTypeVariant::BZIP2,
+                2 => CompressionTypeVariant::XZ,
+                3 => CompressionTypeVariant::ZSTD,
+                _ => CompressionTypeVariant::UNCOMPRESSED,
+            },
+            schema_infer_max_rec: proto.schema_infer_max_rec as usize,
+            date_format: if proto.date_format.is_empty() {
+                None
+            } else {
+                Some(proto.date_format.clone())
+            },
+            datetime_format: if proto.datetime_format.is_empty() {
+                None
+            } else {
+                Some(proto.datetime_format.clone())
+            },
+            timestamp_format: if proto.timestamp_format.is_empty() {
+                None
+            } else {
+                Some(proto.timestamp_format.clone())
+            },
+            timestamp_tz_format: if proto.timestamp_tz_format.is_empty() {
+                None
+            } else {
+                Some(proto.timestamp_tz_format.clone())
+            },
+            time_format: if proto.time_format.is_empty() {
+                None
+            } else {
+                Some(proto.time_format.clone())
+            },
+            null_value: if proto.null_value.is_empty() {
+                None
+            } else {
+                Some(proto.null_value.clone())
+            },
+            comment: if !proto.comment.is_empty() {
+                Some(proto.comment[0])
+            } else {
+                None
+            },
+            newlines_in_values: if proto.newlines_in_values.is_empty() {
+                None
+            } else {
+                Some(proto.newlines_in_values[0] != 0)
+            },
+        }
+    }
+}
+
 // TODO! This is a placeholder for now and needs to be implemented for real.
 impl LogicalExtensionCodec for CsvLogicalExtensionCodec {
     fn try_decode(
@@ -73,17 +183,41 @@ impl LogicalExtensionCodec for CsvLogicalExtensionCodec {
 
     fn try_decode_file_format(
         &self,
-        __buf: &[u8],
-        __ctx: &SessionContext,
+        buf: &[u8],
+        _ctx: &SessionContext,
     ) -> datafusion_common::Result<Arc<dyn FileFormatFactory>> {
-        Ok(Arc::new(CsvFormatFactory::new()))
+        let proto = CsvOptionsProto::decode(buf).map_err(|e| {
+            DataFusionError::Execution(format!(
+                "Failed to decode CsvOptionsProto: {:?}",
+                e
+            ))
+        })?;
+        let options: CsvOptions = (&proto).into();
+        Ok(Arc::new(CsvFormatFactory {
+            options: Some(options),
+        }))
     }
 
     fn try_encode_file_format(
         &self,
-        __buf: &[u8],
-        __node: Arc<dyn FileFormatFactory>,
+        buf: &mut Vec<u8>,
+        node: Arc<dyn FileFormatFactory>,
     ) -> datafusion_common::Result<()> {
+        let options =
+            if let Some(csv_factory) = 
node.as_any().downcast_ref::<CsvFormatFactory>() {
+                csv_factory.options.clone().unwrap_or_default()
+            } else {
+                return exec_err!("{}", "Unsupported FileFormatFactory 
type".to_string());
+            };
+
+        let proto = CsvOptionsProto::from_factory(&CsvFormatFactory {
+            options: Some(options),
+        });
+
+        proto.encode(buf).map_err(|e| {
+            DataFusionError::Execution(format!("Failed to encode CsvOptions: 
{:?}", e))
+        })?;
+
         Ok(())
     }
 }
@@ -141,7 +275,7 @@ impl LogicalExtensionCodec for JsonLogicalExtensionCodec {
 
     fn try_encode_file_format(
         &self,
-        __buf: &[u8],
+        __buf: &mut Vec<u8>,
         __node: Arc<dyn FileFormatFactory>,
     ) -> datafusion_common::Result<()> {
         Ok(())
@@ -201,7 +335,7 @@ impl LogicalExtensionCodec for ParquetLogicalExtensionCodec 
{
 
     fn try_encode_file_format(
         &self,
-        __buf: &[u8],
+        __buf: &mut Vec<u8>,
         __node: Arc<dyn FileFormatFactory>,
     ) -> datafusion_common::Result<()> {
         Ok(())
@@ -261,7 +395,7 @@ impl LogicalExtensionCodec for ArrowLogicalExtensionCodec {
 
     fn try_encode_file_format(
         &self,
-        __buf: &[u8],
+        __buf: &mut Vec<u8>,
         __node: Arc<dyn FileFormatFactory>,
     ) -> datafusion_common::Result<()> {
         Ok(())
@@ -321,7 +455,7 @@ impl LogicalExtensionCodec for AvroLogicalExtensionCodec {
 
     fn try_encode_file_format(
         &self,
-        __buf: &[u8],
+        __buf: &mut Vec<u8>,
         __node: Arc<dyn FileFormatFactory>,
     ) -> datafusion_common::Result<()> {
         Ok(())
diff --git a/datafusion/proto/src/logical_plan/mod.rs 
b/datafusion/proto/src/logical_plan/mod.rs
index 2a963fb13c..5427f34e8e 100644
--- a/datafusion/proto/src/logical_plan/mod.rs
+++ b/datafusion/proto/src/logical_plan/mod.rs
@@ -131,7 +131,7 @@ pub trait LogicalExtensionCodec: Debug + Send + Sync {
 
     fn try_encode_file_format(
         &self,
-        _buf: &[u8],
+        _buf: &mut Vec<u8>,
         _node: Arc<dyn FileFormatFactory>,
     ) -> Result<()> {
         Ok(())
@@ -1666,10 +1666,9 @@ impl AsLogicalPlan for LogicalPlanNode {
                     input,
                     extension_codec,
                 )?;
-
-                let buf = Vec::new();
+                let mut buf = Vec::new();
                 extension_codec
-                    .try_encode_file_format(&buf, 
file_type_to_format(file_type)?)?;
+                    .try_encode_file_format(&mut buf, 
file_type_to_format(file_type)?)?;
 
                 Ok(protobuf::LogicalPlanNode {
                     logical_plan_type: Some(LogicalPlanType::CopyTo(Box::new(
diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs 
b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
index f6557c7b2d..e17515086e 100644
--- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
@@ -15,12 +15,6 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use std::any::Any;
-use std::collections::HashMap;
-use std::fmt::{self, Debug, Formatter};
-use std::sync::Arc;
-use std::vec;
-
 use arrow::array::{
     ArrayRef, FixedSizeListArray, Int32Builder, MapArray, MapBuilder, 
StringBuilder,
 };
@@ -30,11 +24,16 @@ use arrow::datatypes::{
     DECIMAL256_MAX_PRECISION,
 };
 use prost::Message;
+use std::any::Any;
+use std::collections::HashMap;
+use std::fmt::{self, Debug, Formatter};
+use std::sync::Arc;
+use std::vec;
 
 use datafusion::datasource::file_format::arrow::ArrowFormatFactory;
 use datafusion::datasource::file_format::csv::CsvFormatFactory;
-use datafusion::datasource::file_format::format_as_file_type;
 use datafusion::datasource::file_format::parquet::ParquetFormatFactory;
+use datafusion::datasource::file_format::{format_as_file_type, 
DefaultFileType};
 use datafusion::datasource::provider::TableProviderFactory;
 use datafusion::datasource::TableProvider;
 use datafusion::execution::session_state::SessionStateBuilder;
@@ -380,7 +379,9 @@ async fn roundtrip_logical_plan_copy_to_writer_options() -> 
Result<()> {
     parquet_format.global.dictionary_page_size_limit = 444;
     parquet_format.global.max_row_group_size = 555;
 
-    let file_type = format_as_file_type(Arc::new(ParquetFormatFactory::new()));
+    let file_type = format_as_file_type(Arc::new(
+        ParquetFormatFactory::new_with_options(parquet_format),
+    ));
 
     let plan = LogicalPlan::Copy(CopyTo {
         input: Arc::new(input),
@@ -395,7 +396,6 @@ async fn roundtrip_logical_plan_copy_to_writer_options() -> 
Result<()> {
     let logical_round_trip =
         logical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &codec)?;
     assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}"));
-
     match logical_round_trip {
         LogicalPlan::Copy(copy_to) => {
             assert_eq!("test.parquet", copy_to.output_url);
@@ -458,7 +458,9 @@ async fn roundtrip_logical_plan_copy_to_csv() -> Result<()> 
{
     csv_format.time_format = Some("HH:mm:ss".to_string());
     csv_format.null_value = Some("NIL".to_string());
 
-    let file_type = format_as_file_type(Arc::new(CsvFormatFactory::new()));
+    let file_type = 
format_as_file_type(Arc::new(CsvFormatFactory::new_with_options(
+        csv_format.clone(),
+    )));
 
     let plan = LogicalPlan::Copy(CopyTo {
         input: Arc::new(input),
@@ -479,6 +481,27 @@ async fn roundtrip_logical_plan_copy_to_csv() -> 
Result<()> {
             assert_eq!("test.csv", copy_to.output_url);
             assert_eq!("csv".to_string(), copy_to.file_type.get_ext());
             assert_eq!(vec!["a", "b", "c"], copy_to.partition_by);
+
+            let file_type = copy_to
+                .file_type
+                .as_ref()
+                .as_any()
+                .downcast_ref::<DefaultFileType>()
+                .unwrap();
+
+            let format_factory = file_type.as_format_factory();
+            let csv_factory = format_factory
+                .as_ref()
+                .as_any()
+                .downcast_ref::<CsvFormatFactory>()
+                .unwrap();
+            let csv_config = csv_factory.options.as_ref().unwrap();
+            assert_eq!(csv_format.delimiter, csv_config.delimiter);
+            assert_eq!(csv_format.date_format, csv_config.date_format);
+            assert_eq!(csv_format.datetime_format, csv_config.datetime_format);
+            assert_eq!(csv_format.timestamp_format, 
csv_config.timestamp_format);
+            assert_eq!(csv_format.time_format, csv_config.time_format);
+            assert_eq!(csv_format.null_value, csv_config.null_value)
         }
         _ => panic!(),
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to