Repository: spark
Updated Branches:
  refs/heads/master a1a64e358 -> 6a0b77a55


[SPARK-24215][PYSPARK][FOLLOW UP] Implement eager evaluation for DataFrame APIs 
in PySpark

## What changes were proposed in this pull request?

Address comments in #21370 and add more test.

## How was this patch tested?

Enhance test in pyspark/sql/test.py and DataFrameSuite

Author: Yuanjian Li <xyliyuanj...@gmail.com>

Closes #21553 from xuanyuanking/SPARK-24215-follow.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6a0b77a5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6a0b77a5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6a0b77a5

Branch: refs/heads/master
Commit: 6a0b77a55d53e74ac0a0892556c3a7a933474948
Parents: a1a64e3
Author: Yuanjian Li <xyliyuanj...@gmail.com>
Authored: Wed Jun 27 10:43:06 2018 -0700
Committer: Xiao Li <gatorsm...@gmail.com>
Committed: Wed Jun 27 10:43:06 2018 -0700

----------------------------------------------------------------------
 docs/configuration.md                           | 27 ---------
 python/pyspark/sql/dataframe.py                 |  3 +-
 python/pyspark/sql/tests.py                     | 46 ++++++++++++++-
 .../org/apache/spark/sql/internal/SQLConf.scala | 23 ++++++++
 .../scala/org/apache/spark/sql/Dataset.scala    | 11 ++--
 .../org/apache/spark/sql/DataFrameSuite.scala   | 59 ++++++++++++++++++++
 6 files changed, 131 insertions(+), 38 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6a0b77a5/docs/configuration.md
----------------------------------------------------------------------
diff --git a/docs/configuration.md b/docs/configuration.md
index 6aa7878..0c7c447 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -457,33 +457,6 @@ Apart from these, the following properties are also 
available, and may be useful
   </td>
 </tr>
 <tr>
-  <td><code>spark.sql.repl.eagerEval.enabled</code></td>
-  <td>false</td>
-  <td>
-    Enable eager evaluation or not. If true and the REPL you are using 
supports eager evaluation,
-    Dataset will be ran automatically. The HTML table which generated by 
<code>_repl_html_</code>
-    called by notebooks like Jupyter will feedback the queries user have 
defined. For plain Python
-    REPL, the output will be shown like <code>dataframe.show()</code>
-    (see <a 
href="https://issues.apache.org/jira/browse/SPARK-24215";>SPARK-24215</a> for 
more details).
-  </td>
-</tr>
-<tr>
-  <td><code>spark.sql.repl.eagerEval.maxNumRows</code></td>
-  <td>20</td>
-  <td>
-    Default number of rows in eager evaluation output HTML table generated by 
<code>_repr_html_</code> or plain text,
-    this only take effect when <code>spark.sql.repl.eagerEval.enabled</code> 
is set to true.
-  </td>
-</tr>
-<tr>
-  <td><code>spark.sql.repl.eagerEval.truncate</code></td>
-  <td>20</td>
-  <td>
-    Default number of truncate in eager evaluation output HTML table generated 
by <code>_repr_html_</code> or
-    plain text, this only take effect when 
<code>spark.sql.repl.eagerEval.enabled</code> set to true.
-  </td>
-</tr>
-<tr>
   <td><code>spark.files</code></td>
   <td></td>
   <td>

http://git-wip-us.apache.org/repos/asf/spark/blob/6a0b77a5/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 1e6a1ac..cb3fe44 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -393,9 +393,8 @@ class DataFrame(object):
             self._support_repr_html = True
         if self._eager_eval:
             max_num_rows = max(self._max_num_rows, 0)
-            vertical = False
             sock_info = self._jdf.getRowsToPython(
-                max_num_rows, self._truncate, vertical)
+                max_num_rows, self._truncate)
             rows = list(_load_from_socket(sock_info, 
BatchedSerializer(PickleSerializer())))
             head = rows[0]
             row_data = rows[1:]

http://git-wip-us.apache.org/repos/asf/spark/blob/6a0b77a5/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 35a0636..8d73806 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -3351,11 +3351,41 @@ class SQLTests(ReusedSQLTestCase):
         finally:
             shutil.rmtree(path)
 
-    def test_repr_html(self):
+    def test_repr_behaviors(self):
         import re
         pattern = re.compile(r'^ *\|', re.MULTILINE)
         df = self.spark.createDataFrame([(1, "1"), (22222, "22222")], ("key", 
"value"))
-        self.assertEquals(None, df._repr_html_())
+
+        # test when eager evaluation is enabled and _repr_html_ will not be 
called
+        with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}):
+            expected1 = """+-----+-----+
+                ||  key|value|
+                |+-----+-----+
+                ||    1|    1|
+                ||22222|22222|
+                |+-----+-----+
+                |"""
+            self.assertEquals(re.sub(pattern, '', expected1), df.__repr__())
+            with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}):
+                expected2 = """+---+-----+
+                ||key|value|
+                |+---+-----+
+                ||  1|    1|
+                ||222|  222|
+                |+---+-----+
+                |"""
+                self.assertEquals(re.sub(pattern, '', expected2), 
df.__repr__())
+                with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}):
+                    expected3 = """+---+-----+
+                    ||key|value|
+                    |+---+-----+
+                    ||  1|    1|
+                    |+---+-----+
+                    |only showing top 1 row
+                    |"""
+                    self.assertEquals(re.sub(pattern, '', expected3), 
df.__repr__())
+
+        # test when eager evaluation is enabled and _repr_html_ will be called
         with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}):
             expected1 = """<table border='1'>
                 |<tr><th>key</th><th>value</th></tr>
@@ -3381,6 +3411,18 @@ class SQLTests(ReusedSQLTestCase):
                         |"""
                     self.assertEquals(re.sub(pattern, '', expected3), 
df._repr_html_())
 
+        # test when eager evaluation is disabled and _repr_html_ will be called
+        with self.sql_conf({"spark.sql.repl.eagerEval.enabled": False}):
+            expected = "DataFrame[key: bigint, value: string]"
+            self.assertEquals(None, df._repr_html_())
+            self.assertEquals(expected, df.__repr__())
+            with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}):
+                self.assertEquals(None, df._repr_html_())
+                self.assertEquals(expected, df.__repr__())
+                with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}):
+                    self.assertEquals(None, df._repr_html_())
+                    self.assertEquals(expected, df.__repr__())
+
 
 class HiveSparkSubmitTests(SparkSubmitTests):
 

http://git-wip-us.apache.org/repos/asf/spark/blob/6a0b77a5/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 239c826..e1752ff 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -1330,6 +1330,29 @@ object SQLConf {
       "The size function returns null for null input if the flag is disabled.")
     .booleanConf
     .createWithDefault(true)
+
+  val REPL_EAGER_EVAL_ENABLED = buildConf("spark.sql.repl.eagerEval.enabled")
+    .doc("Enables eager evaluation or not. When true, the top K rows of 
Dataset will be " +
+      "displayed if and only if the REPL supports the eager evaluation. 
Currently, the " +
+      "eager evaluation is only supported in PySpark. For the notebooks like 
Jupyter, " +
+      "the HTML table (generated by _repr_html_) will be returned. For plain 
Python REPL, " +
+      "the returned outputs are formatted like dataframe.show().")
+    .booleanConf
+    .createWithDefault(false)
+
+  val REPL_EAGER_EVAL_MAX_NUM_ROWS = 
buildConf("spark.sql.repl.eagerEval.maxNumRows")
+    .doc("The max number of rows that are returned by eager evaluation. This 
only takes " +
+      "effect when spark.sql.repl.eagerEval.enabled is set to true. The valid 
range of this " +
+      "config is from 0 to (Int.MaxValue - 1), so the invalid config like 
negative and " +
+      "greater than (Int.MaxValue - 1) will be normalized to 0 and 
(Int.MaxValue - 1).")
+    .intConf
+    .createWithDefault(20)
+
+  val REPL_EAGER_EVAL_TRUNCATE = buildConf("spark.sql.repl.eagerEval.truncate")
+    .doc("The max number of characters for each cell that is returned by eager 
evaluation. " +
+      "This only takes effect when spark.sql.repl.eagerEval.enabled is set to 
true.")
+    .intConf
+    .createWithDefault(20)
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/6a0b77a5/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 57f1e17..2ec236f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -236,12 +236,10 @@ class Dataset[T] private[sql](
    * @param numRows Number of rows to return
    * @param truncate If set to more than 0, truncates strings to `truncate` 
characters and
    *                   all cells will be aligned right.
-   * @param vertical If set to true, the rows to return do not need truncate.
    */
   private[sql] def getRows(
       numRows: Int,
-      truncate: Int,
-      vertical: Boolean): Seq[Seq[String]] = {
+      truncate: Int): Seq[Seq[String]] = {
     val newDf = toDF()
     val castCols = newDf.logicalPlan.output.map { col =>
       // Since binary types in top-level schema fields have a specific format 
to print,
@@ -289,7 +287,7 @@ class Dataset[T] private[sql](
       vertical: Boolean = false): String = {
     val numRows = _numRows.max(0).min(Int.MaxValue - 1)
     // Get rows represented by Seq[Seq[String]], we may get one more line if 
it has more data.
-    val tmpRows = getRows(numRows, truncate, vertical)
+    val tmpRows = getRows(numRows, truncate)
 
     val hasMoreData = tmpRows.length - 1 > numRows
     val rows = tmpRows.take(numRows + 1)
@@ -3226,11 +3224,10 @@ class Dataset[T] private[sql](
 
   private[sql] def getRowsToPython(
       _numRows: Int,
-      truncate: Int,
-      vertical: Boolean): Array[Any] = {
+      truncate: Int): Array[Any] = {
     EvaluatePython.registerPicklers()
     val numRows = _numRows.max(0).min(Int.MaxValue - 1)
-    val rows = getRows(numRows, truncate, vertical).map(_.toArray).toArray
+    val rows = getRows(numRows, truncate).map(_.toArray).toArray
     val toJava: (Any) => Any = EvaluatePython.toJava(_, 
ArrayType(ArrayType(StringType)))
     val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler(
       rows.iterator.map(toJava))

http://git-wip-us.apache.org/repos/asf/spark/blob/6a0b77a5/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 1cc8cb3..ea00d22 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -1044,6 +1044,65 @@ class DataFrameSuite extends QueryTest with 
SharedSQLContext {
     testData.select($"*").show(1000)
   }
 
+  test("getRows: truncate = [0, 20]") {
+    val longString = Array.fill(21)("1").mkString
+    val df = sparkContext.parallelize(Seq("1", longString)).toDF()
+    val expectedAnswerForFalse = Seq(
+      Seq("value"),
+      Seq("1"),
+      Seq("111111111111111111111"))
+    assert(df.getRows(10, 0) === expectedAnswerForFalse)
+    val expectedAnswerForTrue = Seq(
+      Seq("value"),
+      Seq("1"),
+      Seq("11111111111111111..."))
+    assert(df.getRows(10, 20) === expectedAnswerForTrue)
+  }
+
+  test("getRows: truncate = [3, 17]") {
+    val longString = Array.fill(21)("1").mkString
+    val df = sparkContext.parallelize(Seq("1", longString)).toDF()
+    val expectedAnswerForFalse = Seq(
+      Seq("value"),
+      Seq("1"),
+      Seq("111"))
+    assert(df.getRows(10, 3) === expectedAnswerForFalse)
+    val expectedAnswerForTrue = Seq(
+      Seq("value"),
+      Seq("1"),
+      Seq("11111111111111..."))
+    assert(df.getRows(10, 17) === expectedAnswerForTrue)
+  }
+
+  test("getRows: numRows = 0") {
+    val expectedAnswer = Seq(Seq("key", "value"), Seq("1", "1"))
+    assert(testData.select($"*").getRows(0, 20) === expectedAnswer)
+  }
+
+  test("getRows: array") {
+    val df = Seq(
+      (Array(1, 2, 3), Array(1, 2, 3)),
+      (Array(2, 3, 4), Array(2, 3, 4))
+    ).toDF()
+    val expectedAnswer = Seq(
+      Seq("_1", "_2"),
+      Seq("[1, 2, 3]", "[1, 2, 3]"),
+      Seq("[2, 3, 4]", "[2, 3, 4]"))
+    assert(df.getRows(10, 20) === expectedAnswer)
+  }
+
+  test("getRows: binary") {
+    val df = Seq(
+      ("12".getBytes(StandardCharsets.UTF_8), 
"ABC.".getBytes(StandardCharsets.UTF_8)),
+      ("34".getBytes(StandardCharsets.UTF_8), 
"12346".getBytes(StandardCharsets.UTF_8))
+    ).toDF()
+    val expectedAnswer = Seq(
+      Seq("_1", "_2"),
+      Seq("[31 32]", "[41 42 43 2E]"),
+      Seq("[33 34]", "[31 32 33 34 36]"))
+    assert(df.getRows(10, 20) === expectedAnswer)
+  }
+
   test("showString: truncate = [0, 20]") {
     val longString = Array.fill(21)("1").mkString
     val df = sparkContext.parallelize(Seq("1", longString)).toDF()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to