This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new d698d9deb6 fix: Support Dict types in `in_list` physical plans (#10031)
d698d9deb6 is described below
commit d698d9deb6d06214ea36cc59739993ddf6441b6a
Author: advancedxy <[email protected]>
AuthorDate: Sun Apr 14 04:05:38 2024 +0800
fix: Support Dict types in `in_list` physical plans (#10031)
* fix: Relax type check with dict types in in_list
* refine comments
* fix style, refine comments and address reviewer's comments
* refine comments
* address comments
---
.../physical-expr/src/expressions/in_list.rs | 126 ++++++++++++++++++++-
datafusion/sqllogictest/test_files/dictionary.slt | 39 +++++++
2 files changed, 161 insertions(+), 4 deletions(-)
diff --git a/datafusion/physical-expr/src/expressions/in_list.rs
b/datafusion/physical-expr/src/expressions/in_list.rs
index ecdb03e97e..07185b4d65 100644
--- a/datafusion/physical-expr/src/expressions/in_list.rs
+++ b/datafusion/physical-expr/src/expressions/in_list.rs
@@ -415,6 +415,18 @@ impl PartialEq<dyn Any> for InListExpr {
}
}
+/// Checks if two types are logically equal, dictionary types are compared by
their value types.
+fn is_logically_eq(lhs: &DataType, rhs: &DataType) -> bool {
+ match (lhs, rhs) {
+ (DataType::Dictionary(_, v1), DataType::Dictionary(_, v2)) => {
+ v1.as_ref().eq(v2.as_ref())
+ }
+ (DataType::Dictionary(_, l), _) => l.as_ref().eq(rhs),
+ (_, DataType::Dictionary(_, r)) => lhs.eq(r.as_ref()),
+ _ => lhs.eq(rhs),
+ }
+}
+
/// Creates a unary expression InList
pub fn in_list(
expr: Arc<dyn PhysicalExpr>,
@@ -426,7 +438,7 @@ pub fn in_list(
let expr_data_type = expr.data_type(schema)?;
for list_expr in list.iter() {
let list_expr_data_type = list_expr.data_type(schema)?;
- if !expr_data_type.eq(&list_expr_data_type) {
+ if !is_logically_eq(&expr_data_type, &list_expr_data_type) {
return internal_err!(
"The data type inlist should be same, the value type is
{expr_data_type}, one of list expr type is {list_expr_data_type}"
);
@@ -499,7 +511,21 @@ mod tests {
macro_rules! in_list {
($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr, $COL:expr,
$SCHEMA:expr) => {{
let (cast_expr, cast_list_exprs) = in_list_cast($COL, $LIST,
$SCHEMA)?;
- let expr = in_list(cast_expr, cast_list_exprs, $NEGATED,
$SCHEMA).unwrap();
+ in_list_raw!(
+ $BATCH,
+ cast_list_exprs,
+ $NEGATED,
+ $EXPECTED,
+ cast_expr,
+ $SCHEMA
+ );
+ }};
+ }
+
+ // applies the in_list expr to an input batch and list without cast
+ macro_rules! in_list_raw {
+ ($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr, $COL:expr,
$SCHEMA:expr) => {{
+ let expr = in_list($COL, $LIST, $NEGATED, $SCHEMA).unwrap();
let result = expr
.evaluate(&$BATCH)?
.into_array($BATCH.num_rows())
@@ -540,7 +566,7 @@ mod tests {
&schema
);
- // expression: "a not in ("a", "b")"
+ // expression: "a in ("a", "b", null)"
let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))];
in_list!(
batch,
@@ -551,7 +577,7 @@ mod tests {
&schema
);
- // expression: "a not in ("a", "b")"
+ // expression: "a not in ("a", "b", null)"
let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))];
in_list!(
batch,
@@ -1314,4 +1340,96 @@ mod tests {
Ok(())
}
+
+ #[test]
+ fn in_list_utf8_with_dict_types() -> Result<()> {
+ fn dict_lit(key_type: DataType, value: &str) -> Arc<dyn PhysicalExpr> {
+ lit(ScalarValue::Dictionary(
+ Box::new(key_type),
+ Box::new(ScalarValue::new_utf8(value.to_string())),
+ ))
+ }
+
+ fn null_dict_lit(key_type: DataType) -> Arc<dyn PhysicalExpr> {
+ lit(ScalarValue::Dictionary(
+ Box::new(key_type),
+ Box::new(ScalarValue::Utf8(None)),
+ ))
+ }
+
+ let schema = Schema::new(vec![Field::new(
+ "a",
+ DataType::Dictionary(Box::new(DataType::UInt16),
Box::new(DataType::Utf8)),
+ true,
+ )]);
+ let a: UInt16DictionaryArray =
+ vec![Some("a"), Some("d"), None].into_iter().collect();
+ let col_a = col("a", &schema)?;
+ let batch = RecordBatch::try_new(Arc::new(schema.clone()),
vec![Arc::new(a)])?;
+
+ // expression: "a in ("a", "b")"
+ let lists = [
+ vec![lit("a"), lit("b")],
+ vec![
+ dict_lit(DataType::Int8, "a"),
+ dict_lit(DataType::UInt16, "b"),
+ ],
+ ];
+ for list in lists.iter() {
+ in_list_raw!(
+ batch,
+ list.clone(),
+ &false,
+ vec![Some(true), Some(false), None],
+ col_a.clone(),
+ &schema
+ );
+ }
+
+ // expression: "a not in ("a", "b")"
+ for list in lists.iter() {
+ in_list_raw!(
+ batch,
+ list.clone(),
+ &true,
+ vec![Some(false), Some(true), None],
+ col_a.clone(),
+ &schema
+ );
+ }
+
+ // expression: "a in ("a", "b", null)"
+ let lists = [
+ vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))],
+ vec![
+ dict_lit(DataType::Int8, "a"),
+ dict_lit(DataType::UInt16, "b"),
+ null_dict_lit(DataType::UInt16),
+ ],
+ ];
+ for list in lists.iter() {
+ in_list_raw!(
+ batch,
+ list.clone(),
+ &false,
+ vec![Some(true), None, None],
+ col_a.clone(),
+ &schema
+ );
+ }
+
+ // expression: "a not in ("a", "b", null)"
+ for list in lists.iter() {
+ in_list_raw!(
+ batch,
+ list.clone(),
+ &true,
+ vec![Some(false), None, None],
+ col_a.clone(),
+ &schema
+ );
+ }
+
+ Ok(())
+ }
}
diff --git a/datafusion/sqllogictest/test_files/dictionary.slt
b/datafusion/sqllogictest/test_files/dictionary.slt
index af7bf5cb16..891a09fbc1 100644
--- a/datafusion/sqllogictest/test_files/dictionary.slt
+++ b/datafusion/sqllogictest/test_files/dictionary.slt
@@ -87,6 +87,22 @@ f3 Utf8 YES
f4 Float64 YES
time Timestamp(Nanosecond, None) YES
+# in list with dictionary input
+query BBB
+SELECT
+ tag_id in ('1000'), '1000' in (tag_id, null),
arrow_cast('999','Dictionary(Int32, Utf8)') in (tag_id, null)
+FROM m1
+----
+true true NULL
+true true NULL
+true true NULL
+true true NULL
+true true NULL
+true true NULL
+true true NULL
+true true NULL
+true true NULL
+true true NULL
# Table m2 with a tag columns `tag_id` and `type`, a field column `f5`, and
`time`
statement ok
@@ -165,6 +181,29 @@ order by date_bin('30 minutes', time) DESC
3 400 600 500 2023-12-04T00:30:00
3 100 300 200 2023-12-04T00:00:00
+# query with in list
+query BBBBBBBB
+SELECT
+ type in ('active', 'passive')
+ , 'active' in (type)
+ , 'active' in (type, null)
+ , arrow_cast('passive','Dictionary(Int8, Utf8)') in (type, null)
+ , tag_id in ('1000', '2000')
+ , tag_id in ('999')
+ , '1000' in (tag_id, null)
+ , arrow_cast('999','Dictionary(Int16, Utf8)') in (tag_id, null)
+FROM m2
+----
+true true true NULL true false true NULL
+true true true NULL true false true NULL
+true true true NULL true false true NULL
+true true true NULL true false true NULL
+true true true NULL true false true NULL
+true true true NULL true false true NULL
+true false NULL true true false true NULL
+true false NULL true true false true NULL
+true false NULL true true false true NULL
+true false NULL true true false true NULL
# Reproducer for https://github.com/apache/arrow-datafusion/issues/8738