Repository: spark
Updated Branches:
  refs/heads/master 4a2b15f0a -> d610d2a3f


[SPARK-24259][SQL] ArrayWriter for Arrow produces wrong output

## What changes were proposed in this pull request?

Right now `ArrayWriter` used to output Arrow data for array type, doesn't do 
`clear` or `reset` after each batch. It produces wrong output.

## How was this patch tested?

Added test.

Author: Liang-Chi Hsieh <vii...@gmail.com>

Closes #21312 from viirya/SPARK-24259.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d610d2a3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d610d2a3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d610d2a3

Branch: refs/heads/master
Commit: d610d2a3f57ca551f72cb4e5dfed78f27be62eec
Parents: 4a2b15f
Author: Liang-Chi Hsieh <vii...@gmail.com>
Authored: Tue May 15 22:06:58 2018 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Tue May 15 22:06:58 2018 +0800

----------------------------------------------------------------------
 python/pyspark/sql/tests.py                     | 20 ++++++++++++++++++++
 .../spark/sql/execution/arrow/ArrowWriter.scala |  8 ++++++++
 2 files changed, 28 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d610d2a3/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 16aa937..a1b6db7 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -4680,6 +4680,26 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase):
         self.assertPandasEqual(expected2, result2)
         self.assertPandasEqual(expected3, result3)
 
+    def test_array_type_correct(self):
+        from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col
+
+        df = self.data.withColumn("arr", array(col("id"))).repartition(1, "id")
+
+        output_schema = StructType(
+            [StructField('id', LongType()),
+             StructField('v', IntegerType()),
+             StructField('arr', ArrayType(LongType()))])
+
+        udf = pandas_udf(
+            lambda pdf: pdf,
+            output_schema,
+            PandasUDFType.GROUPED_MAP
+        )
+
+        result = df.groupby('id').apply(udf).sort('id').toPandas()
+        expected = 
df.toPandas().groupby('id').apply(udf.func).reset_index(drop=True)
+        self.assertPandasEqual(expected, result)
+
     def test_register_grouped_map_udf(self):
         from pyspark.sql.functions import pandas_udf, PandasUDFType
 

http://git-wip-us.apache.org/repos/asf/spark/blob/d610d2a3/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
index 22b6351..66888fc 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
@@ -133,6 +133,14 @@ private[arrow] abstract class ArrowFieldWriter {
     valueVector match {
       case fixedWidthVector: BaseFixedWidthVector => fixedWidthVector.reset()
       case variableWidthVector: BaseVariableWidthVector => 
variableWidthVector.reset()
+      case listVector: ListVector =>
+        // Manual "reset" the underlying buffer.
+        // TODO: When we upgrade to Arrow 0.10.0, we can simply remove this 
and call
+        // `listVector.reset()`.
+        val buffers = listVector.getBuffers(false)
+        buffers.foreach(buf => buf.setZero(0, buf.capacity()))
+        listVector.setValueCount(0)
+        listVector.setLastSet(0)
       case _ =>
     }
     count = 0


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to