rluvaton commented on code in PR #18152:
URL: https://github.com/apache/datafusion/pull/18152#discussion_r2466170902
##########
datafusion/physical-expr/src/expressions/case.rs:
##########
@@ -122,6 +125,384 @@ fn is_cheap_and_infallible(expr: &Arc<dyn PhysicalExpr>)
-> bool {
expr.as_any().is::<Column>()
}
+/// Creates a [FilterPredicate] from a boolean array.
+fn create_filter(predicate: &BooleanArray) -> FilterPredicate {
+ let mut filter_builder = FilterBuilder::new(predicate);
+ // Always optimize the filter since we use them multiple times.
+ filter_builder = filter_builder.optimize();
+ filter_builder.build()
+}
+
+// This should be removed when https://github.com/apache/arrow-rs/pull/8693
+// is merged and becomes available.
+fn filter_record_batch(
+ record_batch: &RecordBatch,
+ filter: &FilterPredicate,
+) -> std::result::Result<RecordBatch, ArrowError> {
+ let filtered_columns = record_batch
+ .columns()
+ .iter()
+ .map(|a| filter_array(a, filter))
+ .collect::<std::result::Result<Vec<_>, _>>()?;
+ // SAFETY: since we start from a valid RecordBatch, there's no need to
revalidate the schema
+ // since the set of columns has not changed.
+ // The input column arrays all had the same length (since they're coming
from a valid RecordBatch)
+ // and the filtering them with the same filter will produces a new set of
arrays with identical
+ // lengths.
+ unsafe {
+ Ok(RecordBatch::new_unchecked(
+ record_batch.schema(),
+ filtered_columns,
+ filter.count(),
+ ))
+ }
+}
+
+// This function exists purely to be able to use the same call style
+// for `filter_record_batch` and `filter_array` at the point of use.
+// When https://github.com/apache/arrow-rs/pull/8693 is available, replace
+// both with method calls on `FilterPredicate`.
+#[inline(always)]
+fn filter_array(
+ array: &dyn Array,
+ filter: &FilterPredicate,
+) -> std::result::Result<ArrayRef, ArrowError> {
+ filter.filter(array)
+}
+
+/// Merges elements by index from a list of [`ArrayData`], creating a new
[`ColumnarValue`] from
+/// those values.
+///
+/// Each element in `indices` is the index of an array in `values`. The
`indices` array is processed
+/// sequentially. The first occurrence of index value `n` will be mapped to
the first
+/// value of the array at index `n`. The second occurrence to the second
value, and so on.
+/// An index value of `usize::MAX` is used to indicate null values.
+///
+/// # Implementation notes
+///
+/// This algorithm is similar in nature to both `zip` and `interleave`, but
there are some important
+/// differences.
+///
+/// In contrast to `zip`, this function supports multiple input arrays.
Instead of a boolean
+/// selection vector, an index array is to take values from the input arrays,
and a special marker
+/// value is used to indicate null values.
+///
+/// In contrast to `interleave`, this function does not use pairs of indices.
The values in
+/// `indices` serve the same purpose as the first value in the pairs passed to
`interleave`.
+/// The index in the array is implicit and is derived from the number of times
a particular array
+/// index occurs.
+/// The more constrained indexing mechanism used by this algorithm makes it
easier to copy values
+/// in contiguous slices. In the example below, the two subsequent elements
from array `2` can be
+/// copied in a single operation from the source array instead of copying them
one by one.
+/// Long spans of null values are also especially cheap because they do not
need to be represented
+/// in an input array.
+///
+/// # Safety
+///
+/// This function does not check that the number of occurrences of any
particular array index matches
+/// the length of the corresponding input array. If an array contains more
values than required, the
+/// spurious values will be ignored. If an array contains fewer values than
necessary, this function
+/// will panic.
+///
+/// # Example
+///
+/// ```text
+/// ┌───────────┐ ┌─────────┐ ┌─────────┐
+/// │┌─────────┐│ │ MAX │ │ NULL │
+/// ││ A ││ ├─────────┤ ├─────────┤
+/// │└─────────┘│ │ 1 │ │ B │
+/// │┌─────────┐│ ├─────────┤ ├─────────┤
+/// ││ B ││ │ 0 │ merge(values, indices) │ A │
+/// │└─────────┘│ ├─────────┤ ─────────────────────────▶ ├─────────┤
+/// │┌─────────┐│ │ MAX │ │ NULL │
+/// ││ C ││ ├─────────┤ ├─────────┤
+/// │├─────────┤│ │ 2 │ │ C │
+/// ││ D ││ ├─────────┤ ├─────────┤
+/// │└─────────┘│ │ 2 │ │ D │
+/// └───────────┘ └─────────┘ └─────────┘
+/// values indices result
+///
+/// ```
+fn merge(values: &[ArrayData], indices: &[PartialResultIndex]) ->
Result<ArrayRef> {
+ let data_refs = values.iter().collect();
+ let mut mutable = MutableArrayData::new(data_refs, true, indices.len());
+
+ // This loop extends the mutable array by taking slices from the partial
results.
+ //
+ // take_offsets keeps track of how many values have been taken from each
array.
+ let mut take_offsets = vec![0; values.len() + 1];
+ let mut start_row_ix = 0;
+ loop {
+ let array_ix = indices[start_row_ix];
+
+ // Determine the length of the slice to take.
+ let mut end_row_ix = start_row_ix + 1;
+ while end_row_ix < indices.len() && indices[end_row_ix] == array_ix {
+ end_row_ix += 1;
+ }
+ let slice_length = end_row_ix - start_row_ix;
+
+ // Extend mutable with either nulls or with values from the array.
+ match array_ix.index() {
+ None => mutable.extend_nulls(slice_length),
+ Some(index) => {
+ let start_offset = take_offsets[index];
+ let end_offset = start_offset + slice_length;
+ mutable.extend(index, start_offset, end_offset);
+ take_offsets[index] = end_offset;
+ }
+ }
+
+ if end_row_ix == indices.len() {
+ break;
+ } else {
+ // Set the start_row_ix for the next slice.
+ start_row_ix = end_row_ix;
+ }
+ }
+
+ Ok(make_array(mutable.freeze()))
+}
+
+/// An index into the partial results array that's more compact than `usize`.
+///
+/// `u32::MAX` is reserved as a special 'none' value. This is used instead of
+/// `Option` to keep the array of indices as compact as possible.
+#[derive(Copy, Clone, PartialEq, Eq)]
+struct PartialResultIndex {
+ index: u32,
+}
+
+const NONE_VALUE: u32 = u32::MAX;
+
+impl PartialResultIndex {
+ /// Returns the 'none' placeholder value.
+ fn none() -> Self {
+ Self { index: NONE_VALUE }
+ }
+
+ /// Creates a new partial result index.
+ ///
+ /// If the provide value is greater than or equal to `u32::MAX`
+ /// an error will be returned.
+ fn try_new(index: usize) -> Result<Self> {
+ let Ok(index) = u32::try_from(index) else {
+ return internal_err!("Partial result index exceeds limit");
+ };
+
+ if index == NONE_VALUE {
+ return internal_err!("Partial result index exceeds limit");
+ }
+
+ Ok(Self { index })
+ }
+
+ /// Determines if this index is the 'none' placeholder value or not.
+ fn is_none(&self) -> bool {
+ self.index == NONE_VALUE
+ }
+
+ /// Returns `Some(index)` if this value is not the 'none' placeholder,
`None` otherwise.
+ fn index(&self) -> Option<usize> {
+ if self.is_none() {
+ None
+ } else {
+ Some(self.index as usize)
+ }
+ }
+}
+
+impl Debug for PartialResultIndex {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ if self.is_none() {
+ write!(f, "null")
+ } else {
+ write!(f, "{}", self.index)
+ }
+ }
+}
+
+enum ResultState {
+ /// The final result needs to be computed by merging the the data in
`arrays`.
+ Partial {
+ // A Vec of partial results that should be merged.
`partial_result_indices` contains
+ // indexes into this vec.
+ arrays: Vec<ArrayData>,
+ // Indicates per result row from which array in `partial_results` a
value should be taken.
+ indices: Vec<PartialResultIndex>,
+ },
+ /// A single branch matched all input rows. When creating the final
result, no further merging
+ /// of partial results is necessary.
+ Complete(ColumnarValue),
+}
+
+/// A builder for constructing result arrays for CASE expressions.
+///
+/// Rather than building a monolithic array containing all results, it
maintains a set of
+/// partial result arrays and a mapping that indicates for each row which
partial array
+/// contains the result value for that row.
+///
+/// On finish(), the builder will merge all partial results into a single
array if necessary.
+/// If all rows evaluated to the same array, that array can be returned
directly without
+/// any merging overhead.
+struct ResultBuilder {
+ data_type: DataType,
+ row_count: usize,
+ state: ResultState,
+}
+
+impl ResultBuilder {
+ /// Creates a new ResultBuilder that will produce arrays of the given data
type.
+ ///
+ /// The `row_count` parameter indicates the number of rows in the final
result.
+ fn new(data_type: &DataType, row_count: usize) -> Self {
+ Self {
+ data_type: data_type.clone(),
+ row_count,
+ state: Partial {
+ arrays: vec![],
+ indices: vec![PartialResultIndex::none(); row_count],
+ },
+ }
+ }
+
+ /// Adds a result for one branch of the case expression.
+ ///
+ /// `row_indices` should be a [UInt32Array] containing [RecordBatch]
relative row indices
+ /// for which `value` contains result values.
+ ///
+ /// If `value` is a scalar, the scalar value will be used as the value for
each row in `row_indices`.
+ ///
+ /// If `value` is an array, the values from the array and the indices from
`row_indices` will be
+ /// processed pairwise. The lengths of `value` and `row_indices` must
match.
+ ///
+ /// The diagram below shows a situation where a when expression matched
rows 1 and 4 of the
+ /// record batch. The then expression produced the value array `[A, D]`.
+ /// After adding this result, the result array will have been added to
`Partial::arrays` and
+ /// `Partial::indices` will have been updated at indexes 1 and 4.
+ ///
+ /// ```text
+ /// ┌─────────┐ ┌─────────┐┌───────────┐
┌─────────┐┌───────────┐
+ /// │ C │ │ MAX ││┌─────────┐│ │
MAX ││┌─────────┐│
+ /// ├─────────┤ ├─────────┤││ A ││
├─────────┤││ A ││
+ /// │ D │ │ MAX ││└─────────┘│ │
2 ││└─────────┘│
+ /// └─────────┘ ├─────────┤│┌─────────┐│ add_branch_result(
├─────────┤│┌─────────┐│
+ /// value │ 0 │││ B ││ row indices, │
0 │││ B ││
+ /// ├─────────┤│└─────────┘│ value
├─────────┤│└─────────┘│
+ /// │ MAX ││ │ ) │
MAX ││┌─────────┐│
+ /// ┌─────────┐ ├─────────┤│ │ ─────────────────────────▶
├─────────┤││ C ││
+ /// │ 1 │ │ MAX ││ │ │
2 ││├─────────┤│
+ /// ├─────────┤ ├─────────┤│ │
├─────────┤││ D ││
+ /// │ 4 │ │ 1 ││ │ │
1 ││└─────────┘│
+ /// └─────────┘ └─────────┘└───────────┘
└─────────┘└───────────┘
+ /// row indices
+ /// partial partial
partial partial
+ /// indices arrays
indices arrays
+ /// ```
+ fn add_branch_result(
+ &mut self,
+ row_indices: &ArrayRef,
+ value: ColumnarValue,
+ ) -> Result<()> {
+ match value {
+ ColumnarValue::Array(a) => {
+ if a.len() != row_indices.len() {
+ internal_err!("Array length must match row indices length")
+ } else if row_indices.len() == self.row_count {
+ self.set_single_result(ColumnarValue::Array(a))
+ } else {
+ self.add_partial_result(row_indices, a.to_data())
+ }
+ }
+ ColumnarValue::Scalar(s) => {
+ if row_indices.len() == self.row_count {
+ self.set_single_result(ColumnarValue::Scalar(s))
+ } else {
+ self.add_partial_result(
+ row_indices,
+ s.to_array_of_size(row_indices.len())?.to_data(),
+ )
+ }
+ }
+ }
+ }
+
+ /// Adds a partial result array.
+ ///
+ /// This method adds the given array data as a partial result and updates
the index mapping
+ /// to indicate that the specified rows should take their values from this
array.
+ /// The partial results will be merged into a single array when finish()
is called.
+ fn add_partial_result(
+ &mut self,
+ row_indices: &ArrayRef,
+ row_values: ArrayData,
+ ) -> Result<()> {
+ match &mut self.state {
+ Partial { arrays, indices } => {
+ // This is check is only active for debug config because the
callers of this method,
+ // `case_when_with_expr` and `case_when_no_expr`, already
ensure that
+ // they only calculate a value for each row at most once.
+ #[cfg(debug_assertions)]
+ for row_ix in
row_indices.as_primitive::<UInt32Type>().values().iter() {
Review Comment:
Can you assert (not in debug assertions) that row_indices does not have any
nulls as you assume here that no nulls exists
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]