Repository: spark Updated Branches: refs/heads/master 4a2b15f0a -> d610d2a3f
[SPARK-24259][SQL] ArrayWriter for Arrow produces wrong output ## What changes were proposed in this pull request? Right now `ArrayWriter` used to output Arrow data for array type, doesn't do `clear` or `reset` after each batch. It produces wrong output. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh <vii...@gmail.com> Closes #21312 from viirya/SPARK-24259. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d610d2a3 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d610d2a3 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d610d2a3 Branch: refs/heads/master Commit: d610d2a3f57ca551f72cb4e5dfed78f27be62eec Parents: 4a2b15f Author: Liang-Chi Hsieh <vii...@gmail.com> Authored: Tue May 15 22:06:58 2018 +0800 Committer: Wenchen Fan <wenc...@databricks.com> Committed: Tue May 15 22:06:58 2018 +0800 ---------------------------------------------------------------------- python/pyspark/sql/tests.py | 20 ++++++++++++++++++++ .../spark/sql/execution/arrow/ArrowWriter.scala | 8 ++++++++ 2 files changed, 28 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/d610d2a3/python/pyspark/sql/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 16aa937..a1b6db7 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -4680,6 +4680,26 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase): self.assertPandasEqual(expected2, result2) self.assertPandasEqual(expected3, result3) + def test_array_type_correct(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col + + df = self.data.withColumn("arr", array(col("id"))).repartition(1, "id") + + output_schema = StructType( + [StructField('id', LongType()), + StructField('v', IntegerType()), + StructField('arr', ArrayType(LongType()))]) + + udf = pandas_udf( + lambda pdf: pdf, + output_schema, + PandasUDFType.GROUPED_MAP + ) + + result = df.groupby('id').apply(udf).sort('id').toPandas() + expected = df.toPandas().groupby('id').apply(udf.func).reset_index(drop=True) + self.assertPandasEqual(expected, result) + def test_register_grouped_map_udf(self): from pyspark.sql.functions import pandas_udf, PandasUDFType http://git-wip-us.apache.org/repos/asf/spark/blob/d610d2a3/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 22b6351..66888fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -133,6 +133,14 @@ private[arrow] abstract class ArrowFieldWriter { valueVector match { case fixedWidthVector: BaseFixedWidthVector => fixedWidthVector.reset() case variableWidthVector: BaseVariableWidthVector => variableWidthVector.reset() + case listVector: ListVector => + // Manual "reset" the underlying buffer. + // TODO: When we upgrade to Arrow 0.10.0, we can simply remove this and call + // `listVector.reset()`. + val buffers = listVector.getBuffers(false) + buffers.foreach(buf => buf.setZero(0, buf.capacity())) + listVector.setValueCount(0) + listVector.setLastSet(0) case _ => } count = 0 --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org