This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 32d49fb219 feat: support type cast in SchemaAdapter (#6404)
32d49fb219 is described below
commit 32d49fb219de58f76f0a53da0ab6ddbdb980d6b9
Author: elijah <[email protected]>
AuthorDate: Tue May 23 18:34:02 2023 +0800
feat: support type cast in SchemaAdapter (#6404)
* feat: support type cast in SchemaAdapter
* make ci happy
* improve the code
* make ci happy
---
.../core/src/physical_plan/file_format/mod.rs | 191 +++++++++++++++++++++
1 file changed, 191 insertions(+)
diff --git a/datafusion/core/src/physical_plan/file_format/mod.rs
b/datafusion/core/src/physical_plan/file_format/mod.rs
index 1cdea092df..7e719e2ee9 100644
--- a/datafusion/core/src/physical_plan/file_format/mod.rs
+++ b/datafusion/core/src/physical_plan/file_format/mod.rs
@@ -53,6 +53,7 @@ use crate::{
scalar::ScalarValue,
};
use arrow::array::new_null_array;
+use arrow::compute::can_cast_types;
use arrow::record_batch::RecordBatchOptions;
use datafusion_common::tree_node::{TreeNode, VisitRecursion};
use datafusion_physical_expr::expressions::Column;
@@ -450,6 +451,77 @@ impl SchemaAdapter {
&options,
)?)
}
+
+ /// Creates a `SchemaMapping` that can be used to cast or map the columns
from the file schema to the table schema.
+ ///
+ /// If the provided `file_schema` contains columns of a different type to
the expected
+ /// `table_schema`, the method will attempt to cast the array data from
the file schema
+ /// to the table schema where possible.
+ #[allow(dead_code)]
+ pub fn map_schema(&self, file_schema: &Schema) -> Result<SchemaMapping> {
+ let mut field_mappings = Vec::new();
+
+ for (idx, field) in self.table_schema.fields().iter().enumerate() {
+ match file_schema.field_with_name(field.name()) {
+ Ok(file_field) => {
+ if can_cast_types(file_field.data_type(),
field.data_type()) {
+ field_mappings.push((idx, field.data_type().clone()))
+ } else {
+ return Err(DataFusionError::Plan(format!(
+ "Cannot cast file schema field {} of type {:?} to
table schema field of type {:?}",
+ field.name(),
+ file_field.data_type(),
+ field.data_type()
+ )));
+ }
+ }
+ Err(_) => {
+ return Err(DataFusionError::Plan(format!(
+ "File schema does not contain expected field {}",
+ field.name()
+ )));
+ }
+ }
+ }
+ Ok(SchemaMapping {
+ table_schema: self.table_schema.clone(),
+ field_mappings,
+ })
+ }
+}
+
+/// The SchemaMapping struct holds a mapping from the file schema to the table
schema
+/// and any necessary type conversions that need to be applied.
+#[derive(Debug)]
+pub struct SchemaMapping {
+ #[allow(dead_code)]
+ table_schema: SchemaRef,
+ #[allow(dead_code)]
+ field_mappings: Vec<(usize, DataType)>,
+}
+
+impl SchemaMapping {
+ /// Adapts a `RecordBatch` to match the `table_schema` using the stored
mapping and conversions.
+ #[allow(dead_code)]
+ fn map_batch(&self, batch: RecordBatch) -> Result<RecordBatch> {
+ let mut mapped_cols = Vec::with_capacity(self.field_mappings.len());
+
+ for (idx, data_type) in &self.field_mappings {
+ let array = batch.column(*idx);
+ let casted_array = arrow::compute::cast(array, data_type)?;
+ mapped_cols.push(casted_array);
+ }
+
+ // Necessary to handle empty batches
+ let options =
RecordBatchOptions::new().with_row_count(Some(batch.num_rows()));
+
+ let record_batch = RecordBatch::try_new_with_options(
+ self.table_schema.clone(),
+ mapped_cols,
+ &options,
+ )?;
+ Ok(record_batch)
+ }
}
/// A helper that projects partition columns into the file record batches.
@@ -805,6 +877,9 @@ fn get_projected_output_ordering(
#[cfg(test)]
mod tests {
+ use arrow_array::cast::AsArray;
+ use arrow_array::types::{Float64Type, UInt32Type};
+ use arrow_array::{Float32Array, StringArray, UInt64Array};
use chrono::Utc;
use crate::{
@@ -1124,6 +1199,122 @@ mod tests {
assert!(mapped.is_err());
}
+ #[test]
+ fn schema_adapter_map_schema() {
+ let table_schema = Arc::new(Schema::new(vec![
+ Field::new("c1", DataType::Utf8, true),
+ Field::new("c2", DataType::UInt64, true),
+ Field::new("c3", DataType::Float64, true),
+ ]));
+
+ let adapter = SchemaAdapter::new(table_schema.clone());
+
+ // file schema matches table schema
+ let file_schema = Schema::new(vec![
+ Field::new("c1", DataType::Utf8, true),
+ Field::new("c2", DataType::UInt64, true),
+ Field::new("c3", DataType::Float64, true),
+ ]);
+
+ let mapping = adapter.map_schema(&file_schema).unwrap();
+
+ assert_eq!(
+ mapping.field_mappings,
+ vec![
+ (0, DataType::Utf8),
+ (1, DataType::UInt64),
+ (2, DataType::Float64),
+ ]
+ );
+ assert_eq!(mapping.table_schema, table_schema);
+
+ // file schema has columns of a different but castable type
+ let file_schema = Schema::new(vec![
+ Field::new("c1", DataType::Utf8, true),
+ Field::new("c2", DataType::Int32, true), // can be casted to UInt64
+ Field::new("c3", DataType::Float32, true), // can be casted to
Float64
+ ]);
+
+ let mapping = adapter.map_schema(&file_schema).unwrap();
+
+ assert_eq!(
+ mapping.field_mappings,
+ vec![
+ (0, DataType::Utf8),
+ (1, DataType::UInt64),
+ (2, DataType::Float64),
+ ]
+ );
+ assert_eq!(mapping.table_schema, table_schema);
+
+ // file schema lacks necessary columns
+ let file_schema = Schema::new(vec![
+ Field::new("c1", DataType::Utf8, true),
+ Field::new("c2", DataType::Int32, true),
+ ]);
+
+ let err = adapter.map_schema(&file_schema).unwrap_err();
+
+ assert!(err
+ .to_string()
+ .contains("File schema does not contain expected field"));
+
+ // file schema has columns of a different and non-castable type
+ let file_schema = Schema::new(vec![
+ Field::new("c1", DataType::Utf8, true),
+ Field::new("c2", DataType::Int32, true),
+ Field::new("c3", DataType::Date64, true), // cannot be casted to
Float64
+ ]);
+ let err = adapter.map_schema(&file_schema).unwrap_err();
+
+ assert!(err.to_string().contains("Cannot cast file schema field"));
+ }
+
+ #[test]
+ fn schema_mapping_map_batch() {
+ let table_schema = Arc::new(Schema::new(vec![
+ Field::new("c1", DataType::Utf8, true),
+ Field::new("c2", DataType::UInt32, true),
+ Field::new("c3", DataType::Float64, true),
+ ]));
+
+ let adapter = SchemaAdapter::new(table_schema.clone());
+
+ let file_schema = Schema::new(vec![
+ Field::new("c1", DataType::Utf8, true),
+ Field::new("c2", DataType::UInt64, true),
+ Field::new("c3", DataType::Float32, true),
+ ]);
+
+ let mapping = adapter.map_schema(&file_schema).expect("map schema
failed");
+
+ let c1 = StringArray::from(vec!["hello", "world"]);
+ let c2 = UInt64Array::from(vec![9_u64, 5_u64]);
+ let c3 = Float32Array::from(vec![2.0_f32, 7.0_f32]);
+ let batch = RecordBatch::try_new(
+ Arc::new(file_schema),
+ vec![Arc::new(c1), Arc::new(c2), Arc::new(c3)],
+ )
+ .unwrap();
+
+ let mapped_batch = mapping.map_batch(batch).unwrap();
+
+ assert_eq!(mapped_batch.schema(), table_schema);
+ assert_eq!(mapped_batch.num_columns(), 3);
+ assert_eq!(mapped_batch.num_rows(), 2);
+
+ let c1 = mapped_batch.column(0).as_string::<i32>();
+ let c2 = mapped_batch.column(1).as_primitive::<UInt32Type>();
+ let c3 = mapped_batch.column(2).as_primitive::<Float64Type>();
+
+ assert_eq!(c1.value(0), "hello");
+ assert_eq!(c1.value(1), "world");
+ assert_eq!(c2.value(0), 9_u32);
+ assert_eq!(c2.value(1), 5_u32);
+ assert_eq!(c3.value(0), 2.0_f64);
+ assert_eq!(c3.value(1), 7.0_f64);
+ }
+
// sets default for configs that play no role in projections
fn config_for_projection(
file_schema: SchemaRef,