Repository: spark
Updated Branches:
  refs/heads/branch-2.3 dfdf1bb9b -> d4f204c53


[SPARK-23942][PYTHON][SQL][BRANCH-2.3] Makes collect in PySpark as action for a 
query executor listener

## What changes were proposed in this pull request?

This PR proposes to add `collect` to  a query executor as an action.

Seems `collect` / `collect` with Arrow are not recognised via 
`QueryExecutionListener` as an action. For example, if we have a custom 
listener as below:

```scala
package org.apache.spark.sql

import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.util.QueryExecutionListener

class TestQueryExecutionListener extends QueryExecutionListener with Logging {
  override def onSuccess(funcName: String, qe: QueryExecution, durationNs: 
Long): Unit = {
    logError("Look at me! I'm 'onSuccess'")
  }

  override def onFailure(funcName: String, qe: QueryExecution, exception: 
Exception): Unit = { }
}
```
and set `spark.sql.queryExecutionListeners` to 
`org.apache.spark.sql.TestQueryExecutionListener`

Other operations in PySpark or Scala side seems fine:

```python
>>> sql("SELECT * FROM range(1)").show()
```
```
18/04/09 17:02:04 ERROR TestQueryExecutionListener: Look at me! I'm 'onSuccess'
+---+
| id|
+---+
|  0|
+---+
```

```scala
scala> sql("SELECT * FROM range(1)").collect()
```
```
18/04/09 16:58:41 ERROR TestQueryExecutionListener: Look at me! I'm 'onSuccess'
res1: Array[org.apache.spark.sql.Row] = Array([0])
```

but ..

**Before**

```python
>>> sql("SELECT * FROM range(1)").collect()
```
```
[Row(id=0)]
```

```python
>>> spark.conf.set("spark.sql.execution.arrow.enabled", "true")
>>> sql("SELECT * FROM range(1)").toPandas()
```
```
   id
0   0
```

**After**

```python
>>> sql("SELECT * FROM range(1)").collect()
```
```
18/04/09 16:57:58 ERROR TestQueryExecutionListener: Look at me! I'm 'onSuccess'
[Row(id=0)]
```

```python
>>> spark.conf.set("spark.sql.execution.arrow.enabled", "true")
>>> sql("SELECT * FROM range(1)").toPandas()
```
```
18/04/09 17:53:26 ERROR TestQueryExecutionListener: Look at me! I'm 'onSuccess'
   id
0   0
```

## How was this patch tested?

I have manually tested as described above and unit test was added.

Author: hyukjinkwon <gurwls...@apache.org>

Closes #21060 from HyukjinKwon/PR_TOOL_PICK_PR_21007_BRANCH-2.3.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d4f204c5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d4f204c5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d4f204c5

Branch: refs/heads/branch-2.3
Commit: d4f204c5321cdc3955a48e9717ba06aaebbc2ab4
Parents: dfdf1bb
Author: hyukjinkwon <gurwls...@apache.org>
Authored: Sat Apr 14 21:44:06 2018 +0800
Committer: hyukjinkwon <gurwls...@apache.org>
Committed: Sat Apr 14 21:44:06 2018 +0800

----------------------------------------------------------------------
 python/pyspark/sql/tests.py                     | 99 ++++++++++++++++----
 .../scala/org/apache/spark/sql/Dataset.scala    | 20 ++--
 .../spark/sql/TestQueryExecutionListener.scala  | 44 +++++++++
 3 files changed, 140 insertions(+), 23 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d4f204c5/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index d806e5d..daa97e1 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -185,22 +185,12 @@ class MyObject(object):
         self.value = value
 
 
-class ReusedSQLTestCase(ReusedPySparkTestCase):
-    @classmethod
-    def setUpClass(cls):
-        ReusedPySparkTestCase.setUpClass()
-        cls.spark = SparkSession(cls.sc)
-
-    @classmethod
-    def tearDownClass(cls):
-        ReusedPySparkTestCase.tearDownClass()
-        cls.spark.stop()
-
-    def assertPandasEqual(self, expected, result):
-        msg = ("DataFrames are not equal: " +
-               "\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) +
-               "\n\nResult:\n%s\n%s" % (result, result.dtypes))
-        self.assertTrue(expected.equals(result), msg=msg)
+class SQLTestUtils(object):
+    """
+    This util assumes the instance of this to have 'spark' attribute, having a 
spark session.
+    It is usually used with 'ReusedSQLTestCase' class but can be used if you 
feel sure the
+    the implementation of this class has 'spark' attribute.
+    """
 
     @contextmanager
     def sql_conf(self, pairs):
@@ -209,6 +199,7 @@ class ReusedSQLTestCase(ReusedPySparkTestCase):
         `value` to the configuration `key` and then restores it back when it 
exits.
         """
         assert isinstance(pairs, dict), "pairs should be a dictionary."
+        assert hasattr(self, "spark"), "it should have 'spark' attribute, 
having a spark session."
 
         keys = pairs.keys()
         new_values = pairs.values()
@@ -225,6 +216,24 @@ class ReusedSQLTestCase(ReusedPySparkTestCase):
                     self.spark.conf.set(key, old_value)
 
 
+class ReusedSQLTestCase(ReusedPySparkTestCase, SQLTestUtils):
+    @classmethod
+    def setUpClass(cls):
+        ReusedPySparkTestCase.setUpClass()
+        cls.spark = SparkSession(cls.sc)
+
+    @classmethod
+    def tearDownClass(cls):
+        ReusedPySparkTestCase.tearDownClass()
+        cls.spark.stop()
+
+    def assertPandasEqual(self, expected, result):
+        msg = ("DataFrames are not equal: " +
+               "\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) +
+               "\n\nResult:\n%s\n%s" % (result, result.dtypes))
+        self.assertTrue(expected.equals(result), msg=msg)
+
+
 class DataTypeTests(unittest.TestCase):
     # regression test for SPARK-6055
     def test_data_type_eq(self):
@@ -2980,6 +2989,64 @@ class SQLTests2(ReusedSQLTestCase):
             sc.stop()
 
 
+class QueryExecutionListenerTests(unittest.TestCase, SQLTestUtils):
+    # These tests are separate because it uses 
'spark.sql.queryExecutionListeners' which is
+    # static and immutable. This can't be set or unset, for example, via 
`spark.conf`.
+
+    @classmethod
+    def setUpClass(cls):
+        import glob
+        from pyspark.find_spark_home import _find_spark_home
+
+        SPARK_HOME = _find_spark_home()
+        filename_pattern = (
+            "sql/core/target/scala-*/test-classes/org/apache/spark/sql/"
+            "TestQueryExecutionListener.class")
+        if not glob.glob(os.path.join(SPARK_HOME, filename_pattern)):
+            raise unittest.SkipTest(
+                "'org.apache.spark.sql.TestQueryExecutionListener' is not "
+                "available. Will skip the related tests.")
+
+        # Note that 'spark.sql.queryExecutionListeners' is a static immutable 
configuration.
+        cls.spark = SparkSession.builder \
+            .master("local[4]") \
+            .appName(cls.__name__) \
+            .config(
+                "spark.sql.queryExecutionListeners",
+                "org.apache.spark.sql.TestQueryExecutionListener") \
+            .getOrCreate()
+
+    @classmethod
+    def tearDownClass(cls):
+        cls.spark.stop()
+
+    def tearDown(self):
+        self.spark._jvm.OnSuccessCall.clear()
+
+    def test_query_execution_listener_on_collect(self):
+        self.assertFalse(
+            self.spark._jvm.OnSuccessCall.isCalled(),
+            "The callback from the query execution listener should not be 
called before 'collect'")
+        self.spark.sql("SELECT * FROM range(1)").collect()
+        self.assertTrue(
+            self.spark._jvm.OnSuccessCall.isCalled(),
+            "The callback from the query execution listener should be called 
after 'collect'")
+
+    @unittest.skipIf(
+        not _have_pandas or not _have_pyarrow,
+        _pandas_requirement_message or _pyarrow_requirement_message)
+    def test_query_execution_listener_on_collect_with_arrow(self):
+        with self.sql_conf({"spark.sql.execution.arrow.enabled": True}):
+            self.assertFalse(
+                self.spark._jvm.OnSuccessCall.isCalled(),
+                "The callback from the query execution listener should not be "
+                "called before 'toPandas'")
+            self.spark.sql("SELECT * FROM range(1)").toPandas()
+            self.assertTrue(
+                self.spark._jvm.OnSuccessCall.isCalled(),
+                "The callback from the query execution listener should be 
called after 'toPandas'")
+
+
 class UDFInitializationTests(unittest.TestCase):
     def tearDown(self):
         if SparkSession._instantiatedSession is not None:

http://git-wip-us.apache.org/repos/asf/spark/blob/d4f204c5/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 0aee1d7..9171681 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -3189,10 +3189,10 @@ class Dataset[T] private[sql](
 
   private[sql] def collectToPython(): Int = {
     EvaluatePython.registerPicklers()
-    withNewExecutionId {
+    withAction("collectToPython", queryExecution) { plan =>
       val toJava: (Any) => Any = EvaluatePython.toJava(_, schema)
-      val iter = new SerDeUtil.AutoBatchedPickler(
-        queryExecution.executedPlan.executeCollect().iterator.map(toJava))
+      val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler(
+        plan.executeCollect().iterator.map(toJava))
       PythonRDD.serveIterator(iter, "serve-DataFrame")
     }
   }
@@ -3201,8 +3201,9 @@ class Dataset[T] private[sql](
    * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark.
    */
   private[sql] def collectAsArrowToPython(): Int = {
-    withNewExecutionId {
-      val iter = toArrowPayload.collect().iterator.map(_.asPythonSerializable)
+    withAction("collectAsArrowToPython", queryExecution) { plan =>
+      val iter: Iterator[Array[Byte]] =
+        toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable)
       PythonRDD.serveIterator(iter, "serve-Arrow")
     }
   }
@@ -3311,14 +3312,19 @@ class Dataset[T] private[sql](
   }
 
   /** Convert to an RDD of ArrowPayload byte arrays */
-  private[sql] def toArrowPayload: RDD[ArrowPayload] = {
+  private[sql] def toArrowPayload(plan: SparkPlan): RDD[ArrowPayload] = {
     val schemaCaptured = this.schema
     val maxRecordsPerBatch = 
sparkSession.sessionState.conf.arrowMaxRecordsPerBatch
     val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
-    queryExecution.toRdd.mapPartitionsInternal { iter =>
+    plan.execute().mapPartitionsInternal { iter =>
       val context = TaskContext.get()
       ArrowConverters.toPayloadIterator(
         iter, schemaCaptured, maxRecordsPerBatch, timeZoneId, context)
     }
   }
+
+  // This is only used in tests, for now.
+  private[sql] def toArrowPayload: RDD[ArrowPayload] = {
+    toArrowPayload(queryExecution.executedPlan)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/d4f204c5/sql/core/src/test/scala/org/apache/spark/sql/TestQueryExecutionListener.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/TestQueryExecutionListener.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/TestQueryExecutionListener.scala
new file mode 100644
index 0000000..d2a6358
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/TestQueryExecutionListener.scala
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.util.concurrent.atomic.AtomicBoolean
+
+import org.apache.spark.sql.execution.QueryExecution
+import org.apache.spark.sql.util.QueryExecutionListener
+
+
+class TestQueryExecutionListener extends QueryExecutionListener {
+  override def onSuccess(funcName: String, qe: QueryExecution, durationNs: 
Long): Unit = {
+    OnSuccessCall.isOnSuccessCalled.set(true)
+  }
+
+  override def onFailure(funcName: String, qe: QueryExecution, exception: 
Exception): Unit = { }
+}
+
+/**
+ * This has a variable to check if `onSuccess` is actually called or not. 
Currently, this is for
+ * the test case in PySpark. See SPARK-23942.
+ */
+object OnSuccessCall {
+  val isOnSuccessCalled = new AtomicBoolean(false)
+
+  def isCalled(): Boolean = isOnSuccessCalled.get()
+
+  def clear(): Unit = isOnSuccessCalled.set(false)
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to