HyukjinKwon commented on a change in pull request #33142:
URL: https://github.com/apache/spark/pull/33142#discussion_r667633052



##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
##########
@@ -170,65 +167,71 @@ class EquivalentExpressions {
       // can cause error like NPE.
       (expr.isInstanceOf[PlanExpression[_]] && TaskContext.get != null)
 
-    if (!skip && !addFunc(expr)) {
-      childrenToRecurse(expr).foreach(addExprTree(_, addFunc))
-      
commonChildrenToRecurse(expr).filter(_.nonEmpty).foreach(addCommonExprs(_, 
addFunc))
+    if (!skip && !addExprToMap(expr, map)) {
+      childrenToRecurse(expr).foreach(addExprTree(_, map))
+      
commonChildrenToRecurse(expr).filter(_.nonEmpty).foreach(addCommonExprs(_, map))
     }
   }
 
   /**
-   * Returns all of the expression trees that are equivalent to `e`. Returns
-   * an empty collection if there are none.
+   * Returns the state of the given expression in the `equivalenceMap`. 
Returns None if there is no
+   * equivalent expressions.
    */
-  def getEquivalentExprs(e: Expression): Seq[Expression] = {
-    equivalenceMap.getOrElse(Expr(e), Seq.empty).toSeq
+  def getExprState(e: Expression): Option[ExpressionStats] = {
+    equivalenceMap.get(ExpressionEquals(e))
+  }
+
+  // Exposed for testing.
+  private[sql] def getAllExprStates(count: Int = 0): Seq[ExpressionStats] = {
+    equivalenceMap.values.filter(_.useCount > count).toSeq.sortBy(_.height)
   }
 
   /**
-   * Returns all the equivalent sets of expressions which appear more than 
given `repeatTimes`
-   * times.
+   * Returns a sequence of expressions that more than one equivalent 
expressions.
    */
-  def getAllEquivalentExprs(repeatTimes: Int = 0): Seq[Seq[Expression]] = {
-    equivalenceMap.values.map(_.toSeq).filter(_.size > repeatTimes).toSeq
-      .sortBy(_.head)(new ExpressionContainmentOrdering)
+  def getCommonSubexpressions: Seq[Expression] = {
+    getAllExprStates(1).map(_.expr)
   }
 
   /**
    * Returns the state of the data structure as a string. If `all` is false, 
skips sets of
    * equivalent expressions with cardinality 1.
    */
   def debugString(all: Boolean = false): String = {
-    val sb: mutable.StringBuilder = new StringBuilder()
+    val sb = new java.lang.StringBuilder()
     sb.append("Equivalent expressions:\n")
-    equivalenceMap.foreach { case (k, v) =>
-      if (all || v.length > 1) {
-        sb.append("  " + v.mkString(", ")).append("\n")
-      }
+    equivalenceMap.values.filter(stats => all || stats.useCount > 1).foreach { 
stats =>
+      sb.append("  ").append(s"${stats.expr}: useCount = 
${stats.useCount}").append('\n')
     }
     sb.toString()
   }
 }
 
 /**
- * Orders `Expression` by parent/child relations. The child expression is 
smaller
- * than parent expression. If there is child-parent relationships among the 
subexpressions,
- * we want the child expressions come first than parent expressions, so we can 
replace
- * child expressions in parent expressions with subexpression evaluation. Note 
that
- * this is not for general expression ordering. For example, two irrelevant or 
semantically-equal
- * expressions will be considered as equal by this ordering. But for the usage 
here, the order of
- * irrelevant expressions does not matter.
+ * Wrapper around an Expression that provides semantic equality.
  */
-class ExpressionContainmentOrdering extends Ordering[Expression] {
-  override def compare(x: Expression, y: Expression): Int = {
-    if (x.find(_.semanticEquals(y)).isDefined) {
-      // `y` is child expression of `x`.
-      1
-    } else if (y.find(_.semanticEquals(x)).isDefined) {
-      // `x` is child expression of `y`.
-      -1
-    } else {
-      // Irrelevant or semantically-equal expressions
-      0
-    }
+case class ExpressionEquals(e: Expression) {
+  override def equals(o: Any): Boolean = o match {
+    case other: ExpressionEquals => e.semanticEquals(other.e)
+    case _ => false
+  }
+
+  override def hashCode: Int = e.semanticHash()
+}
+
+/**
+ * A wrapper in place of using Seq[Expression] to record a group of equivalent 
expressions.
+ *
+ * This saves a lot of memory when there are a lot of expressions in a same 
equivalence group.
+ * Instead of appending to a mutable list/buffer of Expressions, just update 
the "flattened"
+ * useCount in this wrapper in-place.
+ */
+case class ExpressionStats(expr: Expression)(var useCount: Int = 1) {
+  // This is used to do a fast pre-check for child-parent relationship. For 
example, expr1 can
+  // only be a parent of expr2 if expr1.height is larger than expr2.height.

Review comment:
       ohh it's correct! sorry for a false alarm




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to