tustvold commented on code in PR #3975:
URL: https://github.com/apache/arrow-datafusion/pull/3975#discussion_r1006427615
##########
datafusion/physical-expr/src/expressions/in_list.rs:
##########
@@ -714,193 +912,86 @@ impl PhysicalExpr for InListExpr {
};
match value_data_type {
- DataType::Float32 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- Float32,
- Float32Array
- )
- }
- DataType::Float64 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- Float64,
- Float64Array
- )
- }
- DataType::Int16 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- Int16,
- Int16Array
- )
- }
- DataType::Int32 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- Int32,
- Int32Array
- )
- }
- DataType::Int64 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- Int64,
- Int64Array
- )
- }
- DataType::Int8 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- Int8,
- Int8Array
- )
- }
- DataType::UInt16 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- UInt16,
- UInt16Array
- )
- }
- DataType::UInt32 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- UInt32,
- UInt32Array
- )
- }
- DataType::UInt64 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- UInt64,
- UInt64Array
- )
- }
- DataType::UInt8 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- UInt8,
- UInt8Array
- )
- }
- DataType::Date32 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- Date32,
- Date32Array
- )
- }
- DataType::Date64 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- Date64,
- Date64Array
- )
- }
- DataType::Boolean => Ok(make_contains!(
- array,
- list_values,
- self.negated,
- Boolean,
- BooleanArray
- )),
- DataType::Utf8 => {
- self.compare_utf8::<i32>(array, list_values, self.negated)
- }
- DataType::LargeUtf8 => {
- self.compare_utf8::<i64>(array, list_values, self.negated)
- }
- DataType::Binary => {
- self.compare_binary::<i32>(array, list_values,
self.negated)
- }
- DataType::LargeBinary => {
- self.compare_binary::<i64>(array, list_values,
self.negated)
- }
- DataType::Null => {
- let null_array = new_null_array(&DataType::Boolean,
array.len());
- Ok(ColumnarValue::Array(Arc::new(null_array)))
- }
- DataType::Decimal128(_, _) => {
- let decimal_array =
-
array.as_any().downcast_ref::<Decimal128Array>().unwrap();
- Ok(make_list_contains_decimal(
- decimal_array,
- list_values,
- self.negated,
- ))
+ DataType::Dictionary(_key_type, value_type) => {
+ // 1
+ let dict_array = array
+ .as_any()
+ .downcast_ref::<DictionaryArray<Int32Type>>()
+ .unwrap();
+ let keys = dict_array.keys();
+
+ let values_result = evaluate_set(&array,
list_values).unwrap();
+ let values = compute::take(values_result.as_ref(), keys,
None)?;
+ Ok(ColumnarValue::Array(values))
+
+ // 2
+ // let dict_array =
array.as_any().downcast_ref::<DictionaryArray<Int32Type>>().unwrap();
+ // let keys = dict_array.keys();
+
+ // let values_result =
dict_to_val_array::<Int32Type>(&array).unwrap();
+ // self.evaluate_non_dict(*value_type, values_result,
list_values)
}
- DataType::Timestamp(unit, _) => match unit {
- TimeUnit::Second => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- TimestampSecond,
- TimestampSecondArray
- )
- }
- TimeUnit::Millisecond => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- TimestampMillisecond,
- TimestampMillisecondArray
- )
- }
- TimeUnit::Microsecond => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- TimestampMicrosecond,
- TimestampMicrosecondArray
- )
- }
- TimeUnit::Nanosecond => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- TimestampNanosecond,
- TimestampNanosecondArray
- )
- }
- },
- datatype =>
Result::Err(DataFusionError::NotImplemented(format!(
- "InList does not support datatype {:?}.",
- datatype
- ))),
+ _ => self.evaluate_non_dict(value_data_type, array,
list_values),
}
}
}
}
+fn dict_to_val_array<K: ArrowDictionaryKeyType>(
+ array: &Arc<dyn Array>,
+) -> Result<Arc<dyn Array>> {
+ let dict_array =
array.as_any().downcast_ref::<DictionaryArray<K>>().unwrap();
+ let mut dict_vals = Vec::with_capacity(dict_array.len());
+ for i in 0..array.len() {
+ let (values_array, values_index) = get_dict_value::<K>(&array, i);
+ // Look up value from Index
+ let value = match values_index {
+ Some(values_index) => ScalarValue::try_from_array(values_array,
values_index),
+ // Entry was null, so return null
+ None => values_array.data_type().try_into(),
+ }?;
+ dict_vals.push(value);
+ }
+ let vals = ScalarValue::iter_to_array(dict_vals).unwrap();
+ Ok(vals)
+}
+
+// Return a boolean array indicating whether the value is in list_values
+fn evaluate_set(
+ array: &Arc<dyn Array>,
+ list_values: Vec<ColumnarValue>,
+) -> Result<Arc<dyn Array>> {
+ // convert value_list to an array
+ let scalars = list_values
+ .iter()
+ .map(|v| match v {
+ ColumnarValue::Scalar(scalar) => scalar.clone(),
+ ColumnarValue::Array(_array) => {
+ unimplemented!("InList does not yet support nested columns.")
+ }
+ })
+ .collect::<Vec<_>>();
+ let list_array = ScalarValue::iter_to_array(scalars).unwrap();
+
+ let cmp = build_compare(&array, &list_array).unwrap();
Review Comment:
This is known not to handle nulls correctly -
https://github.com/apache/arrow-rs/issues/2687
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]