Repository: spark
Updated Branches:
  refs/heads/master 31296628a -> 75ee12f09


[SPARK-8658][SQL] AttributeReference's equals method compares all the members

This fix is to change the equals method to check all of the specified fields 
for equality of AttributeReference.

Author: gatorsmile <gatorsm...@gmail.com>

Closes #9216 from gatorsmile/namedExpressEqual.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/75ee12f0
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/75ee12f0
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/75ee12f0

Branch: refs/heads/master
Commit: 75ee12f09c2645c1ad682764d512965f641eb5c2
Parents: 3129662
Author: gatorsmile <gatorsm...@gmail.com>
Authored: Mon Nov 16 15:22:12 2015 -0800
Committer: Michael Armbrust <mich...@databricks.com>
Committed: Mon Nov 16 15:22:12 2015 -0800

----------------------------------------------------------------------
 .../sql/catalyst/expressions/namedExpressions.scala     |  4 +++-
 .../sql/catalyst/plans/logical/basicOperators.scala     | 10 +++++-----
 .../sql/catalyst/plans/physical/partitioning.scala      | 12 ++++++------
 3 files changed, 14 insertions(+), 12 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/75ee12f0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index f80bcfc..e3dadda 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -194,7 +194,9 @@ case class AttributeReference(
   def sameRef(other: AttributeReference): Boolean = this.exprId == other.exprId
 
   override def equals(other: Any): Boolean = other match {
-    case ar: AttributeReference => name == ar.name && exprId == ar.exprId && 
dataType == ar.dataType
+    case ar: AttributeReference =>
+      name == ar.name && dataType == ar.dataType && nullable == ar.nullable &&
+        metadata == ar.metadata && exprId == ar.exprId && qualifiers == 
ar.qualifiers
     case _ => false
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/75ee12f0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index e2b97b2..45630a5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.types._
-import org.apache.spark.util.collection.OpenHashSet
+import scala.collection.mutable.ArrayBuffer
 
 case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) 
extends UnaryNode {
   override def output: Seq[Attribute] = projectList.map(_.toAttribute)
@@ -244,12 +244,12 @@ private[sql] object Expand {
    */
   private def buildNonSelectExprSet(
       bitmask: Int,
-      exprs: Seq[Expression]): OpenHashSet[Expression] = {
-    val set = new OpenHashSet[Expression](2)
+      exprs: Seq[Expression]): ArrayBuffer[Expression] = {
+    val set = new ArrayBuffer[Expression](2)
 
     var bit = exprs.length - 1
     while (bit >= 0) {
-      if (((bitmask >> bit) & 1) == 0) set.add(exprs(bit))
+      if (((bitmask >> bit) & 1) == 0) set += exprs(bit)
       bit -= 1
     }
 
@@ -279,7 +279,7 @@ private[sql] object Expand {
 
       (child.output :+ gid).map(expr => expr transformDown {
         // TODO this causes a problem when a column is used both for grouping 
and aggregation.
-        case x: Expression if nonSelectedGroupExprSet.contains(x) =>
+        case x: Expression if 
nonSelectedGroupExprSet.exists(_.semanticEquals(x)) =>
           // if the input attribute in the Invalid Grouping Expression set of 
for this group
           // replace it with constant null
           Literal.create(null, expr.dataType)

http://git-wip-us.apache.org/repos/asf/spark/blob/75ee12f0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index 86b9417..f6fb31a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -235,17 +235,17 @@ case class HashPartitioning(expressions: Seq[Expression], 
numPartitions: Int)
   override def satisfies(required: Distribution): Boolean = required match {
     case UnspecifiedDistribution => true
     case ClusteredDistribution(requiredClustering) =>
-      expressions.toSet.subsetOf(requiredClustering.toSet)
+      expressions.forall(x => requiredClustering.exists(_.semanticEquals(x)))
     case _ => false
   }
 
   override def compatibleWith(other: Partitioning): Boolean = other match {
-    case o: HashPartitioning => this == o
+    case o: HashPartitioning => this.semanticEquals(o)
     case _ => false
   }
 
   override def guarantees(other: Partitioning): Boolean = other match {
-    case o: HashPartitioning => this == o
+    case o: HashPartitioning => this.semanticEquals(o)
     case _ => false
   }
 
@@ -276,17 +276,17 @@ case class RangePartitioning(ordering: Seq[SortOrder], 
numPartitions: Int)
       val minSize = Seq(requiredOrdering.size, ordering.size).min
       requiredOrdering.take(minSize) == ordering.take(minSize)
     case ClusteredDistribution(requiredClustering) =>
-      ordering.map(_.child).toSet.subsetOf(requiredClustering.toSet)
+      ordering.map(_.child).forall(x => 
requiredClustering.exists(_.semanticEquals(x)))
     case _ => false
   }
 
   override def compatibleWith(other: Partitioning): Boolean = other match {
-    case o: RangePartitioning => this == o
+    case o: RangePartitioning => this.semanticEquals(o)
     case _ => false
   }
 
   override def guarantees(other: Partitioning): Boolean = other match {
-    case o: RangePartitioning => this == o
+    case o: RangePartitioning => this.semanticEquals(o)
     case _ => false
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to