Repository: spark
Updated Branches:
  refs/heads/branch-2.0 0a2291cd1 -> e11c27918


[SPARK-15981][SQL][STREAMING] Fixed bug and added tests in DataStreamReader 
Python API

## What changes were proposed in this pull request?

- Fixed bug in Python API of DataStreamReader.  Because a single path was being 
converted to a array before calling Java DataStreamReader method (which takes a 
string only), it gave the following error.
```
File "/Users/tdas/Projects/Spark/spark/python/pyspark/sql/readwriter.py", line 
947, in pyspark.sql.readwriter.DataStreamReader.json
Failed example:
    json_sdf = spark.readStream.json(os.path.join(tempfile.mkdtemp(), 'data'),  
               schema = sdf_schema)
Exception raised:
    Traceback (most recent call last):
      File 
"/System/Library/Frameworks/Python.framework/Versions/2.6/lib/python2.6/doctest.py",
 line 1253, in __run
        compileflags, 1) in test.globs
      File "<doctest pyspark.sql.readwriter.DataStreamReader.json[0]>", line 1, 
in <module>
        json_sdf = spark.readStream.json(os.path.join(tempfile.mkdtemp(), 
'data'),                 schema = sdf_schema)
      File "/Users/tdas/Projects/Spark/spark/python/pyspark/sql/readwriter.py", 
line 963, in json
        return self._df(self._jreader.json(path))
      File 
"/Users/tdas/Projects/Spark/spark/python/lib/py4j-0.10.1-src.zip/py4j/java_gateway.py",
 line 933, in __call__
        answer, self.gateway_client, self.target_id, self.name)
      File "/Users/tdas/Projects/Spark/spark/python/pyspark/sql/utils.py", line 
63, in deco
        return f(*a, **kw)
      File 
"/Users/tdas/Projects/Spark/spark/python/lib/py4j-0.10.1-src.zip/py4j/protocol.py",
 line 316, in get_return_value
        format(target_id, ".", name, value))
    Py4JError: An error occurred while calling o121.json. Trace:
    py4j.Py4JException: Method json([class java.util.ArrayList]) does not exist
        at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:318)
        at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:326)
        at py4j.Gateway.invoke(Gateway.java:272)
        at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:128)
        at py4j.commands.CallCommand.execute(CallCommand.java:79)
        at py4j.GatewayConnection.run(GatewayConnection.java:211)
        at java.lang.Thread.run(Thread.java:744)
```

- Reduced code duplication between DataStreamReader and DataFrameWriter
- Added missing Python doctests

## How was this patch tested?
New tests

Author: Tathagata Das <tathagata.das1...@gmail.com>

Closes #13703 from tdas/SPARK-15981.

(cherry picked from commit 084dca770f5c26f906e7555707c7894cf05fb86b)
Signed-off-by: Shixiong Zhu <shixi...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e11c2791
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e11c2791
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e11c2791

Branch: refs/heads/branch-2.0
Commit: e11c279188b34d410f6ecf17cb1773c95f24a19e
Parents: 0a2291c
Author: Tathagata Das <tathagata.das1...@gmail.com>
Authored: Thu Jun 16 13:17:41 2016 -0700
Committer: Shixiong Zhu <shixi...@databricks.com>
Committed: Thu Jun 16 13:17:50 2016 -0700

----------------------------------------------------------------------
 python/pyspark/sql/readwriter.py | 258 ++++++++++++++++++----------------
 1 file changed, 136 insertions(+), 122 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e11c2791/python/pyspark/sql/readwriter.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index c982de6..72fd184 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -44,7 +44,82 @@ def to_str(value):
         return str(value)
 
 
-class DataFrameReader(object):
+class ReaderUtils(object):
+
+    def _set_json_opts(self, schema, primitivesAsString, prefersDecimal,
+                       allowComments, allowUnquotedFieldNames, 
allowSingleQuotes,
+                       allowNumericLeadingZero, 
allowBackslashEscapingAnyCharacter,
+                       mode, columnNameOfCorruptRecord):
+        """
+        Set options based on the Json optional parameters
+        """
+        if schema is not None:
+            self.schema(schema)
+        if primitivesAsString is not None:
+            self.option("primitivesAsString", primitivesAsString)
+        if prefersDecimal is not None:
+            self.option("prefersDecimal", prefersDecimal)
+        if allowComments is not None:
+            self.option("allowComments", allowComments)
+        if allowUnquotedFieldNames is not None:
+            self.option("allowUnquotedFieldNames", allowUnquotedFieldNames)
+        if allowSingleQuotes is not None:
+            self.option("allowSingleQuotes", allowSingleQuotes)
+        if allowNumericLeadingZero is not None:
+            self.option("allowNumericLeadingZero", allowNumericLeadingZero)
+        if allowBackslashEscapingAnyCharacter is not None:
+            self.option("allowBackslashEscapingAnyCharacter", 
allowBackslashEscapingAnyCharacter)
+        if mode is not None:
+            self.option("mode", mode)
+        if columnNameOfCorruptRecord is not None:
+            self.option("columnNameOfCorruptRecord", columnNameOfCorruptRecord)
+
+    def _set_csv_opts(self, schema, sep, encoding, quote, escape,
+                      comment, header, inferSchema, ignoreLeadingWhiteSpace,
+                      ignoreTrailingWhiteSpace, nullValue, nanValue, 
positiveInf, negativeInf,
+                      dateFormat, maxColumns, maxCharsPerColumn, mode):
+        """
+        Set options based on the CSV optional parameters
+        """
+        if schema is not None:
+            self.schema(schema)
+        if sep is not None:
+            self.option("sep", sep)
+        if encoding is not None:
+            self.option("encoding", encoding)
+        if quote is not None:
+            self.option("quote", quote)
+        if escape is not None:
+            self.option("escape", escape)
+        if comment is not None:
+            self.option("comment", comment)
+        if header is not None:
+            self.option("header", header)
+        if inferSchema is not None:
+            self.option("inferSchema", inferSchema)
+        if ignoreLeadingWhiteSpace is not None:
+            self.option("ignoreLeadingWhiteSpace", ignoreLeadingWhiteSpace)
+        if ignoreTrailingWhiteSpace is not None:
+            self.option("ignoreTrailingWhiteSpace", ignoreTrailingWhiteSpace)
+        if nullValue is not None:
+            self.option("nullValue", nullValue)
+        if nanValue is not None:
+            self.option("nanValue", nanValue)
+        if positiveInf is not None:
+            self.option("positiveInf", positiveInf)
+        if negativeInf is not None:
+            self.option("negativeInf", negativeInf)
+        if dateFormat is not None:
+            self.option("dateFormat", dateFormat)
+        if maxColumns is not None:
+            self.option("maxColumns", maxColumns)
+        if maxCharsPerColumn is not None:
+            self.option("maxCharsPerColumn", maxCharsPerColumn)
+        if mode is not None:
+            self.option("mode", mode)
+
+
+class DataFrameReader(ReaderUtils):
     """
     Interface used to load a :class:`DataFrame` from external storage systems
     (e.g. file systems, key-value stores, etc). Use :func:`spark.read`
@@ -193,26 +268,10 @@ class DataFrameReader(object):
         [('age', 'bigint'), ('name', 'string')]
 
         """
-        if schema is not None:
-            self.schema(schema)
-        if primitivesAsString is not None:
-            self.option("primitivesAsString", primitivesAsString)
-        if prefersDecimal is not None:
-            self.option("prefersDecimal", prefersDecimal)
-        if allowComments is not None:
-            self.option("allowComments", allowComments)
-        if allowUnquotedFieldNames is not None:
-            self.option("allowUnquotedFieldNames", allowUnquotedFieldNames)
-        if allowSingleQuotes is not None:
-            self.option("allowSingleQuotes", allowSingleQuotes)
-        if allowNumericLeadingZero is not None:
-            self.option("allowNumericLeadingZero", allowNumericLeadingZero)
-        if allowBackslashEscapingAnyCharacter is not None:
-            self.option("allowBackslashEscapingAnyCharacter", 
allowBackslashEscapingAnyCharacter)
-        if mode is not None:
-            self.option("mode", mode)
-        if columnNameOfCorruptRecord is not None:
-            self.option("columnNameOfCorruptRecord", columnNameOfCorruptRecord)
+        self._set_json_opts(schema, primitivesAsString, prefersDecimal,
+                            allowComments, allowUnquotedFieldNames, 
allowSingleQuotes,
+                            allowNumericLeadingZero, 
allowBackslashEscapingAnyCharacter,
+                            mode, columnNameOfCorruptRecord)
         if isinstance(path, basestring):
             path = [path]
         if type(path) == list:
@@ -345,42 +404,11 @@ class DataFrameReader(object):
         >>> df.dtypes
         [('_c0', 'string'), ('_c1', 'string')]
         """
-        if schema is not None:
-            self.schema(schema)
-        if sep is not None:
-            self.option("sep", sep)
-        if encoding is not None:
-            self.option("encoding", encoding)
-        if quote is not None:
-            self.option("quote", quote)
-        if escape is not None:
-            self.option("escape", escape)
-        if comment is not None:
-            self.option("comment", comment)
-        if header is not None:
-            self.option("header", header)
-        if inferSchema is not None:
-            self.option("inferSchema", inferSchema)
-        if ignoreLeadingWhiteSpace is not None:
-            self.option("ignoreLeadingWhiteSpace", ignoreLeadingWhiteSpace)
-        if ignoreTrailingWhiteSpace is not None:
-            self.option("ignoreTrailingWhiteSpace", ignoreTrailingWhiteSpace)
-        if nullValue is not None:
-            self.option("nullValue", nullValue)
-        if nanValue is not None:
-            self.option("nanValue", nanValue)
-        if positiveInf is not None:
-            self.option("positiveInf", positiveInf)
-        if negativeInf is not None:
-            self.option("negativeInf", negativeInf)
-        if dateFormat is not None:
-            self.option("dateFormat", dateFormat)
-        if maxColumns is not None:
-            self.option("maxColumns", maxColumns)
-        if maxCharsPerColumn is not None:
-            self.option("maxCharsPerColumn", maxCharsPerColumn)
-        if mode is not None:
-            self.option("mode", mode)
+
+        self._set_csv_opts(schema, sep, encoding, quote, escape,
+                           comment, header, inferSchema, 
ignoreLeadingWhiteSpace,
+                           ignoreTrailingWhiteSpace, nullValue, nanValue, 
positiveInf, negativeInf,
+                           dateFormat, maxColumns, maxCharsPerColumn, mode)
         if isinstance(path, basestring):
             path = [path]
         return 
self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path)))
@@ -764,7 +792,7 @@ class DataFrameWriter(object):
         self._jwrite.mode(mode).jdbc(url, table, jprop)
 
 
-class DataStreamReader(object):
+class DataStreamReader(ReaderUtils):
     """
     Interface used to load a streaming :class:`DataFrame` from external 
storage systems
     (e.g. file systems, key-value stores, etc). Use :func:`spark.readStream`
@@ -791,6 +819,7 @@ class DataStreamReader(object):
 
         :param source: string, name of the data source, e.g. 'json', 'parquet'.
 
+        >>> s = spark.readStream.format("text")
         """
         self._jreader = self._jreader.format(source)
         return self
@@ -806,6 +835,8 @@ class DataStreamReader(object):
         .. note:: Experimental.
 
         :param schema: a StructType object
+
+        >>> s = spark.readStream.schema(sdf_schema)
         """
         if not isinstance(schema, StructType):
             raise TypeError("schema should be StructType")
@@ -818,6 +849,8 @@ class DataStreamReader(object):
         """Adds an input option for the underlying data source.
 
         .. note:: Experimental.
+
+        >>> s = spark.readStream.option("x", 1)
         """
         self._jreader = self._jreader.option(key, to_str(value))
         return self
@@ -827,6 +860,8 @@ class DataStreamReader(object):
         """Adds input options for the underlying data source.
 
         .. note:: Experimental.
+
+        >>> s = spark.readStream.options(x="1", y=2)
         """
         for k in options:
             self._jreader = self._jreader.option(k, to_str(options[k]))
@@ -843,6 +878,13 @@ class DataStreamReader(object):
         :param schema: optional :class:`StructType` for the input schema.
         :param options: all other string options
 
+        >>> json_sdf = spark.readStream.format("json")\
+                                       .schema(sdf_schema)\
+                                       
.load(os.path.join(tempfile.mkdtemp(),'data'))
+        >>> json_sdf.isStreaming
+        True
+        >>> json_sdf.schema == sdf_schema
+        True
         """
         if format is not None:
             self.format(format)
@@ -905,29 +947,18 @@ class DataStreamReader(object):
                                           it uses the value specified in
                                           
``spark.sql.columnNameOfCorruptRecord``.
 
+        >>> json_sdf = spark.readStream.json(os.path.join(tempfile.mkdtemp(), 
'data'), \
+                schema = sdf_schema)
+        >>> json_sdf.isStreaming
+        True
+        >>> json_sdf.schema == sdf_schema
+        True
         """
-        if schema is not None:
-            self.schema(schema)
-        if primitivesAsString is not None:
-            self.option("primitivesAsString", primitivesAsString)
-        if prefersDecimal is not None:
-            self.option("prefersDecimal", prefersDecimal)
-        if allowComments is not None:
-            self.option("allowComments", allowComments)
-        if allowUnquotedFieldNames is not None:
-            self.option("allowUnquotedFieldNames", allowUnquotedFieldNames)
-        if allowSingleQuotes is not None:
-            self.option("allowSingleQuotes", allowSingleQuotes)
-        if allowNumericLeadingZero is not None:
-            self.option("allowNumericLeadingZero", allowNumericLeadingZero)
-        if allowBackslashEscapingAnyCharacter is not None:
-            self.option("allowBackslashEscapingAnyCharacter", 
allowBackslashEscapingAnyCharacter)
-        if mode is not None:
-            self.option("mode", mode)
-        if columnNameOfCorruptRecord is not None:
-            self.option("columnNameOfCorruptRecord", columnNameOfCorruptRecord)
+        self._set_json_opts(schema, primitivesAsString, prefersDecimal,
+                            allowComments, allowUnquotedFieldNames, 
allowSingleQuotes,
+                            allowNumericLeadingZero, 
allowBackslashEscapingAnyCharacter,
+                            mode, columnNameOfCorruptRecord)
         if isinstance(path, basestring):
-            path = [path]
             return self._df(self._jreader.json(path))
         else:
             raise TypeError("path can be only a single string")
@@ -943,10 +974,15 @@ class DataStreamReader(object):
 
         .. note:: Experimental.
 
+        >>> parquet_sdf = spark.readStream.schema(sdf_schema)\
+                .parquet(os.path.join(tempfile.mkdtemp()))
+        >>> parquet_sdf.isStreaming
+        True
+        >>> parquet_sdf.schema == sdf_schema
+        True
         """
         if isinstance(path, basestring):
-            path = [path]
-            return 
self._df(self._jreader.parquet(self._spark._sc._jvm.PythonUtils.toSeq(path)))
+            return self._df(self._jreader.parquet(path))
         else:
             raise TypeError("path can be only a single string")
 
@@ -964,10 +1000,14 @@ class DataStreamReader(object):
 
         :param paths: string, or list of strings, for input path(s).
 
+        >>> text_sdf = spark.readStream.text(os.path.join(tempfile.mkdtemp(), 
'data'))
+        >>> text_sdf.isStreaming
+        True
+        >>> "value" in str(text_sdf.schema)
+        True
         """
         if isinstance(path, basestring):
-            path = [path]
-            return 
self._df(self._jreader.text(self._spark._sc._jvm.PythonUtils.toSeq(path)))
+            return self._df(self._jreader.text(path))
         else:
             raise TypeError("path can be only a single string")
 
@@ -1034,46 +1074,20 @@ class DataStreamReader(object):
                 * ``DROPMALFORMED`` : ignores the whole corrupted records.
                 * ``FAILFAST`` : throws an exception when it meets corrupted 
records.
 
+        >>> csv_sdf = spark.readStream.csv(os.path.join(tempfile.mkdtemp(), 
'data'), \
+                schema = sdf_schema)
+        >>> csv_sdf.isStreaming
+        True
+        >>> csv_sdf.schema == sdf_schema
+        True
         """
-        if schema is not None:
-            self.schema(schema)
-        if sep is not None:
-            self.option("sep", sep)
-        if encoding is not None:
-            self.option("encoding", encoding)
-        if quote is not None:
-            self.option("quote", quote)
-        if escape is not None:
-            self.option("escape", escape)
-        if comment is not None:
-            self.option("comment", comment)
-        if header is not None:
-            self.option("header", header)
-        if inferSchema is not None:
-            self.option("inferSchema", inferSchema)
-        if ignoreLeadingWhiteSpace is not None:
-            self.option("ignoreLeadingWhiteSpace", ignoreLeadingWhiteSpace)
-        if ignoreTrailingWhiteSpace is not None:
-            self.option("ignoreTrailingWhiteSpace", ignoreTrailingWhiteSpace)
-        if nullValue is not None:
-            self.option("nullValue", nullValue)
-        if nanValue is not None:
-            self.option("nanValue", nanValue)
-        if positiveInf is not None:
-            self.option("positiveInf", positiveInf)
-        if negativeInf is not None:
-            self.option("negativeInf", negativeInf)
-        if dateFormat is not None:
-            self.option("dateFormat", dateFormat)
-        if maxColumns is not None:
-            self.option("maxColumns", maxColumns)
-        if maxCharsPerColumn is not None:
-            self.option("maxCharsPerColumn", maxCharsPerColumn)
-        if mode is not None:
-            self.option("mode", mode)
+
+        self._set_csv_opts(schema, sep, encoding, quote, escape,
+                           comment, header, inferSchema, 
ignoreLeadingWhiteSpace,
+                           ignoreTrailingWhiteSpace, nullValue, nanValue, 
positiveInf, negativeInf,
+                           dateFormat, maxColumns, maxCharsPerColumn, mode)
         if isinstance(path, basestring):
-            path = [path]
-            return 
self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path)))
+            return self._df(self._jreader.csv(path))
         else:
             raise TypeError("path can be only a single string")
 
@@ -1286,7 +1300,7 @@ def _test():
     globs['df'] = 
spark.read.parquet('python/test_support/sql/parquet_partitioned')
     globs['sdf'] = \
         
spark.readStream.format('text').load('python/test_support/sql/streaming')
-
+    globs['sdf_schema'] = StructType([StructField("data", StringType(), 
False)])
     (failure_count, test_count) = doctest.testmod(
         pyspark.sql.readwriter, globs=globs,
         optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | 
doctest.REPORT_NDIFF)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to