Repository: spark Updated Branches: refs/heads/master aa48164a4 -> d1ca634db
[SPARK-12300] [SQL] [PYSPARK] fix schema inferance on local collections Current schema inference for local python collections halts as soon as there are no NullTypes. This is different than when we specify a sampling ratio of 1.0 on a distributed collection. This could result in incomplete schema information. Author: Holden Karau <hol...@us.ibm.com> Closes #10275 from holdenk/SPARK-12300-fix-schmea-inferance-on-local-collections. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d1ca634d Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d1ca634d Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d1ca634d Branch: refs/heads/master Commit: d1ca634db4ca9db7f0ba7ca38a0e03bcbfec23c9 Parents: aa48164 Author: Holden Karau <hol...@us.ibm.com> Authored: Wed Dec 30 11:14:47 2015 -0800 Committer: Davies Liu <davies....@gmail.com> Committed: Wed Dec 30 11:14:47 2015 -0800 ---------------------------------------------------------------------- python/pyspark/sql/context.py | 10 +++------- python/pyspark/sql/tests.py | 11 +++++++++++ 2 files changed, 14 insertions(+), 7 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/d1ca634d/python/pyspark/sql/context.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index b05aa2f..ba6915a 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -18,6 +18,7 @@ import sys import warnings import json +from functools import reduce if sys.version >= '3': basestring = unicode = str @@ -236,14 +237,9 @@ class SQLContext(object): if type(first) is dict: warnings.warn("inferring schema from dict is deprecated," "please use pyspark.sql.Row instead") - schema = _infer_schema(first) + schema = reduce(_merge_type, map(_infer_schema, data)) if _has_nulltype(schema): - for r in data: - schema = _merge_type(schema, _infer_schema(r)) - if not _has_nulltype(schema): - break - else: - raise ValueError("Some of types cannot be determined after inferring") + raise ValueError("Some of types cannot be determined after inferring") return schema def _inferSchema(self, rdd, samplingRatio=None): http://git-wip-us.apache.org/repos/asf/spark/blob/d1ca634d/python/pyspark/sql/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 9f5f7cf..10b9917 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -353,6 +353,17 @@ class SQLTests(ReusedPySparkTestCase): df3 = self.sqlCtx.createDataFrame(rdd, df.schema) self.assertEqual(10, df3.count()) + def test_infer_schema_to_local(self): + input = [{"a": 1}, {"b": "coffee"}] + rdd = self.sc.parallelize(input) + df = self.sqlCtx.createDataFrame(input) + df2 = self.sqlCtx.createDataFrame(rdd, samplingRatio=1.0) + self.assertEqual(df.schema, df2.schema) + + rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x)) + df3 = self.sqlCtx.createDataFrame(rdd, df.schema) + self.assertEqual(10, df3.count()) + def test_serialize_nested_array_and_map(self): d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})] rdd = self.sc.parallelize(d) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org