Repository: spark
Updated Branches:
  refs/heads/branch-1.6 5ccc1eb08 -> f38509a76


[SPARK-10371][SQL] Implement subexpr elimination for UnsafeProjections

This patch adds the building blocks for codegening subexpr elimination and 
implements
it end to end for UnsafeProjection. The building blocks can be used to do the 
same thing
for other operators.

It introduces some utilities to compute common sub expressions. Expressions can 
be added to
this data structure. The expr and its children will be recursively matched 
against existing
expressions (ones previously added) and grouped into common groups. This is 
built using
the existing `semanticEquals`. It does not understand things like commutative 
or associative
expressions. This can be done as future work.

After building this data structure, the codegen process takes advantage of it 
by:
  1. Generating a helper function in the generated class that computes the 
common
     subexpression. This is done for all common subexpressions that have at 
least
     two occurrences and the expression tree is sufficiently complex.
  2. When generating the apply() function, if the helper function exists, call 
that
     instead of regenerating the expression tree. Repeated calls to the helper 
function
     shortcircuit the evaluation logic.

Author: Nong Li <n...@databricks.com>
Author: Nong Li <non...@gmail.com>

This patch had conflicts when merged, resolved by
Committer: Michael Armbrust <mich...@databricks.com>

Closes #9480 from nongli/spark-10371.

(cherry picked from commit 87aedc48c01dffbd880e6ca84076ed47c68f88d0)
Signed-off-by: Michael Armbrust <mich...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f38509a7
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f38509a7
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f38509a7

Branch: refs/heads/branch-1.6
Commit: f38509a763816f43a224653fe65e4645894c9fc4
Parents: 5ccc1eb
Author: Nong Li <n...@databricks.com>
Authored: Tue Nov 10 11:28:53 2015 -0800
Committer: Michael Armbrust <mich...@databricks.com>
Committed: Tue Nov 10 11:29:05 2015 -0800

----------------------------------------------------------------------
 .../expressions/EquivalentExpressions.scala     | 106 +++++++++++++
 .../sql/catalyst/expressions/Expression.scala   |  50 +++++-
 .../sql/catalyst/expressions/Projection.scala   |  16 ++
 .../expressions/codegen/CodeGenerator.scala     | 110 ++++++++++++-
 .../codegen/GenerateUnsafeProjection.scala      |  36 ++++-
 .../catalyst/expressions/namedExpressions.scala |   4 +
 .../SubexpressionEliminationSuite.scala         | 153 +++++++++++++++++++
 .../scala/org/apache/spark/sql/SQLConf.scala    |   8 +
 .../apache/spark/sql/execution/SparkPlan.scala  |   5 +
 .../spark/sql/execution/basicOperators.scala    |   3 +-
 .../org/apache/spark/sql/SQLQuerySuite.scala    |  48 ++++++
 11 files changed, 523 insertions(+), 16 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f38509a7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
new file mode 100644
index 0000000..e7380d2
--- /dev/null
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import scala.collection.mutable
+
+/**
+ * This class is used to compute equality of (sub)expression trees. 
Expressions can be added
+ * to this class and they subsequently query for expression equality. 
Expression trees are
+ * considered equal if for the same input(s), the same result is produced.
+ */
+class EquivalentExpressions {
+  /**
+   * Wrapper around an Expression that provides semantic equality.
+   */
+  case class Expr(e: Expression) {
+    val hash = e.semanticHash()
+    override def equals(o: Any): Boolean = o match {
+      case other: Expr => e.semanticEquals(other.e)
+      case _ => false
+    }
+    override def hashCode: Int = hash
+  }
+
+  // For each expression, the set of equivalent expressions.
+  private val equivalenceMap: mutable.HashMap[Expr, 
mutable.MutableList[Expression]] =
+      new mutable.HashMap[Expr, mutable.MutableList[Expression]]
+
+  /**
+   * Adds each expression to this data structure, grouping them with existing 
equivalent
+   * expressions. Non-recursive.
+   * Returns if there was already a matching expression.
+   */
+  def addExpr(expr: Expression): Boolean = {
+    if (expr.deterministic) {
+      val e: Expr = Expr(expr)
+      val f = equivalenceMap.get(e)
+      if (f.isDefined) {
+        f.get.+= (expr)
+        true
+      } else {
+        equivalenceMap.put(e, mutable.MutableList(expr))
+        false
+      }
+    } else {
+      false
+    }
+  }
+
+  /**
+   * Adds the expression to this datastructure recursively. Stops if a 
matching expression
+   * is found. That is, if `expr` has already been added, its children are not 
added.
+   * If ignoreLeaf is true, leaf nodes are ignored.
+   */
+  def addExprTree(root: Expression, ignoreLeaf: Boolean = true): Unit = {
+    val skip = root.isInstanceOf[LeafExpression] && ignoreLeaf
+    if (!skip && root.deterministic && !addExpr(root)) {
+     root.children.foreach(addExprTree(_, ignoreLeaf))
+    }
+  }
+
+  /**
+   * Returns all fo the expression trees that are equivalent to `e`. Returns
+   * an empty collection if there are none.
+   */
+  def getEquivalentExprs(e: Expression): Seq[Expression] = {
+    equivalenceMap.get(Expr(e)).getOrElse(mutable.MutableList())
+  }
+
+  /**
+   * Returns all the equivalent sets of expressions.
+   */
+  def getAllEquivalentExprs: Seq[Seq[Expression]] = {
+    equivalenceMap.values.map(_.toSeq).toSeq
+  }
+
+  /**
+   * Returns the state of the datastructure as a string. If all is false, 
skips sets of equivalent
+   * expressions with cardinality 1.
+   */
+  def debugString(all: Boolean = false): String = {
+    val sb: mutable.StringBuilder = new StringBuilder()
+    sb.append("Equivalent expressions:\n")
+    equivalenceMap.foreach { case (k, v) => {
+      if (all || v.length > 1) {
+        sb.append("  " + v.mkString(", ")).append("\n")
+      }
+    }}
+    sb.toString()
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f38509a7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 96fcc79..7d5741e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -92,12 +92,24 @@ abstract class Expression extends TreeNode[Expression] {
    * @return [[GeneratedExpressionCode]]
    */
   def gen(ctx: CodeGenContext): GeneratedExpressionCode = {
-    val isNull = ctx.freshName("isNull")
-    val primitive = ctx.freshName("primitive")
-    val ve = GeneratedExpressionCode("", isNull, primitive)
-    ve.code = genCode(ctx, ve)
-    // Add `this` in the comment.
-    ve.copy(s"/* $this */\n" + ve.code)
+    val subExprState = ctx.subExprEliminationExprs.get(this)
+    if (subExprState.isDefined) {
+      // This expression is repeated meaning the code to evaluated has already 
been added
+      // as a function, `subExprState.fnName`. Just call that.
+      val code =
+        s"""
+           |/* $this */
+           |${subExprState.get.fnName}(${ctx.INPUT_ROW});
+           |""".stripMargin.trim
+      GeneratedExpressionCode(code, subExprState.get.code.isNull, 
subExprState.get.code.value)
+    } else {
+      val isNull = ctx.freshName("isNull")
+      val primitive = ctx.freshName("primitive")
+      val ve = GeneratedExpressionCode("", isNull, primitive)
+      ve.code = genCode(ctx, ve)
+      // Add `this` in the comment.
+      ve.copy(s"/* $this */\n" + ve.code.trim)
+    }
   }
 
   /**
@@ -145,12 +157,38 @@ abstract class Expression extends TreeNode[Expression] {
         case (i1, i2) => i1 == i2
       }
     }
+    // Non-determinstic expressions cannot be equal
+    if (!deterministic || !other.deterministic) return false
     val elements1 = this.productIterator.toSeq
     val elements2 = other.asInstanceOf[Product].productIterator.toSeq
     checkSemantic(elements1, elements2)
   }
 
   /**
+   * Returns the hash for this expression. Expressions that compute the same 
result, even if
+   * they differ cosmetically should return the same hash.
+   */
+  def semanticHash() : Int = {
+    def computeHash(e: Seq[Any]): Int = {
+      // See http://stackoverflow.com/questions/113511/hash-code-implementation
+      var hash: Int = 17
+      e.foreach(i => {
+        val h: Int = i match {
+          case (e: Expression) => e.semanticHash()
+          case (Some(e: Expression)) => e.semanticHash()
+          case (t: Traversable[_]) => computeHash(t.toSeq)
+          case null => 0
+          case (o) => o.hashCode()
+        }
+        hash = hash * 37 + h
+      })
+      hash
+    }
+
+    computeHash(this.productIterator.toSeq)
+  }
+
+  /**
    * Checks the input data types, returns `TypeCheckResult.success` if it's 
valid,
    * or returns a `TypeCheckResult` with an error message if invalid.
    * Note: it's not valid to call this method until `childrenResolved == true`.

http://git-wip-us.apache.org/repos/asf/spark/blob/f38509a7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index 79dabe8..9f0b782 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -144,6 +144,22 @@ object UnsafeProjection {
   def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): 
UnsafeProjection = {
     create(exprs.map(BindReferences.bindReference(_, inputSchema)))
   }
+
+  /**
+    * Same as other create()'s but allowing enabling/disabling subexpression 
elimination.
+    * TODO: refactor the plumbing and clean this up.
+    */
+  def create(
+      exprs: Seq[Expression],
+      inputSchema: Seq[Attribute],
+      subexpressionEliminationEnabled: Boolean): UnsafeProjection = {
+    val e = exprs.map(BindReferences.bindReference(_, inputSchema))
+      .map(_ transform {
+        case CreateStruct(children) => CreateStructUnsafe(children)
+        case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
+    })
+    GenerateUnsafeProjection.generate(e, subexpressionEliminationEnabled)
+  }
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/f38509a7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index f0f7a6c..60a3d60 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -92,6 +92,33 @@ class CodeGenContext {
     addedFunctions += ((funcName, funcCode))
   }
 
+  /**
+   * Holds expressions that are equivalent. Used to perform subexpression 
elimination
+   * during codegen.
+   *
+   * For expressions that appear more than once, generate additional code to 
prevent
+   * recomputing the value.
+   *
+   * For example, consider two exprsesion generated from this SQL statement:
+   *  SELECT (col1 + col2), (col1 + col2) / col3.
+   *
+   *  equivalentExpressions will match the tree containing `col1 + col2` and 
it will only
+   *  be evaluated once.
+   */
+  val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions
+
+  // State used for subexpression elimination.
+  case class SubExprEliminationState(
+    val isLoaded: String, code: GeneratedExpressionCode, val fnName: String)
+
+  // Foreach expression that is participating in subexpression elimination, 
the state to use.
+  val subExprEliminationExprs: mutable.HashMap[Expression, 
SubExprEliminationState] =
+    mutable.HashMap[Expression, SubExprEliminationState]()
+
+  // The collection of isLoaded variables that need to be reset on each row.
+  val subExprIsLoadedVariables: mutable.ArrayBuffer[String] =
+    mutable.ArrayBuffer.empty[String]
+
   final val JAVA_BOOLEAN = "boolean"
   final val JAVA_BYTE = "byte"
   final val JAVA_SHORT = "short"
@@ -317,6 +344,87 @@ class CodeGenContext {
       functions.map(name => s"$name($row);").mkString("\n")
     }
   }
+
+  /**
+   * Checks and sets up the state and codegen for subexpression elimination. 
This finds the
+   * common subexpresses, generates the functions that evaluate those 
expressions and populates
+   * the mapping of common subexpressions to the generated functions.
+   */
+  private def subexpressionElimination(expressions: Seq[Expression]) = {
+    // Add each expression tree and compute the common subexpressions.
+    expressions.foreach(equivalentExpressions.addExprTree(_))
+
+    // Get all the exprs that appear at least twice and set up the state for 
subexpression
+    // elimination.
+    val commonExprs = 
equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1)
+    commonExprs.foreach(e => {
+      val expr = e.head
+      val isLoaded = freshName("isLoaded")
+      val isNull = freshName("isNull")
+      val primitive = freshName("primitive")
+      val fnName = freshName("evalExpr")
+
+      // Generate the code for this expression tree and wrap it in a function.
+      val code = expr.gen(this)
+      val fn =
+        s"""
+           |private void $fnName(InternalRow ${INPUT_ROW}) {
+           |  if (!$isLoaded) {
+           |    ${code.code.trim}
+           |    $isLoaded = true;
+           |    $isNull = ${code.isNull};
+           |    $primitive = ${code.value};
+           |  }
+           |}
+           """.stripMargin
+      code.code = fn
+      code.isNull = isNull
+      code.value = primitive
+
+      addNewFunction(fnName, fn)
+
+      // Add a state and a mapping of the common subexpressions that are 
associate with this
+      // state. Adding this expression to subExprEliminationExprMap means it 
will call `fn`
+      // when it is code generated. This decision should be a cost based one.
+      //
+      // The cost of doing subexpression elimination is:
+      //   1. Extra function call, although this is probably *good* as the JIT 
can decide to
+      //      inline or not.
+      //   2. Extra branch to check isLoaded. This branch is likely to be 
predicted correctly
+      //      very often. The reason it is not loaded is because of a prior 
branch.
+      //   3. Extra store into isLoaded.
+      // The benefit doing subexpression elimination is:
+      //   1. Running the expression logic. Even for a simple expression, it 
is likely more than 3
+      //      above.
+      //   2. Less code.
+      // Currently, we will do this for all non-leaf only expression trees 
(i.e. expr trees with
+      // at least two nodes) as the cost of doing it is expected to be low.
+
+      // Maintain the loaded value and isNull as member variables. This is 
necessary if the codegen
+      // function is split across multiple functions.
+      // TODO: maintaining this as a local variable probably allows the 
compiler to do better
+      // optimizations.
+      addMutableState("boolean", isLoaded, s"$isLoaded = false;")
+      addMutableState("boolean", isNull, s"$isNull = false;")
+      addMutableState(javaType(expr.dataType), primitive,
+        s"$primitive = ${defaultValue(expr.dataType)};")
+      subExprIsLoadedVariables += isLoaded
+
+      val state = SubExprEliminationState(isLoaded, code, fnName)
+      e.foreach(subExprEliminationExprs.put(_, state))
+    })
+  }
+
+  /**
+   * Generates code for expressions. If doSubexpressionElimination is true, 
subexpression
+   * elimination will be performed. Subexpression elimination assumes that the 
code will for each
+   * expression will be combined in the `expressions` order.
+   */
+  def generateExpressions(expressions: Seq[Expression],
+      doSubexpressionElimination: Boolean = false): 
Seq[GeneratedExpressionCode] = {
+    if (doSubexpressionElimination) subexpressionElimination(expressions)
+    expressions.map(e => e.gen(this))
+  }
 }
 
 /**
@@ -349,7 +457,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: 
AnyRef] extends Loggin
   }
 
   protected def declareAddedFunctions(ctx: CodeGenContext): String = {
-    ctx.addedFunctions.map { case (funcName, funcCode) => funcCode 
}.mkString("\n")
+    ctx.addedFunctions.map { case (funcName, funcCode) => funcCode 
}.mkString("\n").trim
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/f38509a7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 2136f82..9ef2261 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -139,9 +139,9 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
         s"""
           ${input.code}
           if (${input.isNull}) {
-            $setNull
+            ${setNull.trim}
           } else {
-            $writeField
+            ${writeField.trim}
           }
         """
     }
@@ -149,7 +149,7 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
     s"""
       $rowWriter.initialize($bufferHolder, ${inputs.length});
       ${ctx.splitExpressions(row, writeFields)}
-    """
+    """.trim
   }
 
   // TODO: if the nullability of array element is correct, we can use it to 
save null check.
@@ -275,8 +275,11 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
     """
   }
 
-  def createCode(ctx: CodeGenContext, expressions: Seq[Expression]): 
GeneratedExpressionCode = {
-    val exprEvals = expressions.map(e => e.gen(ctx))
+  def createCode(
+      ctx: CodeGenContext,
+      expressions: Seq[Expression],
+      useSubexprElimination: Boolean = false): GeneratedExpressionCode = {
+    val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination)
     val exprTypes = expressions.map(_.dataType)
 
     val result = ctx.freshName("result")
@@ -285,10 +288,15 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
     val holderClass = classOf[BufferHolder].getName
     ctx.addMutableState(holderClass, bufferHolder, s"this.$bufferHolder = new 
$holderClass();")
 
+    // Reset the isLoaded flag for each row.
+    val subexprReset = ctx.subExprIsLoadedVariables.map { v => s"${v} = 
false;" }.mkString("\n")
+
     val code =
       s"""
         $bufferHolder.reset();
+        $subexprReset
         ${writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, 
bufferHolder)}
+
         $result.pointTo($bufferHolder.buffer, ${expressions.length}, 
$bufferHolder.totalSize());
       """
     GeneratedExpressionCode(code, "false", result)
@@ -300,10 +308,21 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
   protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): 
Seq[Expression] =
     in.map(BindReferences.bindReference(_, inputSchema))
 
+  def generate(
+    expressions: Seq[Expression],
+    subexpressionEliminationEnabled: Boolean): UnsafeProjection = {
+    create(canonicalize(expressions), subexpressionEliminationEnabled)
+  }
+
   protected def create(expressions: Seq[Expression]): UnsafeProjection = {
-    val ctx = newCodeGenContext()
+    create(expressions, false)
+  }
 
-    val eval = createCode(ctx, expressions)
+  private def create(
+      expressions: Seq[Expression],
+      subexpressionEliminationEnabled: Boolean): UnsafeProjection = {
+    val ctx = newCodeGenContext()
+    val eval = createCode(ctx, expressions, subexpressionEliminationEnabled)
 
     val code = s"""
       public Object generate($exprType[] exprs) {
@@ -315,6 +334,7 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
         private $exprType[] expressions;
 
         ${declareMutableStates(ctx)}
+
         ${declareAddedFunctions(ctx)}
 
         public SpecificUnsafeProjection($exprType[] expressions) {
@@ -328,7 +348,7 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
         }
 
         public UnsafeRow apply(InternalRow ${ctx.INPUT_ROW}) {
-          ${eval.code}
+          ${eval.code.trim}
           return ${eval.value};
         }
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/f38509a7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 9ab5c29..f80bcfc 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -203,6 +203,10 @@ case class AttributeReference(
     case _ => false
   }
 
+  override def semanticHash(): Int = {
+    this.exprId.hashCode()
+  }
+
   override def hashCode: Int = {
     // See http://stackoverflow.com/questions/113511/hash-code-implementation
     var h = 17

http://git-wip-us.apache.org/repos/asf/spark/blob/f38509a7/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
new file mode 100644
index 0000000..9de066e
--- /dev/null
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.types.IntegerType
+
+class SubexpressionEliminationSuite extends SparkFunSuite {
+  test("Semantic equals and hash") {
+    val id = ExprId(1)
+    val a: AttributeReference = AttributeReference("name", IntegerType)()
+    val b1 = a.withName("name2").withExprId(id)
+    val b2 = a.withExprId(id)
+
+    assert(b1 != b2)
+    assert(a != b1)
+    assert(b1.semanticEquals(b2))
+    assert(!b1.semanticEquals(a))
+    assert(a.hashCode != b1.hashCode)
+    assert(b1.hashCode == b2.hashCode)
+    assert(b1.semanticHash() == b2.semanticHash())
+  }
+
+  test("Expression Equivalence - basic") {
+    val equivalence = new EquivalentExpressions
+    assert(equivalence.getAllEquivalentExprs.isEmpty)
+
+    val oneA = Literal(1)
+    val oneB = Literal(1)
+    val twoA = Literal(2)
+    var twoB = Literal(2)
+
+    assert(equivalence.getEquivalentExprs(oneA).isEmpty)
+    assert(equivalence.getEquivalentExprs(twoA).isEmpty)
+
+    // Add oneA and test if it is returned. Since it is a group of one, it 
does not.
+    assert(!equivalence.addExpr(oneA))
+    assert(equivalence.getEquivalentExprs(oneA).size == 1)
+    assert(equivalence.getEquivalentExprs(twoA).isEmpty)
+    assert(equivalence.addExpr((oneA)))
+    assert(equivalence.getEquivalentExprs(oneA).size == 2)
+
+    // Add B and make sure they can see each other.
+    assert(equivalence.addExpr(oneB))
+    // Use exists and reference equality because of how equals is defined.
+    assert(equivalence.getEquivalentExprs(oneA).exists(_ eq oneB))
+    assert(equivalence.getEquivalentExprs(oneA).exists(_ eq oneA))
+    assert(equivalence.getEquivalentExprs(oneB).exists(_ eq oneA))
+    assert(equivalence.getEquivalentExprs(oneB).exists(_ eq oneB))
+    assert(equivalence.getEquivalentExprs(twoA).isEmpty)
+    assert(equivalence.getAllEquivalentExprs.size == 1)
+    assert(equivalence.getAllEquivalentExprs.head.size == 3)
+    assert(equivalence.getAllEquivalentExprs.head.contains(oneA))
+    assert(equivalence.getAllEquivalentExprs.head.contains(oneB))
+
+    val add1 = Add(oneA, oneB)
+    val add2 = Add(oneA, oneB)
+
+    equivalence.addExpr(add1)
+    equivalence.addExpr(add2)
+
+    assert(equivalence.getAllEquivalentExprs.size == 2)
+    assert(equivalence.getEquivalentExprs(add2).exists(_ eq add1))
+    assert(equivalence.getEquivalentExprs(add2).size == 2)
+    assert(equivalence.getEquivalentExprs(add1).exists(_ eq add2))
+  }
+
+  test("Expression Equivalence - Trees") {
+    val one = Literal(1)
+    val two = Literal(2)
+
+    val add = Add(one, two)
+    val abs = Abs(add)
+    val add2 = Add(add, add)
+
+    var equivalence = new EquivalentExpressions
+    equivalence.addExprTree(add, true)
+    equivalence.addExprTree(abs, true)
+    equivalence.addExprTree(add2, true)
+
+    // Should only have one equivalence for `one + two`
+    assert(equivalence.getAllEquivalentExprs.filter(_.size > 1).size == 1)
+    assert(equivalence.getAllEquivalentExprs.filter(_.size > 1).head.size == 4)
+
+    // Set up the expressions
+    //   one * two,
+    //   (one * two) * (one * two)
+    //   sqrt( (one * two) * (one * two) )
+    //   (one * two) + sqrt( (one * two) * (one * two) )
+    equivalence = new EquivalentExpressions
+    val mul = Multiply(one, two)
+    val mul2 = Multiply(mul, mul)
+    val sqrt = Sqrt(mul2)
+    val sum = Add(mul2, sqrt)
+    equivalence.addExprTree(mul, true)
+    equivalence.addExprTree(mul2, true)
+    equivalence.addExprTree(sqrt, true)
+    equivalence.addExprTree(sum, true)
+
+    // (one * two), (one * two) * (one * two) and sqrt( (one * two) * (one * 
two) ) should be found
+    assert(equivalence.getAllEquivalentExprs.filter(_.size > 1).size == 3)
+    assert(equivalence.getEquivalentExprs(mul).size == 3)
+    assert(equivalence.getEquivalentExprs(mul2).size == 3)
+    assert(equivalence.getEquivalentExprs(sqrt).size == 2)
+    assert(equivalence.getEquivalentExprs(sum).size == 1)
+
+    // Some expressions inspired by TPCH-Q1
+    // sum(l_quantity) as sum_qty,
+    // sum(l_extendedprice) as sum_base_price,
+    // sum(l_extendedprice * (1 - l_discount)) as sum_disc_price,
+    // sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge,
+    // avg(l_extendedprice) as avg_price,
+    // avg(l_discount) as avg_disc
+    equivalence = new EquivalentExpressions
+    val quantity = Literal(1)
+    val price = Literal(1.1)
+    val discount = Literal(.24)
+    val tax = Literal(0.1)
+    equivalence.addExprTree(quantity, false)
+    equivalence.addExprTree(price, false)
+    equivalence.addExprTree(Multiply(price, Subtract(Literal(1), discount)), 
false)
+    equivalence.addExprTree(
+      Multiply(
+        Multiply(price, Subtract(Literal(1), discount)),
+        Add(Literal(1), tax)), false)
+    equivalence.addExprTree(price, false)
+    equivalence.addExprTree(discount, false)
+    // quantity, price, discount and (price * (1 - discount))
+    assert(equivalence.getAllEquivalentExprs.filter(_.size > 1).size == 4)
+  }
+
+  test("Expression equivalence - non deterministic") {
+    val sum = Add(Rand(0), Rand(0))
+    val equivalence = new EquivalentExpressions
+    equivalence.addExpr(sum)
+    equivalence.addExpr(sum)
+    assert(equivalence.getAllEquivalentExprs.isEmpty)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f38509a7/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index b731418..89e196c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -268,6 +268,11 @@ private[spark] object SQLConf {
     doc = "When true, use the new optimized Tungsten physical execution 
backend.",
     isPublic = false)
 
+  val SUBEXPRESSION_ELIMINATION_ENABLED = 
booleanConf("spark.sql.subexpressionElimination.enabled",
+    defaultValue = Some(true),  // use CODEGEN_ENABLED as default
+    doc = "When true, common subexpressions will be eliminated.",
+    isPublic = false)
+
   val DIALECT = stringConf(
     "spark.sql.dialect",
     defaultValue = Some("sql"),
@@ -541,6 +546,9 @@ private[sql] class SQLConf extends Serializable with 
CatalystConf {
 
   private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED, 
getConf(TUNGSTEN_ENABLED))
 
+  private[spark] def subexpressionEliminationEnabled: Boolean =
+    getConf(SUBEXPRESSION_ELIMINATION_ENABLED, codegenEnabled)
+
   private[spark] def autoBroadcastJoinThreshold: Int = 
getConf(AUTO_BROADCASTJOIN_THRESHOLD)
 
   private[spark] def defaultSizeInBytes: Long =

http://git-wip-us.apache.org/repos/asf/spark/blob/f38509a7/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 8bb293a..8650ac5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -66,6 +66,11 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with 
Logging with Serializ
   } else {
     false
   }
+  val subexpressionEliminationEnabled: Boolean = if (sqlContext != null) {
+    sqlContext.conf.subexpressionEliminationEnabled
+  } else {
+    false
+  }
 
   /**
    * Whether the "prepare" method is called.

http://git-wip-us.apache.org/repos/asf/spark/blob/f38509a7/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 145de0d..303d636 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -70,7 +70,8 @@ case class TungstenProject(projectList: Seq[NamedExpression], 
child: SparkPlan)
   protected override def doExecute(): RDD[InternalRow] = {
     val numRows = longMetric("numRows")
     child.execute().mapPartitions { iter =>
-      val project = UnsafeProjection.create(projectList, child.output)
+      val project = UnsafeProjection.create(projectList, child.output,
+        subexpressionEliminationEnabled)
       iter.map { row =>
         numRows += 1
         project(row)

http://git-wip-us.apache.org/repos/asf/spark/blob/f38509a7/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 441a0c6..19e850a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1970,4 +1970,52 @@ class SQLQuerySuite extends QueryTest with 
SharedSQLContext {
         Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: 
Row(3, 2) :: Nil)
     }
   }
+
+  test("Common subexpression elimination") {
+    // select from a table to prevent constant folding.
+    val df = sql("SELECT a, b from testData2 limit 1")
+    checkAnswer(df, Row(1, 1))
+
+    checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2))
+    checkAnswer(df.selectExpr("a + 1", "a + 1 + 1"), Row(2, 3))
+
+    // This does not work because the expressions get grouped like (a + a) + 1
+    checkAnswer(df.selectExpr("a + 1", "a + a + 1"), Row(2, 3))
+    checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3))
+
+    // Identity udf that tracks the number of times it is called.
+    val countAcc = sparkContext.accumulator(0, "CallCount")
+    sqlContext.udf.register("testUdf", (x: Int) => {
+      countAcc.++=(1)
+      x
+    })
+
+    // Evaluates df, verifying it is equal to the expectedResult and the 
accumulator's value
+    // is correct.
+    def verifyCallCount(df: DataFrame, expectedResult: Row, expectedCount: 
Int): Unit = {
+      countAcc.setValue(0)
+      checkAnswer(df, expectedResult)
+      assert(countAcc.value == expectedCount)
+    }
+
+    verifyCallCount(df.selectExpr("testUdf(a)"), Row(1), 1)
+    verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1)
+    verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a + 1)"), Row(2, 
2), 1)
+    verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a)"), Row(2, 1), 
2)
+    verifyCallCount(
+      df.selectExpr("testUdf(a + 1) + testUdf(a + 1)", "testUdf(a + 1)"), 
Row(4, 2), 1)
+
+    verifyCallCount(
+      df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), 
Row(4, 2), 2)
+
+    // Would be nice if semantic equals for `+` understood commutative
+    verifyCallCount(
+      df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), 
Row(4, 2), 2)
+
+    // Try disabling it via configuration.
+    sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false")
+    verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2)
+    sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true")
+    verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to