Repository: spark Updated Branches: refs/heads/master 0e6833006 -> eb386be1e
[SPARK-21552][SQL] Add DecimalType support to ArrowWriter. ## What changes were proposed in this pull request? Decimal type is not yet supported in `ArrowWriter`. This is adding the decimal type support. ## How was this patch tested? Added a test to `ArrowConvertersSuite`. Author: Takuya UESHIN <ues...@databricks.com> Closes #18754 from ueshin/issues/SPARK-21552. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/eb386be1 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/eb386be1 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/eb386be1 Branch: refs/heads/master Commit: eb386be1ed383323da6e757f63f3b8a7ced38cc4 Parents: 0e68330 Author: Takuya UESHIN <ues...@databricks.com> Authored: Tue Dec 26 21:37:25 2017 +0900 Committer: hyukjinkwon <gurwls...@gmail.com> Committed: Tue Dec 26 21:37:25 2017 +0900 ---------------------------------------------------------------------- python/pyspark/sql/tests.py | 61 ++++++++++++------ python/pyspark/sql/types.py | 2 +- .../spark/sql/execution/arrow/ArrowWriter.scala | 21 ++++++ .../execution/arrow/ArrowConvertersSuite.scala | 67 +++++++++++++++++++- .../sql/execution/arrow/ArrowWriterSuite.scala | 2 + 5 files changed, 131 insertions(+), 22 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/eb386be1/python/pyspark/sql/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index b977160..b811a0f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3142,6 +3142,7 @@ class ArrowTests(ReusedSQLTestCase): @classmethod def setUpClass(cls): from datetime import datetime + from decimal import Decimal ReusedSQLTestCase.setUpClass() # Synchronize default timezone between Python and Java @@ -3158,11 +3159,15 @@ class ArrowTests(ReusedSQLTestCase): StructField("3_long_t", LongType(), True), StructField("4_float_t", FloatType(), True), StructField("5_double_t", DoubleType(), True), - StructField("6_date_t", DateType(), True), - StructField("7_timestamp_t", TimestampType(), True)]) - cls.data = [(u"a", 1, 10, 0.2, 2.0, datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), - (u"b", 2, 20, 0.4, 4.0, datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), - (u"c", 3, 30, 0.8, 6.0, datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] + StructField("6_decimal_t", DecimalType(38, 18), True), + StructField("7_date_t", DateType(), True), + StructField("8_timestamp_t", TimestampType(), True)]) + cls.data = [(u"a", 1, 10, 0.2, 2.0, Decimal("2.0"), + datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), + (u"b", 2, 20, 0.4, 4.0, Decimal("4.0"), + datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), + (u"c", 3, 30, 0.8, 6.0, Decimal("6.0"), + datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] @classmethod def tearDownClass(cls): @@ -3190,10 +3195,11 @@ class ArrowTests(ReusedSQLTestCase): return pd.DataFrame(data=data_dict) def test_unsupported_datatype(self): - schema = StructType([StructField("decimal", DecimalType(), True)]) + schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) df = self.spark.createDataFrame([(None,)], schema=schema) with QuietTest(self.sc): - self.assertRaises(Exception, lambda: df.toPandas()) + with self.assertRaisesRegexp(Exception, 'Unsupported data type'): + df.toPandas() def test_null_conversion(self): df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] + @@ -3293,7 +3299,7 @@ class ArrowTests(ReusedSQLTestCase): self.assertNotEqual(result_ny, result_la) # Correct result_la by adjusting 3 hours difference between Los Angeles and New York - result_la_corrected = [Row(**{k: v - timedelta(hours=3) if k == '7_timestamp_t' else v + result_la_corrected = [Row(**{k: v - timedelta(hours=3) if k == '8_timestamp_t' else v for k, v in row.asDict().items()}) for row in result_la] self.assertEqual(result_ny, result_la_corrected) @@ -3317,11 +3323,11 @@ class ArrowTests(ReusedSQLTestCase): def test_createDataFrame_with_names(self): pdf = self.create_pandas_data_frame() # Test that schema as a list of column names gets applied - df = self.spark.createDataFrame(pdf, schema=list('abcdefg')) - self.assertEquals(df.schema.fieldNames(), list('abcdefg')) + df = self.spark.createDataFrame(pdf, schema=list('abcdefgh')) + self.assertEquals(df.schema.fieldNames(), list('abcdefgh')) # Test that schema as tuple of column names gets applied - df = self.spark.createDataFrame(pdf, schema=tuple('abcdefg')) - self.assertEquals(df.schema.fieldNames(), list('abcdefg')) + df = self.spark.createDataFrame(pdf, schema=tuple('abcdefgh')) + self.assertEquals(df.schema.fieldNames(), list('abcdefgh')) def test_createDataFrame_column_name_encoding(self): import pandas as pd @@ -3344,7 +3350,7 @@ class ArrowTests(ReusedSQLTestCase): # Some series get converted for Spark to consume, this makes sure input is unchanged pdf = self.create_pandas_data_frame() # Use a nanosecond value to make sure it is not truncated - pdf.ix[0, '7_timestamp_t'] = pd.Timestamp(1) + pdf.ix[0, '8_timestamp_t'] = pd.Timestamp(1) # Integers with nulls will get NaNs filled with 0 and will be casted pdf.ix[1, '2_int_t'] = None pdf_copy = pdf.copy(deep=True) @@ -3514,6 +3520,7 @@ class VectorizedUDFTests(ReusedSQLTestCase): col('id').alias('long'), col('id').cast('float').alias('float'), col('id').cast('double').alias('double'), + col('id').cast('decimal').alias('decimal'), col('id').cast('boolean').alias('bool')) f = lambda x: x str_f = pandas_udf(f, StringType()) @@ -3521,10 +3528,12 @@ class VectorizedUDFTests(ReusedSQLTestCase): long_f = pandas_udf(f, LongType()) float_f = pandas_udf(f, FloatType()) double_f = pandas_udf(f, DoubleType()) + decimal_f = pandas_udf(f, DecimalType()) bool_f = pandas_udf(f, BooleanType()) res = df.select(str_f(col('str')), int_f(col('int')), long_f(col('long')), float_f(col('float')), - double_f(col('double')), bool_f(col('bool'))) + double_f(col('double')), decimal_f('decimal'), + bool_f(col('bool'))) self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_null_boolean(self): @@ -3590,6 +3599,16 @@ class VectorizedUDFTests(ReusedSQLTestCase): res = df.select(double_f(col('double'))) self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_null_decimal(self): + from decimal import Decimal + from pyspark.sql.functions import pandas_udf, col + data = [(Decimal(3.0),), (Decimal(5.0),), (Decimal(-1.0),), (None,)] + schema = StructType().add("decimal", DecimalType(38, 18)) + df = self.spark.createDataFrame(data, schema) + decimal_f = pandas_udf(lambda x: x, DecimalType(38, 18)) + res = df.select(decimal_f(col('decimal'))) + self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_null_string(self): from pyspark.sql.functions import pandas_udf, col data = [("foo",), (None,), ("bar",), ("bar",)] @@ -3607,6 +3626,7 @@ class VectorizedUDFTests(ReusedSQLTestCase): col('id').alias('long'), col('id').cast('float').alias('float'), col('id').cast('double').alias('double'), + col('id').cast('decimal').alias('decimal'), col('id').cast('boolean').alias('bool')) f = lambda x: x str_f = pandas_udf(f, 'string') @@ -3614,10 +3634,12 @@ class VectorizedUDFTests(ReusedSQLTestCase): long_f = pandas_udf(f, 'long') float_f = pandas_udf(f, 'float') double_f = pandas_udf(f, 'double') + decimal_f = pandas_udf(f, 'decimal(38, 18)') bool_f = pandas_udf(f, 'boolean') res = df.select(str_f(col('str')), int_f(col('int')), long_f(col('long')), float_f(col('float')), - double_f(col('double')), bool_f(col('bool'))) + double_f(col('double')), decimal_f('decimal'), + bool_f(col('bool'))) self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_complex(self): @@ -3713,12 +3735,12 @@ class VectorizedUDFTests(ReusedSQLTestCase): def test_vectorized_udf_unsupported_types(self): from pyspark.sql.functions import pandas_udf, col - schema = StructType([StructField("dt", DecimalType(), True)]) + schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) df = self.spark.createDataFrame([(None,)], schema=schema) - f = pandas_udf(lambda x: x, DecimalType()) + f = pandas_udf(lambda x: x, MapType(StringType(), IntegerType())) with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, 'Unsupported data type'): - df.select(f(col('dt'))).collect() + df.select(f(col('map'))).collect() def test_vectorized_udf_null_date(self): from pyspark.sql.functions import pandas_udf, col @@ -4012,7 +4034,8 @@ class GroupbyApplyTests(ReusedSQLTestCase): def test_unsupported_types(self): from pyspark.sql.functions import pandas_udf, col, PandasUDFType schema = StructType( - [StructField("id", LongType(), True), StructField("dt", DecimalType(), True)]) + [StructField("id", LongType(), True), + StructField("map", MapType(StringType(), IntegerType()), True)]) df = self.spark.createDataFrame([(1, None,)], schema=schema) f = pandas_udf(lambda x: x, df.schema, PandasUDFType.GROUP_MAP) with QuietTest(self.sc): http://git-wip-us.apache.org/repos/asf/spark/blob/eb386be1/python/pyspark/sql/types.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 063264a..02b2457 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1617,7 +1617,7 @@ def to_arrow_type(dt): elif type(dt) == DoubleType: arrow_type = pa.float64() elif type(dt) == DecimalType: - arrow_type = pa.decimal(dt.precision, dt.scale) + arrow_type = pa.decimal128(dt.precision, dt.scale) elif type(dt) == StringType: arrow_type = pa.string() elif type(dt) == DateType: http://git-wip-us.apache.org/repos/asf/spark/blob/eb386be1/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 0258056..22b6351 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -53,6 +53,8 @@ object ArrowWriter { case (LongType, vector: BigIntVector) => new LongWriter(vector) case (FloatType, vector: Float4Vector) => new FloatWriter(vector) case (DoubleType, vector: Float8Vector) => new DoubleWriter(vector) + case (DecimalType.Fixed(precision, scale), vector: DecimalVector) => + new DecimalWriter(vector, precision, scale) case (StringType, vector: VarCharVector) => new StringWriter(vector) case (BinaryType, vector: VarBinaryVector) => new BinaryWriter(vector) case (DateType, vector: DateDayVector) => new DateWriter(vector) @@ -214,6 +216,25 @@ private[arrow] class DoubleWriter(val valueVector: Float8Vector) extends ArrowFi } } +private[arrow] class DecimalWriter( + val valueVector: DecimalVector, + precision: Int, + scale: Int) extends ArrowFieldWriter { + + override def setNull(): Unit = { + valueVector.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + val decimal = input.getDecimal(ordinal, precision, scale) + if (decimal.changePrecision(precision, scale)) { + valueVector.setSafe(count, decimal.toJavaBigDecimal) + } else { + setNull() + } + } +} + private[arrow] class StringWriter(val valueVector: VarCharVector) extends ArrowFieldWriter { override def setNull(): Unit = { http://git-wip-us.apache.org/repos/asf/spark/blob/eb386be1/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index fd5a3df..261df06 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{BinaryType, IntegerType, StructField, StructType} +import org.apache.spark.sql.types.{BinaryType, Decimal, IntegerType, StructField, StructType} import org.apache.spark.util.Utils @@ -304,6 +304,70 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { collectAndValidate(df, json, "floating_point-double_precision.json") } + test("decimal conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_d", + | "type" : { + | "name" : "decimal", + | "precision" : 38, + | "scale" : 18 + | }, + | "nullable" : true, + | "children" : [ ] + | }, { + | "name" : "b_d", + | "type" : { + | "name" : "decimal", + | "precision" : 38, + | "scale" : 18 + | }, + | "nullable" : true, + | "children" : [ ] + | } ] + | }, + | "batches" : [ { + | "count" : 7, + | "columns" : [ { + | "name" : "a_d", + | "count" : 7, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ + | "1000000000000000000", + | "2000000000000000000", + | "10000000000000000", + | "200000000000000000000", + | "100000000000000", + | "20000000000000000000000", + | "30000000000000000000" ] + | }, { + | "name" : "b_d", + | "count" : 7, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1, 0 ], + | "DATA" : [ + | "1100000000000000000", + | "0", + | "0", + | "2200000000000000000", + | "0", + | "3300000000000000000", + | "0" ] + | } ] + | } ] + |} + """.stripMargin + + val a_d = List(1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0, 30.0).map(Decimal(_)) + val b_d = List(Some(Decimal(1.1)), None, None, Some(Decimal(2.2)), None, Some(Decimal(3.3)), + Some(Decimal("123456789012345678901234567890"))) + val df = a_d.zip(b_d).toDF("a_d", "b_d") + + collectAndValidate(df, json, "decimalData.json") + } + test("index conversion") { val data = List[Int](1, 2, 3, 4, 5, 6) val json = @@ -1153,7 +1217,6 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { assert(msg.getCause.getClass === classOf[UnsupportedOperationException]) } - runUnsupported { decimalData.toArrowPayload.collect() } runUnsupported { mapData.toDF().toArrowPayload.collect() } runUnsupported { complexData.toArrowPayload.collect() } } http://git-wip-us.apache.org/repos/asf/spark/blob/eb386be1/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala index a71e30a..508c116 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala @@ -49,6 +49,7 @@ class ArrowWriterSuite extends SparkFunSuite { case LongType => reader.getLong(rowId) case FloatType => reader.getFloat(rowId) case DoubleType => reader.getDouble(rowId) + case DecimalType.Fixed(precision, scale) => reader.getDecimal(rowId, precision, scale) case StringType => reader.getUTF8String(rowId) case BinaryType => reader.getBinary(rowId) case DateType => reader.getInt(rowId) @@ -66,6 +67,7 @@ class ArrowWriterSuite extends SparkFunSuite { check(LongType, Seq(1L, 2L, null, 4L)) check(FloatType, Seq(1.0f, 2.0f, null, 4.0f)) check(DoubleType, Seq(1.0d, 2.0d, null, 4.0d)) + check(DecimalType.SYSTEM_DEFAULT, Seq(Decimal(1), Decimal(2), null, Decimal(4))) check(StringType, Seq("a", "b", null, "d").map(UTF8String.fromString)) check(BinaryType, Seq("a".getBytes(), "b".getBytes(), null, "d".getBytes())) check(DateType, Seq(0, 1, 2, null, 4)) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org