Repository: spark
Updated Branches:
  refs/heads/master 71d65a321 -> 3e40eb3f1


[SPARK-22566][PYTHON] Better error message for `_merge_type` in Pandas to Spark 
DF conversion

## What changes were proposed in this pull request?

It provides a better error message when doing 
`spark_session.createDataFrame(pandas_df)` with no schema and an error occurs 
in the schema inference due to incompatible types.

The Pandas column names are propagated down and the error message mentions 
which column had the merging error.

https://issues.apache.org/jira/browse/SPARK-22566

## How was this patch tested?

Manually in the `./bin/pyspark` console, and with new tests: 
`./python/run-tests`

<img width="873" alt="screen shot 2017-11-21 at 13 29 49" 
src="https://user-images.githubusercontent.com/3977115/33080121-382274e0-cecf-11e7-808f-057a65bb7b00.png";>

I state that the contribution is my original work and that I license the work 
to the Apache Spark project under the project’s open source license.

Author: Guilherme Berger <gber...@palantir.com>

Closes #19792 from gberger/master.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3e40eb3f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3e40eb3f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3e40eb3f

Branch: refs/heads/master
Commit: 3e40eb3f1ffac3d2f49459a801e3ce171ed34091
Parents: 71d65a3
Author: Guilherme Berger <gber...@palantir.com>
Authored: Mon Jan 8 14:32:05 2018 +0900
Committer: Takuya UESHIN <ues...@databricks.com>
Committed: Mon Jan 8 14:32:05 2018 +0900

----------------------------------------------------------------------
 python/pyspark/sql/session.py |  17 ++++---
 python/pyspark/sql/tests.py   | 100 +++++++++++++++++++++++++++++++++++++
 python/pyspark/sql/types.py   |  28 ++++++++---
 3 files changed, 129 insertions(+), 16 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3e40eb3f/python/pyspark/sql/session.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 6e5eec4..6052fa9 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -325,11 +325,12 @@ class SparkSession(object):
 
         return DataFrame(jdf, self._wrapped)
 
-    def _inferSchemaFromList(self, data):
+    def _inferSchemaFromList(self, data, names=None):
         """
         Infer schema from list of Row or tuple.
 
         :param data: list of Row or tuple
+        :param names: list of column names
         :return: :class:`pyspark.sql.types.StructType`
         """
         if not data:
@@ -338,12 +339,12 @@ class SparkSession(object):
         if type(first) is dict:
             warnings.warn("inferring schema from dict is deprecated,"
                           "please use pyspark.sql.Row instead")
-        schema = reduce(_merge_type, map(_infer_schema, data))
+        schema = reduce(_merge_type, (_infer_schema(row, names) for row in 
data))
         if _has_nulltype(schema):
             raise ValueError("Some of types cannot be determined after 
inferring")
         return schema
 
-    def _inferSchema(self, rdd, samplingRatio=None):
+    def _inferSchema(self, rdd, samplingRatio=None, names=None):
         """
         Infer schema from an RDD of Row or tuple.
 
@@ -360,10 +361,10 @@ class SparkSession(object):
                           "Use pyspark.sql.Row instead")
 
         if samplingRatio is None:
-            schema = _infer_schema(first)
+            schema = _infer_schema(first, names=names)
             if _has_nulltype(schema):
                 for row in rdd.take(100)[1:]:
-                    schema = _merge_type(schema, _infer_schema(row))
+                    schema = _merge_type(schema, _infer_schema(row, 
names=names))
                     if not _has_nulltype(schema):
                         break
                 else:
@@ -372,7 +373,7 @@ class SparkSession(object):
         else:
             if samplingRatio < 0.99:
                 rdd = rdd.sample(False, float(samplingRatio))
-            schema = rdd.map(_infer_schema).reduce(_merge_type)
+            schema = rdd.map(lambda row: _infer_schema(row, 
names)).reduce(_merge_type)
         return schema
 
     def _createFromRDD(self, rdd, schema, samplingRatio):
@@ -380,7 +381,7 @@ class SparkSession(object):
         Create an RDD for DataFrame from an existing RDD, returns the RDD and 
schema.
         """
         if schema is None or isinstance(schema, (list, tuple)):
-            struct = self._inferSchema(rdd, samplingRatio)
+            struct = self._inferSchema(rdd, samplingRatio, names=schema)
             converter = _create_converter(struct)
             rdd = rdd.map(converter)
             if isinstance(schema, (list, tuple)):
@@ -406,7 +407,7 @@ class SparkSession(object):
             data = list(data)
 
         if schema is None or isinstance(schema, (list, tuple)):
-            struct = self._inferSchemaFromList(data)
+            struct = self._inferSchemaFromList(data, names=schema)
             converter = _create_converter(struct)
             data = map(converter, data)
             if isinstance(schema, (list, tuple)):

http://git-wip-us.apache.org/repos/asf/spark/blob/3e40eb3f/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 122a65b..13576ff 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -68,6 +68,7 @@ from pyspark.sql.types import *
 from pyspark.sql.types import UserDefinedType, _infer_type, _make_type_verifier
 from pyspark.sql.types import _array_signed_int_typecode_ctype_mappings, 
_array_type_mappings
 from pyspark.sql.types import _array_unsigned_int_typecode_ctype_mappings
+from pyspark.sql.types import _merge_type
 from pyspark.tests import QuietTest, ReusedPySparkTestCase, SparkSubmitTests
 from pyspark.sql.functions import UserDefinedFunction, sha2, lit
 from pyspark.sql.window import Window
@@ -898,6 +899,15 @@ class SQLTests(ReusedSQLTestCase):
         result = self.spark.sql("SELECT l[0].a from test2 where d['key'].d = 
'2'")
         self.assertEqual(1, result.head()[0])
 
+    def test_infer_schema_not_enough_names(self):
+        df = self.spark.createDataFrame([["a", "b"]], ["col1"])
+        self.assertEqual(df.columns, ['col1', '_2'])
+
+    def test_infer_schema_fails(self):
+        with self.assertRaisesRegexp(TypeError, 'field a'):
+            
self.spark.createDataFrame(self.spark.sparkContext.parallelize([[1, 1], ["x", 
1]]),
+                                       schema=["a", "b"], samplingRatio=0.99)
+
     def test_infer_nested_schema(self):
         NestedRow = Row("f1", "f2")
         nestedRdd1 = self.sc.parallelize([NestedRow([1, 2], {"row1": 1.0}),
@@ -918,6 +928,10 @@ class SQLTests(ReusedSQLTestCase):
         df = self.spark.createDataFrame(rdd)
         self.assertEqual(Row(field1=1, field2=u'row1'), df.first())
 
+    def test_create_dataframe_from_dict_respects_schema(self):
+        df = self.spark.createDataFrame([{'a': 1}], ["b"])
+        self.assertEqual(df.columns, ['b'])
+
     def test_create_dataframe_from_objects(self):
         data = [MyObject(1, "1"), MyObject(2, "2")]
         df = self.spark.createDataFrame(data)
@@ -1772,6 +1786,92 @@ class SQLTests(ReusedSQLTestCase):
         self.assertEqual(_infer_type(2**61), LongType())
         self.assertEqual(_infer_type(2**71), LongType())
 
+    def test_merge_type(self):
+        self.assertEqual(_merge_type(LongType(), NullType()), LongType())
+        self.assertEqual(_merge_type(NullType(), LongType()), LongType())
+
+        self.assertEqual(_merge_type(LongType(), LongType()), LongType())
+
+        self.assertEqual(_merge_type(
+            ArrayType(LongType()),
+            ArrayType(LongType())
+        ), ArrayType(LongType()))
+        with self.assertRaisesRegexp(TypeError, 'element in array'):
+            _merge_type(ArrayType(LongType()), ArrayType(DoubleType()))
+
+        self.assertEqual(_merge_type(
+            MapType(StringType(), LongType()),
+            MapType(StringType(), LongType())
+        ), MapType(StringType(), LongType()))
+        with self.assertRaisesRegexp(TypeError, 'key of map'):
+            _merge_type(
+                MapType(StringType(), LongType()),
+                MapType(DoubleType(), LongType()))
+        with self.assertRaisesRegexp(TypeError, 'value of map'):
+            _merge_type(
+                MapType(StringType(), LongType()),
+                MapType(StringType(), DoubleType()))
+
+        self.assertEqual(_merge_type(
+            StructType([StructField("f1", LongType()), StructField("f2", 
StringType())]),
+            StructType([StructField("f1", LongType()), StructField("f2", 
StringType())])
+        ), StructType([StructField("f1", LongType()), StructField("f2", 
StringType())]))
+        with self.assertRaisesRegexp(TypeError, 'field f1'):
+            _merge_type(
+                StructType([StructField("f1", LongType()), StructField("f2", 
StringType())]),
+                StructType([StructField("f1", DoubleType()), StructField("f2", 
StringType())]))
+
+        self.assertEqual(_merge_type(
+            StructType([StructField("f1", StructType([StructField("f2", 
LongType())]))]),
+            StructType([StructField("f1", StructType([StructField("f2", 
LongType())]))])
+        ), StructType([StructField("f1", StructType([StructField("f2", 
LongType())]))]))
+        with self.assertRaisesRegexp(TypeError, 'field f2 in field f1'):
+            _merge_type(
+                StructType([StructField("f1", StructType([StructField("f2", 
LongType())]))]),
+                StructType([StructField("f1", StructType([StructField("f2", 
StringType())]))]))
+
+        self.assertEqual(_merge_type(
+            StructType([StructField("f1", ArrayType(LongType())), 
StructField("f2", StringType())]),
+            StructType([StructField("f1", ArrayType(LongType())), 
StructField("f2", StringType())])
+        ), StructType([StructField("f1", ArrayType(LongType())), 
StructField("f2", StringType())]))
+        with self.assertRaisesRegexp(TypeError, 'element in array field f1'):
+            _merge_type(
+                StructType([
+                    StructField("f1", ArrayType(LongType())),
+                    StructField("f2", StringType())]),
+                StructType([
+                    StructField("f1", ArrayType(DoubleType())),
+                    StructField("f2", StringType())]))
+
+        self.assertEqual(_merge_type(
+            StructType([
+                StructField("f1", MapType(StringType(), LongType())),
+                StructField("f2", StringType())]),
+            StructType([
+                StructField("f1", MapType(StringType(), LongType())),
+                StructField("f2", StringType())])
+        ), StructType([
+            StructField("f1", MapType(StringType(), LongType())),
+            StructField("f2", StringType())]))
+        with self.assertRaisesRegexp(TypeError, 'value of map field f1'):
+            _merge_type(
+                StructType([
+                    StructField("f1", MapType(StringType(), LongType())),
+                    StructField("f2", StringType())]),
+                StructType([
+                    StructField("f1", MapType(StringType(), DoubleType())),
+                    StructField("f2", StringType())]))
+
+        self.assertEqual(_merge_type(
+            StructType([StructField("f1", ArrayType(MapType(StringType(), 
LongType())))]),
+            StructType([StructField("f1", ArrayType(MapType(StringType(), 
LongType())))])
+        ), StructType([StructField("f1", ArrayType(MapType(StringType(), 
LongType())))]))
+        with self.assertRaisesRegexp(TypeError, 'key of map element in array 
field f1'):
+            _merge_type(
+                StructType([StructField("f1", ArrayType(MapType(StringType(), 
LongType())))]),
+                StructType([StructField("f1", ArrayType(MapType(DoubleType(), 
LongType())))])
+            )
+
     def test_filter_with_datetime(self):
         time = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000)
         date = time.date()

http://git-wip-us.apache.org/repos/asf/spark/blob/3e40eb3f/python/pyspark/sql/types.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 146e673..0dc5823 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -1073,7 +1073,7 @@ def _infer_type(obj):
             raise TypeError("not supported type: %s" % type(obj))
 
 
-def _infer_schema(row):
+def _infer_schema(row, names=None):
     """Infer the schema from dict/namedtuple/object"""
     if isinstance(row, dict):
         items = sorted(row.items())
@@ -1084,7 +1084,10 @@ def _infer_schema(row):
         elif hasattr(row, "_fields"):  # namedtuple
             items = zip(row._fields, tuple(row))
         else:
-            names = ['_%d' % i for i in range(1, len(row) + 1)]
+            if names is None:
+                names = ['_%d' % i for i in range(1, len(row) + 1)]
+            elif len(names) < len(row):
+                names.extend('_%d' % i for i in range(len(names) + 1, len(row) 
+ 1))
             items = zip(names, row)
 
     elif hasattr(row, "__dict__"):  # object
@@ -1109,19 +1112,27 @@ def _has_nulltype(dt):
         return isinstance(dt, NullType)
 
 
-def _merge_type(a, b):
+def _merge_type(a, b, name=None):
+    if name is None:
+        new_msg = lambda msg: msg
+        new_name = lambda n: "field %s" % n
+    else:
+        new_msg = lambda msg: "%s: %s" % (name, msg)
+        new_name = lambda n: "field %s in %s" % (n, name)
+
     if isinstance(a, NullType):
         return b
     elif isinstance(b, NullType):
         return a
     elif type(a) is not type(b):
         # TODO: type cast (such as int -> long)
-        raise TypeError("Can not merge type %s and %s" % (type(a), type(b)))
+        raise TypeError(new_msg("Can not merge type %s and %s" % (type(a), 
type(b))))
 
     # same type
     if isinstance(a, StructType):
         nfs = dict((f.name, f.dataType) for f in b.fields)
-        fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, 
NullType())))
+        fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, 
NullType()),
+                                                  name=new_name(f.name)))
                   for f in a.fields]
         names = set([f.name for f in fields])
         for n in nfs:
@@ -1130,11 +1141,12 @@ def _merge_type(a, b):
         return StructType(fields)
 
     elif isinstance(a, ArrayType):
-        return ArrayType(_merge_type(a.elementType, b.elementType), True)
+        return ArrayType(_merge_type(a.elementType, b.elementType,
+                                     name='element in array %s' % name), True)
 
     elif isinstance(a, MapType):
-        return MapType(_merge_type(a.keyType, b.keyType),
-                       _merge_type(a.valueType, b.valueType),
+        return MapType(_merge_type(a.keyType, b.keyType, name='key of map %s' 
% name),
+                       _merge_type(a.valueType, b.valueType, name='value of 
map %s' % name),
                        True)
     else:
         return a


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to