Repository: spark
Updated Branches:
  refs/heads/master 6a47114bc -> 4e42842e8


[SPARK-8164] transformExpressions should support nested expression sequence

Currently we only support `Seq[Expression]`, we should handle cases like 
`Seq[Seq[Expression]]` so that we can remove the unnecessary `GroupExpression`.

Author: Wenchen Fan <cloud0...@outlook.com>

Closes #6706 from cloud-fan/clean and squashes the following commits:

60a1193 [Wenchen Fan] support nested expression sequence and remove 
GroupExpression


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4e42842e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4e42842e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4e42842e

Branch: refs/heads/master
Commit: 4e42842e82e058d54329bd66185d8a7e77ab335a
Parents: 6a47114
Author: Wenchen Fan <cloud0...@outlook.com>
Authored: Wed Jun 10 18:22:47 2015 -0700
Committer: Reynold Xin <r...@databricks.com>
Committed: Wed Jun 10 18:22:47 2015 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  |  6 +++---
 .../sql/catalyst/expressions/Expression.scala   | 12 -----------
 .../spark/sql/catalyst/plans/QueryPlan.scala    | 22 +++++++++-----------
 .../catalyst/plans/logical/basicOperators.scala |  2 +-
 .../sql/catalyst/trees/TreeNodeSuite.scala      | 14 +++++++++++++
 .../org/apache/spark/sql/execution/Expand.scala |  4 ++--
 6 files changed, 30 insertions(+), 30 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4e42842e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index c4f12cf..cbd8def 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -172,8 +172,8 @@ class Analyzer(
      * expressions which equal GroupBy expressions with Literal(null), if 
those expressions
      * are not set for this grouping set (according to the bit mask).
      */
-    private[this] def expand(g: GroupingSets): Seq[GroupExpression] = {
-      val result = new scala.collection.mutable.ArrayBuffer[GroupExpression]
+    private[this] def expand(g: GroupingSets): Seq[Seq[Expression]] = {
+      val result = new scala.collection.mutable.ArrayBuffer[Seq[Expression]]
 
       g.bitmasks.foreach { bitmask =>
         // get the non selected grouping attributes according to the bit mask
@@ -194,7 +194,7 @@ class Analyzer(
             Literal.create(bitmask, IntegerType)
         })
 
-        result += GroupExpression(substitution)
+        result += substitution
       }
 
       result.toSeq

http://git-wip-us.apache.org/repos/asf/spark/blob/4e42842e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index a05794f..63dd5f9 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -239,18 +239,6 @@ abstract class UnaryExpression extends Expression with 
trees.UnaryNode[Expressio
   }
 }
 
-// TODO Semantically we probably not need GroupExpression
-// All we need is holding the Seq[Expression], and ONLY used in doing the
-// expressions transformation correctly. Probably will be removed since it's
-// not like a real expressions.
-case class GroupExpression(children: Seq[Expression]) extends Expression {
-  self: Product =>
-  override def eval(input: Row): Any = throw new UnsupportedOperationException
-  override def nullable: Boolean = false
-  override def foldable: Boolean = false
-  override def dataType: DataType = throw new UnsupportedOperationException
-}
-
 /**
  * Expressions that require a specific `DataType` as input should implement 
this trait
  * so that the proper type conversions can be performed in the analyzer.

http://git-wip-us.apache.org/repos/asf/spark/blob/4e42842e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index eff5c61..2f545bb 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -81,17 +81,16 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] 
extends TreeNode[PlanTy
       }
     }
 
-    val newArgs = productIterator.map {
+    def recursiveTransform(arg: Any): AnyRef = arg match {
       case e: Expression => transformExpressionDown(e)
       case Some(e: Expression) => Some(transformExpressionDown(e))
       case m: Map[_, _] => m
       case d: DataType => d // Avoid unpacking Structs
-      case seq: Traversable[_] => seq.map {
-        case e: Expression => transformExpressionDown(e)
-        case other => other
-      }
+      case seq: Traversable[_] => seq.map(recursiveTransform)
       case other: AnyRef => other
-    }.toArray
+    }
+
+    val newArgs = productIterator.map(recursiveTransform).toArray
 
     if (changed) makeCopy(newArgs) else this
   }
@@ -114,17 +113,16 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] 
extends TreeNode[PlanTy
       }
     }
 
-    val newArgs = productIterator.map {
+    def recursiveTransform(arg: Any): AnyRef = arg match {
       case e: Expression => transformExpressionUp(e)
       case Some(e: Expression) => Some(transformExpressionUp(e))
       case m: Map[_, _] => m
       case d: DataType => d // Avoid unpacking Structs
-      case seq: Traversable[_] => seq.map {
-        case e: Expression => transformExpressionUp(e)
-        case other => other
-      }
+      case seq: Traversable[_] => seq.map(recursiveTransform)
       case other: AnyRef => other
-    }.toArray
+    }
+
+    val newArgs = productIterator.map(recursiveTransform).toArray
 
     if (changed) makeCopy(newArgs) else this
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/4e42842e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index e77e5c2..963c782 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -226,7 +226,7 @@ case class Window(
  * @param child       Child operator
  */
 case class Expand(
-    projections: Seq[GroupExpression],
+    projections: Seq[Seq[Expression]],
     output: Seq[Attribute],
     child: LogicalPlan) extends UnaryNode {
   override def statistics: Statistics = {

http://git-wip-us.apache.org/repos/asf/spark/blob/4e42842e/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
index 67db3d5..8ec79c3 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
@@ -31,6 +31,11 @@ case class Dummy(optKey: Option[Expression]) extends 
Expression {
   override def eval(input: Row): Any = null.asInstanceOf[Any]
 }
 
+case class ComplexPlan(exprs: Seq[Seq[Expression]])
+  extends org.apache.spark.sql.catalyst.plans.logical.LeafNode {
+  override def output: Seq[Attribute] = Nil
+}
+
 class TreeNodeSuite extends SparkFunSuite {
   test("top node changed") {
     val after = Literal(1) transform { case Literal(1, _) => Literal(2) }
@@ -220,4 +225,13 @@ class TreeNodeSuite extends SparkFunSuite {
       assert(expected === actual)
     }
   }
+
+  test("transformExpressions on nested expression sequence") {
+    val plan = ComplexPlan(Seq(Seq(Literal(1)), Seq(Literal(2))))
+    val actual = plan.transformExpressions {
+      case Literal(value, _) => Literal(value.toString)
+    }
+    val expected = ComplexPlan(Seq(Seq(Literal("1")), Seq(Literal("2"))))
+    assert(expected === actual)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/4e42842e/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
index f16ca36..4b601c1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
@@ -34,7 +34,7 @@ import 
org.apache.spark.sql.catalyst.plans.physical.{UnknownPartitioning, Partit
  */
 @DeveloperApi
 case class Expand(
-    projections: Seq[GroupExpression],
+    projections: Seq[Seq[Expression]],
     output: Seq[Attribute],
     child: SparkPlan)
   extends UnaryNode {
@@ -49,7 +49,7 @@ case class Expand(
       // workers via closure. However we can't assume the Projection
       // is serializable because of the code gen, so we have to
       // create the projections within each of the partition processing.
-      val groups = projections.map(ee => newProjection(ee.children, 
child.output)).toArray
+      val groups = projections.map(ee => newProjection(ee, 
child.output)).toArray
 
       new Iterator[Row] {
         private[this] var result: Row = _


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to