alamb commented on a change in pull request #9038: URL: https://github.com/apache/arrow/pull/9038#discussion_r552193411
########## File path: rust/datafusion/tests/sql.rs ########## @@ -1849,3 +1849,89 @@ async fn string_expressions() -> Result<()> { assert_eq!(expected, actual); Ok(()) } + +#[tokio::test] +async fn in_list_array() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv_by_sql(&mut ctx).await; + let sql = "SELECT + c1 IN ('a', 'c') AS utf8_in_true + ,c1 IN ('x', 'y') AS utf8_in_false + ,c1 NOT IN ('x', 'y') AS utf8_not_in_true + ,c1 NOT IN ('a', 'c') AS utf8_not_in_false + ,CAST(CAST(c1 AS int) AS varchar) IN ('a', 'c') AS utf8_in_null + FROM aggregate_test_100 WHERE c12 < 0.05"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![ + vec!["true", "false", "true", "false", "NULL"], + vec!["true", "false", "true", "false", "NULL"], + vec!["true", "false", "true", "false", "NULL"], + vec!["false", "false", "true", "true", "NULL"], + vec!["false", "false", "true", "true", "NULL"], + vec!["false", "false", "true", "true", "NULL"], + vec!["false", "false", "true", "true", "NULL"], + ]; + assert_eq!(expected, actual); + Ok(()) +} + +#[tokio::test] +async fn in_list_scalar() -> Result<()> { Review comment: ❤️ ########## File path: rust/datafusion/src/optimizer/utils.rs ########## @@ -416,6 +424,7 @@ pub fn rewrite_expression(expr: &Expr, expressions: &Vec<Expr>) -> Result<Expr> Ok(expr) } } + Expr::InList { .. } => Ok(expr.clone()), Review comment: likewise here, I think we might want to include the `list` -- even though at the moment it only contains constants, it is a `Vec<Expr>` ########## File path: rust/datafusion/src/physical_plan/expressions.rs ########## @@ -3769,4 +4002,166 @@ mod tests { let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; Ok(batch) } + + // applies the in_list expr to an input batch and list + macro_rules! in_list { + ($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr) => {{ + let expr = in_list(col("a"), $LIST, $NEGATED).unwrap(); + let result = expr.evaluate(&$BATCH)?.into_array($BATCH.num_rows()); + let result = result + .as_any() + .downcast_ref::<BooleanArray>() + .expect("failed to downcast to BooleanArray"); + let expected = &BooleanArray::from($EXPECTED); + assert_eq!(expected, result); + }}; + } + + #[test] + fn in_list_utf8() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); + let a = StringArray::from(vec![Some("a"), Some("d"), None]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; + + // expression: "a in ("a", "b")" + let list = vec![ + lit(ScalarValue::Utf8(Some("a".to_string()))), + lit(ScalarValue::Utf8(Some("b".to_string()))), + ]; + in_list!(batch, list, &false, vec![Some(true), Some(false), None]); + + // expression: "a not in ("a", "b")" + let list = vec![ + lit(ScalarValue::Utf8(Some("a".to_string()))), + lit(ScalarValue::Utf8(Some("b".to_string()))), + ]; + in_list!(batch, list, &true, vec![Some(false), Some(true), None]); + + // expression: "a not in ("a", "b")" + let list = vec![ + lit(ScalarValue::Utf8(Some("a".to_string()))), + lit(ScalarValue::Utf8(Some("b".to_string()))), + lit(ScalarValue::Utf8(None)), + ]; + in_list!(batch, list, &false, vec![Some(true), None, None]); + + // expression: "a not in ("a", "b")" + let list = vec![ + lit(ScalarValue::Utf8(Some("a".to_string()))), + lit(ScalarValue::Utf8(Some("b".to_string()))), + lit(ScalarValue::Utf8(None)), + ]; + in_list!(batch, list, &true, vec![Some(false), None, None]); + + Ok(()) + } + + #[test] + fn in_list_int64() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]); + let a = Int64Array::from(vec![Some(0), Some(2), None]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; + + // expression: "a in (0, 1)" + let list = vec![ + lit(ScalarValue::Int64(Some(0))), + lit(ScalarValue::Int64(Some(1))), + ]; + in_list!(batch, list, &false, vec![Some(true), Some(false), None]); + + // expression: "a not in (0, 1)" + let list = vec![ + lit(ScalarValue::Int64(Some(0))), + lit(ScalarValue::Int64(Some(1))), + ]; + in_list!(batch, list, &true, vec![Some(false), Some(true), None]); + + // expression: "a in (0, 1, NULL)" + let list = vec![ + lit(ScalarValue::Int64(Some(0))), + lit(ScalarValue::Int64(Some(1))), + lit(ScalarValue::Utf8(None)), + ]; + in_list!(batch, list, &false, vec![Some(true), None, None]); + + // expression: "a not in (0, 1, NULL)" + let list = vec![ + lit(ScalarValue::Int64(Some(0))), + lit(ScalarValue::Int64(Some(1))), + lit(ScalarValue::Utf8(None)), + ]; + in_list!(batch, list, &true, vec![Some(false), None, None]); + + Ok(()) + } + + #[test] + fn in_list_float64() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]); + let a = Float64Array::from(vec![Some(0.0), Some(0.2), None]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; + + // expression: "a in (0.0, 0.2)" + let list = vec![ + lit(ScalarValue::Float64(Some(0.0))), + lit(ScalarValue::Float64(Some(0.1))), + ]; + in_list!(batch, list, &false, vec![Some(true), Some(false), None]); + + // expression: "a not in (0.0, 0.2)" + let list = vec![ + lit(ScalarValue::Float64(Some(0.0))), + lit(ScalarValue::Float64(Some(0.1))), + ]; + in_list!(batch, list, &true, vec![Some(false), Some(true), None]); + + // expression: "a in (0.0, 0.2, NULL)" + let list = vec![ + lit(ScalarValue::Float64(Some(0.0))), + lit(ScalarValue::Float64(Some(0.1))), + lit(ScalarValue::Utf8(None)), + ]; + in_list!(batch, list, &false, vec![Some(true), None, None]); + + // expression: "a not in (0.0, 0.2, NULL)" + let list = vec![ + lit(ScalarValue::Float64(Some(0.0))), + lit(ScalarValue::Float64(Some(0.1))), + lit(ScalarValue::Utf8(None)), Review comment: I don't think it hurts, but given the coercion logic you added in the planner, I think the literals at this point should all be the same type as the expr value. In other words, can you really see `a NOT IN (NULL::Utf8)`? ########## File path: rust/datafusion/src/optimizer/utils.rs ########## @@ -305,6 +312,7 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result<Vec<Expr>> { low.as_ref().to_owned(), high.as_ref().to_owned(), ]), + Expr::InList { expr, .. } => Ok(vec![expr.as_ref().to_owned()]), Review comment: Shouldn't this also include the exprs in `list` as well ? ########## File path: rust/datafusion/src/physical_plan/planner.rs ########## @@ -882,6 +937,53 @@ mod tests { Ok(()) } + #[test] + fn in_list_types() -> Result<()> { + let testdata = arrow::util::test_util::arrow_test_data(); + let path = format!("{}/csv/aggregate_test_100.csv", testdata); + let options = CsvReadOptions::new().schema_infer_max_records(100); + + // expression: "a in ('a', 1)" + let list = vec![ + Expr::Literal(ScalarValue::Utf8(Some("a".to_string()))), + Expr::Literal(ScalarValue::Int64(Some(1))), + ]; + let logical_plan = LogicalPlanBuilder::scan_csv(&path, options, None)? + // filter clause needs the type coercion rule applied + .filter(col("c12").lt(lit(0.05)))? + .project(vec![col("c1").in_list(list, false)])? + .build()?; + let execution_plan = plan(&logical_plan)?; + // verify that the plan correctly adds cast from Int64(1) to Utf8 Review comment: 👍 ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org