rluvaton commented on code in PR #18183: URL: https://github.com/apache/datafusion/pull/18183#discussion_r2541633473
########## datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs: ########## @@ -0,0 +1,402 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod boolean_lookup_table; +mod bytes_like_lookup_table; +mod primitive_lookup_table; + +use crate::expressions::case::literal_lookup_table::boolean_lookup_table::BooleanIndexMap; +use crate::expressions::case::literal_lookup_table::bytes_like_lookup_table::{ + BytesDictionaryHelper, BytesLikeIndexMap, BytesViewDictionaryHelper, + FixedBinaryHelper, FixedBytesDictionaryHelper, GenericBytesHelper, + GenericBytesViewHelper, +}; +use crate::expressions::case::literal_lookup_table::primitive_lookup_table::PrimitiveIndexMap; +use crate::expressions::case::CaseBody; +use crate::expressions::Literal; +use arrow::array::{downcast_integer, downcast_primitive, ArrayRef, UInt32Array}; +use arrow::datatypes::{ + ArrowDictionaryKeyType, BinaryViewType, DataType, GenericBinaryType, + GenericStringType, StringViewType, +}; +use datafusion_common::DataFusionError; +use datafusion_common::{arrow_datafusion_err, plan_datafusion_err, ScalarValue}; +use indexmap::IndexMap; +use std::fmt::Debug; + +/// Optimization for CASE expressions with literal WHEN and THEN clauses +/// +/// for this form: +/// ```sql +/// CASE <expr_a> +/// WHEN <literal_a> THEN <literal_e> +/// WHEN <literal_b> THEN <literal_f> +/// WHEN <literal_c> THEN <literal_g> +/// WHEN <literal_d> THEN <literal_h> +/// ELSE <optional-fallback_literal> +/// END +/// ``` +/// +/// # Improvement idea +/// TODO - we should think of unwrapping the `IN` expressions into multiple equality comparisons +/// so it will use this optimization as well, e.g. +/// ```sql +/// -- Before +/// CASE +/// WHEN (<expr_a> = <literal_a>) THEN <literal_e> +/// WHEN (<expr_a> in (<literal_b>, <literal_c>) THEN <literal_f> +/// WHEN (<expr_a> = <literal_d>) THEN <literal_g> +/// ELSE <optional-fallback_literal> +/// +/// -- After +/// CASE +/// WHEN (<expr_a> = <literal_a>) THEN <literal_e> +/// WHEN (<expr_a> = <literal_b>) THEN <literal_f> +/// WHEN (<expr_a> = <literal_c>) THEN <literal_g> +/// WHEN (<expr_a> = <literal_d>) THEN <literal_h> +/// ELSE <optional-fallback_literal> +/// END +/// ``` +/// +#[derive(Debug)] +pub(in super::super) struct LiteralLookupTable { + /// The lookup table to use for evaluating the CASE expression + lookup: Box<dyn WhenLiteralIndexMap>, + + else_index: u32, + + /// [`ArrayRef`] where `array[i] = then_literals[i]` + /// the last value in the array is the else_expr + /// + /// This will be used to take from based on the indices returned by the lookup table to build the final output + then_and_else_values: ArrayRef, +} + +impl LiteralLookupTable { + pub(in super::super) fn maybe_new(body: &CaseBody) -> Option<Self> { + // We can't use the optimization if we don't have any when then pairs + if body.when_then_expr.is_empty() { + return None; + } + + // If we only have 1 than this optimization is not useful + if body.when_then_expr.len() == 1 { + return None; + } + + // Try to downcast all the WHEN/THEN expressions to literals + let when_then_exprs_maybe_literals = body + .when_then_expr + .iter() + .map(|(when, then)| { + let when_maybe_literal = when.as_any().downcast_ref::<Literal>(); + let then_maybe_literal = then.as_any().downcast_ref::<Literal>(); + + when_maybe_literal.zip(then_maybe_literal) + }) + .collect::<Vec<_>>(); + + // If not all the WHEN/THEN expressions are literals we cannot use this optimization + if when_then_exprs_maybe_literals.contains(&None) { + return None; + } + + let when_then_exprs_scalars = when_then_exprs_maybe_literals + .into_iter() + // Unwrap the options as we have already checked there is no None + .flatten() + .map(|(when_lit, then_lit)| { + (when_lit.value().clone(), then_lit.value().clone()) + }) + // Only keep non-null WHEN literals + // as they cannot be matched - case NULL WHEN NULL THEN ... ELSE ... END always goes to ELSE + .filter(|(when_lit, _)| !when_lit.is_null()) + .collect::<Vec<_>>(); + + if when_then_exprs_scalars.is_empty() { + // All WHEN literals were nulls, so cannot use optimization + // + // instead, another optimization would be to go straight to the ELSE clause + return None; + } + + // Keep only the first occurrence of each when literal (as the first match is used) + // and remove nulls (as they cannot be matched - case NULL WHEN NULL THEN ... ELSE ... END always goes to ELSE) + let (when, then): (Vec<ScalarValue>, Vec<ScalarValue>) = { + let mut map = IndexMap::with_capacity(body.when_then_expr.len()); + + for (when, then) in when_then_exprs_scalars.into_iter() { + // Don't overwrite existing entries as we want to keep the first occurrence + if !map.contains_key(&when) { + map.insert(when, then); + } + } + + map.into_iter().unzip() + }; + + let else_value: ScalarValue = if let Some(else_expr) = &body.else_expr { + let literal = else_expr.as_any().downcast_ref::<Literal>()?; + + literal.value().clone() + } else { + let Ok(null_scalar) = ScalarValue::try_new_null(&then[0].data_type()) else { + return None; + }; + + null_scalar + }; + + { + let data_type = when[0].data_type(); + + // If not all the WHEN literals are the same data type we cannot use this optimization + if when.iter().any(|l| l.data_type() != data_type) { + return None; + } + } + + { + let data_type = then[0].data_type(); + + // If not all the then and the else literals are the same data type we cannot use this optimization + if then.iter().any(|l| l.data_type() != data_type) { + return None; + } + + if else_value.data_type() != data_type { + return None; + } + } + + let then_and_else_values = ScalarValue::iter_to_array( + then.iter() + // The else is in the end + .chain(std::iter::once(&else_value)) + .cloned(), + ) + .ok()?; + // The else expression is in the end + let else_index = then_and_else_values.len() as u32 - 1; + + let lookup = try_creating_lookup_table(when).ok()?; + + Some(Self { + lookup, + then_and_else_values, + else_index, + }) + } + + pub(in super::super) fn map_input_to_output( + &self, + expr_array: &ArrayRef, + ) -> datafusion_common::Result<ArrayRef> { + let take_indices = self.lookup.map_to_indices(expr_array, self.else_index)?; + + // Zero-copy conversion + let take_indices = UInt32Array::from(take_indices); + + // An optimize version would depend on the type of the values_to_take_from + // For example, if the type is view we can just keep pointing to the same value (similar to dictionary) + // if the type is dictionary we can just use the indices as is (or cast them to the key type) and create a new dictionary array + let output = + arrow::compute::take(&self.then_and_else_values, &take_indices, None) + .map_err(|e| arrow_datafusion_err!(e))?; + + Ok(output) + } +} + +/// Lookup table for mapping literal values to their corresponding indices in the THEN clauses Review Comment: updated -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
