This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 32d49fb219 feat: support type cast in SchemaAdapter (#6404)
32d49fb219 is described below

commit 32d49fb219de58f76f0a53da0ab6ddbdb980d6b9
Author: elijah <[email protected]>
AuthorDate: Tue May 23 18:34:02 2023 +0800

    feat: support type cast in SchemaAdapter (#6404)
    
    * feat: support type cast in SchemaAdapter
    
    * make ci happy
    
    * improve the code
    
    * make ci happy
---
 .../core/src/physical_plan/file_format/mod.rs      | 191 +++++++++++++++++++++
 1 file changed, 191 insertions(+)

diff --git a/datafusion/core/src/physical_plan/file_format/mod.rs 
b/datafusion/core/src/physical_plan/file_format/mod.rs
index 1cdea092df..7e719e2ee9 100644
--- a/datafusion/core/src/physical_plan/file_format/mod.rs
+++ b/datafusion/core/src/physical_plan/file_format/mod.rs
@@ -53,6 +53,7 @@ use crate::{
     scalar::ScalarValue,
 };
 use arrow::array::new_null_array;
+use arrow::compute::can_cast_types;
 use arrow::record_batch::RecordBatchOptions;
 use datafusion_common::tree_node::{TreeNode, VisitRecursion};
 use datafusion_physical_expr::expressions::Column;
@@ -450,6 +451,77 @@ impl SchemaAdapter {
             &options,
         )?)
     }
+
+    /// Creates a `SchemaMapping` that can be used to cast or map the columns 
from the file schema to the table schema.
+    ///
+    /// If the provided `file_schema` contains columns of a different type to 
the expected
+    /// `table_schema`, the method will attempt to cast the array data from 
the file schema
+    /// to the table schema where possible.
+    #[allow(dead_code)]
+    pub fn map_schema(&self, file_schema: &Schema) -> Result<SchemaMapping> {
+        let mut field_mappings = Vec::new();
+
+        for (idx, field) in self.table_schema.fields().iter().enumerate() {
+            match file_schema.field_with_name(field.name()) {
+                Ok(file_field) => {
+                    if can_cast_types(file_field.data_type(), 
field.data_type()) {
+                        field_mappings.push((idx, field.data_type().clone()))
+                    } else {
+                        return Err(DataFusionError::Plan(format!(
+                            "Cannot cast file schema field {} of type {:?} to 
table schema field of type {:?}",
+                            field.name(),
+                            file_field.data_type(),
+                            field.data_type()
+                        )));
+                    }
+                }
+                Err(_) => {
+                    return Err(DataFusionError::Plan(format!(
+                        "File schema does not contain expected field {}",
+                        field.name()
+                    )));
+                }
+            }
+        }
+        Ok(SchemaMapping {
+            table_schema: self.table_schema.clone(),
+            field_mappings,
+        })
+    }
+}
+
+/// The SchemaMapping struct holds a mapping from the file schema to the table 
schema
+/// and any necessary type conversions that need to be applied.
+#[derive(Debug)]
+pub struct SchemaMapping {
+    #[allow(dead_code)]
+    table_schema: SchemaRef,
+    #[allow(dead_code)]
+    field_mappings: Vec<(usize, DataType)>,
+}
+
+impl SchemaMapping {
+    /// Adapts a `RecordBatch` to match the `table_schema` using the stored 
mapping and conversions.
+    #[allow(dead_code)]
+    fn map_batch(&self, batch: RecordBatch) -> Result<RecordBatch> {
+        let mut mapped_cols = Vec::with_capacity(self.field_mappings.len());
+
+        for (idx, data_type) in &self.field_mappings {
+            let array = batch.column(*idx);
+            let casted_array = arrow::compute::cast(array, data_type)?;
+            mapped_cols.push(casted_array);
+        }
+
+        // Necessary to handle empty batches
+        let options = 
RecordBatchOptions::new().with_row_count(Some(batch.num_rows()));
+
+        let record_batch = RecordBatch::try_new_with_options(
+            self.table_schema.clone(),
+            mapped_cols,
+            &options,
+        )?;
+        Ok(record_batch)
+    }
 }
 
 /// A helper that projects partition columns into the file record batches.
@@ -805,6 +877,9 @@ fn get_projected_output_ordering(
 
 #[cfg(test)]
 mod tests {
+    use arrow_array::cast::AsArray;
+    use arrow_array::types::{Float64Type, UInt32Type};
+    use arrow_array::{Float32Array, StringArray, UInt64Array};
     use chrono::Utc;
 
     use crate::{
@@ -1124,6 +1199,122 @@ mod tests {
         assert!(mapped.is_err());
     }
 
+    #[test]
+    fn schema_adapter_map_schema() {
+        let table_schema = Arc::new(Schema::new(vec![
+            Field::new("c1", DataType::Utf8, true),
+            Field::new("c2", DataType::UInt64, true),
+            Field::new("c3", DataType::Float64, true),
+        ]));
+
+        let adapter = SchemaAdapter::new(table_schema.clone());
+
+        // file schema matches table schema
+        let file_schema = Schema::new(vec![
+            Field::new("c1", DataType::Utf8, true),
+            Field::new("c2", DataType::UInt64, true),
+            Field::new("c3", DataType::Float64, true),
+        ]);
+
+        let mapping = adapter.map_schema(&file_schema).unwrap();
+
+        assert_eq!(
+            mapping.field_mappings,
+            vec![
+                (0, DataType::Utf8),
+                (1, DataType::UInt64),
+                (2, DataType::Float64),
+            ]
+        );
+        assert_eq!(mapping.table_schema, table_schema);
+
+        // file schema has columns of a different but castable type
+        let file_schema = Schema::new(vec![
+            Field::new("c1", DataType::Utf8, true),
+            Field::new("c2", DataType::Int32, true), // can be casted to UInt64
+            Field::new("c3", DataType::Float32, true), // can be casted to 
Float64
+        ]);
+
+        let mapping = adapter.map_schema(&file_schema).unwrap();
+
+        assert_eq!(
+            mapping.field_mappings,
+            vec![
+                (0, DataType::Utf8),
+                (1, DataType::UInt64),
+                (2, DataType::Float64),
+            ]
+        );
+        assert_eq!(mapping.table_schema, table_schema);
+
+        // file schema lacks necessary columns
+        let file_schema = Schema::new(vec![
+            Field::new("c1", DataType::Utf8, true),
+            Field::new("c2", DataType::Int32, true),
+        ]);
+
+        let err = adapter.map_schema(&file_schema).unwrap_err();
+
+        assert!(err
+            .to_string()
+            .contains("File schema does not contain expected field"));
+
+        // file schema has columns of a different and non-castable type
+        let file_schema = Schema::new(vec![
+            Field::new("c1", DataType::Utf8, true),
+            Field::new("c2", DataType::Int32, true),
+            Field::new("c3", DataType::Date64, true), // cannot be casted to 
Float64
+        ]);
+        let err = adapter.map_schema(&file_schema).unwrap_err();
+
+        assert!(err.to_string().contains("Cannot cast file schema field"));
+    }
+
+    #[test]
+    fn schema_mapping_map_batch() {
+        let table_schema = Arc::new(Schema::new(vec![
+            Field::new("c1", DataType::Utf8, true),
+            Field::new("c2", DataType::UInt32, true),
+            Field::new("c3", DataType::Float64, true),
+        ]));
+
+        let adapter = SchemaAdapter::new(table_schema.clone());
+
+        let file_schema = Schema::new(vec![
+            Field::new("c1", DataType::Utf8, true),
+            Field::new("c2", DataType::UInt64, true),
+            Field::new("c3", DataType::Float32, true),
+        ]);
+
+        let mapping = adapter.map_schema(&file_schema).expect("map schema 
failed");
+
+        let c1 = StringArray::from(vec!["hello", "world"]);
+        let c2 = UInt64Array::from(vec![9_u64, 5_u64]);
+        let c3 = Float32Array::from(vec![2.0_f32, 7.0_f32]);
+        let batch = RecordBatch::try_new(
+            Arc::new(file_schema),
+            vec![Arc::new(c1), Arc::new(c2), Arc::new(c3)],
+        )
+        .unwrap();
+
+        let mapped_batch = mapping.map_batch(batch).unwrap();
+
+        assert_eq!(mapped_batch.schema(), table_schema);
+        assert_eq!(mapped_batch.num_columns(), 3);
+        assert_eq!(mapped_batch.num_rows(), 2);
+
+        let c1 = mapped_batch.column(0).as_string::<i32>();
+        let c2 = mapped_batch.column(1).as_primitive::<UInt32Type>();
+        let c3 = mapped_batch.column(2).as_primitive::<Float64Type>();
+
+        assert_eq!(c1.value(0), "hello");
+        assert_eq!(c1.value(1), "world");
+        assert_eq!(c2.value(0), 9_u32);
+        assert_eq!(c2.value(1), 5_u32);
+        assert_eq!(c3.value(0), 2.0_f64);
+        assert_eq!(c3.value(1), 7.0_f64);
+    }
+
     // sets default for configs that play no role in projections
     fn config_for_projection(
         file_schema: SchemaRef,

Reply via email to