Github user HyukjinKwon commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19349#discussion_r141244651
  
    --- Diff: 
core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala ---
    @@ -0,0 +1,429 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.api.python
    +
    +import java.io._
    +import java.net._
    +import java.nio.charset.StandardCharsets
    +import java.util.concurrent.atomic.AtomicBoolean
    +
    +import scala.collection.JavaConverters._
    +
    +import org.apache.spark._
    +import org.apache.spark.internal.Logging
    +import org.apache.spark.util._
    +
    +
    +/**
    + * Enumerate the type of command that will be sent to the Python worker
    + */
    +private[spark] object PythonEvalType {
    +  val NON_UDF = 0
    +  val SQL_BATCHED_UDF = 1
    +  val SQL_PANDAS_UDF = 2
    +}
    +
    +/**
    + * A helper class to run Python mapPartition/UDFs in Spark.
    + *
    + * funcs is a list of independent Python functions, each one of them is a 
list of chained Python
    + * functions (from bottom to top).
    + */
    +private[spark] abstract class BasePythonRunner[IN, OUT](
    +    funcs: Seq[ChainedPythonFunctions],
    +    bufferSize: Int,
    +    reuseWorker: Boolean,
    +    evalType: Int,
    +    argOffsets: Array[Array[Int]])
    +  extends Logging {
    +
    +  require(funcs.length == argOffsets.length, "argOffsets should have the 
same length as funcs")
    +
    +  // All the Python functions should have the same exec, version and 
envvars.
    +  protected val envVars = funcs.head.funcs.head.envVars
    +  protected val pythonExec = funcs.head.funcs.head.pythonExec
    +  protected val pythonVer = funcs.head.funcs.head.pythonVer
    +
    +  // TODO: support accumulator in multiple UDF
    +  protected val accumulator = funcs.head.funcs.head.accumulator
    +
    +  def compute(
    +      inputIterator: Iterator[IN],
    +      partitionIndex: Int,
    +      context: TaskContext): Iterator[OUT] = {
    +    val startTime = System.currentTimeMillis
    +    val env = SparkEnv.get
    +    val localdir = env.blockManager.diskBlockManager.localDirs.map(f => 
f.getPath()).mkString(",")
    +    envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor 
thread
    +    if (reuseWorker) {
    +      envVars.put("SPARK_REUSE_WORKER", "1")
    +    }
    +    val worker: Socket = env.createPythonWorker(pythonExec, 
envVars.asScala.toMap)
    +    // Whether is the worker released into idle pool
    +    val released = new AtomicBoolean(false)
    +
    +    // Start a thread to feed the process input from our parent's iterator
    +    val writerThread = newWriterThread(env, worker, inputIterator, 
partitionIndex, context)
    +
    +    context.addTaskCompletionListener { context =>
    +      writerThread.shutdownOnTaskCompletion()
    +      if (!reuseWorker || !released.get) {
    +        try {
    +          worker.close()
    +        } catch {
    +          case e: Exception =>
    +            logWarning("Failed to close worker socket", e)
    +        }
    +      }
    +    }
    +
    +    writerThread.start()
    +    new MonitorThread(env, worker, context).start()
    +
    +    // Return an iterator that read lines from the process's stdout
    +    val stream = new DataInputStream(new 
BufferedInputStream(worker.getInputStream, bufferSize))
    +
    +    val stdoutIterator = newReaderIterator(
    +      stream, writerThread, startTime, env, worker, released, context)
    +    new InterruptibleIterator(context, stdoutIterator)
    +  }
    +
    +  protected def newWriterThread(
    +      env: SparkEnv,
    +      worker: Socket,
    +      inputIterator: Iterator[IN],
    +      partitionIndex: Int,
    +      context: TaskContext): WriterThread
    +
    +  protected def newReaderIterator(
    +      stream: DataInputStream,
    +      writerThread: WriterThread,
    +      startTime: Long,
    +      env: SparkEnv,
    +      worker: Socket,
    +      released: AtomicBoolean,
    +      context: TaskContext): Iterator[OUT]
    +
    +  /**
    +   * The thread responsible for writing the data from the PythonRDD's 
parent iterator to the
    +   * Python process.
    +   */
    +  abstract class WriterThread(
    +      env: SparkEnv,
    +      worker: Socket,
    +      inputIterator: Iterator[IN],
    +      partitionIndex: Int,
    +      context: TaskContext)
    +    extends Thread(s"stdout writer for $pythonExec") {
    +
    +    @volatile private var _exception: Exception = null
    +
    +    private val pythonIncludes = 
funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet
    +    private val broadcastVars = 
funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala))
    +
    +    setDaemon(true)
    +
    +    /** Contains the exception thrown while writing the parent iterator to 
the Python process. */
    +    def exception: Option[Exception] = Option(_exception)
    +
    +    /** Terminates the writer thread, ignoring any exceptions that may 
occur due to cleanup. */
    +    def shutdownOnTaskCompletion() {
    +      assert(context.isCompleted)
    +      this.interrupt()
    +    }
    +
    +    def writeCommand(dataOut: DataOutputStream): Unit
    +    def writeIteratorToStream(dataOut: DataOutputStream): Unit
    --- End diff --
    
    I'd leave few comments for methods that should be implemented here.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to