This is an automated email from the ASF dual-hosted git repository.
liurenjie1024 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-rust.git
The following commit(s) were added to refs/heads/main by this push:
new 0a06c3e24 refactor(arrow,datafusion): Reuse PartitionValueCalculator
in RecordBatchPartitionSplitter (#1781)
0a06c3e24 is described below
commit 0a06c3e241a297998ebdcdf7c20f22d03c4e23f2
Author: Shawn Chang <[email protected]>
AuthorDate: Tue Oct 28 02:28:38 2025 -0700
refactor(arrow,datafusion): Reuse PartitionValueCalculator in
RecordBatchPartitionSplitter (#1781)
## Which issue does this PR close?
- Closes #1786
- Covered some of changes from the previous draft: #1769
## What changes are included in this PR?
- Move PartitionValueCalculator to core/arrow so it can be reused by
RecordBatchPartitionSplitter
- Allow skipping partition value calculation in partition splitter for
projected batches
- Return <PartitionKey, RecordBatch> rather than <Struct, RecordBatch>
pairs in RecordBatchPartitionSplitter::split
## Are these changes tested?
Added uts
---
crates/iceberg/src/arrow/mod.rs | 7 +-
.../src/arrow/partition_value_calculator.rs | 254 ++++++++++++++
.../src/arrow/record_batch_partition_splitter.rs | 377 +++++++++++++++------
.../datafusion/src/physical_plan/project.rs | 190 +++--------
4 files changed, 580 insertions(+), 248 deletions(-)
diff --git a/crates/iceberg/src/arrow/mod.rs b/crates/iceberg/src/arrow/mod.rs
index 28116a4b5..c091c4517 100644
--- a/crates/iceberg/src/arrow/mod.rs
+++ b/crates/iceberg/src/arrow/mod.rs
@@ -35,4 +35,9 @@ mod value;
pub use reader::*;
pub use value::*;
-pub(crate) mod record_batch_partition_splitter;
+/// Partition value calculator for computing partition values
+pub mod partition_value_calculator;
+pub use partition_value_calculator::*;
+/// Record batch partition splitter for partitioned tables
+pub mod record_batch_partition_splitter;
+pub use record_batch_partition_splitter::*;
diff --git a/crates/iceberg/src/arrow/partition_value_calculator.rs
b/crates/iceberg/src/arrow/partition_value_calculator.rs
new file mode 100644
index 000000000..140950345
--- /dev/null
+++ b/crates/iceberg/src/arrow/partition_value_calculator.rs
@@ -0,0 +1,254 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Partition value calculation for Iceberg tables.
+//!
+//! This module provides utilities for calculating partition values from
record batches
+//! based on a partition specification.
+
+use std::sync::Arc;
+
+use arrow_array::{ArrayRef, RecordBatch, StructArray};
+use arrow_schema::DataType;
+
+use super::record_batch_projector::RecordBatchProjector;
+use super::type_to_arrow_type;
+use crate::spec::{PartitionSpec, Schema, StructType, Type};
+use crate::transform::{BoxedTransformFunction, create_transform_function};
+use crate::{Error, ErrorKind, Result};
+
+/// Calculator for partition values in Iceberg tables.
+///
+/// This struct handles the projection of source columns and application of
+/// partition transforms to compute partition values for a given record batch.
+#[derive(Debug)]
+pub struct PartitionValueCalculator {
+ projector: RecordBatchProjector,
+ transform_functions: Vec<BoxedTransformFunction>,
+ partition_type: StructType,
+ partition_arrow_type: DataType,
+}
+
+impl PartitionValueCalculator {
+ /// Create a new PartitionValueCalculator.
+ ///
+ /// # Arguments
+ ///
+ /// * `partition_spec` - The partition specification
+ /// * `table_schema` - The Iceberg table schema
+ ///
+ /// # Returns
+ ///
+ /// Returns a new `PartitionValueCalculator` instance or an error if
initialization fails.
+ ///
+ /// # Errors
+ ///
+ /// Returns an error if:
+ /// - The partition spec is unpartitioned
+ /// - Transform function creation fails
+ /// - Projector initialization fails
+ pub fn try_new(partition_spec: &PartitionSpec, table_schema: &Schema) ->
Result<Self> {
+ if partition_spec.is_unpartitioned() {
+ return Err(Error::new(
+ ErrorKind::DataInvalid,
+ "Cannot create partition calculator for unpartitioned table",
+ ));
+ }
+
+ // Create transform functions for each partition field
+ let transform_functions: Vec<BoxedTransformFunction> = partition_spec
+ .fields()
+ .iter()
+ .map(|pf| create_transform_function(&pf.transform))
+ .collect::<Result<Vec<_>>>()?;
+
+ // Extract source field IDs for projection
+ let source_field_ids: Vec<i32> = partition_spec
+ .fields()
+ .iter()
+ .map(|pf| pf.source_id)
+ .collect();
+
+ // Create projector for extracting source columns
+ let projector = RecordBatchProjector::from_iceberg_schema(
+ Arc::new(table_schema.clone()),
+ &source_field_ids,
+ )?;
+
+ // Get partition type information
+ let partition_type = partition_spec.partition_type(table_schema)?;
+ let partition_arrow_type =
type_to_arrow_type(&Type::Struct(partition_type.clone()))?;
+
+ Ok(Self {
+ projector,
+ transform_functions,
+ partition_type,
+ partition_arrow_type,
+ })
+ }
+
+ /// Get the partition type as an Iceberg StructType.
+ pub fn partition_type(&self) -> &StructType {
+ &self.partition_type
+ }
+
+ /// Get the partition type as an Arrow DataType.
+ pub fn partition_arrow_type(&self) -> &DataType {
+ &self.partition_arrow_type
+ }
+
+ /// Calculate partition values for a record batch.
+ ///
+ /// This method:
+ /// 1. Projects the source columns from the batch
+ /// 2. Applies partition transforms to each source column
+ /// 3. Constructs a StructArray containing the partition values
+ ///
+ /// # Arguments
+ ///
+ /// * `batch` - The record batch to calculate partition values for
+ ///
+ /// # Returns
+ ///
+ /// Returns an ArrayRef containing a StructArray of partition values, or
an error if calculation fails.
+ ///
+ /// # Errors
+ ///
+ /// Returns an error if:
+ /// - Column projection fails
+ /// - Transform application fails
+ /// - StructArray construction fails
+ pub fn calculate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
+ // Project source columns from the batch
+ let source_columns = self.projector.project_column(batch.columns())?;
+
+ // Get expected struct fields for the result
+ let expected_struct_fields = match &self.partition_arrow_type {
+ DataType::Struct(fields) => fields.clone(),
+ _ => {
+ return Err(Error::new(
+ ErrorKind::DataInvalid,
+ "Expected partition type must be a struct",
+ ));
+ }
+ };
+
+ // Apply transforms to each source column
+ let mut partition_values =
Vec::with_capacity(self.transform_functions.len());
+ for (source_column, transform_fn) in
source_columns.iter().zip(&self.transform_functions) {
+ let partition_value =
transform_fn.transform(source_column.clone())?;
+ partition_values.push(partition_value);
+ }
+
+ // Construct the StructArray
+ let struct_array = StructArray::try_new(expected_struct_fields,
partition_values, None)
+ .map_err(|e| {
+ Error::new(
+ ErrorKind::DataInvalid,
+ format!("Failed to create partition struct array: {}", e),
+ )
+ })?;
+
+ Ok(Arc::new(struct_array))
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use std::sync::Arc;
+
+ use arrow_array::{Int32Array, RecordBatch, StringArray};
+ use arrow_schema::{Field, Schema as ArrowSchema};
+
+ use super::*;
+ use crate::spec::{NestedField, PartitionSpecBuilder, PrimitiveType,
Transform};
+
+ #[test]
+ fn test_partition_calculator_identity_transform() {
+ let table_schema = Schema::builder()
+ .with_schema_id(0)
+ .with_fields(vec![
+ NestedField::required(1, "id",
Type::Primitive(PrimitiveType::Int)).into(),
+ NestedField::required(2, "name",
Type::Primitive(PrimitiveType::String)).into(),
+ ])
+ .build()
+ .unwrap();
+
+ let partition_spec =
PartitionSpecBuilder::new(Arc::new(table_schema.clone()))
+ .add_partition_field("id", "id_partition", Transform::Identity)
+ .unwrap()
+ .build()
+ .unwrap();
+
+ let calculator = PartitionValueCalculator::try_new(&partition_spec,
&table_schema).unwrap();
+
+ // Verify partition type
+ assert_eq!(calculator.partition_type().fields().len(), 1);
+ assert_eq!(calculator.partition_type().fields()[0].name,
"id_partition");
+
+ // Create test batch
+ let arrow_schema = Arc::new(ArrowSchema::new(vec![
+ Field::new("id", DataType::Int32, false),
+ Field::new("name", DataType::Utf8, false),
+ ]));
+
+ let batch = RecordBatch::try_new(arrow_schema, vec![
+ Arc::new(Int32Array::from(vec![10, 20, 30])),
+ Arc::new(StringArray::from(vec!["a", "b", "c"])),
+ ])
+ .unwrap();
+
+ // Calculate partition values
+ let result = calculator.calculate(&batch).unwrap();
+ let struct_array =
result.as_any().downcast_ref::<StructArray>().unwrap();
+
+ let id_partition = struct_array
+ .column_by_name("id_partition")
+ .unwrap()
+ .as_any()
+ .downcast_ref::<Int32Array>()
+ .unwrap();
+
+ assert_eq!(id_partition.value(0), 10);
+ assert_eq!(id_partition.value(1), 20);
+ assert_eq!(id_partition.value(2), 30);
+ }
+
+ #[test]
+ fn test_partition_calculator_unpartitioned_error() {
+ let table_schema = Schema::builder()
+ .with_schema_id(0)
+ .with_fields(vec![
+ NestedField::required(1, "id",
Type::Primitive(PrimitiveType::Int)).into(),
+ ])
+ .build()
+ .unwrap();
+
+ let partition_spec =
PartitionSpecBuilder::new(Arc::new(table_schema.clone()))
+ .build()
+ .unwrap();
+
+ let result = PartitionValueCalculator::try_new(&partition_spec,
&table_schema);
+ assert!(result.is_err());
+ assert!(
+ result
+ .unwrap_err()
+ .to_string()
+ .contains("unpartitioned table")
+ );
+ }
+}
diff --git a/crates/iceberg/src/arrow/record_batch_partition_splitter.rs
b/crates/iceberg/src/arrow/record_batch_partition_splitter.rs
index 704a4e9c1..66371fac1 100644
--- a/crates/iceberg/src/arrow/record_batch_partition_splitter.rs
+++ b/crates/iceberg/src/arrow/record_batch_partition_splitter.rs
@@ -19,137 +19,169 @@ use std::collections::HashMap;
use std::sync::Arc;
use arrow_array::{ArrayRef, BooleanArray, RecordBatch, StructArray};
-use arrow_schema::{DataType, SchemaRef as ArrowSchemaRef};
use arrow_select::filter::filter_record_batch;
-use itertools::Itertools;
-use parquet::arrow::PARQUET_FIELD_ID_META_KEY;
use super::arrow_struct_to_literal;
-use super::record_batch_projector::RecordBatchProjector;
-use crate::arrow::type_to_arrow_type;
-use crate::spec::{Literal, PartitionSpecRef, SchemaRef, Struct, StructType,
Type};
-use crate::transform::{BoxedTransformFunction, create_transform_function};
+use super::partition_value_calculator::PartitionValueCalculator;
+use crate::spec::{Literal, PartitionKey, PartitionSpecRef, SchemaRef,
StructType};
use crate::{Error, ErrorKind, Result};
+/// Column name for the projected partition values struct
+pub const PROJECTED_PARTITION_VALUE_COLUMN: &str = "_partition";
+
/// The splitter used to split the record batch into multiple record batches
by the partition spec.
/// 1. It will project and transform the input record batch based on the
partition spec, get the partitioned record batch.
/// 2. Split the input record batch into multiple record batches based on the
partitioned record batch.
+///
+/// # Partition Value Modes
+///
+/// The splitter supports two modes for obtaining partition values:
+/// - **Computed mode** (`calculator` is `Some`): Computes partition values
from source columns using transforms
+/// - **Pre-computed mode** (`calculator` is `None`): Expects a `_partition`
column in the input batch
// # TODO
// Remove this after partition writer supported.
#[allow(dead_code)]
pub struct RecordBatchPartitionSplitter {
schema: SchemaRef,
partition_spec: PartitionSpecRef,
- projector: RecordBatchProjector,
- transform_functions: Vec<BoxedTransformFunction>,
-
+ calculator: Option<PartitionValueCalculator>,
partition_type: StructType,
- partition_arrow_type: DataType,
}
// # TODO
// Remove this after partition writer supported.
#[allow(dead_code)]
impl RecordBatchPartitionSplitter {
+ /// Create a new RecordBatchPartitionSplitter.
+ ///
+ /// # Arguments
+ ///
+ /// * `iceberg_schema` - The Iceberg schema reference
+ /// * `partition_spec` - The partition specification reference
+ /// * `calculator` - Optional calculator for computing partition values
from source columns.
+ /// - `Some(calculator)`: Compute partition values from source columns
using transforms
+ /// - `None`: Expect a pre-computed `_partition` column in the input
batch
+ ///
+ /// # Returns
+ ///
+ /// Returns a new `RecordBatchPartitionSplitter` instance or an error if
initialization fails.
pub fn new(
- input_schema: ArrowSchemaRef,
iceberg_schema: SchemaRef,
partition_spec: PartitionSpecRef,
+ calculator: Option<PartitionValueCalculator>,
) -> Result<Self> {
- let projector = RecordBatchProjector::new(
- input_schema,
- &partition_spec
- .fields()
- .iter()
- .map(|field| field.source_id)
- .collect::<Vec<_>>(),
- // The source columns, selected by ids, must be a primitive type
and cannot be contained in a map or list, but may be nested in a struct.
- // ref: https://iceberg.apache.org/spec/#partitioning
- |field| {
- if !field.data_type().is_primitive() {
- return Ok(None);
- }
- field
- .metadata()
- .get(PARQUET_FIELD_ID_META_KEY)
- .map(|s| {
- s.parse::<i64>()
- .map_err(|e| Error::new(ErrorKind::Unexpected,
e.to_string()))
- })
- .transpose()
- },
- |_| true,
- )?;
- let transform_functions = partition_spec
- .fields()
- .iter()
- .map(|field| create_transform_function(&field.transform))
- .collect::<Result<Vec<_>>>()?;
-
let partition_type = partition_spec.partition_type(&iceberg_schema)?;
- let partition_arrow_type =
type_to_arrow_type(&Type::Struct(partition_type.clone()))?;
Ok(Self {
schema: iceberg_schema,
partition_spec,
- projector,
- transform_functions,
+ calculator,
partition_type,
- partition_arrow_type,
})
}
- fn partition_columns_to_struct(&self, partition_columns: Vec<ArrayRef>) ->
Result<Vec<Struct>> {
- let arrow_struct_array = {
- let partition_arrow_fields = {
- let DataType::Struct(fields) = &self.partition_arrow_type else
{
- return Err(Error::new(
- ErrorKind::DataInvalid,
- "The partition arrow type is not a struct type",
- ));
- };
- fields.clone()
- };
- Arc::new(StructArray::try_new(
- partition_arrow_fields,
- partition_columns,
- None,
- )?) as ArrayRef
- };
- let struct_array = {
- let struct_array = arrow_struct_to_literal(&arrow_struct_array,
&self.partition_type)?;
+ /// Create a new RecordBatchPartitionSplitter with computed partition
values.
+ ///
+ /// This is a convenience method that creates a calculator and initializes
the splitter
+ /// to compute partition values from source columns.
+ ///
+ /// # Arguments
+ ///
+ /// * `iceberg_schema` - The Iceberg schema reference
+ /// * `partition_spec` - The partition specification reference
+ ///
+ /// # Returns
+ ///
+ /// Returns a new `RecordBatchPartitionSplitter` instance or an error if
initialization fails.
+ pub fn new_with_computed_values(
+ iceberg_schema: SchemaRef,
+ partition_spec: PartitionSpecRef,
+ ) -> Result<Self> {
+ let calculator = PartitionValueCalculator::try_new(&partition_spec,
&iceberg_schema)?;
+ Self::new(iceberg_schema, partition_spec, Some(calculator))
+ }
+
+ /// Create a new RecordBatchPartitionSplitter expecting pre-computed
partition values.
+ ///
+ /// This is a convenience method that initializes the splitter to expect a
`_partition`
+ /// column in the input batches.
+ ///
+ /// # Arguments
+ ///
+ /// * `iceberg_schema` - The Iceberg schema reference
+ /// * `partition_spec` - The partition specification reference
+ ///
+ /// # Returns
+ ///
+ /// Returns a new `RecordBatchPartitionSplitter` instance or an error if
initialization fails.
+ pub fn new_with_precomputed_values(
+ iceberg_schema: SchemaRef,
+ partition_spec: PartitionSpecRef,
+ ) -> Result<Self> {
+ Self::new(iceberg_schema, partition_spec, None)
+ }
+
+ /// Split the record batch into multiple record batches based on the
partition spec.
+ pub fn split(&self, batch: &RecordBatch) -> Result<Vec<(PartitionKey,
RecordBatch)>> {
+ let partition_structs = if let Some(calculator) = &self.calculator {
+ // Compute partition values from source columns using calculator
+ let partition_array = calculator.calculate(batch)?;
+ let struct_array = arrow_struct_to_literal(&partition_array,
&self.partition_type)?;
+
struct_array
.into_iter()
.map(|s| {
- if let Some(s) = s {
- if let Literal::Struct(s) = s {
- Ok(s)
- } else {
- Err(Error::new(
- ErrorKind::DataInvalid,
- "The struct is not a struct literal",
- ))
- }
+ if let Some(Literal::Struct(s)) = s {
+ Ok(s)
} else {
- Err(Error::new(ErrorKind::DataInvalid, "The struct is
null"))
+ Err(Error::new(
+ ErrorKind::DataInvalid,
+ "Partition value is not a struct literal or is
null",
+ ))
}
})
.collect::<Result<Vec<_>>>()?
- };
+ } else {
+ // Extract partition values from pre-computed partition column
+ let partition_column = batch
+ .column_by_name(PROJECTED_PARTITION_VALUE_COLUMN)
+ .ok_or_else(|| {
+ Error::new(
+ ErrorKind::DataInvalid,
+ format!(
+ "Partition column '{}' not found in batch",
+ PROJECTED_PARTITION_VALUE_COLUMN
+ ),
+ )
+ })?;
- Ok(struct_array)
- }
+ let partition_struct_array = partition_column
+ .as_any()
+ .downcast_ref::<StructArray>()
+ .ok_or_else(|| {
+ Error::new(
+ ErrorKind::DataInvalid,
+ "Partition column is not a StructArray",
+ )
+ })?;
- /// Split the record batch into multiple record batches based on the
partition spec.
- pub fn split(&self, batch: &RecordBatch) -> Result<Vec<(Struct,
RecordBatch)>> {
- let source_columns = self.projector.project_column(batch.columns())?;
- let partition_columns = source_columns
- .into_iter()
- .zip_eq(self.transform_functions.iter())
- .map(|(source_column, transform_function)|
transform_function.transform(source_column))
- .collect::<Result<Vec<_>>>()?;
+ let arrow_struct_array = Arc::new(partition_struct_array.clone())
as ArrayRef;
+ let struct_array = arrow_struct_to_literal(&arrow_struct_array,
&self.partition_type)?;
- let partition_structs =
self.partition_columns_to_struct(partition_columns)?;
+ struct_array
+ .into_iter()
+ .map(|s| {
+ if let Some(Literal::Struct(s)) = s {
+ Ok(s)
+ } else {
+ Err(Error::new(
+ ErrorKind::DataInvalid,
+ "Partition value is not a struct literal or is
null",
+ ))
+ }
+ })
+ .collect::<Result<Vec<_>>>()?
+ };
// Group the batch by row value.
let mut group_ids = HashMap::new();
@@ -172,8 +204,15 @@ impl RecordBatchPartitionSplitter {
filter.into()
};
+ // Create PartitionKey from the partition struct
+ let partition_key = PartitionKey::new(
+ self.partition_spec.as_ref().clone(),
+ self.schema.clone(),
+ row,
+ );
+
// filter the RecordBatch
- partition_batches.push((row, filter_record_batch(batch,
&filter_array)?));
+ partition_batches.push((partition_key, filter_record_batch(batch,
&filter_array)?));
}
Ok(partition_batches)
@@ -185,11 +224,13 @@ mod tests {
use std::sync::Arc;
use arrow_array::{Int32Array, RecordBatch, StringArray};
+ use arrow_schema::DataType;
+ use parquet::arrow::PARQUET_FIELD_ID_META_KEY;
use super::*;
use crate::arrow::schema_to_arrow_schema;
use crate::spec::{
- NestedField, PartitionSpecBuilder, PrimitiveLiteral, Schema, Transform,
+ NestedField, PartitionSpecBuilder, PrimitiveLiteral, Schema, Struct,
Transform, Type,
UnboundPartitionField,
};
@@ -227,14 +268,14 @@ mod tests {
.build()
.unwrap(),
);
- let input_schema = Arc::new(schema_to_arrow_schema(&schema).unwrap());
let partition_splitter =
- RecordBatchPartitionSplitter::new(input_schema.clone(),
schema.clone(), partition_spec)
+
RecordBatchPartitionSplitter::new_with_computed_values(schema.clone(),
partition_spec)
.expect("Failed to create splitter");
+ let arrow_schema = Arc::new(schema_to_arrow_schema(&schema).unwrap());
let id_array = Int32Array::from(vec![1, 2, 1, 3, 2, 3, 1]);
let data_array = StringArray::from(vec!["a", "b", "c", "d", "e", "f",
"g"]);
- let batch = RecordBatch::try_new(input_schema.clone(), vec![
+ let batch = RecordBatch::try_new(arrow_schema.clone(), vec![
Arc::new(id_array),
Arc::new(data_array),
])
@@ -243,8 +284,8 @@ mod tests {
let mut partitioned_batches = partition_splitter
.split(&batch)
.expect("Failed to split RecordBatch");
- partitioned_batches.sort_by_key(|(row, _)| {
- if let PrimitiveLiteral::Int(i) = row.fields()[0]
+ partitioned_batches.sort_by_key(|(partition_key, _)| {
+ if let PrimitiveLiteral::Int(i) = partition_key.data().fields()[0]
.as_ref()
.unwrap()
.as_primitive_literal()
@@ -260,7 +301,7 @@ mod tests {
// check the first partition
let expected_id_array = Int32Array::from(vec![1, 1, 1]);
let expected_data_array = StringArray::from(vec!["a", "c", "g"]);
- let expected_batch = RecordBatch::try_new(input_schema.clone(),
vec![
+ let expected_batch = RecordBatch::try_new(arrow_schema.clone(),
vec![
Arc::new(expected_id_array),
Arc::new(expected_data_array),
])
@@ -271,7 +312,7 @@ mod tests {
// check the second partition
let expected_id_array = Int32Array::from(vec![2, 2]);
let expected_data_array = StringArray::from(vec!["b", "e"]);
- let expected_batch = RecordBatch::try_new(input_schema.clone(),
vec![
+ let expected_batch = RecordBatch::try_new(arrow_schema.clone(),
vec![
Arc::new(expected_id_array),
Arc::new(expected_data_array),
])
@@ -282,7 +323,7 @@ mod tests {
// check the third partition
let expected_id_array = Int32Array::from(vec![3, 3]);
let expected_data_array = StringArray::from(vec!["d", "f"]);
- let expected_batch = RecordBatch::try_new(input_schema.clone(),
vec![
+ let expected_batch = RecordBatch::try_new(arrow_schema.clone(),
vec![
Arc::new(expected_id_array),
Arc::new(expected_data_array),
])
@@ -292,7 +333,7 @@ mod tests {
let partition_values = partitioned_batches
.iter()
- .map(|(row, _)| row.clone())
+ .map(|(partition_key, _)| partition_key.data().clone())
.collect::<Vec<_>>();
// check partition value is struct(1), struct(2), struct(3)
assert_eq!(partition_values, vec![
@@ -301,4 +342,144 @@ mod tests {
Struct::from_iter(vec![Some(Literal::int(3))]),
]);
}
+
+ #[test]
+ fn test_record_batch_partition_split_with_partition_column() {
+ use arrow_array::StructArray;
+ use arrow_schema::{Field, Schema as ArrowSchema};
+
+ let schema = Arc::new(
+ Schema::builder()
+ .with_fields(vec![
+ NestedField::required(
+ 1,
+ "id",
+ Type::Primitive(crate::spec::PrimitiveType::Int),
+ )
+ .into(),
+ NestedField::required(
+ 2,
+ "name",
+ Type::Primitive(crate::spec::PrimitiveType::String),
+ )
+ .into(),
+ ])
+ .build()
+ .unwrap(),
+ );
+ let partition_spec = Arc::new(
+ PartitionSpecBuilder::new(schema.clone())
+ .with_spec_id(1)
+ .add_unbound_field(UnboundPartitionField {
+ source_id: 1,
+ field_id: None,
+ name: "id_bucket".to_string(),
+ transform: Transform::Identity,
+ })
+ .unwrap()
+ .build()
+ .unwrap(),
+ );
+
+ // Create input schema with _partition column
+ // Note: partition field IDs start from 1000 by default
+ let partition_field = Field::new("id_bucket", DataType::Int32,
false).with_metadata(
+ HashMap::from([(PARQUET_FIELD_ID_META_KEY.to_string(),
"1000".to_string())]),
+ );
+ let partition_struct_field = Field::new(
+ PROJECTED_PARTITION_VALUE_COLUMN,
+ DataType::Struct(vec![partition_field.clone()].into()),
+ false,
+ );
+
+ let input_schema = Arc::new(ArrowSchema::new(vec![
+ Field::new("id", DataType::Int32, false),
+ Field::new("name", DataType::Utf8, false),
+ partition_struct_field,
+ ]));
+
+ // Create splitter expecting pre-computed partition column
+ let partition_splitter =
RecordBatchPartitionSplitter::new_with_precomputed_values(
+ schema.clone(),
+ partition_spec,
+ )
+ .expect("Failed to create splitter");
+
+ // Create test data with pre-computed partition column
+ let id_array = Int32Array::from(vec![1, 2, 1, 3, 2, 3, 1]);
+ let data_array = StringArray::from(vec!["a", "b", "c", "d", "e", "f",
"g"]);
+
+ // Create partition column (same values as id for Identity transform)
+ let partition_values = Int32Array::from(vec![1, 2, 1, 3, 2, 3, 1]);
+ let partition_struct = StructArray::from(vec![(
+ Arc::new(partition_field),
+ Arc::new(partition_values) as ArrayRef,
+ )]);
+
+ let batch = RecordBatch::try_new(input_schema.clone(), vec![
+ Arc::new(id_array),
+ Arc::new(data_array),
+ Arc::new(partition_struct),
+ ])
+ .expect("Failed to create RecordBatch");
+
+ // Split using the pre-computed partition column
+ let mut partitioned_batches = partition_splitter
+ .split(&batch)
+ .expect("Failed to split RecordBatch");
+
+ partitioned_batches.sort_by_key(|(partition_key, _)| {
+ if let PrimitiveLiteral::Int(i) = partition_key.data().fields()[0]
+ .as_ref()
+ .unwrap()
+ .as_primitive_literal()
+ .unwrap()
+ {
+ i
+ } else {
+ panic!("The partition value is not a int");
+ }
+ });
+
+ assert_eq!(partitioned_batches.len(), 3);
+
+ // Helper to extract id and name values from a batch
+ let extract_values = |batch: &RecordBatch| -> (Vec<i32>, Vec<String>) {
+ let id_col = batch
+ .column(0)
+ .as_any()
+ .downcast_ref::<Int32Array>()
+ .unwrap();
+ let name_col = batch
+ .column(1)
+ .as_any()
+ .downcast_ref::<StringArray>()
+ .unwrap();
+ (
+ id_col.values().to_vec(),
+ name_col.iter().map(|s| s.unwrap().to_string()).collect(),
+ )
+ };
+
+ // Verify partition 1: id=1, names=["a", "c", "g"]
+ let (key, batch) = &partitioned_batches[0];
+ assert_eq!(key.data(),
&Struct::from_iter(vec![Some(Literal::int(1))]));
+ let (ids, names) = extract_values(batch);
+ assert_eq!(ids, vec![1, 1, 1]);
+ assert_eq!(names, vec!["a", "c", "g"]);
+
+ // Verify partition 2: id=2, names=["b", "e"]
+ let (key, batch) = &partitioned_batches[1];
+ assert_eq!(key.data(),
&Struct::from_iter(vec![Some(Literal::int(2))]));
+ let (ids, names) = extract_values(batch);
+ assert_eq!(ids, vec![2, 2]);
+ assert_eq!(names, vec!["b", "e"]);
+
+ // Verify partition 3: id=3, names=["d", "f"]
+ let (key, batch) = &partitioned_batches[2];
+ assert_eq!(key.data(),
&Struct::from_iter(vec![Some(Literal::int(3))]));
+ let (ids, names) = extract_values(batch);
+ assert_eq!(ids, vec![3, 3]);
+ assert_eq!(names, vec!["d", "f"]);
+ }
}
diff --git a/crates/integrations/datafusion/src/physical_plan/project.rs
b/crates/integrations/datafusion/src/physical_plan/project.rs
index 4bfe8192b..17492176a 100644
--- a/crates/integrations/datafusion/src/physical_plan/project.rs
+++ b/crates/integrations/datafusion/src/physical_plan/project.rs
@@ -19,24 +19,19 @@
use std::sync::Arc;
-use datafusion::arrow::array::{ArrayRef, RecordBatch, StructArray};
+use datafusion::arrow::array::RecordBatch;
use datafusion::arrow::datatypes::{DataType, Schema as ArrowSchema};
use datafusion::common::Result as DFResult;
-use datafusion::error::DataFusionError;
use datafusion::physical_expr::PhysicalExpr;
use datafusion::physical_expr::expressions::Column;
use datafusion::physical_plan::projection::ProjectionExec;
use datafusion::physical_plan::{ColumnarValue, ExecutionPlan};
-use iceberg::arrow::record_batch_projector::RecordBatchProjector;
-use iceberg::spec::{PartitionSpec, Schema};
+use iceberg::arrow::{PROJECTED_PARTITION_VALUE_COLUMN,
PartitionValueCalculator};
+use iceberg::spec::PartitionSpec;
use iceberg::table::Table;
-use iceberg::transform::BoxedTransformFunction;
use crate::to_datafusion_error;
-/// Column name for the combined partition values struct
-const PARTITION_VALUES_COLUMN: &str = "_partition";
-
/// Extends an ExecutionPlan with partition value calculations for Iceberg
tables.
///
/// This function takes an input ExecutionPlan and extends it with an
additional column
@@ -65,12 +60,9 @@ pub fn project_with_partition(
let input_schema = input.schema();
// TODO: Validate that input_schema matches the Iceberg table schema.
// See: https://github.com/apache/iceberg-rust/issues/1752
- let partition_type = build_partition_type(partition_spec,
table_schema.as_ref())?;
- let calculator = PartitionValueCalculator::new(
- partition_spec.as_ref().clone(),
- table_schema.as_ref().clone(),
- partition_type,
- )?;
+ let calculator =
+ PartitionValueCalculator::try_new(partition_spec.as_ref(),
table_schema.as_ref())
+ .map_err(to_datafusion_error)?;
let mut projection_exprs: Vec<(Arc<dyn PhysicalExpr>, String)> =
Vec::with_capacity(input_schema.fields().len() + 1);
@@ -80,8 +72,8 @@ pub fn project_with_partition(
projection_exprs.push((column_expr, field.name().clone()));
}
- let partition_expr = Arc::new(PartitionExpr::new(calculator));
- projection_exprs.push((partition_expr,
PARTITION_VALUES_COLUMN.to_string()));
+ let partition_expr = Arc::new(PartitionExpr::new(calculator,
partition_spec.clone()));
+ projection_exprs.push((partition_expr,
PROJECTED_PARTITION_VALUE_COLUMN.to_string()));
let projection = ProjectionExec::try_new(projection_exprs, input)?;
Ok(Arc::new(projection))
@@ -91,21 +83,24 @@ pub fn project_with_partition(
#[derive(Debug, Clone)]
struct PartitionExpr {
calculator: Arc<PartitionValueCalculator>,
+ partition_spec: Arc<PartitionSpec>,
}
impl PartitionExpr {
- fn new(calculator: PartitionValueCalculator) -> Self {
+ fn new(calculator: PartitionValueCalculator, partition_spec:
Arc<PartitionSpec>) -> Self {
Self {
calculator: Arc::new(calculator),
+ partition_spec,
}
}
}
// Manual PartialEq/Eq implementations for pointer-based equality
-// (two PartitionExpr are equal if they share the same calculator instance)
+// (two PartitionExpr are equal if they share the same calculator and
partition_spec instances)
impl PartialEq for PartitionExpr {
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.calculator, &other.calculator)
+ && Arc::ptr_eq(&self.partition_spec, &other.partition_spec)
}
}
@@ -117,7 +112,7 @@ impl PhysicalExpr for PartitionExpr {
}
fn data_type(&self, _input_schema: &ArrowSchema) -> DFResult<DataType> {
- Ok(self.calculator.partition_type.clone())
+ Ok(self.calculator.partition_arrow_type().clone())
}
fn nullable(&self, _input_schema: &ArrowSchema) -> DFResult<bool> {
@@ -125,7 +120,10 @@ impl PhysicalExpr for PartitionExpr {
}
fn evaluate(&self, batch: &RecordBatch) -> DFResult<ColumnarValue> {
- let array = self.calculator.calculate(batch)?;
+ let array = self
+ .calculator
+ .calculate(batch)
+ .map_err(to_datafusion_error)?;
Ok(ColumnarValue::Array(array))
}
@@ -142,7 +140,6 @@ impl PhysicalExpr for PartitionExpr {
fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let field_names: Vec<String> = self
- .calculator
.partition_spec
.fields()
.iter()
@@ -155,7 +152,6 @@ impl PhysicalExpr for PartitionExpr {
impl std::fmt::Display for PartitionExpr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let field_names: Vec<&str> = self
- .calculator
.partition_spec
.fields()
.iter()
@@ -167,110 +163,18 @@ impl std::fmt::Display for PartitionExpr {
impl std::hash::Hash for PartitionExpr {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
- // Two PartitionExpr are equal if they share the same calculator Arc
+ // Two PartitionExpr are equal if they share the same calculator and
partition_spec Arcs
Arc::as_ptr(&self.calculator).hash(state);
+ Arc::as_ptr(&self.partition_spec).hash(state);
}
}
-/// Calculator for partition values in Iceberg tables
-#[derive(Debug)]
-struct PartitionValueCalculator {
- partition_spec: PartitionSpec,
- partition_type: DataType,
- projector: RecordBatchProjector,
- transform_functions: Vec<BoxedTransformFunction>,
-}
-
-impl PartitionValueCalculator {
- fn new(
- partition_spec: PartitionSpec,
- table_schema: Schema,
- partition_type: DataType,
- ) -> DFResult<Self> {
- if partition_spec.is_unpartitioned() {
- return Err(DataFusionError::Internal(
- "Cannot create partition calculator for unpartitioned
table".to_string(),
- ));
- }
-
- let transform_functions: Result<Vec<BoxedTransformFunction>, _> =
partition_spec
- .fields()
- .iter()
- .map(|pf|
iceberg::transform::create_transform_function(&pf.transform))
- .collect();
-
- let transform_functions =
transform_functions.map_err(to_datafusion_error)?;
-
- let source_field_ids: Vec<i32> = partition_spec
- .fields()
- .iter()
- .map(|pf| pf.source_id)
- .collect();
-
- let projector = RecordBatchProjector::from_iceberg_schema(
- Arc::new(table_schema.clone()),
- &source_field_ids,
- )
- .map_err(to_datafusion_error)?;
-
- Ok(Self {
- partition_spec,
- partition_type,
- projector,
- transform_functions,
- })
- }
-
- fn calculate(&self, batch: &RecordBatch) -> DFResult<ArrayRef> {
- let source_columns = self
- .projector
- .project_column(batch.columns())
- .map_err(to_datafusion_error)?;
-
- let expected_struct_fields = match &self.partition_type {
- DataType::Struct(fields) => fields.clone(),
- _ => {
- return Err(DataFusionError::Internal(
- "Expected partition type must be a struct".to_string(),
- ));
- }
- };
-
- let mut partition_values =
Vec::with_capacity(self.partition_spec.fields().len());
-
- for (source_column, transform_fn) in
source_columns.iter().zip(&self.transform_functions) {
- let partition_value = transform_fn
- .transform(source_column.clone())
- .map_err(to_datafusion_error)?;
-
- partition_values.push(partition_value);
- }
-
- let struct_array = StructArray::try_new(expected_struct_fields,
partition_values, None)
- .map_err(|e| DataFusionError::ArrowError(e, None))?;
-
- Ok(Arc::new(struct_array))
- }
-}
-
-fn build_partition_type(
- partition_spec: &PartitionSpec,
- table_schema: &Schema,
-) -> DFResult<DataType> {
- let partition_struct_type = partition_spec
- .partition_type(table_schema)
- .map_err(to_datafusion_error)?;
-
-
iceberg::arrow::type_to_arrow_type(&iceberg::spec::Type::Struct(partition_struct_type))
- .map_err(to_datafusion_error)
-}
-
#[cfg(test)]
mod tests {
- use datafusion::arrow::array::Int32Array;
+ use datafusion::arrow::array::{ArrayRef, Int32Array, StructArray};
use datafusion::arrow::datatypes::{Field, Fields};
use datafusion::physical_plan::empty::EmptyExec;
- use iceberg::spec::{NestedField, PrimitiveType, StructType, Transform,
Type};
+ use iceberg::spec::{NestedField, PrimitiveType, Schema, StructType,
Transform, Type};
use super::*;
@@ -291,20 +195,11 @@ mod tests {
.build()
.unwrap();
- let _arrow_schema = Arc::new(ArrowSchema::new(vec![
- Field::new("id", DataType::Int32, false),
- Field::new("name", DataType::Utf8, false),
- ]));
-
- let partition_type = build_partition_type(&partition_spec,
&table_schema).unwrap();
- let calculator = PartitionValueCalculator::new(
- partition_spec.clone(),
- table_schema,
- partition_type.clone(),
- )
- .unwrap();
+ let calculator = PartitionValueCalculator::try_new(&partition_spec,
&table_schema).unwrap();
- assert_eq!(calculator.partition_type, partition_type);
+ // Verify partition type
+ assert_eq!(calculator.partition_type().fields().len(), 1);
+ assert_eq!(calculator.partition_type().fields()[0].name,
"id_partition");
}
#[test]
@@ -318,11 +213,13 @@ mod tests {
.build()
.unwrap();
- let partition_spec =
iceberg::spec::PartitionSpec::builder(Arc::new(table_schema.clone()))
- .add_partition_field("id", "id_partition", Transform::Identity)
- .unwrap()
- .build()
- .unwrap();
+ let partition_spec = Arc::new(
+
iceberg::spec::PartitionSpec::builder(Arc::new(table_schema.clone()))
+ .add_partition_field("id", "id_partition", Transform::Identity)
+ .unwrap()
+ .build()
+ .unwrap(),
+ );
let arrow_schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", DataType::Int32, false),
@@ -331,9 +228,7 @@ mod tests {
let input = Arc::new(EmptyExec::new(arrow_schema.clone()));
- let partition_type = build_partition_type(&partition_spec,
&table_schema).unwrap();
- let calculator =
- PartitionValueCalculator::new(partition_spec, table_schema,
partition_type).unwrap();
+ let calculator = PartitionValueCalculator::try_new(&partition_spec,
&table_schema).unwrap();
let mut projection_exprs: Vec<(Arc<dyn PhysicalExpr>, String)> =
Vec::with_capacity(arrow_schema.fields().len() + 1);
@@ -342,8 +237,8 @@ mod tests {
projection_exprs.push((column_expr, field.name().clone()));
}
- let partition_expr = Arc::new(PartitionExpr::new(calculator));
- projection_exprs.push((partition_expr,
PARTITION_VALUES_COLUMN.to_string()));
+ let partition_expr = Arc::new(PartitionExpr::new(calculator,
partition_spec));
+ projection_exprs.push((partition_expr,
PROJECTED_PARTITION_VALUE_COLUMN.to_string()));
let projection = ProjectionExec::try_new(projection_exprs,
input).unwrap();
let result = Arc::new(projection);
@@ -384,11 +279,10 @@ mod tests {
])
.unwrap();
- let partition_type = build_partition_type(&partition_spec,
&table_schema).unwrap();
- let calculator =
- PartitionValueCalculator::new(partition_spec, table_schema,
partition_type.clone())
- .unwrap();
- let expr = PartitionExpr::new(calculator);
+ let partition_spec = Arc::new(partition_spec);
+ let calculator = PartitionValueCalculator::try_new(&partition_spec,
&table_schema).unwrap();
+ let partition_type = calculator.partition_arrow_type().clone();
+ let expr = PartitionExpr::new(calculator, partition_spec);
assert_eq!(expr.data_type(&arrow_schema).unwrap(), partition_type);
assert!(!expr.nullable(&arrow_schema).unwrap());
@@ -469,9 +363,7 @@ mod tests {
])
.unwrap();
- let partition_type = build_partition_type(&partition_spec,
&table_schema).unwrap();
- let calculator =
- PartitionValueCalculator::new(partition_spec, table_schema,
partition_type).unwrap();
+ let calculator = PartitionValueCalculator::try_new(&partition_spec,
&table_schema).unwrap();
let array = calculator.calculate(&batch).unwrap();
let struct_array =
array.as_any().downcast_ref::<StructArray>().unwrap();