Repository: spark
Updated Branches:
  refs/heads/master a6ee2f794 -> 357900311


[SPARK-6247][SQL] Fix resolution of ambiguous joins caused by new aliases

We need to handle ambiguous `exprId`s that are produced by new aliases as well 
as those caused by leaf nodes (`MultiInstanceRelation`).

Attempting to fix this revealed a bug in `equals` for `Alias` as these objects 
were comparing equal even when the expression ids did not match. Additionally, 
`LocalRelation` did not correctly provide statistics, and some tests in 
`catalyst` and `hive` were not using the helper functions for comparing plans.

Based on #4991 by chenghao-intel

Author: Michael Armbrust <mich...@databricks.com>

Closes #5062 from marmbrus/selfJoins and squashes the following commits:

8e9b84b [Michael Armbrust] check qualifier too
8038a36 [Michael Armbrust] handle aggs too
0b9c687 [Michael Armbrust] fix more tests
c3c574b [Michael Armbrust] revert change.
725f1ab [Michael Armbrust] add statistics
a925d08 [Michael Armbrust] check for conflicting attributes in join resolution
b022ef7 [Michael Armbrust] Handle project aliases.
d8caa40 [Michael Armbrust] test case: SPARK-6247
f9c67c2 [Michael Armbrust] Check for duplicate attributes in join resolution.
898af73 [Michael Armbrust] Fix Alias equality.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/35790031
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/35790031
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/35790031

Branch: refs/heads/master
Commit: 3579003115fa3217cff6aa400729d96b0c7b257b
Parents: a6ee2f7
Author: Michael Armbrust <mich...@databricks.com>
Authored: Tue Mar 17 19:47:51 2015 -0700
Committer: Michael Armbrust <mich...@databricks.com>
Committed: Tue Mar 17 19:47:51 2015 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  | 30 ++++++++++++++++---
 .../catalyst/expressions/namedExpressions.scala |  6 ++++
 .../catalyst/plans/logical/LocalRelation.scala  |  3 ++
 .../catalyst/plans/logical/basicOperators.scala |  7 +++++
 .../analysis/HiveTypeCoercionSuite.scala        | 10 ++++---
 .../spark/sql/catalyst/plans/PlanTest.scala     | 11 +++++--
 .../spark/sql/ColumnExpressionSuite.scala       |  6 +++-
 .../org/apache/spark/sql/SQLQuerySuite.scala    | 31 ++++++++++++++++++++
 .../spark/sql/catalyst/plans/PlanTest.scala     |  4 ++-
 9 files changed, 96 insertions(+), 12 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/35790031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 7753331..92d3db0 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -237,22 +237,33 @@ class Analyzer(catalog: Catalog,
       // Special handling for cases when self-join introduce duplicate 
expression ids.
       case j @ Join(left, right, _, _) if 
left.outputSet.intersect(right.outputSet).nonEmpty =>
         val conflictingAttributes = left.outputSet.intersect(right.outputSet)
+        logDebug(s"Conflicting attributes 
${conflictingAttributes.mkString(",")} in $j")
 
-        val (oldRelation, newRelation, attributeRewrites) = right.collect {
+        val (oldRelation, newRelation) = right.collect {
+          // Handle base relations that might appear more than once.
           case oldVersion: MultiInstanceRelation
               if 
oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
             val newVersion = oldVersion.newInstance()
-            val newAttributes = 
AttributeMap(oldVersion.output.zip(newVersion.output))
-            (oldVersion, newVersion, newAttributes)
+            (oldVersion, newVersion)
+
+          // Handle projects that create conflicting aliases.
+          case oldVersion @ Project(projectList, _)
+              if 
findAliases(projectList).intersect(conflictingAttributes).nonEmpty =>
+            (oldVersion, oldVersion.copy(projectList = 
newAliases(projectList)))
+
+          case oldVersion @ Aggregate(_, aggregateExpressions, _)
+              if 
findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
+            (oldVersion, oldVersion.copy(aggregateExpressions = 
newAliases(aggregateExpressions)))
         }.head // Only handle first case found, others will be fixed on the 
next pass.
 
+        val attributeRewrites = 
AttributeMap(oldRelation.output.zip(newRelation.output))
         val newRight = right transformUp {
           case r if r == oldRelation => newRelation
+        } transformUp {
           case other => other transformExpressions {
             case a: Attribute => attributeRewrites.get(a).getOrElse(a)
           }
         }
-
         j.copy(right = newRight)
 
       case q: LogicalPlan =>
@@ -272,6 +283,17 @@ class Analyzer(catalog: Catalog,
         }
     }
 
+    def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = {
+      expressions.map {
+        case a: Alias => Alias(a.child, a.name)()
+        case other => other
+      }
+    }
+
+    def findAliases(projectList: Seq[NamedExpression]): AttributeSet = {
+      AttributeSet(projectList.collect { case a: Alias => a.toAttribute })
+    }
+
     /**
      * Returns true if `exprs` contains a [[Star]].
      */

http://git-wip-us.apache.org/repos/asf/spark/blob/35790031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 62c062b..17f7f9f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -124,6 +124,12 @@ case class Alias(child: Expression, name: String)
   override def toString: String = s"$child AS $name#${exprId.id}$typeSuffix"
 
   override protected final def otherCopyArgs = exprId :: qualifiers :: Nil
+
+  override def equals(other: Any): Boolean = other match {
+    case a: Alias =>
+      name == a.name && exprId == a.exprId && child == a.child && qualifiers 
== a.qualifiers
+    case _ => false
+  }
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/35790031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
index 92bd057..bb79dc3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
@@ -54,4 +54,7 @@ case class LocalRelation(output: Seq[Attribute], data: 
Seq[Row] = Nil)
       otherOutput.map(_.dataType) == output.map(_.dataType) && otherData == 
data
     case _ => false
   }
+
+  override lazy val statistics =
+    Statistics(sizeInBytes = output.map(_.dataType.defaultSize).sum * 
data.length)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/35790031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 624912d..1e7b449 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -108,6 +108,13 @@ case class Join(
         left.output ++ right.output
     }
   }
+
+  def selfJoinResolved = left.outputSet.intersect(right.outputSet).isEmpty
+
+  // Joins are only resolved if they don't introduce ambiguious expression ids.
+  override lazy val resolved: Boolean = {
+    childrenResolved && !expressions.exists(!_.resolved) && selfJoinResolved
+  }
 }
 
 case class Except(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {

http://git-wip-us.apache.org/repos/asf/spark/blob/35790031/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index 85798d0..ecbb542 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -17,13 +17,13 @@
 
 package org.apache.spark.sql.catalyst.analysis
 
-import org.scalatest.FunSuite
+import org.apache.spark.sql.catalyst.plans.PlanTest
 
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
 import org.apache.spark.sql.types._
 
-class HiveTypeCoercionSuite extends FunSuite {
+class HiveTypeCoercionSuite extends PlanTest {
 
   test("tightest common bound for types") {
     def widenTest(t1: DataType, t2: DataType, tightestCommon: 
Option[DataType]) {
@@ -106,7 +106,8 @@ class HiveTypeCoercionSuite extends FunSuite {
     val booleanCasts = new HiveTypeCoercion { }.BooleanCasts
     def ruleTest(initial: Expression, transformed: Expression) {
       val testRelation = LocalRelation(AttributeReference("a", IntegerType)())
-      assert(booleanCasts(Project(Seq(Alias(initial, "a")()), testRelation)) ==
+      comparePlans(
+        booleanCasts(Project(Seq(Alias(initial, "a")()), testRelation)),
         Project(Seq(Alias(transformed, "a")()), testRelation))
     }
     // Remove superflous boolean -> boolean casts.
@@ -119,7 +120,8 @@ class HiveTypeCoercionSuite extends FunSuite {
     val fac = new HiveTypeCoercion { }.FunctionArgumentConversion
     def ruleTest(initial: Expression, transformed: Expression) {
       val testRelation = LocalRelation(AttributeReference("a", IntegerType)())
-      assert(fac(Project(Seq(Alias(initial, "a")()), testRelation)) ==
+      comparePlans(
+        fac(Project(Seq(Alias(initial, "a")()), testRelation)),
         Project(Seq(Alias(transformed, "a")()), testRelation))
     }
     ruleTest(

http://git-wip-us.apache.org/repos/asf/spark/blob/35790031/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
index 7d609b9..4888404 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
@@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.plans
 
 import org.scalatest.FunSuite
 
-import org.apache.spark.sql.catalyst.expressions.{ExprId, AttributeReference}
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.{NoRelation, Filter, 
LogicalPlan}
 import org.apache.spark.sql.catalyst.util._
 
 /**
@@ -36,6 +36,8 @@ class PlanTest extends FunSuite {
     plan transformAllExpressions {
       case a: AttributeReference =>
         AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0))
+      case a: Alias =>
+        Alias(a.child, a.name)(exprId = ExprId(0))
     }
   }
 
@@ -50,4 +52,9 @@ class PlanTest extends FunSuite {
           |${sideBySide(normalized1.treeString, 
normalized2.treeString).mkString("\n")}
         """.stripMargin)
   }
+
+  /** Fails the test if the two expressions do not match */
+  protected def compareExpressions(e1: Expression, e2: Expression): Unit = {
+    comparePlans(Filter(e1, NoRelation), Filter(e2, NoRelation))
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/35790031/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index 3036fbc..a53ae97 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql
 
+import org.apache.spark.sql.catalyst.expressions.NamedExpression
+import org.apache.spark.sql.catalyst.plans.logical.{Project, NoRelation}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.TestSQLContext
 import org.apache.spark.sql.test.TestSQLContext.implicits._
@@ -311,7 +313,9 @@ class ColumnExpressionSuite extends QueryTest {
   }
 
   test("lift alias out of cast") {
-    assert(col("1234").as("name").cast("int").expr === 
col("1234").cast("int").as("name").expr)
+    compareExpressions(
+      col("1234").as("name").cast("int").expr,
+      col("1234").cast("int").as("name").expr)
   }
 
   test("columns can be compared") {

http://git-wip-us.apache.org/repos/asf/spark/blob/35790031/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 4dedcd3..a3c0076 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -36,6 +36,37 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll 
{
   import org.apache.spark.sql.test.TestSQLContext.implicits._
   val sqlCtx = TestSQLContext
 
+  test("self join with aliases") {
+    Seq(1,2,3).map(i => (i, i.toString)).toDF("int", 
"str").registerTempTable("df")
+
+    checkAnswer(
+      sql(
+        """
+          |SELECT x.str, COUNT(*)
+          |FROM df x JOIN df y ON x.str = y.str
+          |GROUP BY x.str
+        """.stripMargin),
+      Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil)
+  }
+
+  test("self join with alias in agg") {
+      Seq(1,2,3)
+        .map(i => (i, i.toString))
+        .toDF("int", "str")
+        .groupBy("str")
+        .agg($"str", count("str").as("strCount"))
+        .registerTempTable("df")
+
+    checkAnswer(
+      sql(
+        """
+          |SELECT x.str, SUM(x.strCount)
+          |FROM df x JOIN df y ON x.str = y.str
+          |GROUP BY x.str
+        """.stripMargin),
+      Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil)
+  }
+
   test("SPARK-4625 support SORT BY in SimpleSQLParser & DSL") {
     checkAnswer(
       sql("SELECT a FROM testData2 SORT BY a"),

http://git-wip-us.apache.org/repos/asf/spark/blob/35790031/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala 
b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
index 44ee5ab..98f1c0e 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.catalyst.plans
 
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference, ExprId}
+import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, 
ExprId}
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.util._
 import org.scalatest.FunSuite
@@ -38,6 +38,8 @@ class PlanTest extends FunSuite {
     plan transformAllExpressions {
       case a: AttributeReference =>
         AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0))
+      case a: Alias =>
+        Alias(a.child, a.name)(exprId = ExprId(0))
     }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to