alamb commented on a change in pull request #9038:
URL: https://github.com/apache/arrow/pull/9038#discussion_r552193411



##########
File path: rust/datafusion/tests/sql.rs
##########
@@ -1849,3 +1849,89 @@ async fn string_expressions() -> Result<()> {
     assert_eq!(expected, actual);
     Ok(())
 }
+
+#[tokio::test]
+async fn in_list_array() -> Result<()> {
+    let mut ctx = ExecutionContext::new();
+    register_aggregate_csv_by_sql(&mut ctx).await;
+    let sql = "SELECT
+            c1 IN ('a', 'c') AS utf8_in_true
+            ,c1 IN ('x', 'y') AS utf8_in_false
+            ,c1 NOT IN ('x', 'y') AS utf8_not_in_true
+            ,c1 NOT IN ('a', 'c') AS utf8_not_in_false
+            ,CAST(CAST(c1 AS int) AS varchar) IN ('a', 'c') AS utf8_in_null
+        FROM aggregate_test_100 WHERE c12 < 0.05";
+    let actual = execute(&mut ctx, sql).await;
+    let expected = vec![
+        vec!["true", "false", "true", "false", "NULL"],
+        vec!["true", "false", "true", "false", "NULL"],
+        vec!["true", "false", "true", "false", "NULL"],
+        vec!["false", "false", "true", "true", "NULL"],
+        vec!["false", "false", "true", "true", "NULL"],
+        vec!["false", "false", "true", "true", "NULL"],
+        vec!["false", "false", "true", "true", "NULL"],
+    ];
+    assert_eq!(expected, actual);
+    Ok(())
+}
+
+#[tokio::test]
+async fn in_list_scalar() -> Result<()> {

Review comment:
       ❤️ 

##########
File path: rust/datafusion/src/optimizer/utils.rs
##########
@@ -416,6 +424,7 @@ pub fn rewrite_expression(expr: &Expr, expressions: 
&Vec<Expr>) -> Result<Expr>
                 Ok(expr)
             }
         }
+        Expr::InList { .. } => Ok(expr.clone()),

Review comment:
       likewise here, I think we might want to include the `list` -- even 
though at the moment it only contains constants, it is a `Vec<Expr>`

##########
File path: rust/datafusion/src/physical_plan/expressions.rs
##########
@@ -3769,4 +4002,166 @@ mod tests {
         let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
         Ok(batch)
     }
+
+    // applies the in_list expr to an input batch and list
+    macro_rules! in_list {
+        ($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr) => {{
+            let expr = in_list(col("a"), $LIST, $NEGATED).unwrap();
+            let result = expr.evaluate(&$BATCH)?.into_array($BATCH.num_rows());
+            let result = result
+                .as_any()
+                .downcast_ref::<BooleanArray>()
+                .expect("failed to downcast to BooleanArray");
+            let expected = &BooleanArray::from($EXPECTED);
+            assert_eq!(expected, result);
+        }};
+    }
+
+    #[test]
+    fn in_list_utf8() -> Result<()> {
+        let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
+        let a = StringArray::from(vec![Some("a"), Some("d"), None]);
+        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
+
+        // expression: "a in ("a", "b")"
+        let list = vec![
+            lit(ScalarValue::Utf8(Some("a".to_string()))),
+            lit(ScalarValue::Utf8(Some("b".to_string()))),
+        ];
+        in_list!(batch, list, &false, vec![Some(true), Some(false), None]);
+
+        // expression: "a not in ("a", "b")"
+        let list = vec![
+            lit(ScalarValue::Utf8(Some("a".to_string()))),
+            lit(ScalarValue::Utf8(Some("b".to_string()))),
+        ];
+        in_list!(batch, list, &true, vec![Some(false), Some(true), None]);
+
+        // expression: "a not in ("a", "b")"
+        let list = vec![
+            lit(ScalarValue::Utf8(Some("a".to_string()))),
+            lit(ScalarValue::Utf8(Some("b".to_string()))),
+            lit(ScalarValue::Utf8(None)),
+        ];
+        in_list!(batch, list, &false, vec![Some(true), None, None]);
+
+        // expression: "a not in ("a", "b")"
+        let list = vec![
+            lit(ScalarValue::Utf8(Some("a".to_string()))),
+            lit(ScalarValue::Utf8(Some("b".to_string()))),
+            lit(ScalarValue::Utf8(None)),
+        ];
+        in_list!(batch, list, &true, vec![Some(false), None, None]);
+
+        Ok(())
+    }
+
+    #[test]
+    fn in_list_int64() -> Result<()> {
+        let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]);
+        let a = Int64Array::from(vec![Some(0), Some(2), None]);
+        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
+
+        // expression: "a in (0, 1)"
+        let list = vec![
+            lit(ScalarValue::Int64(Some(0))),
+            lit(ScalarValue::Int64(Some(1))),
+        ];
+        in_list!(batch, list, &false, vec![Some(true), Some(false), None]);
+
+        // expression: "a not in (0, 1)"
+        let list = vec![
+            lit(ScalarValue::Int64(Some(0))),
+            lit(ScalarValue::Int64(Some(1))),
+        ];
+        in_list!(batch, list, &true, vec![Some(false), Some(true), None]);
+
+        // expression: "a in (0, 1, NULL)"
+        let list = vec![
+            lit(ScalarValue::Int64(Some(0))),
+            lit(ScalarValue::Int64(Some(1))),
+            lit(ScalarValue::Utf8(None)),
+        ];
+        in_list!(batch, list, &false, vec![Some(true), None, None]);
+
+        // expression: "a not in (0, 1, NULL)"
+        let list = vec![
+            lit(ScalarValue::Int64(Some(0))),
+            lit(ScalarValue::Int64(Some(1))),
+            lit(ScalarValue::Utf8(None)),
+        ];
+        in_list!(batch, list, &true, vec![Some(false), None, None]);
+
+        Ok(())
+    }
+
+    #[test]
+    fn in_list_float64() -> Result<()> {
+        let schema = Schema::new(vec![Field::new("a", DataType::Float64, 
true)]);
+        let a = Float64Array::from(vec![Some(0.0), Some(0.2), None]);
+        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
+
+        // expression: "a in (0.0, 0.2)"
+        let list = vec![
+            lit(ScalarValue::Float64(Some(0.0))),
+            lit(ScalarValue::Float64(Some(0.1))),
+        ];
+        in_list!(batch, list, &false, vec![Some(true), Some(false), None]);
+
+        // expression: "a not in (0.0, 0.2)"
+        let list = vec![
+            lit(ScalarValue::Float64(Some(0.0))),
+            lit(ScalarValue::Float64(Some(0.1))),
+        ];
+        in_list!(batch, list, &true, vec![Some(false), Some(true), None]);
+
+        // expression: "a in (0.0, 0.2, NULL)"
+        let list = vec![
+            lit(ScalarValue::Float64(Some(0.0))),
+            lit(ScalarValue::Float64(Some(0.1))),
+            lit(ScalarValue::Utf8(None)),
+        ];
+        in_list!(batch, list, &false, vec![Some(true), None, None]);
+
+        // expression: "a not in (0.0, 0.2, NULL)"
+        let list = vec![
+            lit(ScalarValue::Float64(Some(0.0))),
+            lit(ScalarValue::Float64(Some(0.1))),
+            lit(ScalarValue::Utf8(None)),

Review comment:
       I don't think it hurts, but given the coercion logic you added in the 
planner, I think the literals at this point should all be the same type as the 
expr value. In other words, can you really see `a NOT IN (NULL::Utf8)`?

##########
File path: rust/datafusion/src/optimizer/utils.rs
##########
@@ -305,6 +312,7 @@ pub fn expr_sub_expressions(expr: &Expr) -> 
Result<Vec<Expr>> {
             low.as_ref().to_owned(),
             high.as_ref().to_owned(),
         ]),
+        Expr::InList { expr, .. } => Ok(vec![expr.as_ref().to_owned()]),

Review comment:
       Shouldn't this also include the exprs in `list` as well ? 

##########
File path: rust/datafusion/src/physical_plan/planner.rs
##########
@@ -882,6 +937,53 @@ mod tests {
         Ok(())
     }
 
+    #[test]
+    fn in_list_types() -> Result<()> {
+        let testdata = arrow::util::test_util::arrow_test_data();
+        let path = format!("{}/csv/aggregate_test_100.csv", testdata);
+        let options = CsvReadOptions::new().schema_infer_max_records(100);
+
+        // expression: "a in ('a', 1)"
+        let list = vec![
+            Expr::Literal(ScalarValue::Utf8(Some("a".to_string()))),
+            Expr::Literal(ScalarValue::Int64(Some(1))),
+        ];
+        let logical_plan = LogicalPlanBuilder::scan_csv(&path, options, None)?
+            // filter clause needs the type coercion rule applied
+            .filter(col("c12").lt(lit(0.05)))?
+            .project(vec![col("c1").in_list(list, false)])?
+            .build()?;
+        let execution_plan = plan(&logical_plan)?;
+        // verify that the plan correctly adds cast from Int64(1) to Utf8

Review comment:
       👍 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to