This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 3725b13 [SPARK-26923][SQL][R] Refactor ArrowRRunner and RRunner to share one BaseRRunner 3725b13 is described below commit 3725b1324f731d57dc776c256bc1a100ec9e6cd0 Author: Hyukjin Kwon <gurwls...@apache.org> AuthorDate: Tue Mar 12 08:45:29 2019 +0900 [SPARK-26923][SQL][R] Refactor ArrowRRunner and RRunner to share one BaseRRunner ## What changes were proposed in this pull request? This PR proposes to have one base R runner. In the high level, Previously, it had `ArrowRRunner` and it inherited `RRunner`: ``` └── RRunner └── ArrowRRunner ``` After this PR, now it has a `BaseRRunner`, and `ArrowRRunner` and `RRunner` inherit `BaseRRunner`: ``` └── BaseRRunner ├── ArrowRRunner └── RRunner ``` This way is consistent with Python's. In more details, see below: ```scala class BaseRRunner[IN, OUT] { def compute: Iterator[OUT] = { ... newWriterThread(...).start() ... newReaderIterator(...) ... } // Make a thread that writes data from JVM to R process abstract protected def newWriterThread(..., iter: Iterator[IN], ...): WriterThread // Make an iterator that reads data from the R process to JVM abstract protected def newReaderIterator(...): ReaderIterator abstract class WriterThread(..., iter: Iterator[IN], ...) extends Thread { override def run(): Unit { ... writeIteratorToStream(...) ... } // Actually writing logic to the socket stream. abstract protected def writeIteratorToStream(dataOut: DataOutputStream): Unit } abstract class ReaderIterator extends Iterator[OUT] { override def hasNext(): Boolean = { ... read(...) ... } override def next(): OUT = { ... hasNext() ... } // Actually reading logic from the socket stream. abstract protected def read(...): OUT } } ``` ```scala case [Arrow]RRunner extends BaseRRunner { override def newWriterThread(...) { new WriterThread(...) { override def writeIteratorToStream(...) { ... } } } override def newReaderIterator(...) { new ReaderIterator(...) { override def read(...) { ... } } } } ``` ## How was this patch tested? Manually tested and existing tests should cover. Closes #23977 from HyukjinKwon/SPARK-26923. Authored-by: Hyukjin Kwon <gurwls...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../api/r/{RRunner.scala => BaseRRunner.scala} | 302 +++++-------- .../main/scala/org/apache/spark/api/r/RRDD.scala | 2 +- .../scala/org/apache/spark/api/r/RRunner.scala | 478 +++++---------------- .../org/apache/spark/sql/execution/objects.scala | 24 +- .../spark/sql/execution/r/ArrowRRunner.scala | 140 +++--- .../sql/execution/r/MapPartitionsRWrapper.scala | 2 +- 6 files changed, 309 insertions(+), 639 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/BaseRRunner.scala similarity index 55% copy from core/src/main/scala/org/apache/spark/api/r/RRunner.scala copy to core/src/main/scala/org/apache/spark/api/r/BaseRRunner.scala index 971d11f..f96c521 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/r/BaseRRunner.scala @@ -34,31 +34,23 @@ import org.apache.spark.util.Utils /** * A helper class to run R UDFs in Spark. */ -private[spark] class RRunner[U]( +private[spark] abstract class BaseRRunner[IN, OUT]( func: Array[Byte], deserializer: String, serializer: String, packageNames: Array[Byte], broadcastVars: Array[Broadcast[Object]], - numPartitions: Int = -1, - isDataFrame: Boolean = false, - colNames: Array[String] = null, - mode: Int = RRunnerModes.RDD) + numPartitions: Int, + isDataFrame: Boolean, + colNames: Array[String], + mode: Int) extends Logging { protected var bootTime: Double = _ - private var dataStream: DataInputStream = _ - val readData = numPartitions match { - case -1 => - serializer match { - case SerializationFormats.STRING => readStringData _ - case _ => readByteArrayData _ - } - case _ => readShuffledData _ - } + protected var dataStream: DataInputStream = _ def compute( - inputIterator: Iterator[_], - partitionIndex: Int): Iterator[U] = { + inputIterator: Iterator[IN], + partitionIndex: Int): Iterator[OUT] = { // Timing start bootTime = System.currentTimeMillis / 1000.0 @@ -68,7 +60,7 @@ private[spark] class RRunner[U]( // The stdout/stderr is shared by multiple tasks, because we use one daemon // to launch child process as worker. - val errThread = RRunner.createRWorker(listenPort) + val errThread = BaseRRunner.createRWorker(listenPort) // We use two sockets to separate input and output, then it's easy to manage // the lifecycle of them to avoid deadlock. @@ -78,12 +70,12 @@ private[spark] class RRunner[U]( serverSocket.setSoTimeout(10000) dataStream = try { val inSocket = serverSocket.accept() - RRunner.authHelper.authClient(inSocket) - startStdinThread(inSocket.getOutputStream(), inputIterator, partitionIndex) + BaseRRunner.authHelper.authClient(inSocket) + newWriterThread(inSocket.getOutputStream(), inputIterator, partitionIndex).start() // the socket used to receive the output of task val outSocket = serverSocket.accept() - RRunner.authHelper.authClient(outSocket) + BaseRRunner.authHelper.authClient(outSocket) val inputStream = new BufferedInputStream(outSocket.getInputStream) new DataInputStream(inputStream) } finally { @@ -98,197 +90,127 @@ private[spark] class RRunner[U]( } } + /** + * Creates an iterator that reads data from R process. + */ protected def newReaderIterator( - dataStream: DataInputStream, errThread: BufferedStreamThread): Iterator[U] = { - new Iterator[U] { - def next(): U = { - val obj = _nextObj - if (hasNext()) { - _nextObj = read() - } - obj - } - - private var _nextObj = read() + dataStream: DataInputStream, errThread: BufferedStreamThread): ReaderIterator - def hasNext(): Boolean = { - val hasMore = _nextObj != null - if (!hasMore) { - dataStream.close() - } - hasMore + /** + * Start a thread to write RDD data to the R process. + */ + protected def newWriterThread( + output: OutputStream, + iter: Iterator[IN], + partitionIndex: Int): WriterThread + + abstract class ReaderIterator( + stream: DataInputStream, + errThread: BufferedStreamThread) + extends Iterator[OUT] { + + private var nextObj: OUT = _ + // eos should be marked as true when the stream is ended. + protected var eos = false + + override def hasNext: Boolean = nextObj != null || { + if (!eos) { + nextObj = read() + hasNext + } else { + false } } - } - protected def writeData( - dataOut: DataOutputStream, - printOut: PrintStream, - iter: Iterator[_]): Unit = { - def writeElem(elem: Any): Unit = { - if (deserializer == SerializationFormats.BYTE) { - val elemArr = elem.asInstanceOf[Array[Byte]] - dataOut.writeInt(elemArr.length) - dataOut.write(elemArr) - } else if (deserializer == SerializationFormats.ROW) { - dataOut.write(elem.asInstanceOf[Array[Byte]]) - } else if (deserializer == SerializationFormats.STRING) { - // write string(for StringRRDD) - // scalastyle:off println - printOut.println(elem) - // scalastyle:on println + override def next(): OUT = { + if (hasNext) { + val obj = nextObj + nextObj = null.asInstanceOf[OUT] + obj + } else { + Iterator.empty.next() } } - for (elem <- iter) { - elem match { - case (key, innerIter: Iterator[_]) => - for (innerElem <- innerIter) { - writeElem(innerElem) - } - // Writes key which can be used as a boundary in group-aggregate - dataOut.writeByte('r') - writeElem(key) - case (key, value) => - writeElem(key) - writeElem(value) - case _ => - writeElem(elem) - } - } + /** + * Reads next object from the stream. + * When the stream reaches end of data, needs to process the following sections, + * and then returns null. + */ + protected def read(): OUT } /** - * Start a thread to write RDD data to the R process. + * The thread responsible for writing the iterator to the R process. */ - private def startStdinThread( + abstract class WriterThread( output: OutputStream, - iter: Iterator[_], - partitionIndex: Int): Unit = { - val env = SparkEnv.get - val taskContext = TaskContext.get() - val bufferSize = System.getProperty(BUFFER_SIZE.key, - BUFFER_SIZE.defaultValueString).toInt - val stream = new BufferedOutputStream(output, bufferSize) - - new Thread("writer for R") { - override def run(): Unit = { - try { - SparkEnv.set(env) - TaskContext.setTaskContext(taskContext) - val dataOut = new DataOutputStream(stream) - dataOut.writeInt(partitionIndex) - - SerDe.writeString(dataOut, deserializer) - SerDe.writeString(dataOut, serializer) - - dataOut.writeInt(packageNames.length) - dataOut.write(packageNames) - - dataOut.writeInt(func.length) - dataOut.write(func) - - dataOut.writeInt(broadcastVars.length) - broadcastVars.foreach { broadcast => - // TODO(shivaram): Read a Long in R to avoid this cast - dataOut.writeInt(broadcast.id.toInt) - // TODO: Pass a byte array from R to avoid this cast ? - val broadcastByteArr = broadcast.value.asInstanceOf[Array[Byte]] - dataOut.writeInt(broadcastByteArr.length) - dataOut.write(broadcastByteArr) - } + iter: Iterator[IN], + partitionIndex: Int) + extends Thread("writer for R") { - dataOut.writeInt(numPartitions) - dataOut.writeInt(mode) - - if (isDataFrame) { - SerDe.writeObject(dataOut, colNames, jvmObjectTracker = null) - } - - if (!iter.hasNext) { - dataOut.writeInt(0) - } else { - dataOut.writeInt(1) - } - - val printOut = new PrintStream(stream) + private val env = SparkEnv.get + private val taskContext = TaskContext.get() + private val bufferSize = System.getProperty(BUFFER_SIZE.key, + BUFFER_SIZE.defaultValueString).toInt + private val stream = new BufferedOutputStream(output, bufferSize) + protected lazy val dataOut = new DataOutputStream(stream) + protected lazy val printOut = new PrintStream(stream) + + override def run(): Unit = { + try { + SparkEnv.set(env) + TaskContext.setTaskContext(taskContext) + dataOut.writeInt(partitionIndex) + + SerDe.writeString(dataOut, deserializer) + SerDe.writeString(dataOut, serializer) + + dataOut.writeInt(packageNames.length) + dataOut.write(packageNames) + + dataOut.writeInt(func.length) + dataOut.write(func) + + dataOut.writeInt(broadcastVars.length) + broadcastVars.foreach { broadcast => + // TODO(shivaram): Read a Long in R to avoid this cast + dataOut.writeInt(broadcast.id.toInt) + // TODO: Pass a byte array from R to avoid this cast ? + val broadcastByteArr = broadcast.value.asInstanceOf[Array[Byte]] + dataOut.writeInt(broadcastByteArr.length) + dataOut.write(broadcastByteArr) + } - writeData(dataOut, printOut, iter) + dataOut.writeInt(numPartitions) + dataOut.writeInt(mode) - stream.flush() - } catch { - // TODO: We should propagate this error to the task thread - case e: Exception => - logError("R Writer thread got an exception", e) - } finally { - Try(output.close()) + if (isDataFrame) { + SerDe.writeObject(dataOut, colNames, jvmObjectTracker = null) } - } - }.start() - } - private def read(): U = { - try { - val length = dataStream.readInt() - - length match { - case SpecialLengths.TIMING_DATA => - // Timing data from R worker - val boot = dataStream.readDouble - bootTime - val init = dataStream.readDouble - val broadcast = dataStream.readDouble - val input = dataStream.readDouble - val compute = dataStream.readDouble - val output = dataStream.readDouble - logInfo( - ("Times: boot = %.3f s, init = %.3f s, broadcast = %.3f s, " + - "read-input = %.3f s, compute = %.3f s, write-output = %.3f s, " + - "total = %.3f s").format( - boot, - init, - broadcast, - input, - compute, - output, - boot + init + broadcast + input + compute + output)) - read() - case length if length >= 0 => - readData(length).asInstanceOf[U] - } - } catch { - case eof: EOFException => - throw new SparkException("R worker exited unexpectedly (cranshed)", eof) - } - } + if (!iter.hasNext) { + dataOut.writeInt(0) + } else { + dataOut.writeInt(1) + } - private def readShuffledData(length: Int): (Int, Array[Byte]) = { - length match { - case length if length == 2 => - val hashedKey = dataStream.readInt() - val contentPairsLength = dataStream.readInt() - val contentPairs = new Array[Byte](contentPairsLength) - dataStream.readFully(contentPairs) - (hashedKey, contentPairs) - case _ => null - } - } + writeIteratorToStream(dataOut) - protected def readByteArrayData(length: Int): Array[Byte] = { - length match { - case length if length > 0 => - val obj = new Array[Byte](length) - dataStream.readFully(obj) - obj - case _ => null + stream.flush() + } catch { + // TODO: We should propagate this error to the task thread + case e: Exception => + logError("R Writer thread got an exception", e) + } finally { + Try(output.close()) + } } - } - private def readStringData(length: Int): String = { - length match { - case length if length > 0 => - SerDe.readStringBytes(dataStream, length) - case _ => null - } + /** + * Writes input data to the stream connected to the R worker. + */ + protected def writeIteratorToStream(dataOut: DataOutputStream): Unit } } @@ -327,7 +249,7 @@ private[spark] class BufferedStreamThread( } } -private[r] object RRunner { +private[r] object BaseRRunner { // Because forking processes from Java is expensive, we prefer to launch // a single R daemon (daemon.R) and tell it to fork new workers for our tasks. // This daemon currently only works on UNIX-based systems now, so we should diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 4a59c3e..07f8405 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -43,7 +43,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( override def getPartitions: Array[Partition] = parent.partitions override def compute(partition: Partition, context: TaskContext): Iterator[U] = { - val runner = new RRunner[U]( + val runner = new RRunner[T, U]( func, deserializer, serializer, packageNames, broadcastVars, numPartitions) // The parent may be also an RRDD, so we should launch it first. diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala index 971d11f..0327386 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala @@ -18,23 +18,14 @@ package org.apache.spark.api.r import java.io._ -import java.net.{InetAddress, ServerSocket} -import java.util.Arrays - -import scala.io.Source -import scala.util.Try import org.apache.spark._ import org.apache.spark.broadcast.Broadcast -import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.BUFFER_SIZE -import org.apache.spark.internal.config.R._ -import org.apache.spark.util.Utils /** * A helper class to run R UDFs in Spark. */ -private[spark] class RRunner[U]( +private[spark] class RRunner[IN, OUT]( func: Array[Byte], deserializer: String, serializer: String, @@ -44,380 +35,149 @@ private[spark] class RRunner[U]( isDataFrame: Boolean = false, colNames: Array[String] = null, mode: Int = RRunnerModes.RDD) - extends Logging { - protected var bootTime: Double = _ - private var dataStream: DataInputStream = _ - val readData = numPartitions match { - case -1 => - serializer match { - case SerializationFormats.STRING => readStringData _ - case _ => readByteArrayData _ - } - case _ => readShuffledData _ - } - - def compute( - inputIterator: Iterator[_], - partitionIndex: Int): Iterator[U] = { - // Timing start - bootTime = System.currentTimeMillis / 1000.0 - - // we expect two connections - val serverSocket = new ServerSocket(0, 2, InetAddress.getByName("localhost")) - val listenPort = serverSocket.getLocalPort() - - // The stdout/stderr is shared by multiple tasks, because we use one daemon - // to launch child process as worker. - val errThread = RRunner.createRWorker(listenPort) - - // We use two sockets to separate input and output, then it's easy to manage - // the lifecycle of them to avoid deadlock. - // TODO: optimize it to use one socket - - // the socket used to send out the input of task - serverSocket.setSoTimeout(10000) - dataStream = try { - val inSocket = serverSocket.accept() - RRunner.authHelper.authClient(inSocket) - startStdinThread(inSocket.getOutputStream(), inputIterator, partitionIndex) - - // the socket used to receive the output of task - val outSocket = serverSocket.accept() - RRunner.authHelper.authClient(outSocket) - val inputStream = new BufferedInputStream(outSocket.getInputStream) - new DataInputStream(inputStream) - } finally { - serverSocket.close() - } - - try { - newReaderIterator(dataStream, errThread) - } catch { - case e: Exception => - throw new SparkException("R computation failed with\n " + errThread.getLines(), e) - } - } + extends BaseRRunner[IN, OUT]( + func, + deserializer, + serializer, + packageNames, + broadcastVars, + numPartitions, + isDataFrame, + colNames, + mode) { protected def newReaderIterator( - dataStream: DataInputStream, errThread: BufferedStreamThread): Iterator[U] = { - new Iterator[U] { - def next(): U = { - val obj = _nextObj - if (hasNext()) { - _nextObj = read() - } - obj + dataStream: DataInputStream, errThread: BufferedStreamThread): ReaderIterator = { + new ReaderIterator(dataStream, errThread) { + private val readData = numPartitions match { + case -1 => + serializer match { + case SerializationFormats.STRING => readStringData _ + case _ => readByteArrayData _ + } + case _ => readShuffledData _ } - private var _nextObj = read() - - def hasNext(): Boolean = { - val hasMore = _nextObj != null - if (!hasMore) { - dataStream.close() + private def readShuffledData(length: Int): (Int, Array[Byte]) = { + length match { + case length if length == 2 => + val hashedKey = dataStream.readInt() + val contentPairsLength = dataStream.readInt() + val contentPairs = new Array[Byte](contentPairsLength) + dataStream.readFully(contentPairs) + (hashedKey, contentPairs) + case _ => null } - hasMore } - } - } - protected def writeData( - dataOut: DataOutputStream, - printOut: PrintStream, - iter: Iterator[_]): Unit = { - def writeElem(elem: Any): Unit = { - if (deserializer == SerializationFormats.BYTE) { - val elemArr = elem.asInstanceOf[Array[Byte]] - dataOut.writeInt(elemArr.length) - dataOut.write(elemArr) - } else if (deserializer == SerializationFormats.ROW) { - dataOut.write(elem.asInstanceOf[Array[Byte]]) - } else if (deserializer == SerializationFormats.STRING) { - // write string(for StringRRDD) - // scalastyle:off println - printOut.println(elem) - // scalastyle:on println + private def readByteArrayData(length: Int): Array[Byte] = { + length match { + case length if length > 0 => + val obj = new Array[Byte](length) + dataStream.readFully(obj) + obj + case _ => null + } } - } - for (elem <- iter) { - elem match { - case (key, innerIter: Iterator[_]) => - for (innerElem <- innerIter) { - writeElem(innerElem) - } - // Writes key which can be used as a boundary in group-aggregate - dataOut.writeByte('r') - writeElem(key) - case (key, value) => - writeElem(key) - writeElem(value) - case _ => - writeElem(elem) + private def readStringData(length: Int): String = { + length match { + case length if length > 0 => + SerDe.readStringBytes(dataStream, length) + case _ => null + } } - } - } - - /** - * Start a thread to write RDD data to the R process. - */ - private def startStdinThread( - output: OutputStream, - iter: Iterator[_], - partitionIndex: Int): Unit = { - val env = SparkEnv.get - val taskContext = TaskContext.get() - val bufferSize = System.getProperty(BUFFER_SIZE.key, - BUFFER_SIZE.defaultValueString).toInt - val stream = new BufferedOutputStream(output, bufferSize) - new Thread("writer for R") { - override def run(): Unit = { + /** + * Reads next object from the stream. + * When the stream reaches end of data, needs to process the following sections, + * and then returns null. + */ + override protected def read(): OUT = { try { - SparkEnv.set(env) - TaskContext.setTaskContext(taskContext) - val dataOut = new DataOutputStream(stream) - dataOut.writeInt(partitionIndex) - - SerDe.writeString(dataOut, deserializer) - SerDe.writeString(dataOut, serializer) - - dataOut.writeInt(packageNames.length) - dataOut.write(packageNames) - - dataOut.writeInt(func.length) - dataOut.write(func) - - dataOut.writeInt(broadcastVars.length) - broadcastVars.foreach { broadcast => - // TODO(shivaram): Read a Long in R to avoid this cast - dataOut.writeInt(broadcast.id.toInt) - // TODO: Pass a byte array from R to avoid this cast ? - val broadcastByteArr = broadcast.value.asInstanceOf[Array[Byte]] - dataOut.writeInt(broadcastByteArr.length) - dataOut.write(broadcastByteArr) - } - - dataOut.writeInt(numPartitions) - dataOut.writeInt(mode) - - if (isDataFrame) { - SerDe.writeObject(dataOut, colNames, jvmObjectTracker = null) + val length = dataStream.readInt() + + length match { + case SpecialLengths.TIMING_DATA => + // Timing data from R worker + val boot = dataStream.readDouble - bootTime + val init = dataStream.readDouble + val broadcast = dataStream.readDouble + val input = dataStream.readDouble + val compute = dataStream.readDouble + val output = dataStream.readDouble + logInfo( + ("Times: boot = %.3f s, init = %.3f s, broadcast = %.3f s, " + + "read-input = %.3f s, compute = %.3f s, write-output = %.3f s, " + + "total = %.3f s").format( + boot, + init, + broadcast, + input, + compute, + output, + boot + init + broadcast + input + compute + output)) + read() + case length if length > 0 => + readData(length).asInstanceOf[OUT] + case length if length == 0 => + // End of stream + eos = true + null.asInstanceOf[OUT] } - - if (!iter.hasNext) { - dataOut.writeInt(0) - } else { - dataOut.writeInt(1) - } - - val printOut = new PrintStream(stream) - - writeData(dataOut, printOut, iter) - - stream.flush() } catch { - // TODO: We should propagate this error to the task thread - case e: Exception => - logError("R Writer thread got an exception", e) - } finally { - Try(output.close()) + case eof: EOFException => + throw new SparkException("R worker exited unexpectedly (cranshed)", eof) } } - }.start() - } - - private def read(): U = { - try { - val length = dataStream.readInt() - - length match { - case SpecialLengths.TIMING_DATA => - // Timing data from R worker - val boot = dataStream.readDouble - bootTime - val init = dataStream.readDouble - val broadcast = dataStream.readDouble - val input = dataStream.readDouble - val compute = dataStream.readDouble - val output = dataStream.readDouble - logInfo( - ("Times: boot = %.3f s, init = %.3f s, broadcast = %.3f s, " + - "read-input = %.3f s, compute = %.3f s, write-output = %.3f s, " + - "total = %.3f s").format( - boot, - init, - broadcast, - input, - compute, - output, - boot + init + broadcast + input + compute + output)) - read() - case length if length >= 0 => - readData(length).asInstanceOf[U] - } - } catch { - case eof: EOFException => - throw new SparkException("R worker exited unexpectedly (cranshed)", eof) } } - private def readShuffledData(length: Int): (Int, Array[Byte]) = { - length match { - case length if length == 2 => - val hashedKey = dataStream.readInt() - val contentPairsLength = dataStream.readInt() - val contentPairs = new Array[Byte](contentPairsLength) - dataStream.readFully(contentPairs) - (hashedKey, contentPairs) - case _ => null - } - } - - protected def readByteArrayData(length: Int): Array[Byte] = { - length match { - case length if length > 0 => - val obj = new Array[Byte](length) - dataStream.readFully(obj) - obj - case _ => null - } - } - - private def readStringData(length: Int): String = { - length match { - case length if length > 0 => - SerDe.readStringBytes(dataStream, length) - case _ => null - } - } -} - -private[spark] object SpecialLengths { - val TIMING_DATA = -1 -} - -private[spark] object RRunnerModes { - val RDD = 0 - val DATAFRAME_DAPPLY = 1 - val DATAFRAME_GAPPLY = 2 -} - -private[spark] class BufferedStreamThread( - in: InputStream, - name: String, - errBufferSize: Int) extends Thread(name) with Logging { - val lines = new Array[String](errBufferSize) - var lineIdx = 0 - override def run() { - for (line <- Source.fromInputStream(in).getLines) { - synchronized { - lines(lineIdx) = line - lineIdx = (lineIdx + 1) % errBufferSize - } - logInfo(line) - } - } - - def getLines(): String = synchronized { - (0 until errBufferSize).filter { x => - lines((x + lineIdx) % errBufferSize) != null - }.map { x => - lines((x + lineIdx) % errBufferSize) - }.mkString("\n") - } -} - -private[r] object RRunner { - // Because forking processes from Java is expensive, we prefer to launch - // a single R daemon (daemon.R) and tell it to fork new workers for our tasks. - // This daemon currently only works on UNIX-based systems now, so we should - // also fall back to launching workers (worker.R) directly. - private[this] var errThread: BufferedStreamThread = _ - private[this] var daemonChannel: DataOutputStream = _ - - private lazy val authHelper = { - val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) - new RAuthHelper(conf) - } - - /** - * Start a thread to print the process's stderr to ours - */ - private def startStdoutThread(proc: Process): BufferedStreamThread = { - val BUFFER_SIZE = 100 - val thread = new BufferedStreamThread(proc.getInputStream, "stdout reader for R", BUFFER_SIZE) - thread.setDaemon(true) - thread.start() - thread - } - - private def createRProcess(port: Int, script: String): BufferedStreamThread = { - // "spark.sparkr.r.command" is deprecated and replaced by "spark.r.command", - // but kept here for backward compatibility. - val sparkConf = SparkEnv.get.conf - var rCommand = sparkConf.get(SPARKR_COMMAND) - rCommand = sparkConf.get(R_COMMAND).orElse(Some(rCommand)).get - - val rConnectionTimeout = sparkConf.get(R_BACKEND_CONNECTION_TIMEOUT) - val rOptions = "--vanilla" - val rLibDir = RUtils.sparkRPackagePath(isDriver = false) - val rExecScript = rLibDir(0) + "/SparkR/worker/" + script - val pb = new ProcessBuilder(Arrays.asList(rCommand, rOptions, rExecScript)) - // Unset the R_TESTS environment variable for workers. - // This is set by R CMD check as startup.Rs - // (http://svn.r-project.org/R/trunk/src/library/tools/R/testing.R) - // and confuses worker script which tries to load a non-existent file - pb.environment().put("R_TESTS", "") - pb.environment().put("SPARKR_RLIBDIR", rLibDir.mkString(",")) - pb.environment().put("SPARKR_WORKER_PORT", port.toString) - pb.environment().put("SPARKR_BACKEND_CONNECTION_TIMEOUT", rConnectionTimeout.toString) - pb.environment().put("SPARKR_SPARKFILES_ROOT_DIR", SparkFiles.getRootDirectory()) - pb.environment().put("SPARKR_IS_RUNNING_ON_WORKER", "TRUE") - pb.environment().put("SPARKR_WORKER_SECRET", authHelper.secret) - pb.redirectErrorStream(true) // redirect stderr into stdout - val proc = pb.start() - val errThread = startStdoutThread(proc) - errThread - } - /** - * ProcessBuilder used to launch worker R processes. + * Start a thread to write RDD data to the R process. */ - def createRWorker(port: Int): BufferedStreamThread = { - val useDaemon = SparkEnv.get.conf.getBoolean("spark.sparkr.use.daemon", true) - if (!Utils.isWindows && useDaemon) { - synchronized { - if (daemonChannel == null) { - // we expect one connections - val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) - val daemonPort = serverSocket.getLocalPort - errThread = createRProcess(daemonPort, "daemon.R") - // the socket used to send out the input of task - serverSocket.setSoTimeout(10000) - val sock = serverSocket.accept() - try { - authHelper.authClient(sock) - daemonChannel = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) - } finally { - serverSocket.close() + protected override def newWriterThread( + output: OutputStream, + iter: Iterator[IN], + partitionIndex: Int): WriterThread = { + new WriterThread(output, iter, partitionIndex) { + + /** + * Writes input data to the stream connected to the R worker. + */ + override protected def writeIteratorToStream(dataOut: DataOutputStream): Unit = { + def writeElem(elem: Any): Unit = { + if (deserializer == SerializationFormats.BYTE) { + val elemArr = elem.asInstanceOf[Array[Byte]] + dataOut.writeInt(elemArr.length) + dataOut.write(elemArr) + } else if (deserializer == SerializationFormats.ROW) { + dataOut.write(elem.asInstanceOf[Array[Byte]]) + } else if (deserializer == SerializationFormats.STRING) { + // write string(for StringRRDD) + // scalastyle:off println + printOut.println(elem) + // scalastyle:on println } } - try { - daemonChannel.writeInt(port) - daemonChannel.flush() - } catch { - case e: IOException => - // daemon process died - daemonChannel.close() - daemonChannel = null - errThread = null - // fail the current task, retry by scheduler - throw e + + for (elem <- iter) { + elem match { + case (key, innerIter: Iterator[_]) => + for (innerElem <- innerIter) { + writeElem(innerElem) + } + // Writes key which can be used as a boundary in group-aggregate + dataOut.writeByte('r') + writeElem(key) + case (key, value) => + writeElem(key) + writeElem(value) + case _ => + writeElem(elem) + } } - errThread } - } else { - createRProcess(port, "worker.R") } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index d298245..bedfa9c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution +import java.io.{ByteArrayOutputStream, DataOutputStream} + import scala.collection.JavaConverters._ import scala.language.existentials @@ -490,7 +492,7 @@ case class FlatMapGroupsInRExec( val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes) val getValue = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) - val runner = new RRunner[Array[Byte]]( + val runner = new RRunner[(Array[Byte], Iterator[Array[Byte]]), Array[Byte]]( func, SerializationFormats.ROW, serializerForR, packageNames, broadcastVars, isDataFrame = true, colNames = inputSchema.fieldNames, mode = RRunnerModes.DATAFRAME_GAPPLY) @@ -548,12 +550,22 @@ case class FlatMapGroupsInRWithArrowExec( child.execute().mapPartitionsInternal { iter => val grouped = GroupedIterator(iter, groupingAttributes, child.output) val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes) - val runner = new ArrowRRunner(func, packageNames, broadcastVars, inputSchema, - SQLConf.get.sessionLocalTimeZone, RRunnerModes.DATAFRAME_GAPPLY) - val groupedByRKey = grouped.map { case (key, rowIter) => - val newKey = rowToRBytes(getKey(key).asInstanceOf[Row]) - (newKey, rowIter) + val keys = collection.mutable.ArrayBuffer.empty[Array[Byte]] + val groupedByRKey: Iterator[Iterator[InternalRow]] = + grouped.map { case (key, rowIter) => + keys.append(rowToRBytes(getKey(key).asInstanceOf[Row])) + rowIter + } + + val runner = new ArrowRRunner(func, packageNames, broadcastVars, inputSchema, + SQLConf.get.sessionLocalTimeZone, RRunnerModes.DATAFRAME_GAPPLY) { + protected override def bufferedWrite( + dataOut: DataOutputStream)(writeFunc: ByteArrayOutputStream => Unit): Unit = { + super.bufferedWrite(dataOut)(writeFunc) + // Don't forget we're sending keys additionally. + keys.foreach(dataOut.write) + } } // The communication mechanism is as follows: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala index ee1f2e3..a94cb0b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala @@ -47,7 +47,7 @@ class ArrowRRunner( schema: StructType, timeZoneId: String, mode: Int) - extends RRunner[ColumnarBatch]( + extends BaseRRunner[Iterator[InternalRow], ColumnarBatch]( func, "arrow", "arrow", @@ -58,60 +58,10 @@ class ArrowRRunner( schema.fieldNames, mode) { - // TODO: it needs to refactor to share the same code with RRunner, and have separate - // ArrowRRunners. - private val getNextBatch = { - if (mode == RRunnerModes.DATAFRAME_GAPPLY) { - // gapply - (inputIterator: Iterator[_], keys: collection.mutable.ArrayBuffer[Array[Byte]]) => { - val (key, nextBatch) = inputIterator - .asInstanceOf[Iterator[(Array[Byte], Iterator[InternalRow])]].next() - keys.append(key) - nextBatch - } - } else { - // dapply - (inputIterator: Iterator[_], keys: collection.mutable.ArrayBuffer[Array[Byte]]) => { - inputIterator - .asInstanceOf[Iterator[Iterator[InternalRow]]].next() - } - } - } - - protected override def writeData( - dataOut: DataOutputStream, - printOut: PrintStream, - inputIterator: Iterator[_]): Unit = if (inputIterator.hasNext) { - val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) - val allocator = ArrowUtils.rootAllocator.newChildAllocator( - "stdout writer for R", 0, Long.MaxValue) - val root = VectorSchemaRoot.create(arrowSchema, allocator) + protected def bufferedWrite( + dataOut: DataOutputStream)(writeFunc: ByteArrayOutputStream => Unit): Unit = { val out = new ByteArrayOutputStream() - val keys = collection.mutable.ArrayBuffer.empty[Array[Byte]] - - Utils.tryWithSafeFinally { - val arrowWriter = ArrowWriter.create(root) - val writer = new ArrowStreamWriter(root, null, Channels.newChannel(out)) - writer.start() - - while (inputIterator.hasNext) { - val nextBatch: Iterator[InternalRow] = getNextBatch(inputIterator, keys) - - while (nextBatch.hasNext) { - arrowWriter.write(nextBatch.next()) - } - - arrowWriter.finish() - writer.writeBatch() - arrowWriter.reset() - } - writer.end() - } { - // Don't close root and allocator in TaskCompletionListener to prevent - // a race condition. See `ArrowPythonRunner`. - root.close() - allocator.close() - } + writeFunc(out) // Currently, there looks no way to read batch by batch by socket connection in R side, // See ARROW-4512. Therefore, it writes the whole Arrow streaming-formatted binary at @@ -119,13 +69,57 @@ class ArrowRRunner( val data = out.toByteArray dataOut.writeInt(data.length) dataOut.write(data) + } - keys.foreach(dataOut.write) + protected override def newWriterThread( + output: OutputStream, + inputIterator: Iterator[Iterator[InternalRow]], + partitionIndex: Int): WriterThread = { + new WriterThread(output, inputIterator, partitionIndex) { + + /** + * Writes input data to the stream connected to the R worker. + */ + override protected def writeIteratorToStream(dataOut: DataOutputStream): Unit = { + if (inputIterator.hasNext) { + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + val allocator = ArrowUtils.rootAllocator.newChildAllocator( + "stdout writer for R", 0, Long.MaxValue) + val root = VectorSchemaRoot.create(arrowSchema, allocator) + + bufferedWrite(dataOut) { out => + Utils.tryWithSafeFinally { + val arrowWriter = ArrowWriter.create(root) + val writer = new ArrowStreamWriter(root, null, Channels.newChannel(out)) + writer.start() + + while (inputIterator.hasNext) { + val nextBatch: Iterator[InternalRow] = inputIterator.next() + + while (nextBatch.hasNext) { + arrowWriter.write(nextBatch.next()) + } + + arrowWriter.finish() + writer.writeBatch() + arrowWriter.reset() + } + writer.end() + } { + // Don't close root and allocator in TaskCompletionListener to prevent + // a race condition. See `ArrowPythonRunner`. + root.close() + allocator.close() + } + } + } + } + } } protected override def newReaderIterator( - dataStream: DataInputStream, errThread: BufferedStreamThread): Iterator[ColumnarBatch] = { - new Iterator[ColumnarBatch] { + dataStream: DataInputStream, errThread: BufferedStreamThread): ReaderIterator = { + new ReaderIterator(dataStream, errThread) { private val allocator = ArrowUtils.rootAllocator.newChildAllocator( "stdin reader for R", 0, Long.MaxValue) @@ -141,29 +135,8 @@ class ArrowRRunner( } private var batchLoaded = true - private var nextObj: ColumnarBatch = _ - private var eos = false - - override def hasNext: Boolean = nextObj != null || { - if (!eos) { - nextObj = read() - hasNext - } else { - false - } - } - - override def next(): ColumnarBatch = { - if (hasNext) { - val obj = nextObj - nextObj = null.asInstanceOf[ColumnarBatch] - obj - } else { - Iterator.empty.next() - } - } - private def read(): ColumnarBatch = try { + protected override def read(): ColumnarBatch = try { if (reader != null && batchLoaded) { batchLoaded = reader.loadNextBatch() if (batchLoaded) { @@ -173,8 +146,8 @@ class ArrowRRunner( } else { reader.close(false) allocator.close() - eos = true - null + // Should read timing data after this. + read() } } else { dataStream.readInt() match { @@ -202,7 +175,9 @@ class ArrowRRunner( // Likewise, there looks no way to send each batch in streaming format via socket // connection. See ARROW-4512. // So, it reads the whole Arrow streaming-formatted binary at once for now. - val in = new ByteArrayReadableSeekableByteChannel(readByteArrayData(length)) + val buffer = new Array[Byte](length) + dataStream.readFully(buffer) + val in = new ByteArrayReadableSeekableByteChannel(buffer) reader = new ArrowStreamReader(in, allocator) root = reader.getVectorSchemaRoot vectors = root.getFieldVectors.asScala.map { vector => @@ -210,6 +185,7 @@ class ArrowRRunner( }.toArray[ColumnVector] read() case length if length == 0 => + // End of stream eos = true null } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala index a62016d..a3a4088 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala @@ -51,7 +51,7 @@ case class MapPartitionsRWrapper( SerializationFormats.BYTE } - val runner = new RRunner[Array[Byte]]( + val runner = new RRunner[Any, Array[Byte]]( func, deserializer, serializer, packageNames, broadcastVars, isDataFrame = true, colNames = colNames, mode = RRunnerModes.DATAFRAME_DAPPLY) // Partition index is ignored. Dataset has no support for mapPartitionsWithIndex. --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org