Repository: spark
Updated Branches:
  refs/heads/master 65b75e66e -> 637a78f1d


[SPARK-13427][SQL] Support USING clause in JOIN.

## What changes were proposed in this pull request?

Support queries that JOIN tables with USING clause.
SELECT * from table1 JOIN table2 USING <column_list>

USING clause can be used as a means to simplify the join condition
when :

1) Equijoin semantics is desired and
2) The column names in the equijoin have the same name.

We already have the support for Natural Join in Spark. This PR makes
use of the already existing infrastructure for natural join to
form the join condition and also the projection list.

## How was the this patch tested?

Have added unit tests in SQLQuerySuite, CatalystQlSuite, ResolveNaturalJoinSuite

Author: Dilip Biswal <dbis...@us.ibm.com>

Closes #11297 from dilipbiswal/spark-13427.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/637a78f1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/637a78f1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/637a78f1

Branch: refs/heads/master
Commit: 637a78f1d3dff00658324de3887d75c5ccd857be
Parents: 65b75e6
Author: Dilip Biswal <dbis...@us.ibm.com>
Authored: Thu Mar 17 10:01:41 2016 -0700
Committer: Michael Armbrust <mich...@databricks.com>
Committed: Thu Mar 17 10:01:41 2016 -0700

----------------------------------------------------------------------
 .../sql/catalyst/parser/FromClauseParser.g      |  9 +-
 .../spark/sql/catalyst/parser/SparkSqlParser.g  |  1 +
 .../spark/sql/catalyst/analysis/Analyzer.scala  | 93 ++++++++++++--------
 .../sql/catalyst/analysis/CheckAnalysis.scala   |  7 ++
 .../sql/catalyst/optimizer/Optimizer.scala      |  2 +
 .../spark/sql/catalyst/parser/CatalystQl.scala  | 49 +++++++----
 .../spark/sql/catalyst/plans/joinTypes.scala    |  8 ++
 .../catalyst/plans/logical/basicOperators.scala |  5 +-
 .../analysis/ResolveNaturalJoinSuite.scala      | 74 +++++++++++-----
 .../sql/catalyst/parser/CatalystQlSuite.scala   | 21 +++++
 .../scala/org/apache/spark/sql/Dataset.scala    | 39 ++------
 .../org/apache/spark/sql/SQLQuerySuite.scala    | 64 ++++++++++++++
 12 files changed, 259 insertions(+), 113 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/637a78f1/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g
 
b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g
index e83f8a7..1bf461c 100644
--- 
a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g
+++ 
b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g
@@ -91,10 +91,17 @@ fromClause
 joinSource
 @init { gParent.pushMsg("join source", state); }
 @after { gParent.popMsg(state); }
-    : fromSource ( joinToken^ fromSource ( KW_ON! expression 
{$joinToken.start.getType() != COMMA}? )? )*
+    : fromSource ( joinToken^ fromSource ( joinCond 
{$joinToken.start.getType() != COMMA}? )? )*
     | uniqueJoinToken^ uniqueJoinSource (COMMA! uniqueJoinSource)+
     ;
 
+joinCond
+@init { gParent.pushMsg("join expression list", state); }
+@after { gParent.popMsg(state); }
+    : KW_ON! expression
+    | KW_USING LPAREN columnNameList RPAREN -> ^(TOK_USING columnNameList)
+    ;
+
 uniqueJoinSource
 @init { gParent.pushMsg("unique join source", state); }
 @after { gParent.popMsg(state); }

http://git-wip-us.apache.org/repos/asf/spark/blob/637a78f1/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g
 
b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g
index 1db3aed..f0c2368 100644
--- 
a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g
+++ 
b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g
@@ -387,6 +387,7 @@ TOK_SETCONFIG;
 TOK_DFS;
 TOK_ADDFILE;
 TOK_ADDJAR;
+TOK_USING;
 }
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/637a78f1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 53ea3cf..e4e934a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -87,7 +87,7 @@ class Analyzer(
       ResolveSubquery ::
       ResolveWindowOrder ::
       ResolveWindowFrame ::
-      ResolveNaturalJoin ::
+      ResolveNaturalAndUsingJoin ::
       ExtractWindowExpressions ::
       GlobalAggregates ::
       ResolveAggregateFunctions ::
@@ -1329,48 +1329,69 @@ class Analyzer(
   }
 
   /**
-   * Removes natural joins by calculating output columns based on output from 
two sides,
-   * Then apply a Project on a normal Join to eliminate natural join.
+   * Removes natural or using joins by calculating output columns based on 
output from two sides,
+   * Then apply a Project on a normal Join to eliminate natural or using join.
    */
-  object ResolveNaturalJoin extends Rule[LogicalPlan] {
+  object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] {
     override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators 
{
+      case j @ Join(left, right, UsingJoin(joinType, usingCols), condition)
+          if left.resolved && right.resolved && j.duplicateResolved =>
+        // Resolve the column names referenced in using clause from both the 
legs of join.
+        val lCols = usingCols.flatMap(col => left.resolveQuoted(col.name, 
resolver))
+        val rCols = usingCols.flatMap(col => right.resolveQuoted(col.name, 
resolver))
+        if ((lCols.length == usingCols.length) && (rCols.length == 
usingCols.length)) {
+          val joinNames = lCols.map(exp => exp.name)
+          commonNaturalJoinProcessing(left, right, joinType, joinNames, None)
+        } else {
+          j
+        }
       case j @ Join(left, right, NaturalJoin(joinType), condition) if 
j.resolvedExceptNatural =>
         // find common column names from both sides
         val joinNames = 
left.output.map(_.name).intersect(right.output.map(_.name))
-        val leftKeys = joinNames.map(keyName => left.output.find(_.name == 
keyName).get)
-        val rightKeys = joinNames.map(keyName => right.output.find(_.name == 
keyName).get)
-        val joinPairs = leftKeys.zip(rightKeys)
-
-        // Add joinPairs to joinConditions
-        val newCondition = (condition ++ joinPairs.map {
-          case (l, r) => EqualTo(l, r)
-        }).reduceOption(And)
-
-        // columns not in joinPairs
-        val lUniqueOutput = left.output.filterNot(att => 
leftKeys.contains(att))
-        val rUniqueOutput = right.output.filterNot(att => 
rightKeys.contains(att))
-
-        // the output list looks like: join keys, columns from left, columns 
from right
-        val projectList = joinType match {
-          case LeftOuter =>
-            leftKeys ++ lUniqueOutput ++ 
rUniqueOutput.map(_.withNullability(true))
-          case RightOuter =>
-            rightKeys ++ lUniqueOutput.map(_.withNullability(true)) ++ 
rUniqueOutput
-          case FullOuter =>
-            // in full outer join, joinCols should be non-null if there is.
-            val joinedCols = joinPairs.map { case (l, r) => 
Alias(Coalesce(Seq(l, r)), l.name)() }
-            joinedCols ++
-              lUniqueOutput.map(_.withNullability(true)) ++
-              rUniqueOutput.map(_.withNullability(true))
-          case Inner =>
-            rightKeys ++ lUniqueOutput ++ rUniqueOutput
-          case _ =>
-            sys.error("Unsupported natural join type " + joinType)
-        }
-        // use Project to trim unnecessary fields
-        Project(projectList, Join(left, right, joinType, newCondition))
+        commonNaturalJoinProcessing(left, right, joinType, joinNames, 
condition)
+    }
+  }
+
+  private def commonNaturalJoinProcessing(
+     left: LogicalPlan,
+     right: LogicalPlan,
+     joinType: JoinType,
+     joinNames: Seq[String],
+     condition: Option[Expression]) = {
+    val leftKeys = joinNames.map(keyName => left.output.find(_.name == 
keyName).get)
+    val rightKeys = joinNames.map(keyName => right.output.find(_.name == 
keyName).get)
+    val joinPairs = leftKeys.zip(rightKeys)
+
+    val newCondition = (condition ++ 
joinPairs.map(EqualTo.tupled)).reduceOption(And)
+
+    // columns not in joinPairs
+    val lUniqueOutput = left.output.filterNot(att => leftKeys.contains(att))
+    val rUniqueOutput = right.output.filterNot(att => rightKeys.contains(att))
+
+    // the output list looks like: join keys, columns from left, columns from 
right
+    val projectList = joinType match {
+      case LeftOuter =>
+        leftKeys ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true))
+      case LeftSemi =>
+        leftKeys ++ lUniqueOutput
+      case RightOuter =>
+        rightKeys ++ lUniqueOutput.map(_.withNullability(true)) ++ 
rUniqueOutput
+      case FullOuter =>
+        // in full outer join, joinCols should be non-null if there is.
+        val joinedCols = joinPairs.map { case (l, r) => Alias(Coalesce(Seq(l, 
r)), l.name)() }
+        joinedCols ++
+          lUniqueOutput.map(_.withNullability(true)) ++
+          rUniqueOutput.map(_.withNullability(true))
+      case Inner =>
+        leftKeys ++ lUniqueOutput ++ rUniqueOutput
+      case _ =>
+        sys.error("Unsupported natural join type " + joinType)
     }
+    // use Project to trim unnecessary fields
+    Project(projectList, Join(left, right, joinType, newCondition))
   }
+
+
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/637a78f1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 1e430c1..1d1e892 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
+import org.apache.spark.sql.catalyst.plans.UsingJoin
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.types._
 
@@ -109,6 +110,12 @@ trait CheckAnalysis {
               s"filter expression '${f.condition.sql}' " +
                 s"of type ${f.condition.dataType.simpleString} is not a 
boolean.")
 
+          case j @ Join(_, _, UsingJoin(_, cols), _) =>
+            val from = operator.inputSet.map(_.name).mkString(", ")
+            failAnalysis(
+              s"using columns [${cols.mkString(",")}] " +
+                s"can not be resolved given input columns: [$from] ")
+
           case j @ Join(_, _, _, Some(condition)) if condition.dataType != 
BooleanType =>
             failAnalysis(
               s"join condition '${condition.sql}' " +

http://git-wip-us.apache.org/repos/asf/spark/blob/637a78f1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index d0e5859..c419b5f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -1133,6 +1133,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] 
with PredicateHelper {
             reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin)
         case FullOuter => f // DO Nothing for Full Outer Join
         case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node")
+        case UsingJoin(_, _) => sys.error("Untransformed Using join node")
       }
 
     // push down the join filter into sub query scanning if applicable
@@ -1168,6 +1169,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] 
with PredicateHelper {
           Join(newLeft, newRight, LeftOuter, newJoinCond)
         case FullOuter => f
         case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node")
+        case UsingJoin(_, _) => sys.error("Untransformed Using join node")
       }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/637a78f1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala
index 7d5a468..c188c5b 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala
@@ -419,30 +419,47 @@ 
https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
           sys.error(s"Unsupported join operation: $other")
         }
 
-        val joinType = joinToken match {
-          case "TOK_JOIN" => Inner
-          case "TOK_CROSSJOIN" => Inner
-          case "TOK_RIGHTOUTERJOIN" => RightOuter
-          case "TOK_LEFTOUTERJOIN" => LeftOuter
-          case "TOK_FULLOUTERJOIN" => FullOuter
-          case "TOK_LEFTSEMIJOIN" => LeftSemi
-          case "TOK_UNIQUEJOIN" => noParseRule("Unique Join", node)
-          case "TOK_ANTIJOIN" => noParseRule("Anti Join", node)
-          case "TOK_NATURALJOIN" => NaturalJoin(Inner)
-          case "TOK_NATURALRIGHTOUTERJOIN" => NaturalJoin(RightOuter)
-          case "TOK_NATURALLEFTOUTERJOIN" => NaturalJoin(LeftOuter)
-          case "TOK_NATURALFULLOUTERJOIN" => NaturalJoin(FullOuter)
-        }
+        val (joinType, joinCondition) = getJoinInfo(joinToken, other, node)
+
         Join(nodeToRelation(relation1),
           nodeToRelation(relation2),
           joinType,
-          other.headOption.map(nodeToExpr))
-
+          joinCondition)
       case _ =>
         noParseRule("Relation", node)
     }
   }
 
+  protected def getJoinInfo(
+     joinToken: String,
+     joinConditionToken: Seq[ASTNode],
+     node: ASTNode): (JoinType, Option[Expression]) = {
+    val joinType = joinToken match {
+      case "TOK_JOIN" => Inner
+      case "TOK_CROSSJOIN" => Inner
+      case "TOK_RIGHTOUTERJOIN" => RightOuter
+      case "TOK_LEFTOUTERJOIN" => LeftOuter
+      case "TOK_FULLOUTERJOIN" => FullOuter
+      case "TOK_LEFTSEMIJOIN" => LeftSemi
+      case "TOK_UNIQUEJOIN" => noParseRule("Unique Join", node)
+      case "TOK_ANTIJOIN" => noParseRule("Anti Join", node)
+      case "TOK_NATURALJOIN" => NaturalJoin(Inner)
+      case "TOK_NATURALRIGHTOUTERJOIN" => NaturalJoin(RightOuter)
+      case "TOK_NATURALLEFTOUTERJOIN" => NaturalJoin(LeftOuter)
+      case "TOK_NATURALFULLOUTERJOIN" => NaturalJoin(FullOuter)
+    }
+
+    joinConditionToken match {
+      case Token("TOK_USING", columnList :: Nil) :: Nil =>
+        val colNames = columnList.children.collect {
+          case Token(name, Nil) => UnresolvedAttribute(name)
+        }
+        (UsingJoin(joinType, colNames), None)
+      /* Join expression specified using ON clause */
+      case _ => (joinType, joinConditionToken.headOption.map(nodeToExpr))
+    }
+  }
+
   protected def nodeToSortOrder(node: ASTNode): SortOrder = node match {
     case Token("TOK_TABSORTCOLNAMEASC", sortExpr :: Nil) =>
       SortOrder(nodeToExpr(sortExpr), Ascending)

http://git-wip-us.apache.org/repos/asf/spark/blob/637a78f1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
index 27a7532..9ca4f13 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.catalyst.plans
 
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
+
 object JoinType {
   def apply(typ: String): JoinType = typ.toLowerCase.replace("_", "") match {
     case "inner" => Inner
@@ -66,3 +68,9 @@ case class NaturalJoin(tpe: JoinType) extends JoinType {
     "Unsupported natural join type " + tpe)
   override def sql: String = "NATURAL " + tpe.sql
 }
+
+case class UsingJoin(tpe: JoinType, usingColumns: Seq[UnresolvedAttribute]) 
extends JoinType {
+  require(Seq(Inner, LeftOuter, LeftSemi, RightOuter, FullOuter).contains(tpe),
+    "Unsupported using join type " + tpe)
+  override def sql: String = "USING " + tpe.sql
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/637a78f1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 09ea3fe..ccc9916 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -298,10 +298,11 @@ case class Join(
       condition.forall(_.dataType == BooleanType)
   }
 
-  // if not a natural join, use `resolvedExceptNatural`. if it is a natural 
join, we still need
-  // to eliminate natural before we mark it resolved.
+  // if not a natural join, use `resolvedExceptNatural`. if it is a natural 
join or
+  // using join, we still need to eliminate natural or using before we mark it 
resolved.
   override lazy val resolved: Boolean = joinType match {
     case NaturalJoin(_) => false
+    case UsingJoin(_, _) => false
     case _ => resolvedExceptNatural
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/637a78f1/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala
index fcf4ac1..1423a87 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.catalyst.analysis
 
+import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.expressions._
@@ -35,56 +36,81 @@ class ResolveNaturalJoinSuite extends AnalysisTest {
   lazy val r3 = LocalRelation(aNotNull, bNotNull)
   lazy val r4 = LocalRelation(cNotNull, bNotNull)
 
-  test("natural inner join") {
-    val plan = r1.join(r2, NaturalJoin(Inner), None)
+  test("natural/using inner join") {
+    val naturalPlan = r1.join(r2, NaturalJoin(Inner), None)
+    val usingPlan = r1.join(r2, UsingJoin(Inner, 
Seq(UnresolvedAttribute("a"))), None)
     val expected = r1.join(r2, Inner, Some(EqualTo(a, a))).select(a, b, c)
-    checkAnalysis(plan, expected)
+    checkAnalysis(naturalPlan, expected)
+    checkAnalysis(usingPlan, expected)
   }
 
-  test("natural left join") {
-    val plan = r1.join(r2, NaturalJoin(LeftOuter), None)
+  test("natural/using left join") {
+    val naturalPlan = r1.join(r2, NaturalJoin(LeftOuter), None)
+    val usingPlan = r1.join(r2, UsingJoin(LeftOuter, 
Seq(UnresolvedAttribute("a"))), None)
     val expected = r1.join(r2, LeftOuter, Some(EqualTo(a, a))).select(a, b, c)
-    checkAnalysis(plan, expected)
+    checkAnalysis(naturalPlan, expected)
+    checkAnalysis(usingPlan, expected)
   }
 
-  test("natural right join") {
-    val plan = r1.join(r2, NaturalJoin(RightOuter), None)
+  test("natural/using right join") {
+    val naturalPlan = r1.join(r2, NaturalJoin(RightOuter), None)
+    val usingPlan = r1.join(r2, UsingJoin(RightOuter, 
Seq(UnresolvedAttribute("a"))), None)
     val expected = r1.join(r2, RightOuter, Some(EqualTo(a, a))).select(a, b, c)
-    checkAnalysis(plan, expected)
+    checkAnalysis(naturalPlan, expected)
+    checkAnalysis(usingPlan, expected)
   }
 
-  test("natural full outer join") {
-    val plan = r1.join(r2, NaturalJoin(FullOuter), None)
+  test("natural/using full outer join") {
+    val naturalPlan = r1.join(r2, NaturalJoin(FullOuter), None)
+    val usingPlan = r1.join(r2, UsingJoin(FullOuter, 
Seq(UnresolvedAttribute("a"))), None)
     val expected = r1.join(r2, FullOuter, Some(EqualTo(a, a))).select(
       Alias(Coalesce(Seq(a, a)), "a")(), b, c)
-    checkAnalysis(plan, expected)
+    checkAnalysis(naturalPlan, expected)
+    checkAnalysis(usingPlan, expected)
   }
 
-  test("natural inner join with no nullability") {
-    val plan = r3.join(r4, NaturalJoin(Inner), None)
+  test("natural/using inner join with no nullability") {
+    val naturalPlan = r3.join(r4, NaturalJoin(Inner), None)
+    val usingPlan = r3.join(r4, UsingJoin(Inner, 
Seq(UnresolvedAttribute("b"))), None)
     val expected = r3.join(r4, Inner, Some(EqualTo(bNotNull, 
bNotNull))).select(
       bNotNull, aNotNull, cNotNull)
-    checkAnalysis(plan, expected)
+    checkAnalysis(naturalPlan, expected)
+    checkAnalysis(usingPlan, expected)
   }
 
-  test("natural left join with no nullability") {
-    val plan = r3.join(r4, NaturalJoin(LeftOuter), None)
+  test("natural/using left join with no nullability") {
+    val naturalPlan = r3.join(r4, NaturalJoin(LeftOuter), None)
+    val usingPlan = r3.join(r4, UsingJoin(LeftOuter, 
Seq(UnresolvedAttribute("b"))), None)
     val expected = r3.join(r4, LeftOuter, Some(EqualTo(bNotNull, 
bNotNull))).select(
       bNotNull, aNotNull, c)
-    checkAnalysis(plan, expected)
+    checkAnalysis(naturalPlan, expected)
+    checkAnalysis(usingPlan, expected)
   }
 
-  test("natural right join with no nullability") {
-    val plan = r3.join(r4, NaturalJoin(RightOuter), None)
+  test("natural/using right join with no nullability") {
+    val naturalPlan = r3.join(r4, NaturalJoin(RightOuter), None)
+    val usingPlan = r3.join(r4, UsingJoin(RightOuter, 
Seq(UnresolvedAttribute("b"))), None)
     val expected = r3.join(r4, RightOuter, Some(EqualTo(bNotNull, 
bNotNull))).select(
       bNotNull, a, cNotNull)
-    checkAnalysis(plan, expected)
+    checkAnalysis(naturalPlan, expected)
+    checkAnalysis(usingPlan, expected)
   }
 
-  test("natural full outer join with no nullability") {
-    val plan = r3.join(r4, NaturalJoin(FullOuter), None)
+  test("natural/using full outer join with no nullability") {
+    val naturalPlan = r3.join(r4, NaturalJoin(FullOuter), None)
+    val usingPlan = r3.join(r4, UsingJoin(FullOuter, 
Seq(UnresolvedAttribute("b"))), None)
     val expected = r3.join(r4, FullOuter, Some(EqualTo(bNotNull, 
bNotNull))).select(
       Alias(Coalesce(Seq(bNotNull, bNotNull)), "b")(), a, c)
-    checkAnalysis(plan, expected)
+    checkAnalysis(naturalPlan, expected)
+    checkAnalysis(usingPlan, expected)
+  }
+
+  test("using unresolved attribute") {
+    val usingPlan = r1.join(r2, UsingJoin(Inner, 
Seq(UnresolvedAttribute("d"))), None)
+    val error = intercept[AnalysisException] {
+      SimpleAnalyzer.checkAnalysis(usingPlan)
+    }
+    assert(error.message.contains(
+      "using columns ['d] can not be resolved given input columns: [b, a, c]"))
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/637a78f1/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala
index 048b4f1..c068e89 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala
@@ -219,4 +219,25 @@ class CatalystQlSuite extends PlanTest {
     parser.parsePlan("select * from t where a = (select b from s)")
     parser.parsePlan("select * from t group by g having a > (select b from s)")
   }
+
+  test("using clause in JOIN") {
+    // Tests parsing of using clause for different join types.
+    parser.parsePlan("select * from t1 join t2 using (c1)")
+    parser.parsePlan("select * from t1 join t2 using (c1, c2)")
+    parser.parsePlan("select * from t1 left join t2 using (c1, c2)")
+    parser.parsePlan("select * from t1 right join t2 using (c1, c2)")
+    parser.parsePlan("select * from t1 full outer join t2 using (c1, c2)")
+    parser.parsePlan("select * from t1 join t2 using (c1) join t3 using (c2)")
+    // Tests errors
+    // (1) Empty using clause
+    // (2) Qualified columns in using
+    // (3) Both on and using clause
+    var error = intercept[AnalysisException](parser.parsePlan("select * from 
t1 join t2 using ()"))
+    assert(error.message.contains("cannot recognize input near ')'"))
+    error = intercept[AnalysisException](parser.parsePlan("select * from t1 
join t2 using (t1.c1)"))
+    assert(error.message.contains("mismatched input '.'"))
+    error = intercept[AnalysisException](parser.parsePlan("select * from t1" +
+      " join t2 using (c1) on t1.c1 = t2.c1"))
+    assert(error.message.contains("missing EOF at 'on' near ')'"))
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/637a78f1/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index ac2ca3c..75f1ffd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -490,41 +490,12 @@ class Dataset[T] private[sql](
       Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), 
None))
       .analyzed.asInstanceOf[Join]
 
-    val condition = usingColumns.map { col =>
-      catalyst.expressions.EqualTo(
-        withPlan(joined.left).resolve(col),
-        withPlan(joined.right).resolve(col))
-    }.reduceLeftOption[catalyst.expressions.BinaryExpression] { (cond, eqTo) =>
-      catalyst.expressions.And(cond, eqTo)
-    }
-
-    // Project only one of the join columns.
-    val joinedCols = JoinType(joinType) match {
-      case Inner | LeftOuter | LeftSemi =>
-        usingColumns.map(col => withPlan(joined.left).resolve(col))
-      case RightOuter =>
-        usingColumns.map(col => withPlan(joined.right).resolve(col))
-      case FullOuter =>
-        usingColumns.map { col =>
-          val leftCol = 
withPlan(joined.left).resolve(col).toAttribute.withNullability(true)
-          val rightCol = 
withPlan(joined.right).resolve(col).toAttribute.withNullability(true)
-          Alias(Coalesce(Seq(leftCol, rightCol)), col)()
-        }
-      case NaturalJoin(_) => sys.error("NaturalJoin with using clause is not 
supported.")
-    }
-    // The nullability of output of joined could be different than original 
column,
-    // so we can only compare them by exprId
-    val joinRefs = AttributeSet(condition.toSeq.flatMap(_.references))
-    val resultCols = joinedCols ++ 
joined.output.filterNot(joinRefs.contains(_))
     withPlan {
-      Project(
-        resultCols,
-        Join(
-          joined.left,
-          joined.right,
-          joinType = JoinType(joinType),
-          condition)
-      )
+      Join(
+        joined.left,
+        joined.right,
+        UsingJoin(JoinType(joinType), 
usingColumns.map(UnresolvedAttribute(_))),
+        None)
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/637a78f1/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 3efe984..6716982 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -2179,4 +2179,68 @@ class SQLQuerySuite extends QueryTest with 
SharedSQLContext {
         Row(4) :: Nil)
     }
   }
+
+  test("join with using clause") {
+    val df1 = Seq(("r1c1", "r1c2", "t1r1c3"),
+      ("r2c1", "r2c2", "t1r2c3"), ("r3c1x", "r3c2", "t1r3c3")).toDF("c1", 
"c2", "c3")
+    val df2 = Seq(("r1c1", "r1c2", "t2r1c3"),
+      ("r2c1", "r2c2", "t2r2c3"), ("r3c1y", "r3c2", "t2r3c3")).toDF("c1", 
"c2", "c3")
+    val df3 = Seq((null, "r1c2", "t3r1c3"),
+      ("r2c1", "r2c2", "t3r2c3"), ("r3c1y", "r3c2", "t3r3c3")).toDF("c1", 
"c2", "c3")
+    withTempTable("t1", "t2", "t3") {
+      df1.registerTempTable("t1")
+      df2.registerTempTable("t2")
+      df3.registerTempTable("t3")
+      // inner join with one using column
+      checkAnswer(
+        sql("SELECT * FROM t1 join t2 using (c1)"),
+        Row("r1c1", "r1c2", "t1r1c3", "r1c2", "t2r1c3") ::
+          Row("r2c1", "r2c2", "t1r2c3", "r2c2", "t2r2c3") :: Nil)
+
+      // inner join with two using columns
+      checkAnswer(
+        sql("SELECT * FROM t1 join t2 using (c1, c2)"),
+        Row("r1c1", "r1c2", "t1r1c3", "t2r1c3") ::
+          Row("r2c1", "r2c2", "t1r2c3", "t2r2c3") :: Nil)
+
+      // Left outer join with one using column.
+      checkAnswer(
+        sql("SELECT * FROM t1 left join t2 using (c1)"),
+        Row("r1c1", "r1c2", "t1r1c3", "r1c2", "t2r1c3") ::
+          Row("r2c1", "r2c2", "t1r2c3", "r2c2", "t2r2c3") ::
+          Row("r3c1x", "r3c2", "t1r3c3", null, null) :: Nil)
+
+      // Right outer join with one using column.
+      checkAnswer(
+        sql("SELECT * FROM t1 right join t2 using (c1)"),
+        Row("r1c1", "r1c2", "t1r1c3", "r1c2", "t2r1c3") ::
+          Row("r2c1", "r2c2", "t1r2c3", "r2c2", "t2r2c3") ::
+          Row("r3c1y", null, null, "r3c2", "t2r3c3") :: Nil)
+
+      // Full outer join with one using column.
+      checkAnswer(
+        sql("SELECT * FROM t1 full outer join t2 using (c1)"),
+        Row("r1c1", "r1c2", "t1r1c3", "r1c2", "t2r1c3") ::
+          Row("r2c1", "r2c2", "t1r2c3", "r2c2", "t2r2c3") ::
+          Row("r3c1x", "r3c2", "t1r3c3", null, null) ::
+          Row("r3c1y", null,
+            null, "r3c2", "t2r3c3") :: Nil)
+
+      // Full outer join with null value in join column.
+      checkAnswer(
+        sql("SELECT * FROM t1 full outer join t3 using (c1)"),
+        Row("r1c1", "r1c2", "t1r1c3", null, null) ::
+          Row("r2c1", "r2c2", "t1r2c3", "r2c2", "t3r2c3") ::
+          Row("r3c1x", "r3c2", "t1r3c3", null, null) ::
+          Row("r3c1y", null, null, "r3c2", "t3r3c3") ::
+          Row(null, null, null, "r1c2", "t3r1c3") :: Nil)
+
+      // Self join with using columns.
+      checkAnswer(
+        sql("SELECT * FROM t1 join t1 using (c1)"),
+        Row("r1c1", "r1c2", "t1r1c3", "r1c2", "t1r1c3") ::
+          Row("r2c1", "r2c2", "t1r2c3", "r2c2", "t1r2c3") ::
+          Row("r3c1x", "r3c2", "t1r3c3", "r3c2", "t1r3c3") :: Nil)
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to