Repository: spark
Updated Branches:
  refs/heads/master 0f80990bf -> 445647a1a


[SPARK-8021] [SQL] [PYSPARK] make Python read/write API consistent with Scala

add schema()/format()/options() for reader,  add 
mode()/format()/options()/partitionBy() for writer

cc rxin yhuai  pwendell

Author: Davies Liu <dav...@databricks.com>

Closes #6578 from davies/readwrite and squashes the following commits:

720d293 [Davies Liu] address comments
b65dfa2 [Davies Liu] Update readwriter.py
1299ab6 [Davies Liu] make Python API consistent with Scala


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/445647a1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/445647a1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/445647a1

Branch: refs/heads/master
Commit: 445647a1a36e1e24076a9fe506492fac462c66ad
Parents: 0f80990
Author: Davies Liu <dav...@databricks.com>
Authored: Tue Jun 2 08:37:18 2015 -0700
Committer: Patrick Wendell <patr...@databricks.com>
Committed: Tue Jun 2 08:37:18 2015 -0700

----------------------------------------------------------------------
 python/pyspark/sql/readwriter.py | 121 ++++++++++++++++++++++++++--------
 1 file changed, 94 insertions(+), 27 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/445647a1/python/pyspark/sql/readwriter.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index b6fd413..d17d874 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -44,6 +44,39 @@ class DataFrameReader(object):
         return DataFrame(jdf, self._sqlContext)
 
     @since(1.4)
+    def format(self, source):
+        """
+        Specifies the input data source format.
+        """
+        self._jreader = self._jreader.format(source)
+        return self
+
+    @since(1.4)
+    def schema(self, schema):
+        """
+        Specifies the input schema. Some data sources (e.g. JSON) can
+        infer the input schema automatically from data. By specifying
+        the schema here, the underlying data source can skip the schema
+        inference step, and thus speed up data loading.
+
+        :param schema: a StructType object
+        """
+        if not isinstance(schema, StructType):
+            raise TypeError("schema should be StructType")
+        jschema = self._sqlContext._ssql_ctx.parseDataType(schema.json())
+        self._jreader = self._jreader.schema(jschema)
+        return self
+
+    @since(1.4)
+    def options(self, **options):
+        """
+        Adds input options for the underlying data source.
+        """
+        for k in options:
+            self._jreader = self._jreader.option(k, options[k])
+        return self
+
+    @since(1.4)
     def load(self, path=None, format=None, schema=None, **options):
         """Loads data from a data source and returns it as a :class`DataFrame`.
 
@@ -52,20 +85,15 @@ class DataFrameReader(object):
         :param schema: optional :class:`StructType` for the input schema.
         :param options: all other string options
         """
-        jreader = self._jreader
         if format is not None:
-            jreader = jreader.format(format)
+            self.format(format)
         if schema is not None:
-            if not isinstance(schema, StructType):
-                raise TypeError("schema should be StructType")
-            jschema = self._sqlContext._ssql_ctx.parseDataType(schema.json())
-            jreader = jreader.schema(jschema)
-        for k in options:
-            jreader = jreader.option(k, options[k])
+            self.schema(schema)
+        self.options(**options)
         if path is not None:
-            return self._df(jreader.load(path))
+            return self._df(self._jreader.load(path))
         else:
-            return self._df(jreader.load())
+            return self._df(self._jreader.load())
 
     @since(1.4)
     def json(self, path, schema=None):
@@ -105,12 +133,9 @@ class DataFrameReader(object):
          |    |-- field5: array (nullable = true)
          |    |    |-- element: integer (containsNull = true)
         """
-        if schema is None:
-            jdf = self._jreader.json(path)
-        else:
-            jschema = self._sqlContext._ssql_ctx.parseDataType(schema.json())
-            jdf = self._jreader.schema(jschema).json(path)
-        return self._df(jdf)
+        if schema is not None:
+            self.schema(schema)
+        return self._df(self._jreader.json(path))
 
     @since(1.4)
     def table(self, tableName):
@@ -195,6 +220,51 @@ class DataFrameWriter(object):
         self._jwrite = df._jdf.write()
 
     @since(1.4)
+    def mode(self, saveMode):
+        """
+        Specifies the behavior when data or table already exists. Options 
include:
+
+        * `append`: Append contents of this :class:`DataFrame` to existing 
data.
+        * `overwrite`: Overwrite existing data.
+        * `error`: Throw an exception if data already exists.
+        * `ignore`: Silently ignore this operation if data already exists.
+        """
+        self._jwrite = self._jwrite.mode(saveMode)
+        return self
+
+    @since(1.4)
+    def format(self, source):
+        """
+        Specifies the underlying output data source. Built-in options include
+        "parquet", "json", etc.
+        """
+        self._jwrite = self._jwrite.format(source)
+        return self
+
+    @since(1.4)
+    def options(self, **options):
+        """
+        Adds output options for the underlying data source.
+        """
+        for k in options:
+            self._jwrite = self._jwrite.option(k, options[k])
+        return self
+
+    @since(1.4)
+    def partitionBy(self, *cols):
+        """
+        Partitions the output by the given columns on the file system.
+        If specified, the output is laid out on the file system similar
+        to Hive's partitioning scheme.
+
+        :param cols: name of columns
+        """
+        if len(cols) == 1 and isinstance(cols[0], (list, tuple)):
+            cols = cols[0]
+        self._jwrite = self._jwrite.partitionBy(_to_seq(self._sqlContext._sc, 
cols))
+        return self
+
+    @since(1.4)
     def save(self, path=None, format=None, mode="error", **options):
         """
         Saves the contents of the :class:`DataFrame` to a data source.
@@ -216,16 +286,15 @@ class DataFrameWriter(object):
         :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: 
error)
         :param options: all other string options
         """
-        jwrite = self._jwrite.mode(mode)
+        self.mode(mode).options(**options)
         if format is not None:
-            jwrite = jwrite.format(format)
-        for k in options:
-            jwrite = jwrite.option(k, options[k])
+            self.format(format)
         if path is None:
-            jwrite.save()
+            self._jwrite.save()
         else:
-            jwrite.save(path)
+            self._jwrite.save(path)
 
+    @since(1.4)
     def insertInto(self, tableName, overwrite=False):
         """
         Inserts the content of the :class:`DataFrame` to the specified table.
@@ -256,12 +325,10 @@ class DataFrameWriter(object):
         :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: 
error)
         :param options: all other string options
         """
-        jwrite = self._jwrite.mode(mode)
+        self.mode(mode).options(**options)
         if format is not None:
-            jwrite = jwrite.format(format)
-        for k in options:
-            jwrite = jwrite.option(k, options[k])
-        return jwrite.saveAsTable(name)
+            self.format(format)
+        return self._jwrite.saveAsTable(name)
 
     @since(1.4)
     def json(self, path, mode="error"):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to