Repository: spark Updated Branches: refs/heads/master f92f53fae -> a60f91284
[SPARK-13467] [PYSPARK] abstract python function to simplify pyspark code ## What changes were proposed in this pull request? When we pass a Python function to JVM side, we also need to send its context, e.g. `envVars`, `pythonIncludes`, `pythonExec`, etc. However, it's annoying to pass around so many parameters at many places. This PR abstract python function along with its context, to simplify some pyspark code and make the logic more clear. ## How was the this patch tested? by existing unit tests. Author: Wenchen Fan <wenc...@databricks.com> Closes #11342 from cloud-fan/python-clean. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a60f9128 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a60f9128 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a60f9128 Branch: refs/heads/master Commit: a60f91284ceee64de13f04559ec19c13a820a133 Parents: f92f53f Author: Wenchen Fan <wenc...@databricks.com> Authored: Wed Feb 24 12:44:54 2016 -0800 Committer: Davies Liu <davies....@gmail.com> Committed: Wed Feb 24 12:44:54 2016 -0800 ---------------------------------------------------------------------- .../org/apache/spark/api/python/PythonRDD.scala | 37 ++++++++++++-------- python/pyspark/rdd.py | 23 +++++++----- python/pyspark/sql/context.py | 2 +- python/pyspark/sql/functions.py | 8 ++--- .../org/apache/spark/sql/UDFRegistration.scala | 8 ++--- .../python/BatchPythonEvaluation.scala | 8 +---- .../spark/sql/execution/python/PythonUDF.scala | 13 ++----- .../python/UserDefinedPythonFunction.scala | 15 ++------ 8 files changed, 51 insertions(+), 63 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/a60f9128/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index f12e2df..05d1c31 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -42,14 +42,8 @@ import org.apache.spark.util.{SerializableConfiguration, Utils} private[spark] class PythonRDD( parent: RDD[_], - command: Array[Byte], - envVars: JMap[String, String], - pythonIncludes: JList[String], - preservePartitoning: Boolean, - pythonExec: String, - pythonVer: String, - broadcastVars: JList[Broadcast[PythonBroadcast]], - accumulator: Accumulator[JList[Array[Byte]]]) + func: PythonFunction, + preservePartitoning: Boolean) extends RDD[Array[Byte]](parent) { val bufferSize = conf.getInt("spark.buffer.size", 65536) @@ -64,29 +58,37 @@ private[spark] class PythonRDD( val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { - val runner = new PythonRunner( - command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars, accumulator, - bufferSize, reuse_worker) + val runner = new PythonRunner(func, bufferSize, reuse_worker) runner.compute(firstParent.iterator(split, context), split.index, context) } } - /** - * A helper class to run Python UDFs in Spark. + * A wrapper for a Python function, contains all necessary context to run the function in Python + * runner. */ -private[spark] class PythonRunner( +private[spark] case class PythonFunction( command: Array[Byte], envVars: JMap[String, String], pythonIncludes: JList[String], pythonExec: String, pythonVer: String, broadcastVars: JList[Broadcast[PythonBroadcast]], - accumulator: Accumulator[JList[Array[Byte]]], + accumulator: Accumulator[JList[Array[Byte]]]) + +/** + * A helper class to run Python UDFs in Spark. + */ +private[spark] class PythonRunner( + func: PythonFunction, bufferSize: Int, reuse_worker: Boolean) extends Logging { + private val envVars = func.envVars + private val pythonExec = func.pythonExec + private val accumulator = func.accumulator + def compute( inputIterator: Iterator[_], partitionIndex: Int, @@ -225,6 +227,11 @@ private[spark] class PythonRunner( @volatile private var _exception: Exception = null + private val pythonVer = func.pythonVer + private val pythonIncludes = func.pythonIncludes + private val broadcastVars = func.broadcastVars + private val command = func.command + setDaemon(true) /** Contains the exception thrown while writing the parent iterator to the Python process. */ http://git-wip-us.apache.org/repos/asf/spark/blob/a60f9128/python/pyspark/rdd.py ---------------------------------------------------------------------- diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 4eaf589..37574ce 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2309,7 +2309,7 @@ class RDD(object): yield row -def _prepare_for_python_RDD(sc, command, obj=None): +def _prepare_for_python_RDD(sc, command): # the serialized command will be compressed by broadcast ser = CloudPickleSerializer() pickled_command = ser.dumps(command) @@ -2329,6 +2329,15 @@ def _prepare_for_python_RDD(sc, command, obj=None): return pickled_command, broadcast_vars, env, includes +def _wrap_function(sc, func, deserializer, serializer, profiler=None): + assert deserializer, "deserializer should not be empty" + assert serializer, "serializer should not be empty" + command = (func, profiler, deserializer, serializer) + pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command) + return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec, + sc.pythonVer, broadcast_vars, sc._javaAccumulator) + + class PipelinedRDD(RDD): """ @@ -2390,14 +2399,10 @@ class PipelinedRDD(RDD): else: profiler = None - command = (self.func, profiler, self._prev_jrdd_deserializer, - self._jrdd_deserializer) - pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self.ctx, command, self) - python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), - bytearray(pickled_cmd), - env, includes, self.preservesPartitioning, - self.ctx.pythonExec, self.ctx.pythonVer, - bvars, self.ctx._javaAccumulator) + wrapped_func = _wrap_function(self.ctx, self.func, self._prev_jrdd_deserializer, + self._jrdd_deserializer, profiler) + python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), wrapped_func, + self.preservesPartitioning) self._jrdd_val = python_rdd.asJavaRDD() if profiler: http://git-wip-us.apache.org/repos/asf/spark/blob/a60f9128/python/pyspark/sql/context.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 89bf144..87e32c0 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -29,7 +29,7 @@ else: from py4j.protocol import Py4JError from pyspark import since -from pyspark.rdd import RDD, _prepare_for_python_RDD, ignore_unicode_prefix +from pyspark.rdd import RDD, ignore_unicode_prefix from pyspark.serializers import AutoBatchedSerializer, PickleSerializer from pyspark.sql.types import Row, StringType, StructType, _verify_type, \ _infer_schema, _has_nulltype, _merge_type, _create_converter http://git-wip-us.apache.org/repos/asf/spark/blob/a60f9128/python/pyspark/sql/functions.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 6894c27..b30cc67 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -25,7 +25,7 @@ if sys.version < "3": from itertools import imap as map from pyspark import since, SparkContext -from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix +from pyspark.rdd import _wrap_function, ignore_unicode_prefix from pyspark.serializers import PickleSerializer, AutoBatchedSerializer from pyspark.sql.types import StringType from pyspark.sql.column import Column, _to_java_column, _to_seq @@ -1645,16 +1645,14 @@ class UserDefinedFunction(object): f, returnType = self.func, self.returnType # put them in closure `func` func = lambda _, it: map(lambda x: returnType.toInternal(f(*x)), it) ser = AutoBatchedSerializer(PickleSerializer()) - command = (func, None, ser, ser) sc = SparkContext.getOrCreate() - pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self) + wrapped_func = _wrap_function(sc, func, ser, ser) ctx = SQLContext.getOrCreate(sc) jdt = ctx._ssql_ctx.parseDataType(self.returnType.json()) if name is None: name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__ judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( - name, bytearray(pickled_command), env, includes, sc.pythonExec, sc.pythonVer, - broadcast_vars, sc._javaAccumulator, jdt) + name, wrapped_func, jdt) return judf def __del__(self): http://git-wip-us.apache.org/repos/asf/spark/blob/a60f9128/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index ecfc170..de01cbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -43,10 +43,10 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { s""" | Registering new PythonUDF: | name: $name - | command: ${udf.command.toSeq} - | envVars: ${udf.envVars} - | pythonIncludes: ${udf.pythonIncludes} - | pythonExec: ${udf.pythonExec} + | command: ${udf.func.command.toSeq} + | envVars: ${udf.func.envVars} + | pythonIncludes: ${udf.func.pythonIncludes} + | pythonExec: ${udf.func.pythonExec} | dataType: ${udf.dataType} """.stripMargin) http://git-wip-us.apache.org/repos/asf/spark/blob/a60f9128/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala index 00df019..c65a7bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala @@ -76,13 +76,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: // Output iterator for results from Python. val outputIterator = new PythonRunner( - udf.command, - udf.envVars, - udf.pythonIncludes, - udf.pythonExec, - udf.pythonVer, - udf.broadcastVars, - udf.accumulator, + udf.func, bufferSize, reuseWorker ).compute(inputIterator, context.partitionId(), context) http://git-wip-us.apache.org/repos/asf/spark/blob/a60f9128/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala index 9aff0be..0aa2785 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala @@ -17,9 +17,8 @@ package org.apache.spark.sql.execution.python -import org.apache.spark.{Accumulator, Logging} -import org.apache.spark.api.python.PythonBroadcast -import org.apache.spark.broadcast.Broadcast +import org.apache.spark.Logging +import org.apache.spark.api.python.PythonFunction import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable} import org.apache.spark.sql.types.DataType @@ -28,13 +27,7 @@ import org.apache.spark.sql.types.DataType */ case class PythonUDF( name: String, - command: Array[Byte], - envVars: java.util.Map[String, String], - pythonIncludes: java.util.List[String], - pythonExec: String, - pythonVer: String, - broadcastVars: java.util.List[Broadcast[PythonBroadcast]], - accumulator: Accumulator[java.util.List[Array[Byte]]], + func: PythonFunction, dataType: DataType, children: Seq[Expression]) extends Expression with Unevaluable with NonSQLExpression with Logging { http://git-wip-us.apache.org/repos/asf/spark/blob/a60f9128/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index 79ac1c8..d301874 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql.execution.python -import org.apache.spark.Accumulator -import org.apache.spark.api.python.PythonBroadcast -import org.apache.spark.broadcast.Broadcast +import org.apache.spark.api.python.PythonFunction import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.Column import org.apache.spark.sql.types.DataType @@ -29,18 +27,11 @@ import org.apache.spark.sql.types.DataType */ case class UserDefinedPythonFunction( name: String, - command: Array[Byte], - envVars: java.util.Map[String, String], - pythonIncludes: java.util.List[String], - pythonExec: String, - pythonVer: String, - broadcastVars: java.util.List[Broadcast[PythonBroadcast]], - accumulator: Accumulator[java.util.List[Array[Byte]]], + func: PythonFunction, dataType: DataType) { def builder(e: Seq[Expression]): PythonUDF = { - PythonUDF(name, command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars, - accumulator, dataType, e) + PythonUDF(name, func, dataType, e) } /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */ --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org