Repository: spark
Updated Branches:
  refs/heads/master ff0501b0c -> dbb4d8382


[SPARK-24215][PYSPARK] Implement _repr_html_ for dataframes in PySpark

## What changes were proposed in this pull request?

Implement `_repr_html_` for PySpark while in notebook and add config named 
"spark.sql.repl.eagerEval.enabled" to control this.

The dev list thread for context: 
http://apache-spark-developers-list.1001551.n3.nabble.com/eager-execution-and-debuggability-td23928.html

## How was this patch tested?

New ut in DataFrameSuite and manual test in jupyter. Some screenshot below.

**After:**
![image](https://user-images.githubusercontent.com/4833765/40268422-8db5bef0-5b9f-11e8-80f1-04bc654a4f2c.png)

**Before:**
![image](https://user-images.githubusercontent.com/4833765/40268431-9f92c1b8-5b9f-11e8-9db9-0611f0940b26.png)

Author: Yuanjian Li <xyliyuanj...@gmail.com>

Closes #21370 from xuanyuanking/SPARK-24215.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/dbb4d838
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/dbb4d838
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/dbb4d838

Branch: refs/heads/master
Commit: dbb4d83829ec4b51d6e6d3a96f7a4e611d8827bc
Parents: ff0501b
Author: Yuanjian Li <xyliyuanj...@gmail.com>
Authored: Tue Jun 5 08:23:08 2018 +0700
Committer: hyukjinkwon <gurwls...@apache.org>
Committed: Tue Jun 5 08:23:08 2018 +0700

----------------------------------------------------------------------
 docs/configuration.md                           | 27 +++++++
 python/pyspark/sql/dataframe.py                 | 65 ++++++++++++++-
 python/pyspark/sql/tests.py                     | 30 +++++++
 .../scala/org/apache/spark/sql/Dataset.scala    | 84 +++++++++++++-------
 4 files changed, 176 insertions(+), 30 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/dbb4d838/docs/configuration.md
----------------------------------------------------------------------
diff --git a/docs/configuration.md b/docs/configuration.md
index 64af0e9..5588c37 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -457,6 +457,33 @@ Apart from these, the following properties are also 
available, and may be useful
   </td>
 </tr>
 <tr>
+  <td><code>spark.sql.repl.eagerEval.enabled</code></td>
+  <td>false</td>
+  <td>
+    Enable eager evaluation or not. If true and the REPL you are using 
supports eager evaluation,
+    Dataset will be ran automatically. The HTML table which generated by 
<code>_repl_html_</code>
+    called by notebooks like Jupyter will feedback the queries user have 
defined. For plain Python
+    REPL, the output will be shown like <code>dataframe.show()</code>
+    (see <a 
href="https://issues.apache.org/jira/browse/SPARK-24215";>SPARK-24215</a> for 
more details).
+  </td>
+</tr>
+<tr>
+  <td><code>spark.sql.repl.eagerEval.maxNumRows</code></td>
+  <td>20</td>
+  <td>
+    Default number of rows in eager evaluation output HTML table generated by 
<code>_repr_html_</code> or plain text,
+    this only take effect when <code>spark.sql.repl.eagerEval.enabled</code> 
is set to true.
+  </td>
+</tr>
+<tr>
+  <td><code>spark.sql.repl.eagerEval.truncate</code></td>
+  <td>20</td>
+  <td>
+    Default number of truncate in eager evaluation output HTML table generated 
by <code>_repr_html_</code> or
+    plain text, this only take effect when 
<code>spark.sql.repl.eagerEval.enabled</code> set to true.
+  </td>
+</tr>
+<tr>
   <td><code>spark.files</code></td>
   <td></td>
   <td>

http://git-wip-us.apache.org/repos/asf/spark/blob/dbb4d838/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 808235a..1e6a1ac 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -78,6 +78,9 @@ class DataFrame(object):
         self.is_cached = False
         self._schema = None  # initialized lazily
         self._lazy_rdd = None
+        # Check whether _repr_html is supported or not, we use it to avoid 
calling _jdf twice
+        # by __repr__ and _repr_html_ while eager evaluation opened.
+        self._support_repr_html = False
 
     @property
     @since(1.3)
@@ -351,8 +354,68 @@ class DataFrame(object):
         else:
             print(self._jdf.showString(n, int(truncate), vertical))
 
+    @property
+    def _eager_eval(self):
+        """Returns true if the eager evaluation enabled.
+        """
+        return self.sql_ctx.getConf(
+            "spark.sql.repl.eagerEval.enabled", "false").lower() == "true"
+
+    @property
+    def _max_num_rows(self):
+        """Returns the max row number for eager evaluation.
+        """
+        return int(self.sql_ctx.getConf(
+            "spark.sql.repl.eagerEval.maxNumRows", "20"))
+
+    @property
+    def _truncate(self):
+        """Returns the truncate length for eager evaluation.
+        """
+        return int(self.sql_ctx.getConf(
+            "spark.sql.repl.eagerEval.truncate", "20"))
+
     def __repr__(self):
-        return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes))
+        if not self._support_repr_html and self._eager_eval:
+            vertical = False
+            return self._jdf.showString(
+                self._max_num_rows, self._truncate, vertical)
+        else:
+            return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in 
self.dtypes))
+
+    def _repr_html_(self):
+        """Returns a dataframe with html code when you enabled eager evaluation
+        by 'spark.sql.repl.eagerEval.enabled', this only called by REPL you are
+        using support eager evaluation with HTML.
+        """
+        import cgi
+        if not self._support_repr_html:
+            self._support_repr_html = True
+        if self._eager_eval:
+            max_num_rows = max(self._max_num_rows, 0)
+            vertical = False
+            sock_info = self._jdf.getRowsToPython(
+                max_num_rows, self._truncate, vertical)
+            rows = list(_load_from_socket(sock_info, 
BatchedSerializer(PickleSerializer())))
+            head = rows[0]
+            row_data = rows[1:]
+            has_more_data = len(row_data) > max_num_rows
+            row_data = row_data[:max_num_rows]
+
+            html = "<table border='1'>\n"
+            # generate table head
+            html += "<tr><th>%s</th></tr>\n" % "</th><th>".join(map(lambda x: 
cgi.escape(x), head))
+            # generate table rows
+            for row in row_data:
+                html += "<tr><td>%s</td></tr>\n" % "</td><td>".join(
+                    map(lambda x: cgi.escape(x), row))
+            html += "</table>\n"
+            if has_more_data:
+                html += "only showing top %d %s\n" % (
+                    max_num_rows, "row" if max_num_rows == 1 else "rows")
+            return html
+        else:
+            return None
 
     @since(2.1)
     def checkpoint(self, eager=True):

http://git-wip-us.apache.org/repos/asf/spark/blob/dbb4d838/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index ea2dd76..487eb19 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -3074,6 +3074,36 @@ class SQLTests(ReusedSQLTestCase):
         finally:
             shutil.rmtree(path)
 
+    def test_repr_html(self):
+        import re
+        pattern = re.compile(r'^ *\|', re.MULTILINE)
+        df = self.spark.createDataFrame([(1, "1"), (22222, "22222")], ("key", 
"value"))
+        self.assertEquals(None, df._repr_html_())
+        with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}):
+            expected1 = """<table border='1'>
+                |<tr><th>key</th><th>value</th></tr>
+                |<tr><td>1</td><td>1</td></tr>
+                |<tr><td>22222</td><td>22222</td></tr>
+                |</table>
+                |"""
+            self.assertEquals(re.sub(pattern, '', expected1), df._repr_html_())
+            with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}):
+                expected2 = """<table border='1'>
+                    |<tr><th>key</th><th>value</th></tr>
+                    |<tr><td>1</td><td>1</td></tr>
+                    |<tr><td>222</td><td>222</td></tr>
+                    |</table>
+                    |"""
+                self.assertEquals(re.sub(pattern, '', expected2), 
df._repr_html_())
+                with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}):
+                    expected3 = """<table border='1'>
+                        |<tr><th>key</th><th>value</th></tr>
+                        |<tr><td>1</td><td>1</td></tr>
+                        |</table>
+                        |only showing top 1 row
+                        |"""
+                    self.assertEquals(re.sub(pattern, '', expected3), 
df._repr_html_())
+
 
 class HiveSparkSubmitTests(SparkSubmitTests):
 

http://git-wip-us.apache.org/repos/asf/spark/blob/dbb4d838/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index abb5ae5..f552610 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -231,16 +231,17 @@ class Dataset[T] private[sql](
   }
 
   /**
-   * Compose the string representing rows for output
+   * Get rows represented in Sequence by specific truncate and vertical 
requirement.
    *
-   * @param _numRows Number of rows to show
+   * @param numRows Number of rows to return
    * @param truncate If set to more than 0, truncates strings to `truncate` 
characters and
    *                   all cells will be aligned right.
-   * @param vertical If set to true, prints output rows vertically (one line 
per column value).
+   * @param vertical If set to true, the rows to return do not need truncate.
    */
-  private[sql] def showString(
-      _numRows: Int, truncate: Int = 20, vertical: Boolean = false): String = {
-    val numRows = _numRows.max(0).min(Int.MaxValue - 1)
+  private[sql] def getRows(
+      numRows: Int,
+      truncate: Int,
+      vertical: Boolean): Seq[Seq[String]] = {
     val newDf = toDF()
     val castCols = newDf.logicalPlan.output.map { col =>
       // Since binary types in top-level schema fields have a specific format 
to print,
@@ -251,14 +252,12 @@ class Dataset[T] private[sql](
         Column(col).cast(StringType)
       }
     }
-    val takeResult = newDf.select(castCols: _*).take(numRows + 1)
-    val hasMoreData = takeResult.length > numRows
-    val data = takeResult.take(numRows)
+    val data = newDf.select(castCols: _*).take(numRows + 1)
 
     // For array values, replace Seq and Array with square brackets
     // For cells that are beyond `truncate` characters, replace it with the
     // first `truncate-3` and "..."
-    val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: data.map { row =>
+    schema.fieldNames.toSeq +: data.map { row =>
       row.toSeq.map { cell =>
         val str = cell match {
           case null => "null"
@@ -274,6 +273,26 @@ class Dataset[T] private[sql](
         }
       }: Seq[String]
     }
+  }
+
+  /**
+   * Compose the string representing rows for output
+   *
+   * @param _numRows Number of rows to show
+   * @param truncate If set to more than 0, truncates strings to `truncate` 
characters and
+   *                   all cells will be aligned right.
+   * @param vertical If set to true, prints output rows vertically (one line 
per column value).
+   */
+  private[sql] def showString(
+      _numRows: Int,
+      truncate: Int = 20,
+      vertical: Boolean = false): String = {
+    val numRows = _numRows.max(0).min(Int.MaxValue - 1)
+    // Get rows represented by Seq[Seq[String]], we may get one more line if 
it has more data.
+    val tmpRows = getRows(numRows, truncate, vertical)
+
+    val hasMoreData = tmpRows.length - 1 > numRows
+    val rows = tmpRows.take(numRows + 1)
 
     val sb = new StringBuilder
     val numCols = schema.fieldNames.length
@@ -291,31 +310,25 @@ class Dataset[T] private[sql](
         }
       }
 
+      val paddedRows = rows.map { row =>
+        row.zipWithIndex.map { case (cell, i) =>
+          if (truncate > 0) {
+            StringUtils.leftPad(cell, colWidths(i))
+          } else {
+            StringUtils.rightPad(cell, colWidths(i))
+          }
+        }
+      }
+
       // Create SeparateLine
       val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", 
"+\n").toString()
 
       // column names
-      rows.head.zipWithIndex.map { case (cell, i) =>
-        if (truncate > 0) {
-          StringUtils.leftPad(cell, colWidths(i))
-        } else {
-          StringUtils.rightPad(cell, colWidths(i))
-        }
-      }.addString(sb, "|", "|", "|\n")
-
+      paddedRows.head.addString(sb, "|", "|", "|\n")
       sb.append(sep)
 
       // data
-      rows.tail.foreach {
-        _.zipWithIndex.map { case (cell, i) =>
-          if (truncate > 0) {
-            StringUtils.leftPad(cell.toString, colWidths(i))
-          } else {
-            StringUtils.rightPad(cell.toString, colWidths(i))
-          }
-        }.addString(sb, "|", "|", "|\n")
-      }
-
+      paddedRows.tail.foreach(_.addString(sb, "|", "|", "|\n"))
       sb.append(sep)
     } else {
       // Extended display mode enabled
@@ -346,7 +359,7 @@ class Dataset[T] private[sql](
     }
 
     // Print a footer
-    if (vertical && data.isEmpty) {
+    if (vertical && rows.tail.isEmpty) {
       // In a vertical mode, print an empty row set explicitly
       sb.append("(0 rows)\n")
     } else if (hasMoreData) {
@@ -3209,6 +3222,19 @@ class Dataset[T] private[sql](
     }
   }
 
+  private[sql] def getRowsToPython(
+      _numRows: Int,
+      truncate: Int,
+      vertical: Boolean): Array[Any] = {
+    EvaluatePython.registerPicklers()
+    val numRows = _numRows.max(0).min(Int.MaxValue - 1)
+    val rows = getRows(numRows, truncate, vertical).map(_.toArray).toArray
+    val toJava: (Any) => Any = EvaluatePython.toJava(_, 
ArrayType(ArrayType(StringType)))
+    val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler(
+      rows.iterator.map(toJava))
+    PythonRDD.serveIterator(iter, "serve-GetRows")
+  }
+
   /**
    * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark.
    */


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to