tustvold commented on code in PR #3975:
URL: https://github.com/apache/arrow-datafusion/pull/3975#discussion_r1006426603
##########
datafusion/physical-expr/src/expressions/in_list.rs:
##########
@@ -714,193 +912,86 @@ impl PhysicalExpr for InListExpr {
};
match value_data_type {
- DataType::Float32 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- Float32,
- Float32Array
- )
- }
- DataType::Float64 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- Float64,
- Float64Array
- )
- }
- DataType::Int16 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- Int16,
- Int16Array
- )
- }
- DataType::Int32 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- Int32,
- Int32Array
- )
- }
- DataType::Int64 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- Int64,
- Int64Array
- )
- }
- DataType::Int8 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- Int8,
- Int8Array
- )
- }
- DataType::UInt16 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- UInt16,
- UInt16Array
- )
- }
- DataType::UInt32 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- UInt32,
- UInt32Array
- )
- }
- DataType::UInt64 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- UInt64,
- UInt64Array
- )
- }
- DataType::UInt8 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- UInt8,
- UInt8Array
- )
- }
- DataType::Date32 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- Date32,
- Date32Array
- )
- }
- DataType::Date64 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- Date64,
- Date64Array
- )
- }
- DataType::Boolean => Ok(make_contains!(
- array,
- list_values,
- self.negated,
- Boolean,
- BooleanArray
- )),
- DataType::Utf8 => {
- self.compare_utf8::<i32>(array, list_values, self.negated)
- }
- DataType::LargeUtf8 => {
- self.compare_utf8::<i64>(array, list_values, self.negated)
- }
- DataType::Binary => {
- self.compare_binary::<i32>(array, list_values,
self.negated)
- }
- DataType::LargeBinary => {
- self.compare_binary::<i64>(array, list_values,
self.negated)
- }
- DataType::Null => {
- let null_array = new_null_array(&DataType::Boolean,
array.len());
- Ok(ColumnarValue::Array(Arc::new(null_array)))
- }
- DataType::Decimal128(_, _) => {
- let decimal_array =
-
array.as_any().downcast_ref::<Decimal128Array>().unwrap();
- Ok(make_list_contains_decimal(
- decimal_array,
- list_values,
- self.negated,
- ))
+ DataType::Dictionary(_key_type, value_type) => {
+ // 1
+ let dict_array = array
+ .as_any()
+ .downcast_ref::<DictionaryArray<Int32Type>>()
+ .unwrap();
+ let keys = dict_array.keys();
+
+ let values_result = evaluate_set(&array,
list_values).unwrap();
+ let values = compute::take(values_result.as_ref(), keys,
None)?;
+ Ok(ColumnarValue::Array(values))
+
+ // 2
+ // let dict_array =
array.as_any().downcast_ref::<DictionaryArray<Int32Type>>().unwrap();
+ // let keys = dict_array.keys();
+
+ // let values_result =
dict_to_val_array::<Int32Type>(&array).unwrap();
+ // self.evaluate_non_dict(*value_type, values_result,
list_values)
}
- DataType::Timestamp(unit, _) => match unit {
- TimeUnit::Second => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- TimestampSecond,
- TimestampSecondArray
- )
- }
- TimeUnit::Millisecond => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- TimestampMillisecond,
- TimestampMillisecondArray
- )
- }
- TimeUnit::Microsecond => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- TimestampMicrosecond,
- TimestampMicrosecondArray
- )
- }
- TimeUnit::Nanosecond => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- TimestampNanosecond,
- TimestampNanosecondArray
- )
- }
- },
- datatype =>
Result::Err(DataFusionError::NotImplemented(format!(
- "InList does not support datatype {:?}.",
- datatype
- ))),
+ _ => self.evaluate_non_dict(value_data_type, array,
list_values),
}
}
}
}
+fn dict_to_val_array<K: ArrowDictionaryKeyType>(
+ array: &Arc<dyn Array>,
+) -> Result<Arc<dyn Array>> {
+ let dict_array =
array.as_any().downcast_ref::<DictionaryArray<K>>().unwrap();
+ let mut dict_vals = Vec::with_capacity(dict_array.len());
+ for i in 0..array.len() {
+ let (values_array, values_index) = get_dict_value::<K>(&array, i);
+ // Look up value from Index
+ let value = match values_index {
+ Some(values_index) => ScalarValue::try_from_array(values_array,
values_index),
+ // Entry was null, so return null
+ None => values_array.data_type().try_into(),
+ }?;
+ dict_vals.push(value);
+ }
+ let vals = ScalarValue::iter_to_array(dict_vals).unwrap();
+ Ok(vals)
+}
+
+// Return a boolean array indicating whether the value is in list_values
+fn evaluate_set(
Review Comment:
Why is this a new function, I think @alamb 's suggestion was to recurse into
InListExpr::evaluate (by pulling it into a free function)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]