Repository: spark Updated Branches: refs/heads/master a1a64e358 -> 6a0b77a55
[SPARK-24215][PYSPARK][FOLLOW UP] Implement eager evaluation for DataFrame APIs in PySpark ## What changes were proposed in this pull request? Address comments in #21370 and add more test. ## How was this patch tested? Enhance test in pyspark/sql/test.py and DataFrameSuite Author: Yuanjian Li <xyliyuanj...@gmail.com> Closes #21553 from xuanyuanking/SPARK-24215-follow. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6a0b77a5 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6a0b77a5 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6a0b77a5 Branch: refs/heads/master Commit: 6a0b77a55d53e74ac0a0892556c3a7a933474948 Parents: a1a64e3 Author: Yuanjian Li <xyliyuanj...@gmail.com> Authored: Wed Jun 27 10:43:06 2018 -0700 Committer: Xiao Li <gatorsm...@gmail.com> Committed: Wed Jun 27 10:43:06 2018 -0700 ---------------------------------------------------------------------- docs/configuration.md | 27 --------- python/pyspark/sql/dataframe.py | 3 +- python/pyspark/sql/tests.py | 46 ++++++++++++++- .../org/apache/spark/sql/internal/SQLConf.scala | 23 ++++++++ .../scala/org/apache/spark/sql/Dataset.scala | 11 ++-- .../org/apache/spark/sql/DataFrameSuite.scala | 59 ++++++++++++++++++++ 6 files changed, 131 insertions(+), 38 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/6a0b77a5/docs/configuration.md ---------------------------------------------------------------------- diff --git a/docs/configuration.md b/docs/configuration.md index 6aa7878..0c7c447 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -457,33 +457,6 @@ Apart from these, the following properties are also available, and may be useful </td> </tr> <tr> - <td><code>spark.sql.repl.eagerEval.enabled</code></td> - <td>false</td> - <td> - Enable eager evaluation or not. If true and the REPL you are using supports eager evaluation, - Dataset will be ran automatically. The HTML table which generated by <code>_repl_html_</code> - called by notebooks like Jupyter will feedback the queries user have defined. For plain Python - REPL, the output will be shown like <code>dataframe.show()</code> - (see <a href="https://issues.apache.org/jira/browse/SPARK-24215">SPARK-24215</a> for more details). - </td> -</tr> -<tr> - <td><code>spark.sql.repl.eagerEval.maxNumRows</code></td> - <td>20</td> - <td> - Default number of rows in eager evaluation output HTML table generated by <code>_repr_html_</code> or plain text, - this only take effect when <code>spark.sql.repl.eagerEval.enabled</code> is set to true. - </td> -</tr> -<tr> - <td><code>spark.sql.repl.eagerEval.truncate</code></td> - <td>20</td> - <td> - Default number of truncate in eager evaluation output HTML table generated by <code>_repr_html_</code> or - plain text, this only take effect when <code>spark.sql.repl.eagerEval.enabled</code> set to true. - </td> -</tr> -<tr> <td><code>spark.files</code></td> <td></td> <td> http://git-wip-us.apache.org/repos/asf/spark/blob/6a0b77a5/python/pyspark/sql/dataframe.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 1e6a1ac..cb3fe44 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -393,9 +393,8 @@ class DataFrame(object): self._support_repr_html = True if self._eager_eval: max_num_rows = max(self._max_num_rows, 0) - vertical = False sock_info = self._jdf.getRowsToPython( - max_num_rows, self._truncate, vertical) + max_num_rows, self._truncate) rows = list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer()))) head = rows[0] row_data = rows[1:] http://git-wip-us.apache.org/repos/asf/spark/blob/6a0b77a5/python/pyspark/sql/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 35a0636..8d73806 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3351,11 +3351,41 @@ class SQLTests(ReusedSQLTestCase): finally: shutil.rmtree(path) - def test_repr_html(self): + def test_repr_behaviors(self): import re pattern = re.compile(r'^ *\|', re.MULTILINE) df = self.spark.createDataFrame([(1, "1"), (22222, "22222")], ("key", "value")) - self.assertEquals(None, df._repr_html_()) + + # test when eager evaluation is enabled and _repr_html_ will not be called + with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}): + expected1 = """+-----+-----+ + || key|value| + |+-----+-----+ + || 1| 1| + ||22222|22222| + |+-----+-----+ + |""" + self.assertEquals(re.sub(pattern, '', expected1), df.__repr__()) + with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}): + expected2 = """+---+-----+ + ||key|value| + |+---+-----+ + || 1| 1| + ||222| 222| + |+---+-----+ + |""" + self.assertEquals(re.sub(pattern, '', expected2), df.__repr__()) + with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}): + expected3 = """+---+-----+ + ||key|value| + |+---+-----+ + || 1| 1| + |+---+-----+ + |only showing top 1 row + |""" + self.assertEquals(re.sub(pattern, '', expected3), df.__repr__()) + + # test when eager evaluation is enabled and _repr_html_ will be called with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}): expected1 = """<table border='1'> |<tr><th>key</th><th>value</th></tr> @@ -3381,6 +3411,18 @@ class SQLTests(ReusedSQLTestCase): |""" self.assertEquals(re.sub(pattern, '', expected3), df._repr_html_()) + # test when eager evaluation is disabled and _repr_html_ will be called + with self.sql_conf({"spark.sql.repl.eagerEval.enabled": False}): + expected = "DataFrame[key: bigint, value: string]" + self.assertEquals(None, df._repr_html_()) + self.assertEquals(expected, df.__repr__()) + with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}): + self.assertEquals(None, df._repr_html_()) + self.assertEquals(expected, df.__repr__()) + with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}): + self.assertEquals(None, df._repr_html_()) + self.assertEquals(expected, df.__repr__()) + class HiveSparkSubmitTests(SparkSubmitTests): http://git-wip-us.apache.org/repos/asf/spark/blob/6a0b77a5/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 239c826..e1752ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1330,6 +1330,29 @@ object SQLConf { "The size function returns null for null input if the flag is disabled.") .booleanConf .createWithDefault(true) + + val REPL_EAGER_EVAL_ENABLED = buildConf("spark.sql.repl.eagerEval.enabled") + .doc("Enables eager evaluation or not. When true, the top K rows of Dataset will be " + + "displayed if and only if the REPL supports the eager evaluation. Currently, the " + + "eager evaluation is only supported in PySpark. For the notebooks like Jupyter, " + + "the HTML table (generated by _repr_html_) will be returned. For plain Python REPL, " + + "the returned outputs are formatted like dataframe.show().") + .booleanConf + .createWithDefault(false) + + val REPL_EAGER_EVAL_MAX_NUM_ROWS = buildConf("spark.sql.repl.eagerEval.maxNumRows") + .doc("The max number of rows that are returned by eager evaluation. This only takes " + + "effect when spark.sql.repl.eagerEval.enabled is set to true. The valid range of this " + + "config is from 0 to (Int.MaxValue - 1), so the invalid config like negative and " + + "greater than (Int.MaxValue - 1) will be normalized to 0 and (Int.MaxValue - 1).") + .intConf + .createWithDefault(20) + + val REPL_EAGER_EVAL_TRUNCATE = buildConf("spark.sql.repl.eagerEval.truncate") + .doc("The max number of characters for each cell that is returned by eager evaluation. " + + "This only takes effect when spark.sql.repl.eagerEval.enabled is set to true.") + .intConf + .createWithDefault(20) } /** http://git-wip-us.apache.org/repos/asf/spark/blob/6a0b77a5/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 57f1e17..2ec236f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -236,12 +236,10 @@ class Dataset[T] private[sql]( * @param numRows Number of rows to return * @param truncate If set to more than 0, truncates strings to `truncate` characters and * all cells will be aligned right. - * @param vertical If set to true, the rows to return do not need truncate. */ private[sql] def getRows( numRows: Int, - truncate: Int, - vertical: Boolean): Seq[Seq[String]] = { + truncate: Int): Seq[Seq[String]] = { val newDf = toDF() val castCols = newDf.logicalPlan.output.map { col => // Since binary types in top-level schema fields have a specific format to print, @@ -289,7 +287,7 @@ class Dataset[T] private[sql]( vertical: Boolean = false): String = { val numRows = _numRows.max(0).min(Int.MaxValue - 1) // Get rows represented by Seq[Seq[String]], we may get one more line if it has more data. - val tmpRows = getRows(numRows, truncate, vertical) + val tmpRows = getRows(numRows, truncate) val hasMoreData = tmpRows.length - 1 > numRows val rows = tmpRows.take(numRows + 1) @@ -3226,11 +3224,10 @@ class Dataset[T] private[sql]( private[sql] def getRowsToPython( _numRows: Int, - truncate: Int, - vertical: Boolean): Array[Any] = { + truncate: Int): Array[Any] = { EvaluatePython.registerPicklers() val numRows = _numRows.max(0).min(Int.MaxValue - 1) - val rows = getRows(numRows, truncate, vertical).map(_.toArray).toArray + val rows = getRows(numRows, truncate).map(_.toArray).toArray val toJava: (Any) => Any = EvaluatePython.toJava(_, ArrayType(ArrayType(StringType))) val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler( rows.iterator.map(toJava)) http://git-wip-us.apache.org/repos/asf/spark/blob/6a0b77a5/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 1cc8cb3..ea00d22 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1044,6 +1044,65 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { testData.select($"*").show(1000) } + test("getRows: truncate = [0, 20]") { + val longString = Array.fill(21)("1").mkString + val df = sparkContext.parallelize(Seq("1", longString)).toDF() + val expectedAnswerForFalse = Seq( + Seq("value"), + Seq("1"), + Seq("111111111111111111111")) + assert(df.getRows(10, 0) === expectedAnswerForFalse) + val expectedAnswerForTrue = Seq( + Seq("value"), + Seq("1"), + Seq("11111111111111111...")) + assert(df.getRows(10, 20) === expectedAnswerForTrue) + } + + test("getRows: truncate = [3, 17]") { + val longString = Array.fill(21)("1").mkString + val df = sparkContext.parallelize(Seq("1", longString)).toDF() + val expectedAnswerForFalse = Seq( + Seq("value"), + Seq("1"), + Seq("111")) + assert(df.getRows(10, 3) === expectedAnswerForFalse) + val expectedAnswerForTrue = Seq( + Seq("value"), + Seq("1"), + Seq("11111111111111...")) + assert(df.getRows(10, 17) === expectedAnswerForTrue) + } + + test("getRows: numRows = 0") { + val expectedAnswer = Seq(Seq("key", "value"), Seq("1", "1")) + assert(testData.select($"*").getRows(0, 20) === expectedAnswer) + } + + test("getRows: array") { + val df = Seq( + (Array(1, 2, 3), Array(1, 2, 3)), + (Array(2, 3, 4), Array(2, 3, 4)) + ).toDF() + val expectedAnswer = Seq( + Seq("_1", "_2"), + Seq("[1, 2, 3]", "[1, 2, 3]"), + Seq("[2, 3, 4]", "[2, 3, 4]")) + assert(df.getRows(10, 20) === expectedAnswer) + } + + test("getRows: binary") { + val df = Seq( + ("12".getBytes(StandardCharsets.UTF_8), "ABC.".getBytes(StandardCharsets.UTF_8)), + ("34".getBytes(StandardCharsets.UTF_8), "12346".getBytes(StandardCharsets.UTF_8)) + ).toDF() + val expectedAnswer = Seq( + Seq("_1", "_2"), + Seq("[31 32]", "[41 42 43 2E]"), + Seq("[33 34]", "[31 32 33 34 36]")) + assert(df.getRows(10, 20) === expectedAnswer) + } + test("showString: truncate = [0, 20]") { val longString = Array.fill(21)("1").mkString val df = sparkContext.parallelize(Seq("1", longString)).toDF() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org