pepijnve commented on code in PR #18183:
URL: https://github.com/apache/datafusion/pull/18183#discussion_r2541841784


##########
datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs:
##########
@@ -0,0 +1,402 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+mod boolean_lookup_table;
+mod bytes_like_lookup_table;
+mod primitive_lookup_table;
+
+use 
crate::expressions::case::literal_lookup_table::boolean_lookup_table::BooleanIndexMap;
+use crate::expressions::case::literal_lookup_table::bytes_like_lookup_table::{
+    BytesDictionaryHelper, BytesLikeIndexMap, BytesViewDictionaryHelper,
+    FixedBinaryHelper, FixedBytesDictionaryHelper, GenericBytesHelper,
+    GenericBytesViewHelper,
+};
+use 
crate::expressions::case::literal_lookup_table::primitive_lookup_table::PrimitiveIndexMap;
+use crate::expressions::case::CaseBody;
+use crate::expressions::Literal;
+use arrow::array::{downcast_integer, downcast_primitive, ArrayRef, 
UInt32Array};
+use arrow::datatypes::{
+    ArrowDictionaryKeyType, BinaryViewType, DataType, GenericBinaryType,
+    GenericStringType, StringViewType,
+};
+use datafusion_common::DataFusionError;
+use datafusion_common::{arrow_datafusion_err, plan_datafusion_err, 
ScalarValue};
+use indexmap::IndexMap;
+use std::fmt::Debug;
+
+/// Optimization for CASE expressions with literal WHEN and THEN clauses
+///
+/// for this form:
+/// ```sql
+/// CASE <expr_a>
+///     WHEN <literal_a> THEN <literal_e>
+///     WHEN <literal_b> THEN <literal_f>
+///     WHEN <literal_c> THEN <literal_g>
+///     WHEN <literal_d> THEN <literal_h>
+///     ELSE <optional-fallback_literal>
+/// END
+/// ```
+///
+/// # Improvement idea
+/// TODO - we should think of unwrapping the `IN` expressions into multiple 
equality comparisons
+/// so it will use this optimization as well, e.g.
+/// ```sql
+/// -- Before
+/// CASE
+///     WHEN (<expr_a> = <literal_a>) THEN <literal_e>
+///     WHEN (<expr_a> in (<literal_b>, <literal_c>) THEN <literal_f>
+///     WHEN (<expr_a> = <literal_d>) THEN <literal_g>
+/// ELSE <optional-fallback_literal>
+///
+/// -- After
+/// CASE
+///     WHEN (<expr_a> = <literal_a>) THEN <literal_e>
+///     WHEN (<expr_a> = <literal_b>) THEN <literal_f>
+///     WHEN (<expr_a> = <literal_c>) THEN <literal_g>
+///     WHEN (<expr_a> = <literal_d>) THEN <literal_h>
+///     ELSE <optional-fallback_literal>
+/// END
+/// ```
+///
+#[derive(Debug)]
+pub(in super::super) struct LiteralLookupTable {
+    /// The lookup table to use for evaluating the CASE expression
+    lookup: Box<dyn WhenLiteralIndexMap>,
+
+    else_index: u32,
+
+    /// [`ArrayRef`] where `array[i] = then_literals[i]`
+    /// the last value in the array is the else_expr
+    ///
+    /// This will be used to take from based on the indices returned by the 
lookup table to build the final output
+    then_and_else_values: ArrayRef,
+}
+
+impl LiteralLookupTable {
+    pub(in super::super) fn maybe_new(body: &CaseBody) -> Option<Self> {
+        // We can't use the optimization if we don't have any when then pairs
+        if body.when_then_expr.is_empty() {
+            return None;
+        }
+
+        // If we only have 1 than this optimization is not useful
+        if body.when_then_expr.len() == 1 {
+            return None;
+        }
+
+        // Try to downcast all the WHEN/THEN expressions to literals
+        let when_then_exprs_maybe_literals = body
+            .when_then_expr
+            .iter()
+            .map(|(when, then)| {
+                let when_maybe_literal = 
when.as_any().downcast_ref::<Literal>();
+                let then_maybe_literal = 
then.as_any().downcast_ref::<Literal>();
+
+                when_maybe_literal.zip(then_maybe_literal)
+            })
+            .collect::<Vec<_>>();
+
+        // If not all the WHEN/THEN expressions are literals we cannot use 
this optimization
+        if when_then_exprs_maybe_literals.contains(&None) {
+            return None;
+        }
+
+        let when_then_exprs_scalars = when_then_exprs_maybe_literals
+            .into_iter()
+            // Unwrap the options as we have already checked there is no None
+            .flatten()
+            .map(|(when_lit, then_lit)| {
+                (when_lit.value().clone(), then_lit.value().clone())
+            })
+            // Only keep non-null WHEN literals
+            // as they cannot be matched - case NULL WHEN NULL THEN ... ELSE 
... END always goes to ELSE
+            .filter(|(when_lit, _)| !when_lit.is_null())
+            .collect::<Vec<_>>();
+
+        if when_then_exprs_scalars.is_empty() {
+            // All WHEN literals were nulls, so cannot use optimization
+            //
+            // instead, another optimization would be to go straight to the 
ELSE clause
+            return None;
+        }
+
+        // Keep only the first occurrence of each when literal (as the first 
match is used)
+        // and remove nulls (as they cannot be matched - case NULL WHEN NULL 
THEN ... ELSE ... END always goes to ELSE)
+        let (when, then): (Vec<ScalarValue>, Vec<ScalarValue>) = {
+            let mut map = IndexMap::with_capacity(body.when_then_expr.len());
+
+            for (when, then) in when_then_exprs_scalars.into_iter() {
+                // Don't overwrite existing entries as we want to keep the 
first occurrence
+                if !map.contains_key(&when) {
+                    map.insert(when, then);
+                }
+            }
+
+            map.into_iter().unzip()
+        };
+
+        let else_value: ScalarValue = if let Some(else_expr) = &body.else_expr 
{
+            let literal = else_expr.as_any().downcast_ref::<Literal>()?;
+
+            literal.value().clone()
+        } else {
+            let Ok(null_scalar) = 
ScalarValue::try_new_null(&then[0].data_type()) else {
+                return None;
+            };
+
+            null_scalar
+        };
+
+        {
+            let data_type = when[0].data_type();
+
+            // If not all the WHEN literals are the same data type we cannot 
use this optimization
+            if when.iter().any(|l| l.data_type() != data_type) {
+                return None;
+            }
+        }
+
+        {
+            let data_type = then[0].data_type();
+
+            // If not all the then and the else literals are the same data 
type we cannot use this optimization
+            if then.iter().any(|l| l.data_type() != data_type) {
+                return None;
+            }
+
+            if else_value.data_type() != data_type {
+                return None;
+            }
+        }
+
+        let then_and_else_values = ScalarValue::iter_to_array(
+            then.iter()
+                // The else is in the end
+                .chain(std::iter::once(&else_value))
+                .cloned(),
+        )
+        .ok()?;
+        // The else expression is in the end
+        let else_index = then_and_else_values.len() as u32 - 1;
+
+        let lookup = try_creating_lookup_table(when).ok()?;
+
+        Some(Self {
+            lookup,
+            then_and_else_values,
+            else_index,
+        })
+    }
+
+    pub(in super::super) fn map_input_to_output(
+        &self,
+        expr_array: &ArrayRef,

Review Comment:
   Ok, I understand where you're coming from. For the lookup tables, I was 
mostly trying to understand them as generic associative arrays without taking 
the usage context into account. When looked at that way `expr_array` is making 
a reference to context that you don't really have at this point in the code. 
Sorry if I'm being too nitpicky about this stuff.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to