Repository: spark Updated Branches: refs/heads/master 248067adb -> c5cbc4923
[SPARK-3335] [SQL] [PySpark] support broadcast in Python UDF After this patch, broadcast can be used in Python UDF. Author: Davies Liu <davies....@gmail.com> Closes #2243 from davies/udf_broadcast and squashes the following commits: 7b88861 [Davies Liu] support broadcast in UDF Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c5cbc492 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c5cbc492 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c5cbc492 Branch: refs/heads/master Commit: c5cbc49233193836b321cb6b77ce69dae798570b Parents: 248067a Author: Davies Liu <davies....@gmail.com> Authored: Wed Sep 3 19:08:39 2014 -0700 Committer: Michael Armbrust <mich...@databricks.com> Committed: Wed Sep 3 19:08:39 2014 -0700 ---------------------------------------------------------------------- python/pyspark/sql.py | 17 ++++++++------- python/pyspark/tests.py | 22 ++++++++++++++++++++ .../org/apache/spark/sql/UdfRegistration.scala | 3 +++ .../apache/spark/sql/execution/pythonUdfs.scala | 3 ++- 4 files changed, 36 insertions(+), 9 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/c5cbc492/python/pyspark/sql.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 4431692..aaa35da 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -942,9 +942,7 @@ class SQLContext: self._jsc = self._sc._jsc self._jvm = self._sc._jvm self._pythonToJava = self._jvm.PythonRDD.pythonToJavaArray - - if sqlContext: - self._scala_SQLContext = sqlContext + self._scala_SQLContext = sqlContext @property def _ssql_ctx(self): @@ -953,7 +951,7 @@ class SQLContext: Subclasses can override this property to provide their own JVM Contexts. """ - if not hasattr(self, '_scala_SQLContext'): + if self._scala_SQLContext is None: self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc()) return self._scala_SQLContext @@ -970,23 +968,26 @@ class SQLContext: >>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect() [Row(c0=4)] - >>> sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) - >>> sqlCtx.sql("SELECT twoArgs('test', 1)").collect() - [Row(c0=5)] """ func = lambda _, it: imap(lambda x: f(*x), it) command = (func, BatchedSerializer(PickleSerializer(), 1024), BatchedSerializer(PickleSerializer(), 1024)) + pickled_command = CloudPickleSerializer().dumps(command) + broadcast_vars = ListConverter().convert( + [x._jbroadcast for x in self._sc._pickled_broadcast_vars], + self._sc._gateway._gateway_client) + self._sc._pickled_broadcast_vars.clear() env = MapConverter().convert(self._sc.environment, self._sc._gateway._gateway_client) includes = ListConverter().convert(self._sc._python_includes, self._sc._gateway._gateway_client) self._ssql_ctx.registerPython(name, - bytearray(CloudPickleSerializer().dumps(command)), + bytearray(pickled_command), env, includes, self._sc.pythonExec, + broadcast_vars, self._sc._javaAccumulator, str(returnType)) http://git-wip-us.apache.org/repos/asf/spark/blob/c5cbc492/python/pyspark/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index f1a75cb..3e74799 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -43,6 +43,7 @@ from pyspark.context import SparkContext from pyspark.files import SparkFiles from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter +from pyspark.sql import SQLContext, IntegerType _have_scipy = False _have_numpy = False @@ -525,6 +526,27 @@ class TestRDDFunctions(PySparkTestCase): self.assertRaises(TypeError, lambda: rdd.histogram(2)) +class TestSQL(PySparkTestCase): + + def setUp(self): + PySparkTestCase.setUp(self) + self.sqlCtx = SQLContext(self.sc) + + def test_udf(self): + self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) + [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect() + self.assertEqual(row[0], 5) + + def test_broadcast_in_udf(self): + bar = {"a": "aa", "b": "bb", "c": "abc"} + foo = self.sc.broadcast(bar) + self.sqlCtx.registerFunction("MYUDF", lambda x: foo.value[x] if x else '') + [res] = self.sqlCtx.sql("SELECT MYUDF('c')").collect() + self.assertEqual("abc", res[0]) + [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect() + self.assertEqual("", res[0]) + + class TestIO(PySparkTestCase): def test_stdout_redirection(self): http://git-wip-us.apache.org/repos/asf/spark/blob/c5cbc492/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala index 0b48e9e..0ea1105 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import java.util.{List => JList, Map => JMap} import org.apache.spark.Accumulator +import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUdf} import org.apache.spark.sql.execution.PythonUDF @@ -38,6 +39,7 @@ protected[sql] trait UDFRegistration { envVars: JMap[String, String], pythonIncludes: JList[String], pythonExec: String, + broadcastVars: JList[Broadcast[Array[Byte]]], accumulator: Accumulator[JList[Array[Byte]]], stringDataType: String): Unit = { log.debug( @@ -61,6 +63,7 @@ protected[sql] trait UDFRegistration { envVars, pythonIncludes, pythonExec, + broadcastVars, accumulator, dataType, e) http://git-wip-us.apache.org/repos/asf/spark/blob/c5cbc492/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 3dc8be2..0977da3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -42,6 +42,7 @@ private[spark] case class PythonUDF( envVars: JMap[String, String], pythonIncludes: JList[String], pythonExec: String, + broadcastVars: JList[Broadcast[Array[Byte]]], accumulator: Accumulator[JList[Array[Byte]]], dataType: DataType, children: Seq[Expression]) extends Expression with SparkLogging { @@ -145,7 +146,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: udf.pythonIncludes, false, udf.pythonExec, - Seq[Broadcast[Array[Byte]]](), + udf.broadcastVars, udf.accumulator ).mapPartitions { iter => val pickle = new Unpickler --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org