alamb commented on a change in pull request #1135:
URL: https://github.com/apache/arrow-datafusion/pull/1135#discussion_r742801716



##########
File path: datafusion/src/dataframe.rs
##########
@@ -211,6 +211,7 @@ pub trait DataFrame: Send + Sync {
         join_type: JoinType,
         left_cols: &[&str],
         right_cols: &[&str],
+        null_equal_safe: bool,

Review comment:
       I think we should be documenting what `null_equal_safe` means in this 
public API

##########
File path: datafusion/src/logical_plan/builder.rs
##########
@@ -488,6 +488,7 @@ impl LogicalPlanBuilder {
         right: &LogicalPlan,

Review comment:
       Likewise I think we should add a doc comment about what 
`null_equal_safe` means in this context. 
   
   Given how specialized the use case of `null_equal_safe` is, what do you 
think about adding a specialized API for it. This would hide some of this 
complexity from most users of DataFusion.
   
   So something like
   
   ```rust
       /// Apply a join with on constraint and specified null equality. 
       /// If null_equal_safe is true then ...
       pub fn join_detailed(
           &self,
           right: &LogicalPlan,
           join_type: JoinType,
           join_keys: (Vec<impl Into<Column>>, Vec<impl Into<Column>>,
           null_equal_safe: bool) {
     ...
       } 
   
       /// Apply a join with on constraint
       pub fn join(
           &self,
           right: &LogicalPlan,
           join_type: JoinType,
           join_keys: (Vec<impl Into<Column>>, Vec<impl Into<Column>>) {
     self.join_detailed(right, join_type, join_keys, false)
       }
   ```
   

##########
File path: datafusion/src/logical_plan/plan.rs
##########
@@ -135,6 +135,8 @@ pub enum LogicalPlan {
         join_constraint: JoinConstraint,
         /// The output schema, containing fields from the left and right inputs
         schema: DFSchemaRef,
+        /// If null_equal_safe is true, null == null.
+        null_equal_safe: bool,

Review comment:
       I suggest naming this field (and all other instances of 
`null_equal_safe`) to `null_equal_null` to make its meaning clearer

##########
File path: datafusion/tests/sql.rs
##########
@@ -5385,3 +5385,70 @@ async fn query_nested_get_indexed_field() -> Result<()> {
     assert_eq!(expected, actual);
     Ok(())
 }
+
+#[tokio::test]
+async fn intersect_with_null_not_equal() {
+    let sql = "SELECT * FROM (SELECT null AS id1, 1 AS id2) t1
+            INTERSECT SELECT * FROM (SELECT null AS id1, 2 AS id2) t2";
+
+    let expected: &[&[&str]] = &[];
+
+    let mut ctx = create_join_context_qualified().unwrap();
+    let actual = execute(&mut ctx, sql).await;
+
+    // left and right shouldn't match anything
+    assert_eq!(expected, actual);
+}
+
+#[tokio::test]
+async fn intersect_with_null_equal() {
+    let sql = "SELECT * FROM (SELECT null AS id1, 1 AS id2) t1
+            INTERSECT SELECT * FROM (SELECT null AS id1, 1 AS id2) t2";
+
+    let expected: Vec<Vec<String>> = vec![vec!["NULL".to_string(), 
"1".to_string()]];
+
+    let mut ctx = create_join_context_qualified().unwrap();
+    let actual = execute(&mut ctx, sql).await;
+
+    // left and right shouldn't match anything

Review comment:
       This comment appears to be incorrect

##########
File path: datafusion/src/physical_plan/hash_join.rs
##########
@@ -751,12 +788,19 @@ fn build_join_indexes(
 }
 
 macro_rules! equal_rows_elem {
-    ($array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident) => 
{{
+    ($array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident, 
$null_equal_safe: ident) => {{
         let left_array = $l.as_any().downcast_ref::<$array_type>().unwrap();
         let right_array = $r.as_any().downcast_ref::<$array_type>().unwrap();
 
         match (left_array.is_null($left), right_array.is_null($right)) {
             (false, false) => left_array.value($left) == 
right_array.value($right),
+            (true, true) => {

Review comment:
       @Dandandan  what do you think about the potential performance impact of 
doing this check for each input element?
   
   I think it is fine but maybe there are some benchmarks we could run to find 
out

##########
File path: datafusion/src/sql/planner.rs
##########
@@ -3542,11 +3554,11 @@ mod tests {
     }
 
     #[test]
-    fn only_union_all_supported() {
+    fn except_not_supported() {
         let sql = "SELECT order_id from orders EXCEPT SELECT order_id FROM 
orders";
         let err = logical_plan(sql).expect_err("query should have failed");
         assert_eq!(
-            "NotImplemented(\"Only UNION ALL and UNION [DISTINCT] are 
supported, found EXCEPT\")",
+            "NotImplemented(\"Only UNION ALL and UNION [DISTINCT] and 
INTERSECT and INTERSECT [DISTINCT] are supported, found EXCEPT\")",

Review comment:
       ❤️ 

##########
File path: datafusion/src/sql/planner.rs
##########
@@ -191,23 +191,32 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                 left,
                 right,
                 all,
-            } => match (op, all) {
+            } => {
+                let left_plan = self.set_expr_to_plan(left.as_ref(), None, 
ctes)?;
+                let right_plan = self.set_expr_to_plan(right.as_ref(), None, 
ctes)?;
+                match (op, all) {
                 (SetOperator::Union, true) => {
-                    let left_plan = self.set_expr_to_plan(left.as_ref(), None, 
ctes)?;
-                    let right_plan = self.set_expr_to_plan(right.as_ref(), 
None, ctes)?;
                     union_with_alias(left_plan, right_plan, alias)
                 }
                 (SetOperator::Union, false) => {
-                    let left_plan = self.set_expr_to_plan(left.as_ref(), None, 
ctes)?;
-                    let right_plan = self.set_expr_to_plan(right.as_ref(), 
None, ctes)?;
                     let union_plan = union_with_alias(left_plan, right_plan, 
alias)?;
                     LogicalPlanBuilder::from(union_plan).distinct()?.build()
                 }
+                (SetOperator::Intersect, true) => {
+                    let join_keys = 
left_plan.schema().fields().iter().zip(right_plan.schema().fields().iter()).map(|(left_field,
 right_field)| ((Column::from_name(left_field.name())), 
(Column::from_name(right_field.name())))).unzip();
+                    LogicalPlanBuilder::from(left_plan).join(&right_plan, 
JoinType::Semi, join_keys, true)?.build()
+                }
+                (SetOperator::Intersect, false) => {
+                    let distinct_left_plan = 
LogicalPlanBuilder::from(left_plan).distinct()?.build()?;
+                    let join_keys = 
distinct_left_plan.schema().fields().iter().zip(right_plan.schema().fields().iter()).map(|(left_field,
 right_field)| ((Column::from_name(left_field.name())), 
(Column::from_name(right_field.name())))).unzip();
+                    
LogicalPlanBuilder::from(distinct_left_plan).join(&right_plan, JoinType::Semi, 
join_keys, true)?.build()

Review comment:
       This is just a cleanliness suggestion, definitely not needed: 
   ```suggestion
                       let join_keys = distinct_left_plan
                           .schema()
                           .fields()
                           .iter()
                           .zip(right_plan.schema().fields().iter())
                           .map(|(left_field, right_field)| {
                               ((Column::from_name(left_field.name())), 
                                (Column::from_name(right_field.name())))
                            })
                            .unzip();
                       LogicalPlanBuilder::from(left_plan)
                           .distinct()?
                           .build()?
                           .join(&right_plan, JoinType::Semi, join_keys, true)?
                           .build()
   ```

##########
File path: datafusion/src/physical_plan/hash_join.rs
##########
@@ -499,6 +510,8 @@ struct HashJoinStream {
     join_metrics: HashJoinMetrics,
     /// Information of index and left / right placement of columns
     column_indices: Vec<ColumnIndex>,
+    ///

Review comment:
       seems like a missing doc comment

##########
File path: datafusion/tests/sql.rs
##########
@@ -5385,3 +5385,70 @@ async fn query_nested_get_indexed_field() -> Result<()> {
     assert_eq!(expected, actual);
     Ok(())
 }
+
+#[tokio::test]
+async fn intersect_with_null_not_equal() {
+    let sql = "SELECT * FROM (SELECT null AS id1, 1 AS id2) t1
+            INTERSECT SELECT * FROM (SELECT null AS id1, 2 AS id2) t2";
+
+    let expected: &[&[&str]] = &[];
+
+    let mut ctx = create_join_context_qualified().unwrap();
+    let actual = execute(&mut ctx, sql).await;

Review comment:
       is there a reason you didn't use `execute_to_batches!` and 
`assert_batches_eq!` in the first two tests and then did not use them in the 
subsequent tests? In other words, why are the tests inconsistent?

##########
File path: datafusion/tests/sql.rs
##########
@@ -5385,3 +5385,70 @@ async fn query_nested_get_indexed_field() -> Result<()> {
     assert_eq!(expected, actual);
     Ok(())
 }
+
+#[tokio::test]
+async fn intersect_with_null_not_equal() {
+    let sql = "SELECT * FROM (SELECT null AS id1, 1 AS id2) t1
+            INTERSECT SELECT * FROM (SELECT null AS id1, 2 AS id2) t2";
+
+    let expected: &[&[&str]] = &[];
+
+    let mut ctx = create_join_context_qualified().unwrap();
+    let actual = execute(&mut ctx, sql).await;
+
+    // left and right shouldn't match anything
+    assert_eq!(expected, actual);
+}
+
+#[tokio::test]
+async fn intersect_with_null_equal() {
+    let sql = "SELECT * FROM (SELECT null AS id1, 1 AS id2) t1
+            INTERSECT SELECT * FROM (SELECT null AS id1, 1 AS id2) t2";
+
+    let expected: Vec<Vec<String>> = vec![vec!["NULL".to_string(), 
"1".to_string()]];
+
+    let mut ctx = create_join_context_qualified().unwrap();
+    let actual = execute(&mut ctx, sql).await;
+
+    // left and right shouldn't match anything
+    assert_eq!(expected, actual);
+}
+
+#[tokio::test]
+async fn test_intersect_all() -> Result<()> {
+    let mut ctx = ExecutionContext::new();
+    register_alltypes_parquet(&mut ctx).await;
+    // execute the query
+    let sql = "SELECT int_col, double_col FROM alltypes_plain where int_col > 
0 INTERSECT ALL SELECT int_col, double_col FROM alltypes_plain LIMIT 4";
+    let actual = execute_to_batches(&mut ctx, sql).await;
+    let expected = vec![
+        "+---------+------------+",
+        "| int_col | double_col |",
+        "+---------+------------+",
+        "| 1       | 10.1       |",
+        "| 1       | 10.1       |",
+        "| 1       | 10.1       |",
+        "| 1       | 10.1       |",
+        "+---------+------------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+    Ok(())
+}
+
+#[tokio::test]
+async fn test_intersect_distinct() -> Result<()> {
+    let mut ctx = ExecutionContext::new();
+    register_alltypes_parquet(&mut ctx).await;
+    // execute the query

Review comment:
       👍 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to