Repository: spark
Updated Branches:
  refs/heads/master 287c0ac77 -> dc4d577c6


[SPARK-3198] [SQL] Remove the TreeNode.id

Thus id property of the TreeNode API does save time in a faster way to compare 
2 TreeNodes, it is kind of performance bottleneck during the expression object 
creation in a multi-threading env (because of the memory barrier).
Fortunately, the tree node comparison only happen once in master, so even we 
remove it, the entire performance will not be affected.

Author: Cheng Hao <hao.ch...@intel.com>

Closes #2155 from chenghao-intel/treenode and squashes the following commits:

7cf2cd2 [Cheng Hao] Remove the implicit keyword for TreeNodeRef and some other 
small issues
5873415 [Cheng Hao] Remove the TreeNode.id


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/dc4d577c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/dc4d577c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/dc4d577c

Branch: refs/heads/master
Commit: dc4d577c6549df58f42c0e22cac354554d169896
Parents: 287c0ac
Author: Cheng Hao <hao.ch...@intel.com>
Authored: Fri Aug 29 15:32:26 2014 -0700
Committer: Michael Armbrust <mich...@databricks.com>
Committed: Fri Aug 29 15:32:26 2014 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/planning/patterns.scala  | 11 +++++----
 .../spark/sql/catalyst/plans/QueryPlan.scala    | 12 +++++-----
 .../spark/sql/catalyst/trees/TreeNode.scala     | 24 ++------------------
 .../spark/sql/catalyst/trees/package.scala      | 11 +++++++++
 .../sql/catalyst/trees/TreeNodeSuite.scala      |  5 +++-
 .../sql/execution/GeneratedAggregate.scala      | 10 ++++----
 .../spark/sql/execution/debug/package.scala     |  7 +++---
 .../apache/spark/sql/execution/pythonUdfs.scala |  2 +-
 8 files changed, 40 insertions(+), 42 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/dc4d577c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 90923fe..f0fd9a8 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -19,8 +19,9 @@ package org.apache.spark.sql.catalyst.planning
 
 import scala.annotation.tailrec
 
-import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.Logging
+import org.apache.spark.sql.catalyst.trees.TreeNodeRef
+import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
 
@@ -134,8 +135,8 @@ object PartialAggregation {
       // Only do partial aggregation if supported by all aggregate expressions.
       if (allAggregates.size == partialAggregates.size) {
         // Create a map of expressions to their partial evaluations for all 
aggregate expressions.
-        val partialEvaluations: Map[Long, SplitEvaluation] =
-          partialAggregates.map(a => (a.id, a.asPartial)).toMap
+        val partialEvaluations: Map[TreeNodeRef, SplitEvaluation] =
+          partialAggregates.map(a => (new TreeNodeRef(a), a.asPartial)).toMap
 
         // We need to pass all grouping expressions though so the grouping can 
happen a second
         // time. However some of them might be unnamed so we alias them 
allowing them to be
@@ -148,8 +149,8 @@ object PartialAggregation {
         // Replace aggregations with a new expression that computes the result 
from the already
         // computed partial evaluations and grouping values.
         val rewrittenAggregateExpressions = 
aggregateExpressions.map(_.transformUp {
-          case e: Expression if partialEvaluations.contains(e.id) =>
-            partialEvaluations(e.id).finalEvaluation
+          case e: Expression if partialEvaluations.contains(new 
TreeNodeRef(e)) =>
+            partialEvaluations(new TreeNodeRef(e)).finalEvaluation
           case e: Expression if namedGroupingExpressions.contains(e) =>
             namedGroupingExpressions(e).toAttribute
         }).asInstanceOf[Seq[NamedExpression]]

http://git-wip-us.apache.org/repos/asf/spark/blob/dc4d577c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index 1e177e2..af9e4d8 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -50,11 +50,11 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] 
extends TreeNode[PlanTy
 
     @inline def transformExpressionDown(e: Expression) = {
       val newE = e.transformDown(rule)
-      if (newE.id != e.id && newE != e) {
+      if (newE.fastEquals(e)) {
+        e
+      } else {
         changed = true
         newE
-      } else {
-        e
       }
     }
 
@@ -82,11 +82,11 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] 
extends TreeNode[PlanTy
 
     @inline def transformExpressionUp(e: Expression) = {
       val newE = e.transformUp(rule)
-      if (newE.id != e.id && newE != e) {
+      if (newE.fastEquals(e)) {
+        e
+      } else {
         changed = true
         newE
-      } else {
-        e
       }
     }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/dc4d577c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index 96ce359..2013ae4 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -19,11 +19,6 @@ package org.apache.spark.sql.catalyst.trees
 
 import org.apache.spark.sql.catalyst.errors._
 
-object TreeNode {
-  private val currentId = new java.util.concurrent.atomic.AtomicLong
-  protected def nextId() = currentId.getAndIncrement()
-}
-
 /** Used by [[TreeNode.getNodeNumbered]] when traversing the tree for a given 
number */
 private class MutableInt(var i: Int)
 
@@ -34,28 +29,12 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
   def children: Seq[BaseType]
 
   /**
-   * A globally unique id for this specific instance. Not preserved across 
copies.
-   * Unlike `equals`, `id` can be used to differentiate distinct but 
structurally
-   * identical branches of a tree.
-   */
-  val id = TreeNode.nextId()
-
-  /**
-   * Returns true if other is the same [[catalyst.trees.TreeNode TreeNode]] 
instance.  Unlike
-   * `equals` this function will return false for different instances of 
structurally identical
-   * trees.
-   */
-  def sameInstance(other: TreeNode[_]): Boolean = {
-    this.id == other.id
-  }
-
-  /**
    * Faster version of equality which short-circuits when two treeNodes are 
the same instance.
    * We don't just override Object.Equals, as doing so prevents the scala 
compiler from from
    * generating case class `equals` methods
    */
   def fastEquals(other: TreeNode[_]): Boolean = {
-    sameInstance(other) || this == other
+    this.eq(other) || this == other
   }
 
   /**
@@ -393,3 +372,4 @@ trait UnaryNode[BaseType <: TreeNode[BaseType]] {
   def child: BaseType
   def children = child :: Nil
 }
+

http://git-wip-us.apache.org/repos/asf/spark/blob/dc4d577c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala
index d725a92..79a8e06 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala
@@ -37,4 +37,15 @@ package object trees extends Logging {
   // Since we want tree nodes to be lightweight, we create one logger for all 
treenode instances.
   protected override def logName = "catalyst.trees"
 
+  /**
+   * A [[TreeNode]] companion for reference equality for Hash based Collection.
+   */
+  class TreeNodeRef(val obj: TreeNode[_]) {
+    override def equals(o: Any) = o match {
+      case that: TreeNodeRef => that.obj.eq(obj)
+      case _ => false
+    }
+
+    override def hashCode = if (obj == null) 0 else obj.hashCode
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/dc4d577c/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
index 2962025..036fd3f 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
@@ -51,7 +51,10 @@ class TreeNodeSuite extends FunSuite {
     val after = before transform { case Literal(5, _) => Literal(1)}
 
     assert(before === after)
-    assert(before.map(_.id) === after.map(_.id))
+    // Ensure that the objects after are the same objects before the 
transformation.
+    
before.map(identity[Expression]).zip(after.map(identity[Expression])).foreach {
+      case (b, a) => assert(b eq a)
+    }
   }
 
   test("collect") {

http://git-wip-us.apache.org/repos/asf/spark/blob/dc4d577c/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index 31ad5e8..b3edd50 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution
 
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.catalyst.trees._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.catalyst.types._
@@ -141,9 +142,10 @@ case class GeneratedAggregate(
 
     val computationSchema = computeFunctions.flatMap(_.schema)
 
-    val resultMap: Map[Long, Expression] = 
aggregatesToCompute.zip(computeFunctions).map {
-      case (agg, func) => agg.id -> func.result
-    }.toMap
+    val resultMap: Map[TreeNodeRef, Expression] = 
+      aggregatesToCompute.zip(computeFunctions).map {
+        case (agg, func) => new TreeNodeRef(agg) -> func.result
+      }.toMap
 
     val namedGroups = groupingExpressions.zipWithIndex.map {
       case (ne: NamedExpression, _) => (ne, ne)
@@ -156,7 +158,7 @@ case class GeneratedAggregate(
     // The set of expressions that produce the final output given the 
aggregation buffer and the
     // grouping expressions.
     val resultExpressions = aggregateExpressions.map(_.transform {
-      case e: Expression if resultMap.contains(e.id) => resultMap(e.id)
+      case e: Expression if resultMap.contains(new TreeNodeRef(e)) => 
resultMap(new TreeNodeRef(e))
       case e: Expression if groupMap.contains(e) => groupMap(e)
     })
 

http://git-wip-us.apache.org/repos/asf/spark/blob/dc4d577c/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index 5b896c5..8ff757b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -23,6 +23,7 @@ import org.apache.spark.{AccumulatorParam, Accumulator, 
SparkContext}
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.SparkContext._
 import org.apache.spark.sql.{SchemaRDD, Row}
+import org.apache.spark.sql.catalyst.trees.TreeNodeRef
 
 /**
  * :: DeveloperApi ::
@@ -43,10 +44,10 @@ package object debug {
   implicit class DebugQuery(query: SchemaRDD) {
     def debug(): Unit = {
       val plan = query.queryExecution.executedPlan
-      val visited = new collection.mutable.HashSet[Long]()
+      val visited = new collection.mutable.HashSet[TreeNodeRef]()
       val debugPlan = plan transform {
-        case s: SparkPlan if !visited.contains(s.id) =>
-          visited += s.id
+        case s: SparkPlan if !visited.contains(new TreeNodeRef(s)) =>
+          visited += new TreeNodeRef(s)
           DebugNode(s)
       }
       println(s"Results returned: ${debugPlan.execute().count()}")

http://git-wip-us.apache.org/repos/asf/spark/blob/dc4d577c/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
index aef6ebf..3dc8be2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -98,7 +98,7 @@ private[spark] object ExtractPythonUdfs extends 
Rule[LogicalPlan] {
         logical.Project(
           l.output,
           l.transformExpressions {
-            case p: PythonUDF if p.id == udf.id => evaluation.resultAttribute
+            case p: PythonUDF if p.fastEquals(udf) => 
evaluation.resultAttribute
           }.withNewChildren(newChildren))
       }
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to