This is an automated email from the ASF dual-hosted git repository.
github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new e9431fc694 Optimize merging of partial case expression results (#18152)
e9431fc694 is described below
commit e9431fc694e1fc905f86ea00c689ccecf6065e9d
Author: Pepijn Van Eeckhoudt <[email protected]>
AuthorDate: Tue Oct 28 16:20:22 2025 +0100
Optimize merging of partial case expression results (#18152)
## Which issue does this PR close?
- Improvement in the context of
https://github.com/apache/datafusion/issues/18075
- Continues on #17898
## Rationale for this change
Case evaluation currently uses `PhysicalExpr::evaluate_selection` for
each branch of the case expression. This implementation is fine, but
because `evaluate_selection` is not specific to the `case` logic we're
missing some optimisation opportunities. The main consequence is that
too much work is being done filtering record batches and scattering
results. This PR introduces specialised filtering logic and result
interleaving for case.
A more detailed description and diagrams are available at
https://github.com/apache/datafusion/issues/18075#issuecomment-3422326710
## What changes are included in this PR?
Rewrite the `case_when_no_expr` and `case_when_with_expr` evaluation
loops to avoid as much unnecessary work as possible. In particular the
remaining rows to be evaluated are retained across loop iterations. This
allows the record batch that needs to be filtered to shrink as the loop
is being evaluated which reduces the number of rows that needs to be
refiltered. If a when predicate does not match any rows at all,
filtering is avoided entirely.
The final result is also not merged every loop iteration. Instead an
index vector is constructed which is used to compose the final result
once using a custom 'multi zip'/'interleave' like operation.
## Are these changes tested?
Covered by existing unit tests and SLTs
## Are there any user-facing changes?
No
---
datafusion/physical-expr/src/expressions/case.rs | 732 +++++++++++++++++++----
datafusion/sqllogictest/test_files/case.slt | 22 +
2 files changed, 636 insertions(+), 118 deletions(-)
diff --git a/datafusion/physical-expr/src/expressions/case.rs
b/datafusion/physical-expr/src/expressions/case.rs
index 2db599047b..0b4c3af1d9 100644
--- a/datafusion/physical-expr/src/expressions/case.rs
+++ b/datafusion/physical-expr/src/expressions/case.rs
@@ -15,25 +15,28 @@
// specific language governing permissions and limitations
// under the License.
+use super::{Column, Literal};
+use crate::expressions::case::ResultState::{Complete, Empty, Partial};
use crate::expressions::try_cast;
use crate::PhysicalExpr;
-use std::borrow::Cow;
-use std::hash::Hash;
-use std::{any::Any, sync::Arc};
-
use arrow::array::*;
use arrow::compute::kernels::zip::zip;
-use arrow::compute::{and, and_not, is_null, not, nullif, or,
prep_null_mask_filter};
-use arrow::datatypes::{DataType, Schema};
+use arrow::compute::{
+ is_not_null, not, nullif, prep_null_mask_filter, FilterBuilder,
FilterPredicate,
+};
+use arrow::datatypes::{DataType, Schema, UInt32Type};
+use arrow::error::ArrowError;
use datafusion_common::cast::as_boolean_array;
use datafusion_common::{
exec_err, internal_datafusion_err, internal_err, DataFusionError, Result,
ScalarValue,
};
use datafusion_expr::ColumnarValue;
-
-use super::{Column, Literal};
use datafusion_physical_expr_common::datum::compare_with_eq;
use itertools::Itertools;
+use std::borrow::Cow;
+use std::fmt::{Debug, Formatter};
+use std::hash::Hash;
+use std::{any::Any, sync::Arc};
type WhenThen = (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>);
@@ -98,7 +101,7 @@ pub struct CaseExpr {
}
impl std::fmt::Display for CaseExpr {
- fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
write!(f, "CASE ")?;
if let Some(e) = &self.expr {
write!(f, "{e} ")?;
@@ -122,6 +125,419 @@ fn is_cheap_and_infallible(expr: &Arc<dyn PhysicalExpr>)
-> bool {
expr.as_any().is::<Column>()
}
+/// Creates a [FilterPredicate] from a boolean array.
+fn create_filter(predicate: &BooleanArray) -> FilterPredicate {
+ let mut filter_builder = FilterBuilder::new(predicate);
+ // Always optimize the filter since we use them multiple times.
+ filter_builder = filter_builder.optimize();
+ filter_builder.build()
+}
+
+// This should be removed when https://github.com/apache/arrow-rs/pull/8693
+// is merged and becomes available.
+fn filter_record_batch(
+ record_batch: &RecordBatch,
+ filter: &FilterPredicate,
+) -> std::result::Result<RecordBatch, ArrowError> {
+ let filtered_columns = record_batch
+ .columns()
+ .iter()
+ .map(|a| filter_array(a, filter))
+ .collect::<std::result::Result<Vec<_>, _>>()?;
+ // SAFETY: since we start from a valid RecordBatch, there's no need to
revalidate the schema
+ // since the set of columns has not changed.
+ // The input column arrays all had the same length (since they're coming
from a valid RecordBatch)
+ // and the filtering them with the same filter will produces a new set of
arrays with identical
+ // lengths.
+ unsafe {
+ Ok(RecordBatch::new_unchecked(
+ record_batch.schema(),
+ filtered_columns,
+ filter.count(),
+ ))
+ }
+}
+
+// This function exists purely to be able to use the same call style
+// for `filter_record_batch` and `filter_array` at the point of use.
+// When https://github.com/apache/arrow-rs/pull/8693 is available, replace
+// both with method calls on `FilterPredicate`.
+#[inline(always)]
+fn filter_array(
+ array: &dyn Array,
+ filter: &FilterPredicate,
+) -> std::result::Result<ArrayRef, ArrowError> {
+ filter.filter(array)
+}
+
+/// Merges elements by index from a list of [`ArrayData`], creating a new
[`ColumnarValue`] from
+/// those values.
+///
+/// Each element in `indices` is the index of an array in `values`. The
`indices` array is processed
+/// sequentially. The first occurrence of index value `n` will be mapped to
the first
+/// value of the array at index `n`. The second occurrence to the second
value, and so on.
+/// An index value where `PartialResultIndex::is_none` is `true` is used to
indicate null values.
+///
+/// # Implementation notes
+///
+/// This algorithm is similar in nature to both `zip` and `interleave`, but
there are some important
+/// differences.
+///
+/// In contrast to `zip`, this function supports multiple input arrays.
Instead of a boolean
+/// selection vector, an index array is to take values from the input arrays,
and a special marker
+/// value is used to indicate null values.
+///
+/// In contrast to `interleave`, this function does not use pairs of indices.
The values in
+/// `indices` serve the same purpose as the first value in the pairs passed to
`interleave`.
+/// The index in the array is implicit and is derived from the number of times
a particular array
+/// index occurs.
+/// The more constrained indexing mechanism used by this algorithm makes it
easier to copy values
+/// in contiguous slices. In the example below, the two subsequent elements
from array `2` can be
+/// copied in a single operation from the source array instead of copying them
one by one.
+/// Long spans of null values are also especially cheap because they do not
need to be represented
+/// in an input array.
+///
+/// # Safety
+///
+/// This function does not check that the number of occurrences of any
particular array index matches
+/// the length of the corresponding input array. If an array contains more
values than required, the
+/// spurious values will be ignored. If an array contains fewer values than
necessary, this function
+/// will panic.
+///
+/// # Example
+///
+/// ```text
+/// ┌───────────┐ ┌─────────┐ ┌─────────┐
+/// │┌─────────┐│ │ None │ │ NULL │
+/// ││ A ││ ├─────────┤ ├─────────┤
+/// │└─────────┘│ │ 1 │ │ B │
+/// │┌─────────┐│ ├─────────┤ ├─────────┤
+/// ││ B ││ │ 0 │ merge(values, indices) │ A │
+/// │└─────────┘│ ├─────────┤ ─────────────────────────▶ ├─────────┤
+/// │┌─────────┐│ │ None │ │ NULL │
+/// ││ C ││ ├─────────┤ ├─────────┤
+/// │├─────────┤│ │ 2 │ │ C │
+/// ││ D ││ ├─────────┤ ├─────────┤
+/// │└─────────┘│ │ 2 │ │ D │
+/// └───────────┘ └─────────┘ └─────────┘
+/// values indices result
+///
+/// ```
+fn merge(values: &[ArrayData], indices: &[PartialResultIndex]) ->
Result<ArrayRef> {
+ #[cfg(debug_assertions)]
+ for ix in indices {
+ if let Some(index) = ix.index() {
+ assert!(
+ index < values.len(),
+ "Index out of bounds: {} >= {}",
+ index,
+ values.len()
+ );
+ }
+ }
+
+ let data_refs = values.iter().collect();
+ let mut mutable = MutableArrayData::new(data_refs, true, indices.len());
+
+ // This loop extends the mutable array by taking slices from the partial
results.
+ //
+ // take_offsets keeps track of how many values have been taken from each
array.
+ let mut take_offsets = vec![0; values.len() + 1];
+ let mut start_row_ix = 0;
+ loop {
+ let array_ix = indices[start_row_ix];
+
+ // Determine the length of the slice to take.
+ let mut end_row_ix = start_row_ix + 1;
+ while end_row_ix < indices.len() && indices[end_row_ix] == array_ix {
+ end_row_ix += 1;
+ }
+ let slice_length = end_row_ix - start_row_ix;
+
+ // Extend mutable with either nulls or with values from the array.
+ match array_ix.index() {
+ None => mutable.extend_nulls(slice_length),
+ Some(index) => {
+ let start_offset = take_offsets[index];
+ let end_offset = start_offset + slice_length;
+ mutable.extend(index, start_offset, end_offset);
+ take_offsets[index] = end_offset;
+ }
+ }
+
+ if end_row_ix == indices.len() {
+ break;
+ } else {
+ // Set the start_row_ix for the next slice.
+ start_row_ix = end_row_ix;
+ }
+ }
+
+ Ok(make_array(mutable.freeze()))
+}
+
+/// An index into the partial results array that's more compact than `usize`.
+///
+/// `u32::MAX` is reserved as a special 'none' value. This is used instead of
+/// `Option` to keep the array of indices as compact as possible.
+#[derive(Copy, Clone, PartialEq, Eq)]
+struct PartialResultIndex {
+ index: u32,
+}
+
+const NONE_VALUE: u32 = u32::MAX;
+
+impl PartialResultIndex {
+ /// Returns the 'none' placeholder value.
+ fn none() -> Self {
+ Self { index: NONE_VALUE }
+ }
+
+ fn zero() -> Self {
+ Self { index: 0 }
+ }
+
+ /// Creates a new partial result index.
+ ///
+ /// If the provided value is greater than or equal to `u32::MAX`
+ /// an error will be returned.
+ fn try_new(index: usize) -> Result<Self> {
+ let Ok(index) = u32::try_from(index) else {
+ return internal_err!("Partial result index exceeds limit");
+ };
+
+ if index == NONE_VALUE {
+ return internal_err!("Partial result index exceeds limit");
+ }
+
+ Ok(Self { index })
+ }
+
+ /// Determines if this index is the 'none' placeholder value or not.
+ fn is_none(&self) -> bool {
+ self.index == NONE_VALUE
+ }
+
+ /// Returns `Some(index)` if this value is not the 'none' placeholder,
`None` otherwise.
+ fn index(&self) -> Option<usize> {
+ if self.is_none() {
+ None
+ } else {
+ Some(self.index as usize)
+ }
+ }
+}
+
+impl Debug for PartialResultIndex {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ if self.is_none() {
+ write!(f, "null")
+ } else {
+ write!(f, "{}", self.index)
+ }
+ }
+}
+
+enum ResultState {
+ /// The final result is an array containing only null values.
+ Empty,
+ /// The final result needs to be computed by merging the data in `arrays`.
+ Partial {
+ // A `Vec` of partial results that should be merged.
+ // `partial_result_indices` contains indexes into this vec.
+ arrays: Vec<ArrayData>,
+ // Indicates per result row from which array in `partial_results` a
value should be taken.
+ indices: Vec<PartialResultIndex>,
+ },
+ /// A single branch matched all input rows. When creating the final
result, no further merging
+ /// of partial results is necessary.
+ Complete(ColumnarValue),
+}
+
+/// A builder for constructing result arrays for CASE expressions.
+///
+/// Rather than building a monolithic array containing all results, it
maintains a set of
+/// partial result arrays and a mapping that indicates for each row which
partial array
+/// contains the result value for that row.
+///
+/// On finish(), the builder will merge all partial results into a single
array if necessary.
+/// If all rows evaluated to the same array, that array can be returned
directly without
+/// any merging overhead.
+struct ResultBuilder {
+ data_type: DataType,
+ /// The number of rows in the final result.
+ row_count: usize,
+ state: ResultState,
+}
+
+impl ResultBuilder {
+ /// Creates a new ResultBuilder that will produce arrays of the given data
type.
+ ///
+ /// The `row_count` parameter indicates the number of rows in the final
result.
+ fn new(data_type: &DataType, row_count: usize) -> Self {
+ Self {
+ data_type: data_type.clone(),
+ row_count,
+ state: Empty,
+ }
+ }
+
+ /// Adds a result for one branch of the case expression.
+ ///
+ /// `row_indices` should be a [UInt32Array] containing [RecordBatch]
relative row indices
+ /// for which `value` contains result values.
+ ///
+ /// If `value` is a scalar, the scalar value will be used as the value for
each row in `row_indices`.
+ ///
+ /// If `value` is an array, the values from the array and the indices from
`row_indices` will be
+ /// processed pairwise. The lengths of `value` and `row_indices` must
match.
+ ///
+ /// The diagram below shows a situation where a when expression matched
rows 1 and 4 of the
+ /// record batch. The then expression produced the value array `[A, D]`.
+ /// After adding this result, the result array will have been added to
`partial arrays` and
+ /// `partial indices` will have been updated at indexes `1` and `4`.
+ ///
+ /// ```text
+ /// ┌─────────┐ ┌─────────┐┌───────────┐
┌─────────┐┌───────────┐
+ /// │ C │ │ 0: None ││┌ 0 ──────┐│ │
0: None ││┌ 0 ──────┐│
+ /// ├─────────┤ ├─────────┤││ A ││
├─────────┤││ A ││
+ /// │ D │ │ 1: None ││└─────────┘│ │
1: 2 ││└─────────┘│
+ /// └─────────┘ ├─────────┤│┌ 1 ──────┐│ add_branch_result(
├─────────┤│┌ 1 ──────┐│
+ /// matching │ 2: 0 │││ B ││ row indices, │
2: 0 │││ B ││
+ /// 'then' values ├─────────┤│└─────────┘│ value
├─────────┤│└─────────┘│
+ /// │ 3: None ││ │ ) │
3: None ││┌ 2 ──────┐│
+ /// ┌─────────┐ ├─────────┤│ │ ─────────────────────────▶
├─────────┤││ C ││
+ /// │ 1 │ │ 4: None ││ │ │
4: 2 ││├─────────┤│
+ /// ├─────────┤ ├─────────┤│ │
├─────────┤││ D ││
+ /// │ 4 │ │ 5: 1 ││ │ │
5: 1 ││└─────────┘│
+ /// └─────────┘ └─────────┘└───────────┘
└─────────┘└───────────┘
+ /// row indices partial partial
partial partial
+ /// indices arrays
indices arrays
+ /// ```
+ fn add_branch_result(
+ &mut self,
+ row_indices: &ArrayRef,
+ value: ColumnarValue,
+ ) -> Result<()> {
+ match value {
+ ColumnarValue::Array(a) => {
+ if a.len() != row_indices.len() {
+ internal_err!("Array length must match row indices length")
+ } else if row_indices.len() == self.row_count {
+ self.set_complete_result(ColumnarValue::Array(a))
+ } else {
+ self.add_partial_result(row_indices, a.to_data())
+ }
+ }
+ ColumnarValue::Scalar(s) => {
+ if row_indices.len() == self.row_count {
+ self.set_complete_result(ColumnarValue::Scalar(s))
+ } else {
+ self.add_partial_result(
+ row_indices,
+ s.to_array_of_size(row_indices.len())?.to_data(),
+ )
+ }
+ }
+ }
+ }
+
+ /// Adds a partial result array.
+ ///
+ /// This method adds the given array data as a partial result and updates
the index mapping
+ /// to indicate that the specified rows should take their values from this
array.
+ /// The partial results will be merged into a single array when finish()
is called.
+ fn add_partial_result(
+ &mut self,
+ row_indices: &ArrayRef,
+ row_values: ArrayData,
+ ) -> Result<()> {
+ if row_indices.null_count() != 0 {
+ return internal_err!("Row indices must not contain nulls");
+ }
+
+ match &mut self.state {
+ Empty => {
+ let array_index = PartialResultIndex::zero();
+ let mut indices = vec![PartialResultIndex::none();
self.row_count];
+ for row_ix in
row_indices.as_primitive::<UInt32Type>().values().iter() {
+ indices[*row_ix as usize] = array_index;
+ }
+
+ self.state = Partial {
+ arrays: vec![row_values],
+ indices,
+ };
+
+ Ok(())
+ }
+ Partial { arrays, indices } => {
+ let array_index = PartialResultIndex::try_new(arrays.len())?;
+
+ arrays.push(row_values);
+
+ for row_ix in
row_indices.as_primitive::<UInt32Type>().values().iter() {
+ // This is check is only active for debug config because
the callers of this method,
+ // `case_when_with_expr` and `case_when_no_expr`, already
ensure that
+ // they only calculate a value for each row at most once.
+ #[cfg(debug_assertions)]
+ if !indices[*row_ix as usize].is_none() {
+ return internal_err!("Duplicate value for row {}",
*row_ix);
+ }
+
+ indices[*row_ix as usize] = array_index;
+ }
+ Ok(())
+ }
+ Complete(_) => internal_err!(
+ "Cannot add a partial result when complete result is already
set"
+ ),
+ }
+ }
+
+ /// Sets a result that applies to all rows.
+ ///
+ /// This is an optimization for cases where all rows evaluate to the same
result.
+ /// When a complete result is set, the builder will return it directly
from finish()
+ /// without any merging overhead.
+ fn set_complete_result(&mut self, value: ColumnarValue) -> Result<()> {
+ match &self.state {
+ Empty => {
+ self.state = Complete(value);
+ Ok(())
+ }
+ Partial { .. } => {
+ internal_err!(
+ "Cannot set a complete result when there are already
partial results"
+ )
+ }
+ Complete(_) => internal_err!("Complete result already set"),
+ }
+ }
+
+ /// Finishes building the result and returns the final array.
+ fn finish(self) -> Result<ColumnarValue> {
+ match self.state {
+ Empty => {
+ // No complete result and no partial results.
+ // This can happen for case expressions with no else branch
where no rows
+ // matched.
+ Ok(ColumnarValue::Scalar(ScalarValue::try_new_null(
+ &self.data_type,
+ )?))
+ }
+ Partial { arrays, indices } => {
+ // Merge partial results into a single array.
+ Ok(ColumnarValue::Array(merge(&arrays, &indices)?))
+ }
+ Complete(v) => {
+ // If we have a complete result, we can just return it.
+ Ok(v)
+ }
+ }
+ }
+}
+
impl CaseExpr {
/// Create a new CASE WHEN expression
pub fn try_new(
@@ -196,82 +612,146 @@ impl CaseExpr {
/// END
fn case_when_with_expr(&self, batch: &RecordBatch) ->
Result<ColumnarValue> {
let return_type = self.data_type(&batch.schema())?;
- let expr = self.expr.as_ref().unwrap();
- let base_value = expr.evaluate(batch)?;
- let base_value = base_value.into_array(batch.num_rows())?;
- let base_nulls = is_null(base_value.as_ref())?;
-
- // start with nulls as default output
- let mut current_value = new_null_array(&return_type, batch.num_rows());
- // We only consider non-null values while comparing with whens
- let mut remainder = not(&base_nulls)?;
- let mut non_null_remainder_count = remainder.true_count();
- for i in 0..self.when_then_expr.len() {
- // If there are no rows left to process, break out of the loop
early
- if non_null_remainder_count == 0 {
- break;
- }
+ let mut result_builder = ResultBuilder::new(&return_type,
batch.num_rows());
+
+ // `remainder_rows` contains the indices of the rows that need to be
evaluated
+ let mut remainder_rows: ArrayRef =
+ Arc::new(UInt32Array::from_iter_values(0..batch.num_rows() as
u32));
+ // `remainder_batch` contains the rows themselves that need to be
evaluated
+ let mut remainder_batch = Cow::Borrowed(batch);
+
+ // evaluate the base expression
+ let mut base_values = self
+ .expr
+ .as_ref()
+ .unwrap()
+ .evaluate(batch)?
+ .into_array(batch.num_rows())?;
- let when_predicate = &self.when_then_expr[i].0;
- let when_value = when_predicate.evaluate_selection(batch,
&remainder)?;
- let when_value = when_value.into_array(batch.num_rows())?;
- // build boolean array representing which rows match the "when"
value
- let when_match = compare_with_eq(
- &when_value,
- &base_value,
- // The types of case and when expressions will be coerced to
match.
- // We only need to check if the base_value is nested.
- base_value.data_type().is_nested(),
- )?;
- // Treat nulls as false
- let when_match = match when_match.null_count() {
- 0 => Cow::Borrowed(&when_match),
- _ => Cow::Owned(prep_null_mask_filter(&when_match)),
- };
- // Make sure we only consider rows that have not been matched yet
- let when_value = and(&when_match, &remainder)?;
+ // Fill in a result value already for rows where the base expression
value is null
+ // Since each when expression is tested against the base expression
using the equality
+ // operator, null base values can never match any when expression. `x
= NULL` is falsy,
+ // for all possible values of `x`.
+ if base_values.null_count() > 0 {
+ // Use `is_not_null` since this is a cheap clone of the null
buffer from 'base_value'.
+ // We already checked there are nulls, so we can be sure a new
buffer will not be
+ // created.
+ let base_not_nulls = is_not_null(base_values.as_ref())?;
+ let base_all_null = base_values.null_count() ==
remainder_batch.num_rows();
+
+ // If there is an else expression, use that as the default value
for the null rows
+ // Otherwise the default `null` value from the result builder will
be used.
+ if let Some(e) = self.else_expr() {
+ let expr = try_cast(Arc::clone(e), &batch.schema(),
return_type.clone())?;
- // If the predicate did not match any rows, continue to the next
branch immediately
- let when_match_count = when_value.true_count();
- if when_match_count == 0 {
- continue;
+ if base_all_null {
+ // All base values were null, so no need to filter
+ let nulls_value = expr.evaluate(&remainder_batch)?;
+ result_builder.add_branch_result(&remainder_rows,
nulls_value)?;
+ } else {
+ // Filter out the null rows and evaluate the else
expression for those
+ let nulls_filter = create_filter(¬(&base_not_nulls)?);
+ let nulls_batch =
+ filter_record_batch(&remainder_batch, &nulls_filter)?;
+ let nulls_rows = filter_array(&remainder_rows,
&nulls_filter)?;
+ let nulls_value = expr.evaluate(&nulls_batch)?;
+ result_builder.add_branch_result(&nulls_rows,
nulls_value)?;
+ }
}
- let then_expression = &self.when_then_expr[i].1;
- let then_value = then_expression.evaluate_selection(batch,
&when_value)?;
+ // All base values are null, so we can return early
+ if base_all_null {
+ return result_builder.finish();
+ }
- current_value = match then_value {
- ColumnarValue::Scalar(ScalarValue::Null) => {
- nullif(current_value.as_ref(), &when_value)?
- }
- ColumnarValue::Scalar(then_value) => {
- zip(&when_value, &then_value.to_scalar()?, ¤t_value)?
+ // Remove the null rows from the remainder batch
+ let not_null_filter = create_filter(&base_not_nulls);
+ remainder_batch =
+ Cow::Owned(filter_record_batch(&remainder_batch,
¬_null_filter)?);
+ remainder_rows = filter_array(&remainder_rows, ¬_null_filter)?;
+ base_values = filter_array(&base_values, ¬_null_filter)?;
+ }
+
+ // The types of case and when expressions will be coerced to match.
+ // We only need to check if the base_value is nested.
+ let base_value_is_nested = base_values.data_type().is_nested();
+
+ for i in 0..self.when_then_expr.len() {
+ // Evaluate the 'when' predicate for the remainder batch
+ // This results in a boolean array with the same length as the
remaining number of rows
+ let when_expr = &self.when_then_expr[i].0;
+ let when_value = match when_expr.evaluate(&remainder_batch)? {
+ ColumnarValue::Array(a) => {
+ compare_with_eq(&a, &base_values, base_value_is_nested)
}
- ColumnarValue::Array(then_value) => {
- zip(&when_value, &then_value, ¤t_value)?
+ ColumnarValue::Scalar(s) => {
+ let scalar = Scalar::new(s.to_array()?);
+ compare_with_eq(&scalar, &base_values,
base_value_is_nested)
}
- };
+ }?;
- remainder = and_not(&remainder, &when_value)?;
- non_null_remainder_count -= when_match_count;
- }
+ // `true_count` ignores `true` values where the validity bit is
not set, so there's
+ // no need to call `prep_null_mask_filter`.
+ let when_true_count = when_value.true_count();
- if let Some(e) = self.else_expr() {
- // null and unmatched tuples should be assigned else value
- remainder = or(&base_nulls, &remainder)?;
+ // If the 'when' predicate did not match any rows, continue to the
next branch immediately
+ if when_true_count == 0 {
+ continue;
+ }
- if remainder.true_count() > 0 {
- // keep `else_expr`'s data type and return type consistent
- let expr = try_cast(Arc::clone(e), &batch.schema(),
return_type.clone())?;
+ // If the 'when' predicate matched all remaining rows, there is no
need to filter
+ if when_true_count == remainder_batch.num_rows() {
+ let then_expression = &self.when_then_expr[i].1;
+ let then_value = then_expression.evaluate(&remainder_batch)?;
+ result_builder.add_branch_result(&remainder_rows, then_value)?;
+ return result_builder.finish();
+ }
+
+ // Filter the remainder batch based on the 'when' value
+ // This results in a batch containing only the rows that need to
be evaluated
+ // for the current branch
+ // Still no need to call `prep_null_mask_filter` since
`create_filter` will already do
+ // this unconditionally.
+ let then_filter = create_filter(&when_value);
+ let then_batch = filter_record_batch(&remainder_batch,
&then_filter)?;
+ let then_rows = filter_array(&remainder_rows, &then_filter)?;
- let else_ = expr
- .evaluate_selection(batch, &remainder)?
- .into_array(batch.num_rows())?;
- current_value = zip(&remainder, &else_, ¤t_value)?;
+ let then_expression = &self.when_then_expr[i].1;
+ let then_value = then_expression.evaluate(&then_batch)?;
+ result_builder.add_branch_result(&then_rows, then_value)?;
+
+ // If this is the last 'when' branch and there is no 'else'
expression, there's no
+ // point in calculating the remaining rows.
+ if self.else_expr.is_none() && i == self.when_then_expr.len() - 1 {
+ return result_builder.finish();
}
+
+ // Prepare the next when branch (or the else branch)
+ let next_selection = match when_value.null_count() {
+ 0 => not(&when_value),
+ _ => {
+ // `prep_null_mask_filter` is required to ensure the not
operation treats nulls
+ // as false
+ not(&prep_null_mask_filter(&when_value))
+ }
+ }?;
+ let next_filter = create_filter(&next_selection);
+ remainder_batch =
+ Cow::Owned(filter_record_batch(&remainder_batch,
&next_filter)?);
+ remainder_rows = filter_array(&remainder_rows, &next_filter)?;
+ base_values = filter_array(&base_values, &next_filter)?;
+ }
+
+ // If we reached this point, some rows were left unmatched.
+ // Check if those need to be evaluated using the 'else' expression.
+ if let Some(e) = self.else_expr() {
+ // keep `else_expr`'s data type and return type consistent
+ let expr = try_cast(Arc::clone(e), &batch.schema(),
return_type.clone())?;
+ let else_value = expr.evaluate(&remainder_batch)?;
+ result_builder.add_branch_result(&remainder_rows, else_value)?;
}
- Ok(ColumnarValue::Array(current_value))
+ result_builder.finish()
}
/// This function evaluates the form of CASE where each WHEN expression is
a boolean
@@ -283,70 +763,86 @@ impl CaseExpr {
/// END
fn case_when_no_expr(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
let return_type = self.data_type(&batch.schema())?;
+ let mut result_builder = ResultBuilder::new(&return_type,
batch.num_rows());
- // start with nulls as default output
- let mut current_value = new_null_array(&return_type, batch.num_rows());
- let mut remainder = BooleanArray::from(vec![true; batch.num_rows()]);
- let mut remainder_count = batch.num_rows();
- for i in 0..self.when_then_expr.len() {
- // If there are no rows left to process, break out of the loop
early
- if remainder_count == 0 {
- break;
- }
+ // `remainder_rows` contains the indices of the rows that need to be
evaluated
+ let mut remainder_rows: ArrayRef =
+ Arc::new(UInt32Array::from_iter(0..batch.num_rows() as u32));
+ // `remainder_batch` contains the rows themselves that need to be
evaluated
+ let mut remainder_batch = Cow::Borrowed(batch);
+ for i in 0..self.when_then_expr.len() {
+ // Evaluate the 'when' predicate for the remainder batch
+ // This results in a boolean array with the same length as the
remaining number of rows
let when_predicate = &self.when_then_expr[i].0;
- let when_value = when_predicate.evaluate_selection(batch,
&remainder)?;
- let when_value = when_value.into_array(batch.num_rows())?;
+ let when_value = when_predicate
+ .evaluate(&remainder_batch)?
+ .into_array(remainder_batch.num_rows())?;
let when_value = as_boolean_array(&when_value).map_err(|_| {
internal_datafusion_err!("WHEN expression did not return a
BooleanArray")
})?;
- // Treat 'NULL' as false value
- let when_value = match when_value.null_count() {
- 0 => Cow::Borrowed(when_value),
- _ => Cow::Owned(prep_null_mask_filter(when_value)),
- };
- // Make sure we only consider rows that have not been matched yet
- let when_value = and(&when_value, &remainder)?;
- // If the predicate did not match any rows, continue to the next
branch immediately
- let when_match_count = when_value.true_count();
- if when_match_count == 0 {
+ // `true_count` ignores `true` values where the validity bit is
not set, so there's
+ // no need to call `prep_null_mask_filter`.
+ let when_true_count = when_value.true_count();
+
+ // If the 'when' predicate did not match any rows, continue to the
next branch immediately
+ if when_true_count == 0 {
continue;
}
+ // If the 'when' predicate matched all remaining rows, there is no
need to filter
+ if when_true_count == remainder_batch.num_rows() {
+ let then_expression = &self.when_then_expr[i].1;
+ let then_value = then_expression.evaluate(&remainder_batch)?;
+ result_builder.add_branch_result(&remainder_rows, then_value)?;
+ return result_builder.finish();
+ }
+
+ // Filter the remainder batch based on the 'when' value
+ // This results in a batch containing only the rows that need to
be evaluated
+ // for the current branch
+ // Still no need to call `prep_null_mask_filter` since
`create_filter` will already do
+ // this unconditionally.
+ let then_filter = create_filter(when_value);
+ let then_batch = filter_record_batch(&remainder_batch,
&then_filter)?;
+ let then_rows = filter_array(&remainder_rows, &then_filter)?;
+
let then_expression = &self.when_then_expr[i].1;
- let then_value = then_expression.evaluate_selection(batch,
&when_value)?;
+ let then_value = then_expression.evaluate(&then_batch)?;
+ result_builder.add_branch_result(&then_rows, then_value)?;
- current_value = match then_value {
- ColumnarValue::Scalar(ScalarValue::Null) => {
- nullif(current_value.as_ref(), &when_value)?
- }
- ColumnarValue::Scalar(then_value) => {
- zip(&when_value, &then_value.to_scalar()?, ¤t_value)?
- }
- ColumnarValue::Array(then_value) => {
- zip(&when_value, &then_value, ¤t_value)?
- }
- };
+ // If this is the last 'when' branch and there is no 'else'
expression, there's no
+ // point in calculating the remaining rows.
+ if self.else_expr.is_none() && i == self.when_then_expr.len() - 1 {
+ return result_builder.finish();
+ }
- // Succeed tuples should be filtered out for short-circuit
evaluation,
- // null values for the current when expr should be kept
- remainder = and_not(&remainder, &when_value)?;
- remainder_count -= when_match_count;
+ // Prepare the next when branch (or the else branch)
+ let next_selection = match when_value.null_count() {
+ 0 => not(when_value),
+ _ => {
+ // `prep_null_mask_filter` is required to ensure the not
operation treats nulls
+ // as false
+ not(&prep_null_mask_filter(when_value))
+ }
+ }?;
+ let next_filter = create_filter(&next_selection);
+ remainder_batch =
+ Cow::Owned(filter_record_batch(&remainder_batch,
&next_filter)?);
+ remainder_rows = filter_array(&remainder_rows, &next_filter)?;
}
+ // If we reached this point, some rows were left unmatched.
+ // Check if those need to be evaluated using the 'else' expression.
if let Some(e) = self.else_expr() {
- if remainder_count > 0 {
- // keep `else_expr`'s data type and return type consistent
- let expr = try_cast(Arc::clone(e), &batch.schema(),
return_type.clone())?;
- let else_ = expr
- .evaluate_selection(batch, &remainder)?
- .into_array(batch.num_rows())?;
- current_value = zip(&remainder, &else_, ¤t_value)?;
- }
+ // keep `else_expr`'s data type and return type consistent
+ let expr = try_cast(Arc::clone(e), &batch.schema(),
return_type.clone())?;
+ let else_value = expr.evaluate(&remainder_batch)?;
+ result_builder.add_branch_result(&remainder_rows, else_value)?;
}
- Ok(ColumnarValue::Array(current_value))
+ result_builder.finish()
}
/// This function evaluates the specialized case of:
@@ -587,7 +1083,7 @@ impl PhysicalExpr for CaseExpr {
}
}
- fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "CASE ")?;
if let Some(e) = &self.expr {
e.fmt_sql(f)?;
diff --git a/datafusion/sqllogictest/test_files/case.slt
b/datafusion/sqllogictest/test_files/case.slt
index 352300e753..4eaa87b0b5 100644
--- a/datafusion/sqllogictest/test_files/case.slt
+++ b/datafusion/sqllogictest/test_files/case.slt
@@ -595,3 +595,25 @@ SELECT CASE WHEN a = 'a' THEN 0 WHEN a = 'b' THEN 1 ELSE 2
END FROM (VALUES (NUL
----
2
2
+
+# The `WHEN 1/0` is not effectively reachable in this query and should never
be executed
+query T
+SELECT CASE a WHEN 1 THEN 'a' WHEN 2 THEN 'b' WHEN 1 / 0 THEN 'c' ELSE 'd' END
FROM (VALUES (1), (2)) t(a)
+----
+a
+b
+
+# The `WHEN 1/0` is not effectively reachable in this query and should never
be executed
+query T
+SELECT CASE WHEN a = 1 THEN 'a' WHEN a = 2 THEN 'b' WHEN a = 1 / 0 THEN 'c'
ELSE 'd' END FROM (VALUES (1), (2)) t(a)
+----
+a
+b
+
+# The `WHEN 1/0` is not effectively reachable in this query and should never
be executed
+query T
+SELECT CASE WHEN a = 0 THEN 'a' WHEN 1 / a = 1 THEN 'b' ELSE 'c' END FROM
(VALUES (0), (1), (2)) t(a)
+----
+a
+b
+c
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]