Repository: spark Updated Branches: refs/heads/master dd077abf2 -> f84c799ea
[SPARK-5996][SQL] Fix specialized outbound conversions Author: Michael Armbrust <mich...@databricks.com> Closes #4757 from marmbrus/udtConversions and squashes the following commits: 3714aad [Michael Armbrust] [SPARK-5996][SQL] Fix specialized outbound conversions Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f84c799e Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f84c799e Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f84c799e Branch: refs/heads/master Commit: f84c799ea0b82abca6a4fad39532c2515743b632 Parents: dd077ab Author: Michael Armbrust <mich...@databricks.com> Authored: Wed Feb 25 10:13:40 2015 -0800 Committer: Xiangrui Meng <m...@databricks.com> Committed: Wed Feb 25 10:13:40 2015 -0800 ---------------------------------------------------------------------- .../org/apache/spark/sql/execution/LocalTableScan.scala | 7 +++++-- .../org/apache/spark/sql/execution/basicOperators.scala | 8 +++++--- .../scala/org/apache/spark/sql/UserDefinedTypeSuite.scala | 10 ++++++++++ 3 files changed, 20 insertions(+), 5 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/f84c799e/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala index d6d8258..d3a18b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.Attribute @@ -30,7 +31,9 @@ case class LocalTableScan(output: Seq[Attribute], rows: Seq[Row]) extends LeafNo override def execute() = rdd - override def executeCollect() = rows.toArray + override def executeCollect() = + rows.map(ScalaReflection.convertRowToScala(_, schema)).toArray - override def executeTake(limit: Int) = rows.take(limit).toArray + override def executeTake(limit: Int) = + rows.map(ScalaReflection.convertRowToScala(_, schema)).take(limit).toArray } http://git-wip-us.apache.org/repos/asf/spark/blob/f84c799e/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 4dc506c..7102685 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -134,13 +134,15 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) val ord = new RowOrdering(sortOrder, child.output) + private def collectData() = child.execute().map(_.copy()).takeOrdered(limit)(ord) + // TODO: Is this copying for no reason? - override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ord) - .map(ScalaReflection.convertRowToScala(_, this.schema)) + override def executeCollect() = + collectData().map(ScalaReflection.convertRowToScala(_, this.schema)) // TODO: Terminal split should be implemented differently from non-terminal split. // TODO: Pick num splits based on |limit|. - override def execute() = sparkContext.makeRDD(executeCollect(), 1) + override def execute() = sparkContext.makeRDD(collectData(), 1) } /** http://git-wip-us.apache.org/repos/asf/spark/blob/f84c799e/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 9c098df..47fdb55 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -22,6 +22,7 @@ import java.io.File import scala.beans.{BeanInfo, BeanProperty} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.{sparkContext, sql} import org.apache.spark.sql.test.TestSQLContext.implicits._ @@ -105,4 +106,13 @@ class UserDefinedTypeSuite extends QueryTest { tempDir.delete() pointsRDD.repartition(1).saveAsParquetFile(tempDir.getCanonicalPath) } + + // Tests to make sure that all operators correctly convert types on the way out. + test("Local UDTs") { + val df = Seq((1, new MyDenseVector(Array(0.1, 1.0)))).toDF("int", "vec") + df.collect()(0).getAs[MyDenseVector](1) + df.take(1)(0).getAs[MyDenseVector](1) + df.limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0) + df.orderBy('int).limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org