Repository: spark
Updated Branches:
  refs/heads/master dd077abf2 -> f84c799ea


[SPARK-5996][SQL] Fix specialized outbound conversions

Author: Michael Armbrust <mich...@databricks.com>

Closes #4757 from marmbrus/udtConversions and squashes the following commits:

3714aad [Michael Armbrust] [SPARK-5996][SQL] Fix specialized outbound 
conversions


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f84c799e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f84c799e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f84c799e

Branch: refs/heads/master
Commit: f84c799ea0b82abca6a4fad39532c2515743b632
Parents: dd077ab
Author: Michael Armbrust <mich...@databricks.com>
Authored: Wed Feb 25 10:13:40 2015 -0800
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Wed Feb 25 10:13:40 2015 -0800

----------------------------------------------------------------------
 .../org/apache/spark/sql/execution/LocalTableScan.scala   |  7 +++++--
 .../org/apache/spark/sql/execution/basicOperators.scala   |  8 +++++---
 .../scala/org/apache/spark/sql/UserDefinedTypeSuite.scala | 10 ++++++++++
 3 files changed, 20 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f84c799e/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala
index d6d8258..d3a18b3 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.execution
 
 import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.ScalaReflection
 import org.apache.spark.sql.catalyst.expressions.Attribute
 
 
@@ -30,7 +31,9 @@ case class LocalTableScan(output: Seq[Attribute], rows: 
Seq[Row]) extends LeafNo
 
   override def execute() = rdd
 
-  override def executeCollect() = rows.toArray
+  override def executeCollect() =
+    rows.map(ScalaReflection.convertRowToScala(_, schema)).toArray
 
-  override def executeTake(limit: Int) = rows.take(limit).toArray
+  override def executeTake(limit: Int) =
+    rows.map(ScalaReflection.convertRowToScala(_, schema)).take(limit).toArray
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f84c799e/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 4dc506c..7102685 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -134,13 +134,15 @@ case class TakeOrdered(limit: Int, sortOrder: 
Seq[SortOrder], child: SparkPlan)
 
   val ord = new RowOrdering(sortOrder, child.output)
 
+  private def collectData() = 
child.execute().map(_.copy()).takeOrdered(limit)(ord)
+
   // TODO: Is this copying for no reason?
-  override def executeCollect() = 
child.execute().map(_.copy()).takeOrdered(limit)(ord)
-    .map(ScalaReflection.convertRowToScala(_, this.schema))
+  override def executeCollect() =
+    collectData().map(ScalaReflection.convertRowToScala(_, this.schema))
 
   // TODO: Terminal split should be implemented differently from non-terminal 
split.
   // TODO: Pick num splits based on |limit|.
-  override def execute() = sparkContext.makeRDD(executeCollect(), 1)
+  override def execute() = sparkContext.makeRDD(collectData(), 1)
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/f84c799e/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index 9c098df..47fdb55 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -22,6 +22,7 @@ import java.io.File
 import scala.beans.{BeanInfo, BeanProperty}
 
 import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.TestSQLContext
 import org.apache.spark.sql.test.TestSQLContext.{sparkContext, sql}
 import org.apache.spark.sql.test.TestSQLContext.implicits._
@@ -105,4 +106,13 @@ class UserDefinedTypeSuite extends QueryTest {
     tempDir.delete()
     pointsRDD.repartition(1).saveAsParquetFile(tempDir.getCanonicalPath)
   }
+
+  // Tests to make sure that all operators correctly convert types on the way 
out.
+  test("Local UDTs") {
+    val df = Seq((1, new MyDenseVector(Array(0.1, 1.0)))).toDF("int", "vec")
+    df.collect()(0).getAs[MyDenseVector](1)
+    df.take(1)(0).getAs[MyDenseVector](1)
+    
df.limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0)
+    
df.orderBy('int).limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to