Repository: spark
Updated Branches:
  refs/heads/master 5c9117a3e -> 661c21049


[SPARK-15381] [SQL] physical object operator should define reference correctly

## What changes were proposed in this pull request?

Whole Stage Codegen depends on `SparkPlan.reference` to do some optimization. 
For physical object operators, they should be consistent with their logical 
version and set the `reference` correctly.

## How was this patch tested?

new test in DatasetSuite

Author: Wenchen Fan <wenc...@databricks.com>

Closes #13167 from cloud-fan/bug.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/661c2104
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/661c2104
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/661c2104

Branch: refs/heads/master
Commit: 661c21049b62ebfaf788dcbc31d62a09e206265b
Parents: 5c9117a
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Wed May 18 21:43:07 2016 -0700
Committer: Davies Liu <davies....@gmail.com>
Committed: Wed May 18 21:43:07 2016 -0700

----------------------------------------------------------------------
 .../sql/catalyst/plans/logical/object.scala     | 10 +--
 .../spark/sql/execution/SparkStrategies.scala   |  2 +-
 .../apache/spark/sql/execution/objects.scala    | 94 +++++++++++---------
 .../org/apache/spark/sql/DatasetSuite.scala     |  5 ++
 4 files changed, 64 insertions(+), 47 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/661c2104/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index 84339f4..98ce5dd 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -94,7 +94,7 @@ case class DeserializeToObject(
  */
 case class SerializeFromObject(
     serializer: Seq[NamedExpression],
-    child: LogicalPlan) extends UnaryNode with ObjectConsumer {
+    child: LogicalPlan) extends ObjectConsumer {
 
   override def output: Seq[Attribute] = serializer.map(_.toAttribute)
 }
@@ -118,7 +118,7 @@ object MapPartitions {
 case class MapPartitions(
     func: Iterator[Any] => Iterator[Any],
     outputObjAttr: Attribute,
-    child: LogicalPlan) extends UnaryNode with ObjectConsumer with 
ObjectProducer
+    child: LogicalPlan) extends ObjectConsumer with ObjectProducer
 
 object MapPartitionsInR {
   def apply(
@@ -152,7 +152,7 @@ case class MapPartitionsInR(
     inputSchema: StructType,
     outputSchema: StructType,
     outputObjAttr: Attribute,
-    child: LogicalPlan) extends UnaryNode with ObjectConsumer with 
ObjectProducer {
+    child: LogicalPlan) extends ObjectConsumer with ObjectProducer {
   override lazy val schema = outputSchema
 }
 
@@ -175,7 +175,7 @@ object MapElements {
 case class MapElements(
     func: AnyRef,
     outputObjAttr: Attribute,
-    child: LogicalPlan) extends UnaryNode with ObjectConsumer with 
ObjectProducer
+    child: LogicalPlan) extends ObjectConsumer with ObjectProducer
 
 /** Factory for constructing new `AppendColumn` nodes. */
 object AppendColumns {
@@ -215,7 +215,7 @@ case class AppendColumnsWithObject(
     func: Any => Any,
     childSerializer: Seq[NamedExpression],
     newColumnsSerializer: Seq[NamedExpression],
-    child: LogicalPlan) extends UnaryNode with ObjectConsumer {
+    child: LogicalPlan) extends ObjectConsumer {
 
   override def output: Seq[Attribute] = (childSerializer ++ 
newColumnsSerializer).map(_.toAttribute)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/661c2104/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index faf359f..5cfb6d5 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -303,7 +303,7 @@ private[sql] abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
           "logical except operator should have been replaced by anti-join in 
the optimizer")
 
       case logical.DeserializeToObject(deserializer, objAttr, child) =>
-        execution.DeserializeToObject(deserializer, objAttr, planLater(child)) 
:: Nil
+        execution.DeserializeToObjectExec(deserializer, objAttr, 
planLater(child)) :: Nil
       case logical.SerializeFromObject(serializer, child) =>
         execution.SerializeFromObjectExec(serializer, planLater(child)) :: Nil
       case logical.MapPartitions(f, objAttr, child) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/661c2104/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
index 3ff9913..5fced94 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -28,17 +28,41 @@ import 
org.apache.spark.sql.catalyst.expressions.objects.Invoke
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.types.{DataType, ObjectType}
 
+
+/**
+ * Physical version of `ObjectProducer`.
+ */
+trait ObjectProducerExec extends SparkPlan {
+  // The attribute that reference to the single object field this operator 
outputs.
+  protected def outputObjAttr: Attribute
+
+  override def output: Seq[Attribute] = outputObjAttr :: Nil
+
+  override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
+
+  def outputObjectType: DataType = outputObjAttr.dataType
+}
+
+/**
+ * Physical version of `ObjectConsumer`.
+ */
+trait ObjectConsumerExec extends UnaryExecNode {
+  assert(child.output.length == 1)
+
+  // This operator always need all columns of its child, even it doesn't 
reference to.
+  override def references: AttributeSet = child.outputSet
+
+  def inputObjectType: DataType = child.output.head.dataType
+}
+
 /**
  * Takes the input row from child and turns it into object using the given 
deserializer expression.
  * The output of this operator is a single-field safe row containing the 
deserialized object.
  */
-case class DeserializeToObject(
+case class DeserializeToObjectExec(
     deserializer: Expression,
     outputObjAttr: Attribute,
-    child: SparkPlan) extends UnaryExecNode with CodegenSupport {
-
-  override def output: Seq[Attribute] = outputObjAttr :: Nil
-  override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
+    child: SparkPlan) extends UnaryExecNode with ObjectProducerExec with 
CodegenSupport {
 
   override def inputRDDs(): Seq[RDD[InternalRow]] = {
     child.asInstanceOf[CodegenSupport].inputRDDs()
@@ -70,7 +94,7 @@ case class DeserializeToObject(
  */
 case class SerializeFromObjectExec(
     serializer: Seq[NamedExpression],
-    child: SparkPlan) extends UnaryExecNode with CodegenSupport {
+    child: SparkPlan) extends ObjectConsumerExec with CodegenSupport {
 
   override def output: Seq[Attribute] = serializer.map(_.toAttribute)
 
@@ -102,7 +126,7 @@ case class SerializeFromObjectExec(
 /**
  * Helper functions for physical operators that work with user defined objects.
  */
-trait ObjectOperator extends SparkPlan {
+object ObjectOperator {
   def deserializeRowToObject(
       deserializer: Expression,
       inputSchema: Seq[Attribute]): InternalRow => Any = {
@@ -141,15 +165,12 @@ case class MapPartitionsExec(
     func: Iterator[Any] => Iterator[Any],
     outputObjAttr: Attribute,
     child: SparkPlan)
-  extends UnaryExecNode with ObjectOperator {
-
-  override def output: Seq[Attribute] = outputObjAttr :: Nil
-  override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
+  extends ObjectConsumerExec with ObjectProducerExec {
 
   override protected def doExecute(): RDD[InternalRow] = {
     child.execute().mapPartitionsInternal { iter =>
-      val getObject = unwrapObjectFromRow(child.output.head.dataType)
-      val outputObject = wrapObjectToRow(outputObjAttr.dataType)
+      val getObject = 
ObjectOperator.unwrapObjectFromRow(child.output.head.dataType)
+      val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
       func(iter.map(getObject)).map(outputObject)
     }
   }
@@ -166,10 +187,7 @@ case class MapElementsExec(
     func: AnyRef,
     outputObjAttr: Attribute,
     child: SparkPlan)
-  extends UnaryExecNode with ObjectOperator with CodegenSupport {
-
-  override def output: Seq[Attribute] = outputObjAttr :: Nil
-  override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
+  extends ObjectConsumerExec with ObjectProducerExec with CodegenSupport {
 
   override def inputRDDs(): Seq[RDD[InternalRow]] = {
     child.asInstanceOf[CodegenSupport].inputRDDs()
@@ -202,8 +220,8 @@ case class MapElementsExec(
     }
 
     child.execute().mapPartitionsInternal { iter =>
-      val getObject = unwrapObjectFromRow(child.output.head.dataType)
-      val outputObject = wrapObjectToRow(outputObjAttr.dataType)
+      val getObject = 
ObjectOperator.unwrapObjectFromRow(child.output.head.dataType)
+      val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
       iter.map(row => outputObject(callFunc(getObject(row))))
     }
   }
@@ -218,7 +236,7 @@ case class AppendColumnsExec(
     func: Any => Any,
     deserializer: Expression,
     serializer: Seq[NamedExpression],
-    child: SparkPlan) extends UnaryExecNode with ObjectOperator {
+    child: SparkPlan) extends UnaryExecNode {
 
   override def output: Seq[Attribute] = child.output ++ 
serializer.map(_.toAttribute)
 
@@ -226,9 +244,9 @@ case class AppendColumnsExec(
 
   override protected def doExecute(): RDD[InternalRow] = {
     child.execute().mapPartitionsInternal { iter =>
-      val getObject = deserializeRowToObject(deserializer, child.output)
+      val getObject = ObjectOperator.deserializeRowToObject(deserializer, 
child.output)
       val combiner = GenerateUnsafeRowJoiner.create(child.schema, 
newColumnSchema)
-      val outputObject = serializeObjectToRow(serializer)
+      val outputObject = ObjectOperator.serializeObjectToRow(serializer)
 
       iter.map { row =>
         val newColumns = outputObject(func(getObject(row)))
@@ -246,7 +264,7 @@ case class AppendColumnsWithObjectExec(
     func: Any => Any,
     inputSerializer: Seq[NamedExpression],
     newColumnsSerializer: Seq[NamedExpression],
-    child: SparkPlan) extends UnaryExecNode with ObjectOperator {
+    child: SparkPlan) extends ObjectConsumerExec {
 
   override def output: Seq[Attribute] = (inputSerializer ++ 
newColumnsSerializer).map(_.toAttribute)
 
@@ -255,9 +273,9 @@ case class AppendColumnsWithObjectExec(
 
   override protected def doExecute(): RDD[InternalRow] = {
     child.execute().mapPartitionsInternal { iter =>
-      val getChildObject = unwrapObjectFromRow(child.output.head.dataType)
-      val outputChildObject = serializeObjectToRow(inputSerializer)
-      val outputNewColumnOjb = serializeObjectToRow(newColumnsSerializer)
+      val getChildObject = 
ObjectOperator.unwrapObjectFromRow(child.output.head.dataType)
+      val outputChildObject = 
ObjectOperator.serializeObjectToRow(inputSerializer)
+      val outputNewColumnOjb = 
ObjectOperator.serializeObjectToRow(newColumnsSerializer)
       val combiner = GenerateUnsafeRowJoiner.create(inputSchema, 
newColumnSchema)
 
       iter.map { row =>
@@ -280,10 +298,7 @@ case class MapGroupsExec(
     groupingAttributes: Seq[Attribute],
     dataAttributes: Seq[Attribute],
     outputObjAttr: Attribute,
-    child: SparkPlan) extends UnaryExecNode with ObjectOperator {
-
-  override def output: Seq[Attribute] = outputObjAttr :: Nil
-  override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
+    child: SparkPlan) extends UnaryExecNode with ObjectProducerExec {
 
   override def requiredChildDistribution: Seq[Distribution] =
     ClusteredDistribution(groupingAttributes) :: Nil
@@ -295,9 +310,9 @@ case class MapGroupsExec(
     child.execute().mapPartitionsInternal { iter =>
       val grouped = GroupedIterator(iter, groupingAttributes, child.output)
 
-      val getKey = deserializeRowToObject(keyDeserializer, groupingAttributes)
-      val getValue = deserializeRowToObject(valueDeserializer, dataAttributes)
-      val outputObject = wrapObjectToRow(outputObjAttr.dataType)
+      val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer, 
groupingAttributes)
+      val getValue = ObjectOperator.deserializeRowToObject(valueDeserializer, 
dataAttributes)
+      val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
 
       grouped.flatMap { case (key, rowIter) =>
         val result = func(
@@ -325,10 +340,7 @@ case class CoGroupExec(
     rightAttr: Seq[Attribute],
     outputObjAttr: Attribute,
     left: SparkPlan,
-    right: SparkPlan) extends BinaryExecNode with ObjectOperator {
-
-  override def output: Seq[Attribute] = outputObjAttr :: Nil
-  override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
+    right: SparkPlan) extends BinaryExecNode with ObjectProducerExec {
 
   override def requiredChildDistribution: Seq[Distribution] =
     ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: 
Nil
@@ -341,10 +353,10 @@ case class CoGroupExec(
       val leftGrouped = GroupedIterator(leftData, leftGroup, left.output)
       val rightGrouped = GroupedIterator(rightData, rightGroup, right.output)
 
-      val getKey = deserializeRowToObject(keyDeserializer, leftGroup)
-      val getLeft = deserializeRowToObject(leftDeserializer, leftAttr)
-      val getRight = deserializeRowToObject(rightDeserializer, rightAttr)
-      val outputObject = wrapObjectToRow(outputObjAttr.dataType)
+      val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer, 
leftGroup)
+      val getLeft = ObjectOperator.deserializeRowToObject(leftDeserializer, 
leftAttr)
+      val getRight = ObjectOperator.deserializeRowToObject(rightDeserializer, 
rightAttr)
+      val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
 
       new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap {
         case (key, leftResult, rightResult) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/661c2104/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 1935e41..52e7062 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -711,6 +711,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext 
{
     assert(e.message.contains("already exists"))
     dataset.sparkSession.catalog.dropTempView("tempView")
   }
+
+  test("SPARK-15381: physical object operator should define `reference` 
correctly") {
+    val df = Seq(1 -> 2).toDF("a", "b")
+    checkAnswer(df.map(row => row)(RowEncoder(df.schema)).select("b", "a"), 
Row(2, 1))
+  }
 }
 
 case class Generic[T](id: T, value: Double)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to