Repository: spark Updated Branches: refs/heads/master 770d8153a -> ec5c4279e
[SPARK-4866] support StructType as key in MapType This PR brings support of using StructType(and other hashable types) as key in MapType. Author: Davies Liu <dav...@databricks.com> Closes #3714 from davies/fix_struct_in_map and squashes the following commits: 68585d7 [Davies Liu] fix primitive types in MapType 9601534 [Davies Liu] support StructType as key in MapType Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ec5c4279 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ec5c4279 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ec5c4279 Branch: refs/heads/master Commit: ec5c4279edabd5ea2b187aff6662ac07ed825b08 Parents: 770d815 Author: Davies Liu <dav...@databricks.com> Authored: Tue Dec 16 21:23:28 2014 -0800 Committer: Michael Armbrust <mich...@databricks.com> Committed: Tue Dec 16 21:23:28 2014 -0800 ---------------------------------------------------------------------- python/pyspark/sql.py | 17 ++++++++++------- python/pyspark/tests.py | 8 ++++++++ .../apache/spark/sql/execution/pythonUdfs.scala | 2 +- 3 files changed, 19 insertions(+), 8 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/ec5c4279/python/pyspark/sql.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index ae28847..1ee0b28 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -788,8 +788,9 @@ def _create_converter(dataType): return lambda row: map(conv, row) elif isinstance(dataType, MapType): - conv = _create_converter(dataType.valueType) - return lambda row: dict((k, conv(v)) for k, v in row.iteritems()) + kconv = _create_converter(dataType.keyType) + vconv = _create_converter(dataType.valueType) + return lambda row: dict((kconv(k), vconv(v)) for k, v in row.iteritems()) elif isinstance(dataType, NullType): return lambda x: None @@ -944,7 +945,7 @@ def _infer_schema_type(obj, dataType): elif isinstance(dataType, MapType): k, v = obj.iteritems().next() - return MapType(_infer_type(k), + return MapType(_infer_schema_type(k, dataType.keyType), _infer_schema_type(v, dataType.valueType)) elif isinstance(dataType, StructType): @@ -1085,7 +1086,7 @@ def _has_struct_or_date(dt): elif isinstance(dt, ArrayType): return _has_struct_or_date(dt.elementType) elif isinstance(dt, MapType): - return _has_struct_or_date(dt.valueType) + return _has_struct_or_date(dt.keyType) or _has_struct_or_date(dt.valueType) elif isinstance(dt, DateType): return True elif isinstance(dt, UserDefinedType): @@ -1148,12 +1149,13 @@ def _create_cls(dataType): return List elif isinstance(dataType, MapType): - cls = _create_cls(dataType.valueType) + kcls = _create_cls(dataType.keyType) + vcls = _create_cls(dataType.valueType) def Dict(d): if d is None: return - return dict((k, _create_object(cls, v)) for k, v in d.items()) + return dict((_create_object(kcls, k), _create_object(vcls, v)) for k, v in d.items()) return Dict @@ -1164,7 +1166,8 @@ def _create_cls(dataType): return lambda datum: dataType.deserialize(datum) elif not isinstance(dataType, StructType): - raise Exception("unexpected data type: %s" % dataType) + # no wrapper for primitive types + return lambda x: x class Row(tuple): http://git-wip-us.apache.org/repos/asf/spark/blob/ec5c4279/python/pyspark/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index bca52a7..b474fcf 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -923,6 +923,14 @@ class SQLTests(ReusedPySparkTestCase): result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'") self.assertEqual(1, result.first()[0]) + def test_struct_in_map(self): + d = [Row(m={Row(i=1): Row(s="")})] + rdd = self.sc.parallelize(d) + srdd = self.sqlCtx.inferSchema(rdd) + k, v = srdd.first().m.items()[0] + self.assertEqual(1, k.i) + self.assertEqual("", v.s) + def test_convert_row_to_dict(self): row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}) self.assertEqual(1, row.asDict()['l'][0].a) http://git-wip-us.apache.org/repos/asf/spark/blob/ec5c4279/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 2b4a88d..5a41399 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -132,7 +132,7 @@ object EvaluatePython { arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType)) case (obj: Map[_, _], mt: MapType) => obj.map { - case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type + case (k, v) => (toJava(k, mt.keyType), toJava(v, mt.valueType)) }.asJava case (ud, udt: UserDefinedType[_]) => toJava(udt.serialize(ud), udt.sqlType) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org