alamb commented on code in PR #3975:
URL: https://github.com/apache/arrow-datafusion/pull/3975#discussion_r1006872108


##########
datafusion/physical-expr/src/expressions/in_list.rs:
##########
@@ -714,193 +912,86 @@ impl PhysicalExpr for InListExpr {
             };
 
             match value_data_type {
-                DataType::Float32 => {
-                    make_contains_primitive!(
-                        array,
-                        list_values,
-                        self.negated,
-                        Float32,
-                        Float32Array
-                    )
-                }
-                DataType::Float64 => {
-                    make_contains_primitive!(
-                        array,
-                        list_values,
-                        self.negated,
-                        Float64,
-                        Float64Array
-                    )
-                }
-                DataType::Int16 => {
-                    make_contains_primitive!(
-                        array,
-                        list_values,
-                        self.negated,
-                        Int16,
-                        Int16Array
-                    )
-                }
-                DataType::Int32 => {
-                    make_contains_primitive!(
-                        array,
-                        list_values,
-                        self.negated,
-                        Int32,
-                        Int32Array
-                    )
-                }
-                DataType::Int64 => {
-                    make_contains_primitive!(
-                        array,
-                        list_values,
-                        self.negated,
-                        Int64,
-                        Int64Array
-                    )
-                }
-                DataType::Int8 => {
-                    make_contains_primitive!(
-                        array,
-                        list_values,
-                        self.negated,
-                        Int8,
-                        Int8Array
-                    )
-                }
-                DataType::UInt16 => {
-                    make_contains_primitive!(
-                        array,
-                        list_values,
-                        self.negated,
-                        UInt16,
-                        UInt16Array
-                    )
-                }
-                DataType::UInt32 => {
-                    make_contains_primitive!(
-                        array,
-                        list_values,
-                        self.negated,
-                        UInt32,
-                        UInt32Array
-                    )
-                }
-                DataType::UInt64 => {
-                    make_contains_primitive!(
-                        array,
-                        list_values,
-                        self.negated,
-                        UInt64,
-                        UInt64Array
-                    )
-                }
-                DataType::UInt8 => {
-                    make_contains_primitive!(
-                        array,
-                        list_values,
-                        self.negated,
-                        UInt8,
-                        UInt8Array
-                    )
-                }
-                DataType::Date32 => {
-                    make_contains_primitive!(
-                        array,
-                        list_values,
-                        self.negated,
-                        Date32,
-                        Date32Array
-                    )
-                }
-                DataType::Date64 => {
-                    make_contains_primitive!(
-                        array,
-                        list_values,
-                        self.negated,
-                        Date64,
-                        Date64Array
-                    )
-                }
-                DataType::Boolean => Ok(make_contains!(
-                    array,
-                    list_values,
-                    self.negated,
-                    Boolean,
-                    BooleanArray
-                )),
-                DataType::Utf8 => {
-                    self.compare_utf8::<i32>(array, list_values, self.negated)
-                }
-                DataType::LargeUtf8 => {
-                    self.compare_utf8::<i64>(array, list_values, self.negated)
-                }
-                DataType::Binary => {
-                    self.compare_binary::<i32>(array, list_values, 
self.negated)
-                }
-                DataType::LargeBinary => {
-                    self.compare_binary::<i64>(array, list_values, 
self.negated)
-                }
-                DataType::Null => {
-                    let null_array = new_null_array(&DataType::Boolean, 
array.len());
-                    Ok(ColumnarValue::Array(Arc::new(null_array)))
-                }
-                DataType::Decimal128(_, _) => {
-                    let decimal_array =
-                        
array.as_any().downcast_ref::<Decimal128Array>().unwrap();
-                    Ok(make_list_contains_decimal(
-                        decimal_array,
-                        list_values,
-                        self.negated,
-                    ))
+                DataType::Dictionary(_key_type, value_type) => {
+                    // 1
+                    let dict_array = array
+                        .as_any()
+                        .downcast_ref::<DictionaryArray<Int32Type>>()
+                        .unwrap();
+                    let keys = dict_array.keys();
+
+                    let values_result = evaluate_set(&array, 
list_values).unwrap();
+                    let values = compute::take(values_result.as_ref(), keys, 
None)?;
+                    Ok(ColumnarValue::Array(values))
+
+                    // 2
+                    // let dict_array = 
array.as_any().downcast_ref::<DictionaryArray<Int32Type>>().unwrap();
+                    // let keys = dict_array.keys();
+
+                    // let values_result = 
dict_to_val_array::<Int32Type>(&array).unwrap();
+                    // self.evaluate_non_dict(*value_type, values_result, 
list_values)
                 }
-                DataType::Timestamp(unit, _) => match unit {
-                    TimeUnit::Second => {
-                        make_contains_primitive!(
-                            array,
-                            list_values,
-                            self.negated,
-                            TimestampSecond,
-                            TimestampSecondArray
-                        )
-                    }
-                    TimeUnit::Millisecond => {
-                        make_contains_primitive!(
-                            array,
-                            list_values,
-                            self.negated,
-                            TimestampMillisecond,
-                            TimestampMillisecondArray
-                        )
-                    }
-                    TimeUnit::Microsecond => {
-                        make_contains_primitive!(
-                            array,
-                            list_values,
-                            self.negated,
-                            TimestampMicrosecond,
-                            TimestampMicrosecondArray
-                        )
-                    }
-                    TimeUnit::Nanosecond => {
-                        make_contains_primitive!(
-                            array,
-                            list_values,
-                            self.negated,
-                            TimestampNanosecond,
-                            TimestampNanosecondArray
-                        )
-                    }
-                },
-                datatype => 
Result::Err(DataFusionError::NotImplemented(format!(
-                    "InList does not support datatype {:?}.",
-                    datatype
-                ))),
+                _ => self.evaluate_non_dict(value_data_type, array, 
list_values),
             }
         }
     }
 }
 
+fn dict_to_val_array<K: ArrowDictionaryKeyType>(
+    array: &Arc<dyn Array>,
+) -> Result<Arc<dyn Array>> {
+    let dict_array = 
array.as_any().downcast_ref::<DictionaryArray<K>>().unwrap();
+    let mut dict_vals = Vec::with_capacity(dict_array.len());
+    for i in 0..array.len() {
+        let (values_array, values_index) = get_dict_value::<K>(&array, i);
+        // Look up value from Index
+        let value = match values_index {
+            Some(values_index) => ScalarValue::try_from_array(values_array, 
values_index),
+            // Entry was null, so return null
+            None => values_array.data_type().try_into(),
+        }?;
+        dict_vals.push(value);
+    }
+    let vals = ScalarValue::iter_to_array(dict_vals).unwrap();
+    Ok(vals)
+}
+
+// Return a boolean array indicating whether the value is in list_values
+fn evaluate_set(

Review Comment:
   Yeah, I was suggesting that since InList was already implemented for non 
dictionary types, we re-use that implementation (though I think that will 
require some restructuring of how `evaluate` is written



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to