Repository: spark
Updated Branches:
  refs/heads/master be7ef0676 -> afae9766f


[SPARK-8070] [SQL] [PYSPARK] avoid spark jobs in createDataFrame

Avoid the unnecessary jobs when infer schema from list.

cc yhuai mengxr

Author: Davies Liu <dav...@databricks.com>

Closes #6606 from davies/improve_create and squashes the following commits:

a5928bf [Davies Liu] Update MimaExcludes.scala
62da911 [Davies Liu] fix mima
bab4d7d [Davies Liu] Merge branch 'improve_create' of github.com:davies/spark 
into improve_create
eee44a8 [Davies Liu] Merge branch 'master' of github.com:apache/spark into 
improve_create
8d9292d [Davies Liu] Update context.py
eb24531 [Davies Liu] Update context.py
c969997 [Davies Liu] bug fix
d5a8ab0 [Davies Liu] fix tests
8c3f10d [Davies Liu] Merge branch 'master' of github.com:apache/spark into 
improve_create
6ea5925 [Davies Liu] address comments
6ceaeff [Davies Liu] avoid spark jobs in createDataFrame


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/afae9766
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/afae9766
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/afae9766

Branch: refs/heads/master
Commit: afae9766f28d2e58297405c39862d20a04267b62
Parents: be7ef06
Author: Davies Liu <dav...@databricks.com>
Authored: Mon Jun 29 13:20:55 2015 -0700
Committer: Reynold Xin <r...@databricks.com>
Committed: Mon Jun 29 13:20:55 2015 -0700

----------------------------------------------------------------------
 python/pyspark/sql/context.py | 64 ++++++++++++++++++++++++++++----------
 python/pyspark/sql/types.py   | 48 ++++++++++++++++------------
 2 files changed, 75 insertions(+), 37 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/afae9766/python/pyspark/sql/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index dc23922..4dda3b4 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -203,7 +203,37 @@ class SQLContext(object):
                                             self._sc._javaAccumulator,
                                             returnType.json())
 
+    def _inferSchemaFromList(self, data):
+        """
+        Infer schema from list of Row or tuple.
+
+        :param data: list of Row or tuple
+        :return: StructType
+        """
+        if not data:
+            raise ValueError("can not infer schema from empty dataset")
+        first = data[0]
+        if type(first) is dict:
+            warnings.warn("inferring schema from dict is deprecated,"
+                          "please use pyspark.sql.Row instead")
+        schema = _infer_schema(first)
+        if _has_nulltype(schema):
+            for r in data:
+                schema = _merge_type(schema, _infer_schema(r))
+                if not _has_nulltype(schema):
+                    break
+            else:
+                raise ValueError("Some of types cannot be determined after 
inferring")
+        return schema
+
     def _inferSchema(self, rdd, samplingRatio=None):
+        """
+        Infer schema from an RDD of Row or tuple.
+
+        :param rdd: an RDD of Row or tuple
+        :param samplingRatio: sampling ratio, or no sampling (default)
+        :return: StructType
+        """
         first = rdd.first()
         if not first:
             raise ValueError("The first row in RDD is empty, "
@@ -322,6 +352,8 @@ class SQLContext(object):
             data = [r.tolist() for r in data.to_records(index=False)]
 
         if not isinstance(data, RDD):
+            if not isinstance(data, list):
+                data = list(data)
             try:
                 # data could be list, tuple, generator ...
                 rdd = self._sc.parallelize(data)
@@ -330,28 +362,26 @@ class SQLContext(object):
         else:
             rdd = data
 
-        if schema is None:
-            schema = self._inferSchema(rdd, samplingRatio)
+        if schema is None or isinstance(schema, (list, tuple)):
+            if isinstance(data, RDD):
+                struct = self._inferSchema(rdd, samplingRatio)
+            else:
+                struct = self._inferSchemaFromList(data)
+            if isinstance(schema, (list, tuple)):
+                for i, name in enumerate(schema):
+                    struct.fields[i].name = name
+            schema = struct
             converter = _create_converter(schema)
             rdd = rdd.map(converter)
 
-        if isinstance(schema, (list, tuple)):
-            first = rdd.first()
-            if not isinstance(first, (list, tuple)):
-                raise TypeError("each row in `rdd` should be list or tuple, "
-                                "but got %r" % type(first))
-            row_cls = Row(*schema)
-            schema = self._inferSchema(rdd.map(lambda r: row_cls(*r)), 
samplingRatio)
-
-        # take the first few rows to verify schema
-        rows = rdd.take(10)
-        # Row() cannot been deserialized by Pyrolite
-        if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ 
== 'Row':
-            rdd = rdd.map(tuple)
+        elif isinstance(schema, StructType):
+            # take the first few rows to verify schema
             rows = rdd.take(10)
+            for row in rows:
+                _verify_type(row, schema)
 
-        for row in rows:
-            _verify_type(row, schema)
+        else:
+            raise TypeError("schema should be StructType or list or None")
 
         # convert python objects to sql data
         converter = _python_to_sql_converter(schema)

http://git-wip-us.apache.org/repos/asf/spark/blob/afae9766/python/pyspark/sql/types.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 23d9adb..932686e 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -635,7 +635,7 @@ def _need_python_to_sql_conversion(dataType):
     >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), 
False), False),
     ...                       StructField("values", ArrayType(DoubleType(), 
False), False)])
     >>> _need_python_to_sql_conversion(schema0)
-    False
+    True
     >>> _need_python_to_sql_conversion(ExamplePointUDT())
     True
     >>> schema1 = ArrayType(ExamplePointUDT(), False)
@@ -647,7 +647,8 @@ def _need_python_to_sql_conversion(dataType):
     True
     """
     if isinstance(dataType, StructType):
-        return any([_need_python_to_sql_conversion(f.dataType) for f in 
dataType.fields])
+        # convert namedtuple or Row into tuple
+        return True
     elif isinstance(dataType, ArrayType):
         return _need_python_to_sql_conversion(dataType.elementType)
     elif isinstance(dataType, MapType):
@@ -688,21 +689,25 @@ def _python_to_sql_converter(dataType):
 
     if isinstance(dataType, StructType):
         names, types = zip(*[(f.name, f.dataType) for f in dataType.fields])
-        converters = [_python_to_sql_converter(t) for t in types]
-
-        def converter(obj):
-            if isinstance(obj, dict):
-                return tuple(c(obj.get(n)) for n, c in zip(names, converters))
-            elif isinstance(obj, tuple):
-                if hasattr(obj, "__fields__") or hasattr(obj, "_fields"):
-                    return tuple(c(v) for c, v in zip(converters, obj))
-                elif all(isinstance(x, tuple) and len(x) == 2 for x in obj):  
# k-v pairs
-                    d = dict(obj)
-                    return tuple(c(d.get(n)) for n, c in zip(names, 
converters))
+        if any(_need_python_to_sql_conversion(t) for t in types):
+            converters = [_python_to_sql_converter(t) for t in types]
+
+            def converter(obj):
+                if isinstance(obj, dict):
+                    return tuple(c(obj.get(n)) for n, c in zip(names, 
converters))
+                elif isinstance(obj, tuple):
+                    if hasattr(obj, "__fields__") or hasattr(obj, "_fields"):
+                        return tuple(c(v) for c, v in zip(converters, obj))
+                    else:
+                        return tuple(c(v) for c, v in zip(converters, obj))
+                elif obj is not None:
+                    raise ValueError("Unexpected tuple %r with type %r" % 
(obj, dataType))
+        else:
+            def converter(obj):
+                if isinstance(obj, dict):
+                    return tuple(obj.get(n) for n in names)
                 else:
-                    return tuple(c(v) for c, v in zip(converters, obj))
-            elif obj is not None:
-                raise ValueError("Unexpected tuple %r with type %r" % (obj, 
dataType))
+                    return tuple(obj)
         return converter
     elif isinstance(dataType, ArrayType):
         element_converter = _python_to_sql_converter(dataType.elementType)
@@ -1027,10 +1032,13 @@ def _verify_type(obj, dataType):
     _type = type(dataType)
     assert _type in _acceptable_types, "unknown datatype: %s" % dataType
 
-    # subclass of them can not be deserialized in JVM
-    if type(obj) not in _acceptable_types[_type]:
-        raise TypeError("%s can not accept object in type %s"
-                        % (dataType, type(obj)))
+    if _type is StructType:
+        if not isinstance(obj, (tuple, list)):
+            raise TypeError("StructType can not accept object in type %s" % 
type(obj))
+    else:
+        # subclass of them can not be deserialized in JVM
+        if type(obj) not in _acceptable_types[_type]:
+            raise TypeError("%s can not accept object in type %s" % (dataType, 
type(obj)))
 
     if isinstance(dataType, ArrayType):
         for i in obj:


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to