Repository: spark Updated Branches: refs/heads/master 6690924c4 -> 65a4bc143
[SPARK-21274][SQL] Implement INTERSECT ALL clause ## What changes were proposed in this pull request? Implements INTERSECT ALL clause through query rewrites using existing operators in Spark. Please refer to [Link](https://drive.google.com/open?id=1nyW0T0b_ajUduQoPgZLAsyHK8s3_dko3ulQuxaLpUXE) for the design. Input Query ``` SQL SELECT c1 FROM ut1 INTERSECT ALL SELECT c1 FROM ut2 ``` Rewritten Query ```SQL SELECT c1 FROM ( SELECT replicate_row(min_count, c1) FROM ( SELECT c1, IF (vcol1_cnt > vcol2_cnt, vcol2_cnt, vcol1_cnt) AS min_count FROM ( SELECT c1, count(vcol1) as vcol1_cnt, count(vcol2) as vcol2_cnt FROM ( SELECT c1, true as vcol1, null as vcol2 FROM ut1 UNION ALL SELECT c1, null as vcol1, true as vcol2 FROM ut2 ) AS union_all GROUP BY c1 HAVING vcol1_cnt >= 1 AND vcol2_cnt >= 1 ) ) ) ``` ## How was this patch tested? Added test cases in SQLQueryTestSuite, DataFrameSuite, SetOperationSuite Author: Dilip Biswal <dbis...@us.ibm.com> Closes #21886 from dilipbiswal/dkb_intersect_all_final. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/65a4bc14 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/65a4bc14 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/65a4bc14 Branch: refs/heads/master Commit: 65a4bc143ab5dc2ced589dc107bbafa8a7290931 Parents: 6690924 Author: Dilip Biswal <dbis...@us.ibm.com> Authored: Sun Jul 29 22:11:01 2018 -0700 Committer: Xiao Li <gatorsm...@gmail.com> Committed: Sun Jul 29 22:11:01 2018 -0700 ---------------------------------------------------------------------- python/pyspark/sql/dataframe.py | 22 ++ .../spark/sql/catalyst/analysis/Analyzer.scala | 2 +- .../sql/catalyst/analysis/TypeCoercion.scala | 4 +- .../analysis/UnsupportedOperationChecker.scala | 2 +- .../sql/catalyst/optimizer/Optimizer.scala | 81 ++++++- .../spark/sql/catalyst/parser/AstBuilder.scala | 2 +- .../plans/logical/basicLogicalOperators.scala | 7 +- .../catalyst/optimizer/SetOperationSuite.scala | 32 ++- .../sql/catalyst/parser/PlanParserSuite.scala | 1 - .../scala/org/apache/spark/sql/Dataset.scala | 19 +- .../spark/sql/execution/SparkStrategies.scala | 8 +- .../sql-tests/inputs/intersect-all.sql | 123 ++++++++++ .../sql-tests/results/intersect-all.sql.out | 241 +++++++++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 54 +++++ .../org/apache/spark/sql/test/SQLTestData.scala | 13 + 15 files changed, 599 insertions(+), 12 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/65a4bc14/python/pyspark/sql/dataframe.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index b2e0a5b..07fb260 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1500,6 +1500,28 @@ class DataFrame(object): """ return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx) + @since(2.4) + def intersectAll(self, other): + """ Return a new :class:`DataFrame` containing rows in both this dataframe and other + dataframe while preserving duplicates. + + This is equivalent to `INTERSECT ALL` in SQL. + >>> df1 = spark.createDataFrame([("a", 1), ("a", 1), ("b", 3), ("c", 4)], ["C1", "C2"]) + >>> df2 = spark.createDataFrame([("a", 1), ("a", 1), ("b", 3)], ["C1", "C2"]) + + >>> df1.intersectAll(df2).sort("C1", "C2").show() + +---+---+ + | C1| C2| + +---+---+ + | a| 1| + | a| 1| + | b| 3| + +---+---+ + + Also as standard in SQL, this function resolves columns by position (not by name). + """ + return DataFrame(self._jdf.intersectAll(other._jdf), self.sql_ctx) + @since(1.3) def subtract(self, other): """ Return a new :class:`DataFrame` containing rows in this frame http://git-wip-us.apache.org/repos/asf/spark/blob/65a4bc14/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8abb1c7..9965cd6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -914,7 +914,7 @@ class Analyzer( // To resolve duplicate expression IDs for Join and Intersect case j @ Join(left, right, _, _) if !j.duplicateResolved => j.copy(right = dedupRight(left, right)) - case i @ Intersect(left, right) if !i.duplicateResolved => + case i @ Intersect(left, right, _) if !i.duplicateResolved => i.copy(right = dedupRight(left, right)) case e @ Except(left, right, _) if !e.duplicateResolved => e.copy(right = dedupRight(left, right)) http://git-wip-us.apache.org/repos/asf/spark/blob/65a4bc14/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index f9edca5..7dd26b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -325,11 +325,11 @@ object TypeCoercion { assert(newChildren.length == 2) Except(newChildren.head, newChildren.last, isAll) - case s @ Intersect(left, right) if s.childrenResolved && + case s @ Intersect(left, right, isAll) if s.childrenResolved && left.output.length == right.output.length && !s.resolved => val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil) assert(newChildren.length == 2) - Intersect(newChildren.head, newChildren.last) + Intersect(newChildren.head, newChildren.last, isAll) case s: Union if s.childrenResolved && s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved => http://git-wip-us.apache.org/repos/asf/spark/blob/65a4bc14/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index c9a3ee4..cff4cee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -309,7 +309,7 @@ object UnsupportedOperationChecker { case Except(left, right, _) if right.isStreaming => throwError("Except on a streaming DataFrame/Dataset on the right is not supported") - case Intersect(left, right) if left.isStreaming && right.isStreaming => + case Intersect(left, right, _) if left.isStreaming && right.isStreaming => throwError("Intersect between two streaming DataFrames/Datasets is not supported") case GroupingSets(_, _, child, _) if child.isStreaming => http://git-wip-us.apache.org/repos/asf/spark/blob/65a4bc14/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 193f659..105623c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -136,6 +136,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) OptimizeSubqueries) :: Batch("Replace Operators", fixedPoint, RewriteExcepAll, + RewriteIntersectAll, ReplaceIntersectWithSemiJoin, ReplaceExceptWithFilter, ReplaceExceptWithAntiJoin, @@ -1402,7 +1403,7 @@ object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] { */ object ReplaceIntersectWithSemiJoin extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case Intersect(left, right) => + case Intersect(left, right, false) => assert(left.output.size == right.output.size) val joinCond = left.output.zip(right.output).map { case (l, r) => EqualNullSafe(l, r) } Distinct(Join(left, right, LeftSemi, joinCond.reduceLeftOption(And))) @@ -1489,6 +1490,84 @@ object RewriteExcepAll extends Rule[LogicalPlan] { } /** + * Replaces logical [[Intersect]] operator using a combination of Union, Aggregate + * and Generate operator. + * + * Input Query : + * {{{ + * SELECT c1 FROM ut1 INTERSECT ALL SELECT c1 FROM ut2 + * }}} + * + * Rewritten Query: + * {{{ + * SELECT c1 + * FROM ( + * SELECT replicate_row(min_count, c1) + * FROM ( + * SELECT c1, If (vcol1_cnt > vcol2_cnt, vcol2_cnt, vcol1_cnt) AS min_count + * FROM ( + * SELECT c1, count(vcol1) as vcol1_cnt, count(vcol2) as vcol2_cnt + * FROM ( + * SELECT true as vcol1, null as , c1 FROM ut1 + * UNION ALL + * SELECT null as vcol1, true as vcol2, c1 FROM ut2 + * ) AS union_all + * GROUP BY c1 + * HAVING vcol1_cnt >= 1 AND vcol2_cnt >= 1 + * ) + * ) + * ) + * }}} + */ +object RewriteIntersectAll extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Intersect(left, right, true) => + assert(left.output.size == right.output.size) + + val trueVcol1 = Alias(Literal(true), "vcol1")() + val nullVcol1 = Alias(Literal(null, BooleanType), "vcol1")() + + val trueVcol2 = Alias(Literal(true), "vcol2")() + val nullVcol2 = Alias(Literal(null, BooleanType), "vcol2")() + + // Add a projection on the top of left and right plans to project out + // the additional virtual columns. + val leftPlanWithAddedVirtualCols = Project(Seq(trueVcol1, nullVcol2) ++ left.output, left) + val rightPlanWithAddedVirtualCols = Project(Seq(nullVcol1, trueVcol2) ++ right.output, right) + + val unionPlan = Union(leftPlanWithAddedVirtualCols, rightPlanWithAddedVirtualCols) + + // Expressions to compute count and minimum of both the counts. + val vCol1AggrExpr = + Alias(AggregateExpression(Count(unionPlan.output(0)), Complete, false), "vcol1_count")() + val vCol2AggrExpr = + Alias(AggregateExpression(Count(unionPlan.output(1)), Complete, false), "vcol2_count")() + val ifExpression = Alias(If( + GreaterThan(vCol1AggrExpr.toAttribute, vCol2AggrExpr.toAttribute), + vCol2AggrExpr.toAttribute, + vCol1AggrExpr.toAttribute + ), "min_count")() + + val aggregatePlan = Aggregate(left.output, + Seq(vCol1AggrExpr, vCol2AggrExpr) ++ left.output, unionPlan) + val filterPlan = Filter(And(GreaterThanOrEqual(vCol1AggrExpr.toAttribute, Literal(1L)), + GreaterThanOrEqual(vCol2AggrExpr.toAttribute, Literal(1L))), aggregatePlan) + val projectMinPlan = Project(left.output ++ Seq(ifExpression), filterPlan) + + // Apply the replicator to replicate rows based on min_count + val genRowPlan = Generate( + ReplicateRows(Seq(ifExpression.toAttribute) ++ left.output), + unrequiredChildIndex = Nil, + outer = false, + qualifier = None, + left.output, + projectMinPlan + ) + Project(left.output, genRowPlan) + } +} + +/** * Removes literals from group expressions in [[Aggregate]], as they have no effect to the result * but only makes the grouping key bigger. */ http://git-wip-us.apache.org/repos/asf/spark/blob/65a4bc14/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 8b3c068..8a8db6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -533,7 +533,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case SqlBaseParser.UNION => Distinct(Union(left, right)) case SqlBaseParser.INTERSECT if all => - throw new ParseException("INTERSECT ALL is not supported.", ctx) + Intersect(left, right, isAll = true) case SqlBaseParser.INTERSECT => Intersect(left, right) case SqlBaseParser.EXCEPT if all => http://git-wip-us.apache.org/repos/asf/spark/blob/65a4bc14/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 498a13a..13b5130 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -164,7 +164,12 @@ object SetOperation { def unapply(p: SetOperation): Option[(LogicalPlan, LogicalPlan)] = Some((p.left, p.right)) } -case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { +case class Intersect( + left: LogicalPlan, + right: LogicalPlan, + isAll: Boolean = false) extends SetOperation(left, right) { + + override def nodeName: String = getClass.getSimpleName + ( if ( isAll ) "All" else "" ) override def output: Seq[Attribute] = left.output.zip(right.output).map { case (leftAttr, rightAttr) => http://git-wip-us.apache.org/repos/asf/spark/blob/65a4bc14/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala index f002aa3..cb744be 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala @@ -20,10 +20,11 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Alias, GreaterThan, Literal, ReplicateRows} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, GreaterThan, GreaterThanOrEqual, If, Literal, ReplicateRows} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.types.BooleanType class SetOperationSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { @@ -166,4 +167,33 @@ class SetOperationSuite extends PlanTest { )) comparePlans(expectedPlan, rewrittenPlan) } + + test("INTERSECT ALL rewrite") { + val input = Intersect(testRelation, testRelation2, isAll = true) + val rewrittenPlan = RewriteIntersectAll(input) + val leftRelation = testRelation + .select(Literal(true).as("vcol1"), Literal(null, BooleanType).as("vcol2"), 'a, 'b, 'c) + val rightRelation = testRelation2 + .select(Literal(null, BooleanType).as("vcol1"), Literal(true).as("vcol2"), 'd, 'e, 'f) + val planFragment = leftRelation.union(rightRelation) + .groupBy('a, 'b, 'c)(count('vcol1).as("vcol1_count"), + count('vcol2).as("vcol2_count"), 'a, 'b, 'c) + .where(And(GreaterThanOrEqual('vcol1_count, Literal(1L)), + GreaterThanOrEqual('vcol2_count, Literal(1L)))) + .select('a, 'b, 'c, + If(GreaterThan('vcol1_count, 'vcol2_count), 'vcol2_count, 'vcol1_count).as("min_count")) + .analyze + val multiplerAttr = planFragment.output.last + val output = planFragment.output.dropRight(1) + val expectedPlan = Project(output, + Generate( + ReplicateRows(Seq(multiplerAttr) ++ output), + Nil, + false, + None, + output, + planFragment + )) + comparePlans(expectedPlan, rewrittenPlan) + } } http://git-wip-us.apache.org/repos/asf/spark/blob/65a4bc14/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 629e3c4..9be0ec5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -70,7 +70,6 @@ class PlanParserSuite extends AnalysisTest { intercept("select * from a minus all select * from b", "MINUS ALL is not supported.") assertEqual("select * from a minus distinct select * from b", a.except(b)) assertEqual("select * from a intersect select * from b", a.intersect(b)) - intercept("select * from a intersect all select * from b", "INTERSECT ALL is not supported.") assertEqual("select * from a intersect distinct select * from b", a.intersect(b)) } http://git-wip-us.apache.org/repos/asf/spark/blob/65a4bc14/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index e6a3b0a..d36c8d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1935,6 +1935,23 @@ class Dataset[T] private[sql]( } /** + * Returns a new Dataset containing rows only in both this Dataset and another Dataset while + * preserving the duplicates. + * This is equivalent to `INTERSECT ALL` in SQL. + * + * @note Equality checking is performed directly on the encoded representation of the data + * and thus is not affected by a custom `equals` function defined on `T`. Also as standard + * in SQL, this function resolves columns by position (not by name). + * + * @group typedrel + * @since 2.4.0 + */ + def intersectAll(other: Dataset[T]): Dataset[T] = withSetOperator { + Intersect(logicalPlan, other.logicalPlan, isAll = true) + } + + + /** * Returns a new Dataset containing rows in this Dataset but not in another Dataset. * This is equivalent to `EXCEPT DISTINCT` in SQL. * @@ -1961,7 +1978,7 @@ class Dataset[T] private[sql]( * @since 2.4.0 */ def exceptAll(other: Dataset[T]): Dataset[T] = withSetOperator { - Except(planWithBarrier, other.planWithBarrier, isAll = true) + Except(logicalPlan, other.logicalPlan, isAll = true) } /** http://git-wip-us.apache.org/repos/asf/spark/blob/65a4bc14/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 3f5fd3d..75eff8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -529,9 +529,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Distinct(child) => throw new IllegalStateException( "logical distinct operator should have been replaced by aggregate in the optimizer") - case logical.Intersect(left, right) => + case logical.Intersect(left, right, false) => throw new IllegalStateException( - "logical intersect operator should have been replaced by semi-join in the optimizer") + "logical intersect operator should have been replaced by semi-join in the optimizer") + case logical.Intersect(left, right, true) => + throw new IllegalStateException( + "logical intersect operator should have been replaced by union, aggregate" + + "and generate operators in the optimizer") case logical.Except(left, right, false) => throw new IllegalStateException( "logical except operator should have been replaced by anti-join in the optimizer") http://git-wip-us.apache.org/repos/asf/spark/blob/65a4bc14/sql/core/src/test/resources/sql-tests/inputs/intersect-all.sql ---------------------------------------------------------------------- diff --git a/sql/core/src/test/resources/sql-tests/inputs/intersect-all.sql b/sql/core/src/test/resources/sql-tests/inputs/intersect-all.sql new file mode 100644 index 0000000..ff4395c --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/intersect-all.sql @@ -0,0 +1,123 @@ +CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES + (1, 2), + (1, 2), + (1, 3), + (1, 3), + (2, 3), + (null, null), + (null, null) + AS tab1(k, v); +CREATE TEMPORARY VIEW tab2 AS SELECT * FROM VALUES + (1, 2), + (1, 2), + (2, 3), + (3, 4), + (null, null), + (null, null) + AS tab2(k, v); + +-- Basic INTERSECT ALL +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2; + +-- INTERSECT ALL same table in both branches +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab1 WHERE k = 1; + +-- Empty left relation +SELECT * FROM tab1 WHERE k > 2 +INTERSECT ALL +SELECT * FROM tab2; + +-- Empty right relation +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2 WHERE k > 3; + +-- Type Coerced INTERSECT ALL +SELECT * FROM tab1 +INTERSECT ALL +SELECT CAST(1 AS BIGINT), CAST(2 AS BIGINT); + +-- Error as types of two side are not compatible +SELECT * FROM tab1 +INTERSECT ALL +SELECT array(1), 2; + +-- Mismatch on number of columns across both branches +SELECT k FROM tab1 +INTERSECT ALL +SELECT k, v FROM tab2; + +-- Basic +SELECT * FROM tab2 +INTERSECT ALL +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2; + +-- Chain of different `set operations +-- We need to parenthesize the following two queries to enforce +-- certain order of evaluation of operators. After fix to +-- SPARK-24966 this can be removed. +SELECT * FROM tab1 +EXCEPT +SELECT * FROM tab2 +UNION ALL +( +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2 +); + +-- Chain of different `set operations +SELECT * FROM tab1 +EXCEPT +SELECT * FROM tab2 +EXCEPT +( +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2 +); + +-- Join under intersect all +SELECT * +FROM (SELECT tab1.k, + tab2.v + FROM tab1 + JOIN tab2 + ON tab1.k = tab2.k) +INTERSECT ALL +SELECT * +FROM (SELECT tab1.k, + tab2.v + FROM tab1 + JOIN tab2 + ON tab1.k = tab2.k); + +-- Join under intersect all (2) +SELECT * +FROM (SELECT tab1.k, + tab2.v + FROM tab1 + JOIN tab2 + ON tab1.k = tab2.k) +INTERSECT ALL +SELECT * +FROM (SELECT tab2.v AS k, + tab1.k AS v + FROM tab1 + JOIN tab2 + ON tab1.k = tab2.k); + +-- Group by under intersect all +SELECT v FROM tab1 GROUP BY v +INTERSECT ALL +SELECT k FROM tab2 GROUP BY k; + +-- Clean-up +DROP VIEW IF EXISTS tab1; +DROP VIEW IF EXISTS tab2; http://git-wip-us.apache.org/repos/asf/spark/blob/65a4bc14/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out ---------------------------------------------------------------------- diff --git a/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out b/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out new file mode 100644 index 0000000..792791b --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out @@ -0,0 +1,241 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 17 + + +-- !query 0 +CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES + (1, 2), + (1, 2), + (1, 3), + (1, 3), + (2, 3), + (null, null), + (null, null) + AS tab1(k, v) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE TEMPORARY VIEW tab2 AS SELECT * FROM VALUES + (1, 2), + (1, 2), + (2, 3), + (3, 4), + (null, null), + (null, null) + AS tab2(k, v) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2 +-- !query 2 schema +struct<k:int,v:int> +-- !query 2 output +1 2 +1 2 +2 3 +NULL NULL +NULL NULL + + +-- !query 3 +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab1 WHERE k = 1 +-- !query 3 schema +struct<k:int,v:int> +-- !query 3 output +1 2 +1 2 +1 3 +1 3 + + +-- !query 4 +SELECT * FROM tab1 WHERE k > 2 +INTERSECT ALL +SELECT * FROM tab2 +-- !query 4 schema +struct<k:int,v:int> +-- !query 4 output + + + +-- !query 5 +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2 WHERE k > 3 +-- !query 5 schema +struct<k:int,v:int> +-- !query 5 output + + + +-- !query 6 +SELECT * FROM tab1 +INTERSECT ALL +SELECT CAST(1 AS BIGINT), CAST(2 AS BIGINT) +-- !query 6 schema +struct<k:bigint,v:bigint> +-- !query 6 output +1 2 + + +-- !query 7 +SELECT * FROM tab1 +INTERSECT ALL +SELECT array(1), 2 +-- !query 7 schema +struct<> +-- !query 7 output +org.apache.spark.sql.AnalysisException +IntersectAll can only be performed on tables with the compatible column types. array<int> <> int at the first column of the second table; + + +-- !query 8 +SELECT k FROM tab1 +INTERSECT ALL +SELECT k, v FROM tab2 +-- !query 8 schema +struct<> +-- !query 8 output +org.apache.spark.sql.AnalysisException +IntersectAll can only be performed on tables with the same number of columns, but the first table has 1 columns and the second table has 2 columns; + + +-- !query 9 +SELECT * FROM tab2 +INTERSECT ALL +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2 +-- !query 9 schema +struct<k:int,v:int> +-- !query 9 output +1 2 +1 2 +2 3 +NULL NULL +NULL NULL + + +-- !query 10 +SELECT * FROM tab1 +EXCEPT +SELECT * FROM tab2 +UNION ALL +( +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2 +) +-- !query 10 schema +struct<k:int,v:int> +-- !query 10 output +1 2 +1 2 +1 3 +2 3 +NULL NULL +NULL NULL + + +-- !query 11 +SELECT * FROM tab1 +EXCEPT +SELECT * FROM tab2 +EXCEPT +( +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2 +) +-- !query 11 schema +struct<k:int,v:int> +-- !query 11 output +1 3 + + +-- !query 12 +SELECT * +FROM (SELECT tab1.k, + tab2.v + FROM tab1 + JOIN tab2 + ON tab1.k = tab2.k) +INTERSECT ALL +SELECT * +FROM (SELECT tab1.k, + tab2.v + FROM tab1 + JOIN tab2 + ON tab1.k = tab2.k) +-- !query 12 schema +struct<k:int,v:int> +-- !query 12 output +1 2 +1 2 +1 2 +1 2 +1 2 +1 2 +1 2 +1 2 +2 3 + + +-- !query 13 +SELECT * +FROM (SELECT tab1.k, + tab2.v + FROM tab1 + JOIN tab2 + ON tab1.k = tab2.k) +INTERSECT ALL +SELECT * +FROM (SELECT tab2.v AS k, + tab1.k AS v + FROM tab1 + JOIN tab2 + ON tab1.k = tab2.k) +-- !query 13 schema +struct<k:int,v:int> +-- !query 13 output + + + +-- !query 14 +SELECT v FROM tab1 GROUP BY v +INTERSECT ALL +SELECT k FROM tab2 GROUP BY k +-- !query 14 schema +struct<v:int> +-- !query 14 output +2 +3 +NULL + + +-- !query 15 +DROP VIEW IF EXISTS tab1 +-- !query 15 schema +struct<> +-- !query 15 output + + + +-- !query 16 +DROP VIEW IF EXISTS tab2 +-- !query 16 schema +struct<> +-- !query 16 output + http://git-wip-us.apache.org/repos/asf/spark/blob/65a4bc14/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index af07359..b0e22a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -749,6 +749,60 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df4.schema.forall(!_.nullable)) } + test("intersectAll") { + checkAnswer( + lowerCaseDataWithDuplicates.intersectAll(lowerCaseDataWithDuplicates), + Row(1, "a") :: + Row(2, "b") :: + Row(2, "b") :: + Row(3, "c") :: + Row(3, "c") :: + Row(3, "c") :: + Row(4, "d") :: Nil) + checkAnswer(lowerCaseData.intersectAll(upperCaseData), Nil) + + // check null equality + checkAnswer( + nullInts.intersectAll(nullInts), + Row(1) :: + Row(2) :: + Row(3) :: + Row(null) :: Nil) + + // Duplicate nulls are preserved. + checkAnswer( + allNulls.intersectAll(allNulls), + Row(null) :: Row(null) :: Row(null) :: Row(null) :: Nil) + + val df_left = Seq(1, 2, 2, 3, 3, 4).toDF("id") + val df_right = Seq(1, 2, 2, 3).toDF("id") + + checkAnswer( + df_left.intersectAll(df_right), + Row(1) :: Row(2) :: Row(2) :: Row(3) :: Nil) + } + + test("intersectAll - nullability") { + val nonNullableInts = Seq(Tuple1(1), Tuple1(3)).toDF() + assert(nonNullableInts.schema.forall(!_.nullable)) + + val df1 = nonNullableInts.intersectAll(nullInts) + checkAnswer(df1, Row(1) :: Row(3) :: Nil) + assert(df1.schema.forall(!_.nullable)) + + val df2 = nullInts.intersectAll(nonNullableInts) + checkAnswer(df2, Row(1) :: Row(3) :: Nil) + assert(df2.schema.forall(!_.nullable)) + + val df3 = nullInts.intersectAll(nullInts) + checkAnswer(df3, Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) + assert(df3.schema.forall(_.nullable)) + + val df4 = nonNullableInts.intersectAll(nonNullableInts) + checkAnswer(df4, Row(1) :: Row(3) :: Nil) + assert(df4.schema.forall(!_.nullable)) + } + test("udf") { val foo = udf((a: Int, b: String) => a.toString + b) http://git-wip-us.apache.org/repos/asf/spark/blob/65a4bc14/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index 0cfe260..deea9db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -136,6 +136,19 @@ private[sql] trait SQLTestData { self => df } + protected lazy val lowerCaseDataWithDuplicates: DataFrame = { + val df = spark.sparkContext.parallelize( + LowerCaseData(1, "a") :: + LowerCaseData(2, "b") :: + LowerCaseData(2, "b") :: + LowerCaseData(3, "c") :: + LowerCaseData(3, "c") :: + LowerCaseData(3, "c") :: + LowerCaseData(4, "d") :: Nil).toDF() + df.createOrReplaceTempView("lowerCaseData") + df + } + protected lazy val arrayData: RDD[ArrayData] = { val rdd = spark.sparkContext.parallelize( ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) :: --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org