Repository: spark Updated Branches: refs/heads/master c284c4e1f -> 1c9f95cb7
[SPARK-22530][PYTHON][SQL] Adding Arrow support for ArrayType ## What changes were proposed in this pull request? This change adds `ArrayType` support for working with Arrow in pyspark when creating a DataFrame, calling `toPandas()`, and using vectorized `pandas_udf`. ## How was this patch tested? Added new Python unit tests using Array data. Author: Bryan Cutler <cutl...@gmail.com> Closes #20114 from BryanCutler/arrow-ArrayType-support-SPARK-22530. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1c9f95cb Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1c9f95cb Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1c9f95cb Branch: refs/heads/master Commit: 1c9f95cb771ac78775a77edd1abfeb2d8ae2a124 Parents: c284c4e Author: Bryan Cutler <cutl...@gmail.com> Authored: Tue Jan 2 07:13:27 2018 +0900 Committer: hyukjinkwon <gurwls...@gmail.com> Committed: Tue Jan 2 07:13:27 2018 +0900 ---------------------------------------------------------------------- python/pyspark/sql/tests.py | 47 +++++++++++++++++++- python/pyspark/sql/types.py | 4 ++ .../execution/vectorized/ArrowColumnVector.java | 13 +++++- 3 files changed, 61 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/1c9f95cb/python/pyspark/sql/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1c34c89..67bdb3d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3372,6 +3372,31 @@ class ArrowTests(ReusedSQLTestCase): schema_rt = from_arrow_schema(arrow_schema) self.assertEquals(self.schema, schema_rt) + def test_createDataFrame_with_array_type(self): + import pandas as pd + pdf = pd.DataFrame({"a": [[1, 2], [3, 4]], "b": [[u"x", u"y"], [u"y", u"z"]]}) + df, df_arrow = self._createDataFrame_toggle(pdf) + result = df.collect() + result_arrow = df_arrow.collect() + expected = [tuple(list(e) for e in rec) for rec in pdf.to_records(index=False)] + for r in range(len(expected)): + for e in range(len(expected[r])): + self.assertTrue(expected[r][e] == result_arrow[r][e] and + result[r][e] == result_arrow[r][e]) + + def test_toPandas_with_array_type(self): + expected = [([1, 2], [u"x", u"y"]), ([3, 4], [u"y", u"z"])] + array_schema = StructType([StructField("a", ArrayType(IntegerType())), + StructField("b", ArrayType(StringType()))]) + df = self.spark.createDataFrame(expected, schema=array_schema) + pdf, pdf_arrow = self._toPandas_arrow_toggle(df) + result = [tuple(list(e) for e in rec) for rec in pdf.to_records(index=False)] + result_arrow = [tuple(list(e) for e in rec) for rec in pdf_arrow.to_records(index=False)] + for r in range(len(expected)): + for e in range(len(expected[r])): + self.assertTrue(expected[r][e] == result_arrow[r][e] and + result[r][e] == result_arrow[r][e]) + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class PandasUDFTests(ReusedSQLTestCase): @@ -3651,6 +3676,24 @@ class VectorizedUDFTests(ReusedSQLTestCase): bool_f(col('bool'))) self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_array_type(self): + from pyspark.sql.functions import pandas_udf, col + data = [([1, 2],), ([3, 4],)] + array_schema = StructType([StructField("array", ArrayType(IntegerType()))]) + df = self.spark.createDataFrame(data, schema=array_schema) + array_f = pandas_udf(lambda x: x, ArrayType(IntegerType())) + result = df.select(array_f(col('array'))) + self.assertEquals(df.collect(), result.collect()) + + def test_vectorized_udf_null_array(self): + from pyspark.sql.functions import pandas_udf, col + data = [([1, 2],), (None,), (None,), ([3, 4],), (None,)] + array_schema = StructType([StructField("array", ArrayType(IntegerType()))]) + df = self.spark.createDataFrame(data, schema=array_schema) + array_f = pandas_udf(lambda x: x, ArrayType(IntegerType())) + result = df.select(array_f(col('array'))) + self.assertEquals(df.collect(), result.collect()) + def test_vectorized_udf_complex(self): from pyspark.sql.functions import pandas_udf, col, expr df = self.spark.range(10).select( @@ -3705,7 +3748,7 @@ class VectorizedUDFTests(ReusedSQLTestCase): def test_vectorized_udf_wrong_return_type(self): from pyspark.sql.functions import pandas_udf, col df = self.spark.range(10) - f = pandas_udf(lambda x: x * 1.0, ArrayType(LongType())) + f = pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType())) with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, 'Unsupported.*type.*conversion'): df.select(f(col('id'))).collect() @@ -4009,7 +4052,7 @@ class GroupbyApplyTests(ReusedSQLTestCase): foo = pandas_udf( lambda pdf: pdf, - 'id long, v array<int>', + 'id long, v map<int, int>', PandasUDFType.GROUP_MAP ) http://git-wip-us.apache.org/repos/asf/spark/blob/1c9f95cb/python/pyspark/sql/types.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 02b2457..146e673 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1625,6 +1625,8 @@ def to_arrow_type(dt): elif type(dt) == TimestampType: # Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read arrow_type = pa.timestamp('us', tz='UTC') + elif type(dt) == ArrayType: + arrow_type = pa.list_(to_arrow_type(dt.elementType)) else: raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) return arrow_type @@ -1665,6 +1667,8 @@ def from_arrow_type(at): spark_type = DateType() elif types.is_timestamp(at): spark_type = TimestampType() + elif types.is_list(at): + spark_type = ArrayType(from_arrow_type(at.value_type)) else: raise TypeError("Unsupported type in conversion from Arrow: " + str(at)) return spark_type http://git-wip-us.apache.org/repos/asf/spark/blob/1c9f95cb/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java ---------------------------------------------------------------------- diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java index 528f66f..af5673e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java @@ -326,7 +326,8 @@ public final class ArrowColumnVector extends ColumnVector { this.vector = vector; } - final boolean isNullAt(int rowId) { + // TODO: should be final after removing ArrayAccessor workaround + boolean isNullAt(int rowId) { return vector.isNull(rowId); } @@ -590,6 +591,16 @@ public final class ArrowColumnVector extends ColumnVector { } @Override + final boolean isNullAt(int rowId) { + // TODO: Workaround if vector has all non-null values, see ARROW-1948 + if (accessor.getValueCount() > 0 && accessor.getValidityBuffer().capacity() == 0) { + return false; + } else { + return super.isNullAt(rowId); + } + } + + @Override final int getArrayLength(int rowId) { return accessor.getInnerValueCountAt(rowId); } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org