Repository: spark
Updated Branches:
  refs/heads/master 067afb4e9 -> 995221774


[SPARK-10731] [SQL] Delegate to Scala's DataFrame.take implementation in Python 
DataFrame.

Python DataFrame.head/take now requires scanning all the partitions. This pull 
request changes them to delegate the actual implementation to Scala DataFrame 
(by calling DataFrame.take).

This is more of a hack for fixing this issue in 1.5.1. A more proper fix is to 
change executeCollect and executeTake to return InternalRow rather than Row, 
and thus eliminate the extra round-trip conversion.

Author: Reynold Xin <r...@databricks.com>

Closes #8876 from rxin/SPARK-10731.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/99522177
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/99522177
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/99522177

Branch: refs/heads/master
Commit: 9952217749118ae78fe794ca11e1c4a87a4ae8ba
Parents: 067afb4
Author: Reynold Xin <r...@databricks.com>
Authored: Wed Sep 23 16:43:21 2015 -0700
Committer: Reynold Xin <r...@databricks.com>
Committed: Wed Sep 23 16:43:21 2015 -0700

----------------------------------------------------------------------
 .../org/apache/spark/api/python/PythonRDD.scala |   2 +-
 python/pyspark/sql/dataframe.py                 |   5 +-
 .../org/apache/spark/sql/execution/python.scala | 417 +++++++++++++++++++
 .../apache/spark/sql/execution/pythonUDFs.scala | 405 ------------------
 .../apache/spark/sql/test/ExamplePointUDT.scala |  16 +-
 5 files changed, 429 insertions(+), 416 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/99522177/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 19be093..8464b57 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -633,7 +633,7 @@ private[spark] object PythonRDD extends Logging {
    *
    * The thread will terminate after all the data are sent or any exceptions 
happen.
    */
-  private def serveIterator[T](items: Iterator[T], threadName: String): Int = {
+  def serveIterator[T](items: Iterator[T], threadName: String): Int = {
     val serverSocket = new ServerSocket(0, 1, 
InetAddress.getByName("localhost"))
     // Close the socket if no connection in 3 seconds
     serverSocket.setSoTimeout(3000)

http://git-wip-us.apache.org/repos/asf/spark/blob/99522177/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 80f8d8a..b09422a 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -300,7 +300,10 @@ class DataFrame(object):
         >>> df.take(2)
         [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
         """
-        return self.limit(num).collect()
+        with SCCallSiteSync(self._sc) as css:
+            port = 
self._sc._jvm.org.apache.spark.sql.execution.EvaluatePython.takeAndServe(
+                self._jdf, num)
+        return list(_load_from_socket(port, 
BatchedSerializer(PickleSerializer())))
 
     @ignore_unicode_prefix
     @since(1.3)

http://git-wip-us.apache.org/repos/asf/spark/blob/99522177/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala
new file mode 100644
index 0000000..d6aaf42
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala
@@ -0,0 +1,417 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements.  See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License.  You may obtain a copy of the License at
+*
+*    http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution
+
+import java.io.OutputStream
+import java.util.{List => JList, Map => JMap}
+
+import scala.collection.JavaConverters._
+
+import net.razorvine.pickle._
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.api.python.{PythonRunner, PythonBroadcast, PythonRDD, 
SerDeUtil}
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.{Logging => SparkLogging, TaskContext, Accumulator}
+
+/**
+ * A serialized version of a Python lambda function.  Suitable for use in a 
[[PythonRDD]].
+ */
+private[spark] case class PythonUDF(
+    name: String,
+    command: Array[Byte],
+    envVars: JMap[String, String],
+    pythonIncludes: JList[String],
+    pythonExec: String,
+    pythonVer: String,
+    broadcastVars: JList[Broadcast[PythonBroadcast]],
+    accumulator: Accumulator[JList[Array[Byte]]],
+    dataType: DataType,
+    children: Seq[Expression]) extends Expression with Unevaluable with 
SparkLogging {
+
+  override def toString: String = s"PythonUDF#$name(${children.mkString(",")})"
+
+  override def nullable: Boolean = true
+}
+
+/**
+ * Extracts PythonUDFs from operators, rewriting the query plan so that the 
UDF can be evaluated
+ * alone in a batch.
+ *
+ * This has the limitation that the input to the Python UDF is not allowed 
include attributes from
+ * multiple child operators.
+ */
+private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] {
+  def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+    // Skip EvaluatePython nodes.
+    case plan: EvaluatePython => plan
+
+    case plan: LogicalPlan if plan.resolved =>
+      // Extract any PythonUDFs from the current operator.
+      val udfs = plan.expressions.flatMap(_.collect { case udf: PythonUDF => 
udf })
+      if (udfs.isEmpty) {
+        // If there aren't any, we are done.
+        plan
+      } else {
+        // Pick the UDF we are going to evaluate (TODO: Support evaluating 
multiple UDFs at a time)
+        // If there is more than one, we will add another evaluation operator 
in a subsequent pass.
+        udfs.find(_.resolved) match {
+          case Some(udf) =>
+            var evaluation: EvaluatePython = null
+
+            // Rewrite the child that has the input required for the UDF
+            val newChildren = plan.children.map { child =>
+              // Check to make sure that the UDF can be evaluated with only 
the input of this child.
+              // Other cases are disallowed as they are ambiguous or would 
require a cartesian
+              // product.
+              if (udf.references.subsetOf(child.outputSet)) {
+                evaluation = EvaluatePython(udf, child)
+                evaluation
+              } else if (udf.references.intersect(child.outputSet).nonEmpty) {
+                sys.error(s"Invalid PythonUDF $udf, requires attributes from 
more than one child.")
+              } else {
+                child
+              }
+            }
+
+            assert(evaluation != null, "Unable to evaluate PythonUDF.  Missing 
input attributes.")
+
+            // Trim away the new UDF value if it was only used for filtering 
or something.
+            logical.Project(
+              plan.output,
+              plan.transformExpressions {
+                case p: PythonUDF if p.fastEquals(udf) => 
evaluation.resultAttribute
+              }.withNewChildren(newChildren))
+
+          case None =>
+            // If there is no Python UDF that is resolved, skip this round.
+            plan
+        }
+      }
+  }
+}
+
+object EvaluatePython {
+  def apply(udf: PythonUDF, child: LogicalPlan): EvaluatePython =
+    new EvaluatePython(udf, child, AttributeReference("pythonUDF", 
udf.dataType)())
+
+  def takeAndServe(df: DataFrame, n: Int): Int = {
+    registerPicklers()
+    // This is an annoying hack - we should refactor the code so 
executeCollect and executeTake
+    // returns InternalRow rather than Row.
+    val converter = CatalystTypeConverters.createToCatalystConverter(df.schema)
+    val iter = new SerDeUtil.AutoBatchedPickler(df.take(n).iterator.map { row 
=>
+      EvaluatePython.toJava(converter(row).asInstanceOf[InternalRow], 
df.schema)
+    })
+    PythonRDD.serveIterator(iter, s"serve-DataFrame")
+  }
+
+  /**
+   * Helper for converting from Catalyst type to java type suitable for 
Pyrolite.
+   */
+  def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match {
+    case (null, _) => null
+
+    case (row: InternalRow, struct: StructType) =>
+      val values = new Array[Any](row.numFields)
+      var i = 0
+      while (i < row.numFields) {
+        values(i) = toJava(row.get(i, struct.fields(i).dataType), 
struct.fields(i).dataType)
+        i += 1
+      }
+      new GenericInternalRowWithSchema(values, struct)
+
+    case (a: ArrayData, array: ArrayType) =>
+      val values = new java.util.ArrayList[Any](a.numElements())
+      a.foreach(array.elementType, (_, e) => {
+        values.add(toJava(e, array.elementType))
+      })
+      values
+
+    case (map: MapData, mt: MapType) =>
+      val jmap = new java.util.HashMap[Any, Any](map.numElements())
+      map.foreach(mt.keyType, mt.valueType, (k, v) => {
+        jmap.put(toJava(k, mt.keyType), toJava(v, mt.valueType))
+      })
+      jmap
+
+    case (ud, udt: UserDefinedType[_]) => toJava(ud, udt.sqlType)
+
+    case (d: Decimal, _) => d.toJavaBigDecimal
+
+    case (s: UTF8String, StringType) => s.toString
+
+    case (other, _) => other
+  }
+
+  /**
+   * Converts `obj` to the type specified by the data type, or returns null if 
the type of obj is
+   * unexpected. Because Python doesn't enforce the type.
+   */
+  def fromJava(obj: Any, dataType: DataType): Any = (obj, dataType) match {
+    case (null, _) => null
+
+    case (c: Boolean, BooleanType) => c
+
+    case (c: Int, ByteType) => c.toByte
+    case (c: Long, ByteType) => c.toByte
+
+    case (c: Int, ShortType) => c.toShort
+    case (c: Long, ShortType) => c.toShort
+
+    case (c: Int, IntegerType) => c
+    case (c: Long, IntegerType) => c.toInt
+
+    case (c: Int, LongType) => c.toLong
+    case (c: Long, LongType) => c
+
+    case (c: Double, FloatType) => c.toFloat
+
+    case (c: Double, DoubleType) => c
+
+    case (c: java.math.BigDecimal, dt: DecimalType) => Decimal(c, 
dt.precision, dt.scale)
+
+    case (c: Int, DateType) => c
+
+    case (c: Long, TimestampType) => c
+
+    case (c: String, StringType) => UTF8String.fromString(c)
+    case (c, StringType) =>
+      // If we get here, c is not a string. Call toString on it.
+      UTF8String.fromString(c.toString)
+
+    case (c: String, BinaryType) => c.getBytes("utf-8")
+    case (c, BinaryType) if c.getClass.isArray && 
c.getClass.getComponentType.getName == "byte" => c
+
+    case (c: java.util.List[_], ArrayType(elementType, _)) =>
+      new GenericArrayData(c.asScala.map { e => fromJava(e, 
elementType)}.toArray)
+
+    case (c, ArrayType(elementType, _)) if c.getClass.isArray =>
+      new GenericArrayData(c.asInstanceOf[Array[_]].map(e => fromJava(e, 
elementType)))
+
+    case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) =>
+      val keyValues = c.asScala.toSeq
+      val keys = keyValues.map(kv => fromJava(kv._1, keyType)).toArray
+      val values = keyValues.map(kv => fromJava(kv._2, valueType)).toArray
+      ArrayBasedMapData(keys, values)
+
+    case (c, StructType(fields)) if c.getClass.isArray =>
+      new GenericInternalRow(c.asInstanceOf[Array[_]].zip(fields).map {
+        case (e, f) => fromJava(e, f.dataType)
+      })
+
+    case (_, udt: UserDefinedType[_]) => fromJava(obj, udt.sqlType)
+
+    // all other unexpected type should be null, or we will have runtime 
exception
+    // TODO(davies): we could improve this by try to cast the object to 
expected type
+    case (c, _) => null
+  }
+
+
+  private val module = "pyspark.sql.types"
+
+  /**
+   * Pickler for StructType
+   */
+  private class StructTypePickler extends IObjectPickler {
+
+    private val cls = classOf[StructType]
+
+    def register(): Unit = {
+      Pickler.registerCustomPickler(cls, this)
+    }
+
+    def pickle(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
+      out.write(Opcodes.GLOBAL)
+      out.write((module + "\n" + "_parse_datatype_json_string" + 
"\n").getBytes("utf-8"))
+      val schema = obj.asInstanceOf[StructType]
+      pickler.save(schema.json)
+      out.write(Opcodes.TUPLE1)
+      out.write(Opcodes.REDUCE)
+    }
+  }
+
+  /**
+   * Pickler for InternalRow
+   */
+  private class RowPickler extends IObjectPickler {
+
+    private val cls = classOf[GenericInternalRowWithSchema]
+
+    // register this to Pickler and Unpickler
+    def register(): Unit = {
+      Pickler.registerCustomPickler(this.getClass, this)
+      Pickler.registerCustomPickler(cls, this)
+    }
+
+    def pickle(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
+      if (obj == this) {
+        out.write(Opcodes.GLOBAL)
+        out.write((module + "\n" + "_create_row_inbound_converter" + 
"\n").getBytes("utf-8"))
+      } else {
+        // it will be memorized by Pickler to save some bytes
+        pickler.save(this)
+        val row = obj.asInstanceOf[GenericInternalRowWithSchema]
+        // schema should always be same object for memoization
+        pickler.save(row.schema)
+        out.write(Opcodes.TUPLE1)
+        out.write(Opcodes.REDUCE)
+
+        out.write(Opcodes.MARK)
+        var i = 0
+        while (i < row.values.size) {
+          pickler.save(row.values(i))
+          i += 1
+        }
+        out.write(Opcodes.TUPLE)
+        out.write(Opcodes.REDUCE)
+      }
+    }
+  }
+
+  private[this] var registered = false
+  /**
+   * This should be called before trying to serialize any above classes un 
cluster mode,
+   * this should be put in the closure
+   */
+  def registerPicklers(): Unit = {
+    synchronized {
+      if (!registered) {
+        SerDeUtil.initialize()
+        new StructTypePickler().register()
+        new RowPickler().register()
+        registered = true
+      }
+    }
+  }
+
+  /**
+   * Convert an RDD of Java objects to an RDD of serialized Python objects, 
that is usable by
+   * PySpark.
+   */
+  def javaToPython(rdd: RDD[Any]): RDD[Array[Byte]] = {
+    rdd.mapPartitions { iter =>
+      registerPicklers()  // let it called in executor
+      new SerDeUtil.AutoBatchedPickler(iter)
+    }
+  }
+}
+
+/**
+ * :: DeveloperApi ::
+ * Evaluates a [[PythonUDF]], appending the result to the end of the input 
tuple.
+ */
+@DeveloperApi
+case class EvaluatePython(
+    udf: PythonUDF,
+    child: LogicalPlan,
+    resultAttribute: AttributeReference)
+  extends logical.UnaryNode {
+
+  def output: Seq[Attribute] = child.output :+ resultAttribute
+
+  // References should not include the produced attribute.
+  override def references: AttributeSet = udf.references
+}
+
+/**
+ * :: DeveloperApi ::
+ * Uses PythonRDD to evaluate a [[PythonUDF]], one partition of tuples at a 
time.
+ *
+ * Python evaluation works by sending the necessary (projected) input data via 
a socket to an
+ * external Python process, and combine the result from the Python process 
with the original row.
+ *
+ * For each row we send to Python, we also put it in a queue. For each output 
row from Python,
+ * we drain the queue to find the original input row. Note that if the Python 
process is way too
+ * slow, this could lead to the queue growing unbounded and eventually run out 
of memory.
+ */
+@DeveloperApi
+case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], 
child: SparkPlan)
+  extends SparkPlan {
+
+  def children: Seq[SparkPlan] = child :: Nil
+
+  override def outputsUnsafeRows: Boolean = false
+  override def canProcessUnsafeRows: Boolean = true
+  override def canProcessSafeRows: Boolean = true
+
+  protected override def doExecute(): RDD[InternalRow] = {
+    val inputRDD = child.execute().map(_.copy())
+    val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
+    val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", 
defaultValue = true)
+
+    inputRDD.mapPartitions { iter =>
+      EvaluatePython.registerPicklers()  // register pickler for Row
+
+      // The queue used to buffer input rows so we can drain it to
+      // combine input with output from Python.
+      val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]()
+
+      val pickle = new Pickler
+      val currentRow = newMutableProjection(udf.children, child.output)()
+      val fields = udf.children.map(_.dataType)
+      val schema = new StructType(fields.map(t => new StructField("", t, 
true)).toArray)
+
+      // Input iterator to Python: input rows are grouped so we send them in 
batches to Python.
+      // For each row, add it to the queue.
+      val inputIterator = iter.grouped(100).map { inputRows =>
+        val toBePickled = inputRows.map { row =>
+          queue.add(row)
+          EvaluatePython.toJava(currentRow(row), schema)
+        }.toArray
+        pickle.dumps(toBePickled)
+      }
+
+      val context = TaskContext.get()
+
+      // Output iterator for results from Python.
+      val outputIterator = new PythonRunner(
+        udf.command,
+        udf.envVars,
+        udf.pythonIncludes,
+        udf.pythonExec,
+        udf.pythonVer,
+        udf.broadcastVars,
+        udf.accumulator,
+        bufferSize,
+        reuseWorker
+      ).compute(inputIterator, context.partitionId(), context)
+
+      val unpickle = new Unpickler
+      val row = new GenericMutableRow(1)
+      val joined = new JoinedRow
+
+      outputIterator.flatMap { pickedResult =>
+        val unpickledBatch = unpickle.loads(pickedResult)
+        unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
+      }.map { result =>
+        row(0) = EvaluatePython.fromJava(result, udf.dataType)
+        joined(queue.poll(), row)
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/99522177/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
deleted file mode 100644
index c35c726..0000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
+++ /dev/null
@@ -1,405 +0,0 @@
-/*
-* Licensed to the Apache Software Foundation (ASF) under one or more
-* contributor license agreements.  See the NOTICE file distributed with
-* this work for additional information regarding copyright ownership.
-* The ASF licenses this file to You under the Apache License, Version 2.0
-* (the "License"); you may not use this file except in compliance with
-* the License.  You may obtain a copy of the License at
-*
-*    http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*/
-
-package org.apache.spark.sql.execution
-
-import java.io.OutputStream
-import java.util.{List => JList, Map => JMap}
-
-import scala.collection.JavaConverters._
-
-import net.razorvine.pickle._
-
-import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.api.python.{PythonRunner, PythonBroadcast, PythonRDD, 
SerDeUtil}
-import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
-import org.apache.spark.{Logging => SparkLogging, TaskContext, Accumulator}
-
-/**
- * A serialized version of a Python lambda function.  Suitable for use in a 
[[PythonRDD]].
- */
-private[spark] case class PythonUDF(
-    name: String,
-    command: Array[Byte],
-    envVars: JMap[String, String],
-    pythonIncludes: JList[String],
-    pythonExec: String,
-    pythonVer: String,
-    broadcastVars: JList[Broadcast[PythonBroadcast]],
-    accumulator: Accumulator[JList[Array[Byte]]],
-    dataType: DataType,
-    children: Seq[Expression]) extends Expression with Unevaluable with 
SparkLogging {
-
-  override def toString: String = s"PythonUDF#$name(${children.mkString(",")})"
-
-  override def nullable: Boolean = true
-}
-
-/**
- * Extracts PythonUDFs from operators, rewriting the query plan so that the 
UDF can be evaluated
- * alone in a batch.
- *
- * This has the limitation that the input to the Python UDF is not allowed 
include attributes from
- * multiple child operators.
- */
-private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] {
-  def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
-    // Skip EvaluatePython nodes.
-    case plan: EvaluatePython => plan
-
-    case plan: LogicalPlan if plan.resolved =>
-      // Extract any PythonUDFs from the current operator.
-      val udfs = plan.expressions.flatMap(_.collect { case udf: PythonUDF => 
udf })
-      if (udfs.isEmpty) {
-        // If there aren't any, we are done.
-        plan
-      } else {
-        // Pick the UDF we are going to evaluate (TODO: Support evaluating 
multiple UDFs at a time)
-        // If there is more than one, we will add another evaluation operator 
in a subsequent pass.
-        udfs.find(_.resolved) match {
-          case Some(udf) =>
-            var evaluation: EvaluatePython = null
-
-            // Rewrite the child that has the input required for the UDF
-            val newChildren = plan.children.map { child =>
-              // Check to make sure that the UDF can be evaluated with only 
the input of this child.
-              // Other cases are disallowed as they are ambiguous or would 
require a cartesian
-              // product.
-              if (udf.references.subsetOf(child.outputSet)) {
-                evaluation = EvaluatePython(udf, child)
-                evaluation
-              } else if (udf.references.intersect(child.outputSet).nonEmpty) {
-                sys.error(s"Invalid PythonUDF $udf, requires attributes from 
more than one child.")
-              } else {
-                child
-              }
-            }
-
-            assert(evaluation != null, "Unable to evaluate PythonUDF.  Missing 
input attributes.")
-
-            // Trim away the new UDF value if it was only used for filtering 
or something.
-            logical.Project(
-              plan.output,
-              plan.transformExpressions {
-                case p: PythonUDF if p.fastEquals(udf) => 
evaluation.resultAttribute
-              }.withNewChildren(newChildren))
-
-          case None =>
-            // If there is no Python UDF that is resolved, skip this round.
-            plan
-        }
-      }
-  }
-}
-
-object EvaluatePython {
-  def apply(udf: PythonUDF, child: LogicalPlan): EvaluatePython =
-    new EvaluatePython(udf, child, AttributeReference("pythonUDF", 
udf.dataType)())
-
-  /**
-   * Helper for converting from Catalyst type to java type suitable for 
Pyrolite.
-   */
-  def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match {
-    case (null, _) => null
-
-    case (row: InternalRow, struct: StructType) =>
-      val values = new Array[Any](row.numFields)
-      var i = 0
-      while (i < row.numFields) {
-        values(i) = toJava(row.get(i, struct.fields(i).dataType), 
struct.fields(i).dataType)
-        i += 1
-      }
-      new GenericInternalRowWithSchema(values, struct)
-
-    case (a: ArrayData, array: ArrayType) =>
-      val values = new java.util.ArrayList[Any](a.numElements())
-      a.foreach(array.elementType, (_, e) => {
-        values.add(toJava(e, array.elementType))
-      })
-      values
-
-    case (map: MapData, mt: MapType) =>
-      val jmap = new java.util.HashMap[Any, Any](map.numElements())
-      map.foreach(mt.keyType, mt.valueType, (k, v) => {
-        jmap.put(toJava(k, mt.keyType), toJava(v, mt.valueType))
-      })
-      jmap
-
-    case (ud, udt: UserDefinedType[_]) => toJava(ud, udt.sqlType)
-
-    case (d: Decimal, _) => d.toJavaBigDecimal
-
-    case (s: UTF8String, StringType) => s.toString
-
-    case (other, _) => other
-  }
-
-  /**
-   * Converts `obj` to the type specified by the data type, or returns null if 
the type of obj is
-   * unexpected. Because Python doesn't enforce the type.
-   */
-  def fromJava(obj: Any, dataType: DataType): Any = (obj, dataType) match {
-    case (null, _) => null
-
-    case (c: Boolean, BooleanType) => c
-
-    case (c: Int, ByteType) => c.toByte
-    case (c: Long, ByteType) => c.toByte
-
-    case (c: Int, ShortType) => c.toShort
-    case (c: Long, ShortType) => c.toShort
-
-    case (c: Int, IntegerType) => c
-    case (c: Long, IntegerType) => c.toInt
-
-    case (c: Int, LongType) => c.toLong
-    case (c: Long, LongType) => c
-
-    case (c: Double, FloatType) => c.toFloat
-
-    case (c: Double, DoubleType) => c
-
-    case (c: java.math.BigDecimal, dt: DecimalType) => Decimal(c, 
dt.precision, dt.scale)
-
-    case (c: Int, DateType) => c
-
-    case (c: Long, TimestampType) => c
-
-    case (c: String, StringType) => UTF8String.fromString(c)
-    case (c, StringType) =>
-      // If we get here, c is not a string. Call toString on it.
-      UTF8String.fromString(c.toString)
-
-    case (c: String, BinaryType) => c.getBytes("utf-8")
-    case (c, BinaryType) if c.getClass.isArray && 
c.getClass.getComponentType.getName == "byte" => c
-
-    case (c: java.util.List[_], ArrayType(elementType, _)) =>
-      new GenericArrayData(c.asScala.map { e => fromJava(e, 
elementType)}.toArray)
-
-    case (c, ArrayType(elementType, _)) if c.getClass.isArray =>
-      new GenericArrayData(c.asInstanceOf[Array[_]].map(e => fromJava(e, 
elementType)))
-
-    case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) =>
-      val keyValues = c.asScala.toSeq
-      val keys = keyValues.map(kv => fromJava(kv._1, keyType)).toArray
-      val values = keyValues.map(kv => fromJava(kv._2, valueType)).toArray
-      ArrayBasedMapData(keys, values)
-
-    case (c, StructType(fields)) if c.getClass.isArray =>
-      new GenericInternalRow(c.asInstanceOf[Array[_]].zip(fields).map {
-        case (e, f) => fromJava(e, f.dataType)
-      })
-
-    case (_, udt: UserDefinedType[_]) => fromJava(obj, udt.sqlType)
-
-    // all other unexpected type should be null, or we will have runtime 
exception
-    // TODO(davies): we could improve this by try to cast the object to 
expected type
-    case (c, _) => null
-  }
-
-
-  private val module = "pyspark.sql.types"
-
-  /**
-   * Pickler for StructType
-   */
-  private class StructTypePickler extends IObjectPickler {
-
-    private val cls = classOf[StructType]
-
-    def register(): Unit = {
-      Pickler.registerCustomPickler(cls, this)
-    }
-
-    def pickle(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
-      out.write(Opcodes.GLOBAL)
-      out.write((module + "\n" + "_parse_datatype_json_string" + 
"\n").getBytes("utf-8"))
-      val schema = obj.asInstanceOf[StructType]
-      pickler.save(schema.json)
-      out.write(Opcodes.TUPLE1)
-      out.write(Opcodes.REDUCE)
-    }
-  }
-
-  /**
-   * Pickler for InternalRow
-   */
-  private class RowPickler extends IObjectPickler {
-
-    private val cls = classOf[GenericInternalRowWithSchema]
-
-    // register this to Pickler and Unpickler
-    def register(): Unit = {
-      Pickler.registerCustomPickler(this.getClass, this)
-      Pickler.registerCustomPickler(cls, this)
-    }
-
-    def pickle(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
-      if (obj == this) {
-        out.write(Opcodes.GLOBAL)
-        out.write((module + "\n" + "_create_row_inbound_converter" + 
"\n").getBytes("utf-8"))
-      } else {
-        // it will be memorized by Pickler to save some bytes
-        pickler.save(this)
-        val row = obj.asInstanceOf[GenericInternalRowWithSchema]
-        // schema should always be same object for memoization
-        pickler.save(row.schema)
-        out.write(Opcodes.TUPLE1)
-        out.write(Opcodes.REDUCE)
-
-        out.write(Opcodes.MARK)
-        var i = 0
-        while (i < row.values.size) {
-          pickler.save(row.values(i))
-          i += 1
-        }
-        out.write(Opcodes.TUPLE)
-        out.write(Opcodes.REDUCE)
-      }
-    }
-  }
-
-  private[this] var registered = false
-  /**
-   * This should be called before trying to serialize any above classes un 
cluster mode,
-   * this should be put in the closure
-   */
-  def registerPicklers(): Unit = {
-    synchronized {
-      if (!registered) {
-        SerDeUtil.initialize()
-        new StructTypePickler().register()
-        new RowPickler().register()
-        registered = true
-      }
-    }
-  }
-
-  /**
-   * Convert an RDD of Java objects to an RDD of serialized Python objects, 
that is usable by
-   * PySpark.
-   */
-  def javaToPython(rdd: RDD[Any]): RDD[Array[Byte]] = {
-    rdd.mapPartitions { iter =>
-      registerPicklers()  // let it called in executor
-      new SerDeUtil.AutoBatchedPickler(iter)
-    }
-  }
-}
-
-/**
- * :: DeveloperApi ::
- * Evaluates a [[PythonUDF]], appending the result to the end of the input 
tuple.
- */
-@DeveloperApi
-case class EvaluatePython(
-    udf: PythonUDF,
-    child: LogicalPlan,
-    resultAttribute: AttributeReference)
-  extends logical.UnaryNode {
-
-  def output: Seq[Attribute] = child.output :+ resultAttribute
-
-  // References should not include the produced attribute.
-  override def references: AttributeSet = udf.references
-}
-
-/**
- * :: DeveloperApi ::
- * Uses PythonRDD to evaluate a [[PythonUDF]], one partition of tuples at a 
time.
- *
- * Python evaluation works by sending the necessary (projected) input data via 
a socket to an
- * external Python process, and combine the result from the Python process 
with the original row.
- *
- * For each row we send to Python, we also put it in a queue. For each output 
row from Python,
- * we drain the queue to find the original input row. Note that if the Python 
process is way too
- * slow, this could lead to the queue growing unbounded and eventually run out 
of memory.
- */
-@DeveloperApi
-case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], 
child: SparkPlan)
-  extends SparkPlan {
-
-  def children: Seq[SparkPlan] = child :: Nil
-
-  override def outputsUnsafeRows: Boolean = false
-  override def canProcessUnsafeRows: Boolean = true
-  override def canProcessSafeRows: Boolean = true
-
-  protected override def doExecute(): RDD[InternalRow] = {
-    val inputRDD = child.execute().map(_.copy())
-    val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
-    val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", 
defaultValue = true)
-
-    inputRDD.mapPartitions { iter =>
-      EvaluatePython.registerPicklers()  // register pickler for Row
-
-      // The queue used to buffer input rows so we can drain it to
-      // combine input with output from Python.
-      val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]()
-
-      val pickle = new Pickler
-      val currentRow = newMutableProjection(udf.children, child.output)()
-      val fields = udf.children.map(_.dataType)
-      val schema = new StructType(fields.map(t => new StructField("", t, 
true)).toArray)
-
-      // Input iterator to Python: input rows are grouped so we send them in 
batches to Python.
-      // For each row, add it to the queue.
-      val inputIterator = iter.grouped(100).map { inputRows =>
-        val toBePickled = inputRows.map { row =>
-          queue.add(row)
-          EvaluatePython.toJava(currentRow(row), schema)
-        }.toArray
-        pickle.dumps(toBePickled)
-      }
-
-      val context = TaskContext.get()
-
-      // Output iterator for results from Python.
-      val outputIterator = new PythonRunner(
-        udf.command,
-        udf.envVars,
-        udf.pythonIncludes,
-        udf.pythonExec,
-        udf.pythonVer,
-        udf.broadcastVars,
-        udf.accumulator,
-        bufferSize,
-        reuseWorker
-      ).compute(inputIterator, context.partitionId(), context)
-
-      val unpickle = new Unpickler
-      val row = new GenericMutableRow(1)
-      val joined = new JoinedRow
-
-      outputIterator.flatMap { pickedResult =>
-        val unpickledBatch = unpickle.loads(pickedResult)
-        unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
-      }.map { result =>
-        row(0) = EvaluatePython.fromJava(result, udf.dataType)
-        joined(queue.poll(), row)
-      }
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/99522177/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
index 2fdd798..963e603 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
@@ -39,22 +39,20 @@ private[sql] class ExamplePointUDT extends 
UserDefinedType[ExamplePoint] {
 
   override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT"
 
-  override def serialize(obj: Any): Seq[Double] = {
+  override def serialize(obj: Any): GenericArrayData = {
     obj match {
       case p: ExamplePoint =>
-        Seq(p.x, p.y)
+        val output = new Array[Any](2)
+        output(0) = p.x
+        output(1) = p.y
+        new GenericArrayData(output)
     }
   }
 
   override def deserialize(datum: Any): ExamplePoint = {
     datum match {
-      case values: Seq[_] =>
-        val xy = values.asInstanceOf[Seq[Double]]
-        assert(xy.length == 2)
-        new ExamplePoint(xy(0), xy(1))
-      case values: util.ArrayList[_] =>
-        val xy = values.asInstanceOf[util.ArrayList[Double]].asScala
-        new ExamplePoint(xy(0), xy(1))
+      case values: ArrayData =>
+        new ExamplePoint(values.getDouble(0), values.getDouble(1))
     }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to