Repository: spark
Updated Branches:
  refs/heads/master 0e6833006 -> eb386be1e


[SPARK-21552][SQL] Add DecimalType support to ArrowWriter.

## What changes were proposed in this pull request?

Decimal type is not yet supported in `ArrowWriter`.
This is adding the decimal type support.

## How was this patch tested?

Added a test to `ArrowConvertersSuite`.

Author: Takuya UESHIN <ues...@databricks.com>

Closes #18754 from ueshin/issues/SPARK-21552.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/eb386be1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/eb386be1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/eb386be1

Branch: refs/heads/master
Commit: eb386be1ed383323da6e757f63f3b8a7ced38cc4
Parents: 0e68330
Author: Takuya UESHIN <ues...@databricks.com>
Authored: Tue Dec 26 21:37:25 2017 +0900
Committer: hyukjinkwon <gurwls...@gmail.com>
Committed: Tue Dec 26 21:37:25 2017 +0900

----------------------------------------------------------------------
 python/pyspark/sql/tests.py                     | 61 ++++++++++++------
 python/pyspark/sql/types.py                     |  2 +-
 .../spark/sql/execution/arrow/ArrowWriter.scala | 21 ++++++
 .../execution/arrow/ArrowConvertersSuite.scala  | 67 +++++++++++++++++++-
 .../sql/execution/arrow/ArrowWriterSuite.scala  |  2 +
 5 files changed, 131 insertions(+), 22 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/eb386be1/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index b977160..b811a0f 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -3142,6 +3142,7 @@ class ArrowTests(ReusedSQLTestCase):
     @classmethod
     def setUpClass(cls):
         from datetime import datetime
+        from decimal import Decimal
         ReusedSQLTestCase.setUpClass()
 
         # Synchronize default timezone between Python and Java
@@ -3158,11 +3159,15 @@ class ArrowTests(ReusedSQLTestCase):
             StructField("3_long_t", LongType(), True),
             StructField("4_float_t", FloatType(), True),
             StructField("5_double_t", DoubleType(), True),
-            StructField("6_date_t", DateType(), True),
-            StructField("7_timestamp_t", TimestampType(), True)])
-        cls.data = [(u"a", 1, 10, 0.2, 2.0, datetime(1969, 1, 1), 
datetime(1969, 1, 1, 1, 1, 1)),
-                    (u"b", 2, 20, 0.4, 4.0, datetime(2012, 2, 2), 
datetime(2012, 2, 2, 2, 2, 2)),
-                    (u"c", 3, 30, 0.8, 6.0, datetime(2100, 3, 3), 
datetime(2100, 3, 3, 3, 3, 3))]
+            StructField("6_decimal_t", DecimalType(38, 18), True),
+            StructField("7_date_t", DateType(), True),
+            StructField("8_timestamp_t", TimestampType(), True)])
+        cls.data = [(u"a", 1, 10, 0.2, 2.0, Decimal("2.0"),
+                     datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
+                    (u"b", 2, 20, 0.4, 4.0, Decimal("4.0"),
+                     datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)),
+                    (u"c", 3, 30, 0.8, 6.0, Decimal("6.0"),
+                     datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))]
 
     @classmethod
     def tearDownClass(cls):
@@ -3190,10 +3195,11 @@ class ArrowTests(ReusedSQLTestCase):
         return pd.DataFrame(data=data_dict)
 
     def test_unsupported_datatype(self):
-        schema = StructType([StructField("decimal", DecimalType(), True)])
+        schema = StructType([StructField("map", MapType(StringType(), 
IntegerType()), True)])
         df = self.spark.createDataFrame([(None,)], schema=schema)
         with QuietTest(self.sc):
-            self.assertRaises(Exception, lambda: df.toPandas())
+            with self.assertRaisesRegexp(Exception, 'Unsupported data type'):
+                df.toPandas()
 
     def test_null_conversion(self):
         df_null = self.spark.createDataFrame([tuple([None for _ in 
range(len(self.data[0]))])] +
@@ -3293,7 +3299,7 @@ class ArrowTests(ReusedSQLTestCase):
             self.assertNotEqual(result_ny, result_la)
 
             # Correct result_la by adjusting 3 hours difference between Los 
Angeles and New York
-            result_la_corrected = [Row(**{k: v - timedelta(hours=3) if k == 
'7_timestamp_t' else v
+            result_la_corrected = [Row(**{k: v - timedelta(hours=3) if k == 
'8_timestamp_t' else v
                                           for k, v in row.asDict().items()})
                                    for row in result_la]
             self.assertEqual(result_ny, result_la_corrected)
@@ -3317,11 +3323,11 @@ class ArrowTests(ReusedSQLTestCase):
     def test_createDataFrame_with_names(self):
         pdf = self.create_pandas_data_frame()
         # Test that schema as a list of column names gets applied
-        df = self.spark.createDataFrame(pdf, schema=list('abcdefg'))
-        self.assertEquals(df.schema.fieldNames(), list('abcdefg'))
+        df = self.spark.createDataFrame(pdf, schema=list('abcdefgh'))
+        self.assertEquals(df.schema.fieldNames(), list('abcdefgh'))
         # Test that schema as tuple of column names gets applied
-        df = self.spark.createDataFrame(pdf, schema=tuple('abcdefg'))
-        self.assertEquals(df.schema.fieldNames(), list('abcdefg'))
+        df = self.spark.createDataFrame(pdf, schema=tuple('abcdefgh'))
+        self.assertEquals(df.schema.fieldNames(), list('abcdefgh'))
 
     def test_createDataFrame_column_name_encoding(self):
         import pandas as pd
@@ -3344,7 +3350,7 @@ class ArrowTests(ReusedSQLTestCase):
         # Some series get converted for Spark to consume, this makes sure 
input is unchanged
         pdf = self.create_pandas_data_frame()
         # Use a nanosecond value to make sure it is not truncated
-        pdf.ix[0, '7_timestamp_t'] = pd.Timestamp(1)
+        pdf.ix[0, '8_timestamp_t'] = pd.Timestamp(1)
         # Integers with nulls will get NaNs filled with 0 and will be casted
         pdf.ix[1, '2_int_t'] = None
         pdf_copy = pdf.copy(deep=True)
@@ -3514,6 +3520,7 @@ class VectorizedUDFTests(ReusedSQLTestCase):
             col('id').alias('long'),
             col('id').cast('float').alias('float'),
             col('id').cast('double').alias('double'),
+            col('id').cast('decimal').alias('decimal'),
             col('id').cast('boolean').alias('bool'))
         f = lambda x: x
         str_f = pandas_udf(f, StringType())
@@ -3521,10 +3528,12 @@ class VectorizedUDFTests(ReusedSQLTestCase):
         long_f = pandas_udf(f, LongType())
         float_f = pandas_udf(f, FloatType())
         double_f = pandas_udf(f, DoubleType())
+        decimal_f = pandas_udf(f, DecimalType())
         bool_f = pandas_udf(f, BooleanType())
         res = df.select(str_f(col('str')), int_f(col('int')),
                         long_f(col('long')), float_f(col('float')),
-                        double_f(col('double')), bool_f(col('bool')))
+                        double_f(col('double')), decimal_f('decimal'),
+                        bool_f(col('bool')))
         self.assertEquals(df.collect(), res.collect())
 
     def test_vectorized_udf_null_boolean(self):
@@ -3590,6 +3599,16 @@ class VectorizedUDFTests(ReusedSQLTestCase):
         res = df.select(double_f(col('double')))
         self.assertEquals(df.collect(), res.collect())
 
+    def test_vectorized_udf_null_decimal(self):
+        from decimal import Decimal
+        from pyspark.sql.functions import pandas_udf, col
+        data = [(Decimal(3.0),), (Decimal(5.0),), (Decimal(-1.0),), (None,)]
+        schema = StructType().add("decimal", DecimalType(38, 18))
+        df = self.spark.createDataFrame(data, schema)
+        decimal_f = pandas_udf(lambda x: x, DecimalType(38, 18))
+        res = df.select(decimal_f(col('decimal')))
+        self.assertEquals(df.collect(), res.collect())
+
     def test_vectorized_udf_null_string(self):
         from pyspark.sql.functions import pandas_udf, col
         data = [("foo",), (None,), ("bar",), ("bar",)]
@@ -3607,6 +3626,7 @@ class VectorizedUDFTests(ReusedSQLTestCase):
             col('id').alias('long'),
             col('id').cast('float').alias('float'),
             col('id').cast('double').alias('double'),
+            col('id').cast('decimal').alias('decimal'),
             col('id').cast('boolean').alias('bool'))
         f = lambda x: x
         str_f = pandas_udf(f, 'string')
@@ -3614,10 +3634,12 @@ class VectorizedUDFTests(ReusedSQLTestCase):
         long_f = pandas_udf(f, 'long')
         float_f = pandas_udf(f, 'float')
         double_f = pandas_udf(f, 'double')
+        decimal_f = pandas_udf(f, 'decimal(38, 18)')
         bool_f = pandas_udf(f, 'boolean')
         res = df.select(str_f(col('str')), int_f(col('int')),
                         long_f(col('long')), float_f(col('float')),
-                        double_f(col('double')), bool_f(col('bool')))
+                        double_f(col('double')), decimal_f('decimal'),
+                        bool_f(col('bool')))
         self.assertEquals(df.collect(), res.collect())
 
     def test_vectorized_udf_complex(self):
@@ -3713,12 +3735,12 @@ class VectorizedUDFTests(ReusedSQLTestCase):
 
     def test_vectorized_udf_unsupported_types(self):
         from pyspark.sql.functions import pandas_udf, col
-        schema = StructType([StructField("dt", DecimalType(), True)])
+        schema = StructType([StructField("map", MapType(StringType(), 
IntegerType()), True)])
         df = self.spark.createDataFrame([(None,)], schema=schema)
-        f = pandas_udf(lambda x: x, DecimalType())
+        f = pandas_udf(lambda x: x, MapType(StringType(), IntegerType()))
         with QuietTest(self.sc):
             with self.assertRaisesRegexp(Exception, 'Unsupported data type'):
-                df.select(f(col('dt'))).collect()
+                df.select(f(col('map'))).collect()
 
     def test_vectorized_udf_null_date(self):
         from pyspark.sql.functions import pandas_udf, col
@@ -4012,7 +4034,8 @@ class GroupbyApplyTests(ReusedSQLTestCase):
     def test_unsupported_types(self):
         from pyspark.sql.functions import pandas_udf, col, PandasUDFType
         schema = StructType(
-            [StructField("id", LongType(), True), StructField("dt", 
DecimalType(), True)])
+            [StructField("id", LongType(), True),
+             StructField("map", MapType(StringType(), IntegerType()), True)])
         df = self.spark.createDataFrame([(1, None,)], schema=schema)
         f = pandas_udf(lambda x: x, df.schema, PandasUDFType.GROUP_MAP)
         with QuietTest(self.sc):

http://git-wip-us.apache.org/repos/asf/spark/blob/eb386be1/python/pyspark/sql/types.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 063264a..02b2457 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -1617,7 +1617,7 @@ def to_arrow_type(dt):
     elif type(dt) == DoubleType:
         arrow_type = pa.float64()
     elif type(dt) == DecimalType:
-        arrow_type = pa.decimal(dt.precision, dt.scale)
+        arrow_type = pa.decimal128(dt.precision, dt.scale)
     elif type(dt) == StringType:
         arrow_type = pa.string()
     elif type(dt) == DateType:

http://git-wip-us.apache.org/repos/asf/spark/blob/eb386be1/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
index 0258056..22b6351 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
@@ -53,6 +53,8 @@ object ArrowWriter {
       case (LongType, vector: BigIntVector) => new LongWriter(vector)
       case (FloatType, vector: Float4Vector) => new FloatWriter(vector)
       case (DoubleType, vector: Float8Vector) => new DoubleWriter(vector)
+      case (DecimalType.Fixed(precision, scale), vector: DecimalVector) =>
+        new DecimalWriter(vector, precision, scale)
       case (StringType, vector: VarCharVector) => new StringWriter(vector)
       case (BinaryType, vector: VarBinaryVector) => new BinaryWriter(vector)
       case (DateType, vector: DateDayVector) => new DateWriter(vector)
@@ -214,6 +216,25 @@ private[arrow] class DoubleWriter(val valueVector: 
Float8Vector) extends ArrowFi
   }
 }
 
+private[arrow] class DecimalWriter(
+    val valueVector: DecimalVector,
+    precision: Int,
+    scale: Int) extends ArrowFieldWriter {
+
+  override def setNull(): Unit = {
+    valueVector.setNull(count)
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    val decimal = input.getDecimal(ordinal, precision, scale)
+    if (decimal.changePrecision(precision, scale)) {
+      valueVector.setSafe(count, decimal.toJavaBigDecimal)
+    } else {
+      setNull()
+    }
+  }
+}
+
 private[arrow] class StringWriter(val valueVector: VarCharVector) extends 
ArrowFieldWriter {
 
   override def setNull(): Unit = {

http://git-wip-us.apache.org/repos/asf/spark/blob/eb386be1/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
index fd5a3df..261df06 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
@@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.sql.types.{BinaryType, IntegerType, StructField, 
StructType}
+import org.apache.spark.sql.types.{BinaryType, Decimal, IntegerType, 
StructField, StructType}
 import org.apache.spark.util.Utils
 
 
@@ -304,6 +304,70 @@ class ArrowConvertersSuite extends SharedSQLContext with 
BeforeAndAfterAll {
     collectAndValidate(df, json, "floating_point-double_precision.json")
   }
 
+  test("decimal conversion") {
+    val json =
+      s"""
+         |{
+         |  "schema" : {
+         |    "fields" : [ {
+         |      "name" : "a_d",
+         |      "type" : {
+         |        "name" : "decimal",
+         |        "precision" : 38,
+         |        "scale" : 18
+         |      },
+         |      "nullable" : true,
+         |      "children" : [ ]
+         |    }, {
+         |      "name" : "b_d",
+         |      "type" : {
+         |        "name" : "decimal",
+         |        "precision" : 38,
+         |        "scale" : 18
+         |      },
+         |      "nullable" : true,
+         |      "children" : [ ]
+         |    } ]
+         |  },
+         |  "batches" : [ {
+         |    "count" : 7,
+         |    "columns" : [ {
+         |      "name" : "a_d",
+         |      "count" : 7,
+         |      "VALIDITY" : [ 1, 1, 1, 1, 1, 1, 1 ],
+         |      "DATA" : [
+         |        "1000000000000000000",
+         |        "2000000000000000000",
+         |        "10000000000000000",
+         |        "200000000000000000000",
+         |        "100000000000000",
+         |        "20000000000000000000000",
+         |        "30000000000000000000" ]
+         |    }, {
+         |      "name" : "b_d",
+         |      "count" : 7,
+         |      "VALIDITY" : [ 1, 0, 0, 1, 0, 1, 0 ],
+         |      "DATA" : [
+         |        "1100000000000000000",
+         |        "0",
+         |        "0",
+         |        "2200000000000000000",
+         |        "0",
+         |        "3300000000000000000",
+         |        "0" ]
+         |    } ]
+         |  } ]
+         |}
+       """.stripMargin
+
+    val a_d = List(1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0, 
30.0).map(Decimal(_))
+    val b_d = List(Some(Decimal(1.1)), None, None, Some(Decimal(2.2)), None, 
Some(Decimal(3.3)),
+      Some(Decimal("123456789012345678901234567890")))
+    val df = a_d.zip(b_d).toDF("a_d", "b_d")
+
+    collectAndValidate(df, json, "decimalData.json")
+  }
+
   test("index conversion") {
     val data = List[Int](1, 2, 3, 4, 5, 6)
     val json =
@@ -1153,7 +1217,6 @@ class ArrowConvertersSuite extends SharedSQLContext with 
BeforeAndAfterAll {
       assert(msg.getCause.getClass === classOf[UnsupportedOperationException])
     }
 
-    runUnsupported { decimalData.toArrowPayload.collect() }
     runUnsupported { mapData.toDF().toArrowPayload.collect() }
     runUnsupported { complexData.toArrowPayload.collect() }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/eb386be1/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
index a71e30a..508c116 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
@@ -49,6 +49,7 @@ class ArrowWriterSuite extends SparkFunSuite {
             case LongType => reader.getLong(rowId)
             case FloatType => reader.getFloat(rowId)
             case DoubleType => reader.getDouble(rowId)
+            case DecimalType.Fixed(precision, scale) => 
reader.getDecimal(rowId, precision, scale)
             case StringType => reader.getUTF8String(rowId)
             case BinaryType => reader.getBinary(rowId)
             case DateType => reader.getInt(rowId)
@@ -66,6 +67,7 @@ class ArrowWriterSuite extends SparkFunSuite {
     check(LongType, Seq(1L, 2L, null, 4L))
     check(FloatType, Seq(1.0f, 2.0f, null, 4.0f))
     check(DoubleType, Seq(1.0d, 2.0d, null, 4.0d))
+    check(DecimalType.SYSTEM_DEFAULT, Seq(Decimal(1), Decimal(2), null, 
Decimal(4)))
     check(StringType, Seq("a", "b", null, "d").map(UTF8String.fromString))
     check(BinaryType, Seq("a".getBytes(), "b".getBytes(), null, 
"d".getBytes()))
     check(DateType, Seq(0, 1, 2, null, 4))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to