Repository: spark
Updated Branches:
  refs/heads/branch-1.6 486db8789 -> 68bcb9b33


[SPARK-11140][CORE] Transfer files using network lib when using NettyRpcEnv - 
1.6.version.

This patch is the same code as in SPARK-11140 in master, but with some added 
code to still use the HTTP file server by default in NettyRpcEnv. This is 
mostly to avoid conflicts when backporting patches to 1.6.

Author: Marcelo Vanzin <van...@cloudera.com>

Closes #9947 from vanzin/SPARK-11140-branch-1.6.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/68bcb9b3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/68bcb9b3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/68bcb9b3

Branch: refs/heads/branch-1.6
Commit: 68bcb9b33b731af33f6c9444c8c2fc54e50b3202
Parents: 486db87
Author: Marcelo Vanzin <van...@cloudera.com>
Authored: Tue Nov 24 21:48:51 2015 -0800
Committer: Reynold Xin <r...@databricks.com>
Committed: Tue Nov 24 21:48:51 2015 -0800

----------------------------------------------------------------------
 .../scala/org/apache/spark/SparkContext.scala   |   8 +-
 .../main/scala/org/apache/spark/SparkEnv.scala  |  14 --
 .../scala/org/apache/spark/rpc/RpcEnv.scala     |  46 ++++++
 .../org/apache/spark/rpc/akka/AkkaRpcEnv.scala  |  60 +++++++-
 .../spark/rpc/netty/HttpBasedFileServer.scala   |  59 ++++++++
 .../apache/spark/rpc/netty/NettyRpcEnv.scala    | 147 +++++++++++++++++--
 .../spark/rpc/netty/NettyStreamManager.scala    |  63 ++++++++
 .../scala/org/apache/spark/util/Utils.scala     |   9 ++
 .../org/apache/spark/rpc/RpcEnvSuite.scala      |  39 ++++-
 .../spark/rpc/netty/NettyRpcHandlerSuite.scala  |  10 +-
 .../spark/launcher/AbstractCommandBuilder.java  |   2 +-
 .../network/client/TransportClientFactory.java  |   6 +-
 .../network/server/TransportChannelHandler.java |   1 +
 13 files changed, 418 insertions(+), 46 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/68bcb9b3/core/src/main/scala/org/apache/spark/SparkContext.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala 
b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 90480e5..e19ba11 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -1379,7 +1379,7 @@ class SparkContext(config: SparkConf) extends Logging 
with ExecutorAllocationCli
     }
 
     val key = if (!isLocal && scheme == "file") {
-      env.httpFileServer.addFile(new File(uri.getPath))
+      env.rpcEnv.fileServer.addFile(new File(uri.getPath))
     } else {
       schemeCorrectedPath
     }
@@ -1630,7 +1630,7 @@ class SparkContext(config: SparkConf) extends Logging 
with ExecutorAllocationCli
       var key = ""
       if (path.contains("\\")) {
         // For local paths with backslashes on Windows, URI throws an exception
-        key = env.httpFileServer.addJar(new File(path))
+        key = env.rpcEnv.fileServer.addJar(new File(path))
       } else {
         val uri = new URI(path)
         key = uri.getScheme match {
@@ -1644,7 +1644,7 @@ class SparkContext(config: SparkConf) extends Logging 
with ExecutorAllocationCli
               // of the AM to make it show up in the current working directory.
               val fileName = new Path(uri.getPath).getName()
               try {
-                env.httpFileServer.addJar(new File(fileName))
+                env.rpcEnv.fileServer.addJar(new File(fileName))
               } catch {
                 case e: Exception =>
                   // For now just log an error but allow to go through so 
spark examples work.
@@ -1655,7 +1655,7 @@ class SparkContext(config: SparkConf) extends Logging 
with ExecutorAllocationCli
               }
             } else {
               try {
-                env.httpFileServer.addJar(new File(uri.getPath))
+                env.rpcEnv.fileServer.addJar(new File(uri.getPath))
               } catch {
                 case exc: FileNotFoundException =>
                   logError(s"Jar not found at $path")

http://git-wip-us.apache.org/repos/asf/spark/blob/68bcb9b3/core/src/main/scala/org/apache/spark/SparkEnv.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala 
b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 88df27f..84230e3 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -66,7 +66,6 @@ class SparkEnv (
     val blockTransferService: BlockTransferService,
     val blockManager: BlockManager,
     val securityManager: SecurityManager,
-    val httpFileServer: HttpFileServer,
     val sparkFilesDir: String,
     val metricsSystem: MetricsSystem,
     val memoryManager: MemoryManager,
@@ -91,7 +90,6 @@ class SparkEnv (
     if (!isStopped) {
       isStopped = true
       pythonWorkers.values.foreach(_.stop())
-      Option(httpFileServer).foreach(_.stop())
       mapOutputTracker.stop()
       shuffleManager.stop()
       broadcastManager.stop()
@@ -367,17 +365,6 @@ object SparkEnv extends Logging {
 
     val cacheManager = new CacheManager(blockManager)
 
-    val httpFileServer =
-      if (isDriver) {
-        val fileServerPort = conf.getInt("spark.fileserver.port", 0)
-        val server = new HttpFileServer(conf, securityManager, fileServerPort)
-        server.initialize()
-        conf.set("spark.fileserver.uri", server.serverUri)
-        server
-      } else {
-        null
-      }
-
     val metricsSystem = if (isDriver) {
       // Don't start metrics system right now for Driver.
       // We need to wait for the task scheduler to give us an app ID.
@@ -422,7 +409,6 @@ object SparkEnv extends Logging {
       blockTransferService,
       blockManager,
       securityManager,
-      httpFileServer,
       sparkFilesDir,
       metricsSystem,
       memoryManager,

http://git-wip-us.apache.org/repos/asf/spark/blob/68bcb9b3/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala 
b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
index a560fd1..3d7d281 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
@@ -17,6 +17,9 @@
 
 package org.apache.spark.rpc
 
+import java.io.File
+import java.nio.channels.ReadableByteChannel
+
 import scala.concurrent.Future
 
 import org.apache.spark.{SecurityManager, SparkConf}
@@ -132,8 +135,51 @@ private[spark] abstract class RpcEnv(conf: SparkConf) {
    * that contains [[RpcEndpointRef]]s, the deserialization codes should be 
wrapped by this method.
    */
   def deserialize[T](deserializationAction: () => T): T
+
+  /**
+   * Return the instance of the file server used to serve files. This may be 
`null` if the
+   * RpcEnv is not operating in server mode.
+   */
+  def fileServer: RpcEnvFileServer
+
+  /**
+   * Open a channel to download a file from the given URI. If the URIs 
returned by the
+   * RpcEnvFileServer use the "spark" scheme, this method will be called by 
the Utils class to
+   * retrieve the files.
+   *
+   * @param uri URI with location of the file.
+   */
+  def openChannel(uri: String): ReadableByteChannel
+
 }
 
+/**
+ * A server used by the RpcEnv to server files to other processes owned by the 
application.
+ *
+ * The file server can return URIs handled by common libraries (such as "http" 
or "hdfs"), or
+ * it can return "spark" URIs which will be handled by `RpcEnv#fetchFile`.
+ */
+private[spark] trait RpcEnvFileServer {
+
+  /**
+   * Adds a file to be served by this RpcEnv. This is used to serve files from 
the driver
+   * to executors when they're stored on the driver's local file system.
+   *
+   * @param file Local file to serve.
+   * @return A URI for the location of the file.
+   */
+  def addFile(file: File): String
+
+  /**
+   * Adds a jar to be served by this RpcEnv. Similar to `addFile` but for jars 
added using
+   * `SparkContext.addJar`.
+   *
+   * @param file Local file to serve.
+   * @return A URI for the location of the file.
+   */
+  def addJar(file: File): String
+
+}
 
 private[spark] case class RpcEnvConfig(
     conf: SparkConf,

http://git-wip-us.apache.org/repos/asf/spark/blob/68bcb9b3/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala 
b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
index 059a7e1..94dbec5 100644
--- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.rpc.akka
 
+import java.io.File
+import java.nio.channels.ReadableByteChannel
 import java.util.concurrent.ConcurrentHashMap
 
 import scala.concurrent.Future
@@ -30,7 +32,7 @@ import akka.pattern.{ask => akkaAsk}
 import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, 
AssociationErrorEvent}
 import akka.serialization.JavaSerializer
 
-import org.apache.spark.{SparkException, Logging, SparkConf}
+import org.apache.spark.{HttpFileServer, Logging, SecurityManager, SparkConf, 
SparkException}
 import org.apache.spark.rpc._
 import org.apache.spark.util.{ActorLogReceive, AkkaUtils, ThreadUtils}
 
@@ -41,7 +43,10 @@ import org.apache.spark.util.{ActorLogReceive, AkkaUtils, 
ThreadUtils}
  * remove Akka from the dependencies.
  */
 private[spark] class AkkaRpcEnv private[akka] (
-    val actorSystem: ActorSystem, conf: SparkConf, boundPort: Int)
+    val actorSystem: ActorSystem,
+    val securityManager: SecurityManager,
+    conf: SparkConf,
+    boundPort: Int)
   extends RpcEnv(conf) with Logging {
 
   private val defaultAddress: RpcAddress = {
@@ -64,6 +69,8 @@ private[spark] class AkkaRpcEnv private[akka] (
    */
   private val refToEndpoint = new ConcurrentHashMap[RpcEndpointRef, 
RpcEndpoint]()
 
+  private val _fileServer = new AkkaFileServer(conf, securityManager)
+
   private def registerEndpoint(endpoint: RpcEndpoint, endpointRef: 
RpcEndpointRef): Unit = {
     endpointToRef.put(endpoint, endpointRef)
     refToEndpoint.put(endpointRef, endpoint)
@@ -223,6 +230,7 @@ private[spark] class AkkaRpcEnv private[akka] (
 
   override def shutdown(): Unit = {
     actorSystem.shutdown()
+    _fileServer.shutdown()
   }
 
   override def stop(endpoint: RpcEndpointRef): Unit = {
@@ -241,6 +249,52 @@ private[spark] class AkkaRpcEnv private[akka] (
       deserializationAction()
     }
   }
+
+  override def openChannel(uri: String): ReadableByteChannel = {
+    throw new UnsupportedOperationException(
+      "AkkaRpcEnv's files should be retrieved using an HTTP client.")
+  }
+
+  override def fileServer: RpcEnvFileServer = _fileServer
+
+}
+
+private[akka] class AkkaFileServer(
+    conf: SparkConf,
+    securityManager: SecurityManager) extends RpcEnvFileServer {
+
+  @volatile private var httpFileServer: HttpFileServer = _
+
+  override def addFile(file: File): String = {
+    getFileServer().addFile(file)
+  }
+
+  override def addJar(file: File): String = {
+    getFileServer().addJar(file)
+  }
+
+  def shutdown(): Unit = {
+    if (httpFileServer != null) {
+      httpFileServer.stop()
+    }
+  }
+
+  private def getFileServer(): HttpFileServer = {
+    if (httpFileServer == null) synchronized {
+      if (httpFileServer == null) {
+        httpFileServer = startFileServer()
+      }
+    }
+    httpFileServer
+  }
+
+  private def startFileServer(): HttpFileServer = {
+    val fileServerPort = conf.getInt("spark.fileserver.port", 0)
+    val server = new HttpFileServer(conf, securityManager, fileServerPort)
+    server.initialize()
+    server
+  }
+
 }
 
 private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory {
@@ -249,7 +303,7 @@ private[spark] class AkkaRpcEnvFactory extends 
RpcEnvFactory {
     val (actorSystem, boundPort) = AkkaUtils.createActorSystem(
       config.name, config.host, config.port, config.conf, 
config.securityManager)
     actorSystem.actorOf(Props(classOf[ErrorMonitor]), "ErrorMonitor")
-    new AkkaRpcEnv(actorSystem, config.conf, boundPort)
+    new AkkaRpcEnv(actorSystem, config.securityManager, config.conf, boundPort)
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/68bcb9b3/core/src/main/scala/org/apache/spark/rpc/netty/HttpBasedFileServer.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/rpc/netty/HttpBasedFileServer.scala 
b/core/src/main/scala/org/apache/spark/rpc/netty/HttpBasedFileServer.scala
new file mode 100644
index 0000000..8a7a409
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/HttpBasedFileServer.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.rpc.netty
+
+import java.io.File
+
+import org.apache.spark.{HttpFileServer, SecurityManager, SparkConf}
+import org.apache.spark.rpc.RpcEnvFileServer
+
+private[netty] class HttpBasedFileServer(conf: SparkConf, securityManager: 
SecurityManager)
+  extends RpcEnvFileServer {
+
+  @volatile private var httpFileServer: HttpFileServer = _
+
+  override def addFile(file: File): String = {
+    getFileServer().addFile(file)
+  }
+
+  override def addJar(file: File): String = {
+    getFileServer().addJar(file)
+  }
+
+  def shutdown(): Unit = {
+    if (httpFileServer != null) {
+      httpFileServer.stop()
+    }
+  }
+
+  private def getFileServer(): HttpFileServer = {
+    if (httpFileServer == null) synchronized {
+      if (httpFileServer == null) {
+        httpFileServer = startFileServer()
+      }
+    }
+    httpFileServer
+  }
+
+  private def startFileServer(): HttpFileServer = {
+    val fileServerPort = conf.getInt("spark.fileserver.port", 0)
+    val server = new HttpFileServer(conf, securityManager, fileServerPort)
+    server.initialize()
+    server
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/68bcb9b3/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala 
b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
index 3ce3598..7495f3e 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
@@ -20,6 +20,7 @@ import java.io._
 import java.lang.{Boolean => JBoolean}
 import java.net.{InetSocketAddress, URI}
 import java.nio.ByteBuffer
+import java.nio.channels.{Pipe, ReadableByteChannel, WritableByteChannel}
 import java.util.concurrent._
 import java.util.concurrent.atomic.AtomicBoolean
 import javax.annotation.Nullable
@@ -29,7 +30,7 @@ import scala.reflect.ClassTag
 import scala.util.{DynamicVariable, Failure, Success}
 import scala.util.control.NonFatal
 
-import org.apache.spark.{Logging, SecurityManager, SparkConf}
+import org.apache.spark.{Logging, HttpFileServer, SecurityManager, SparkConf}
 import org.apache.spark.network.TransportContext
 import org.apache.spark.network.client._
 import org.apache.spark.network.netty.SparkTransportConf
@@ -45,27 +46,46 @@ private[netty] class NettyRpcEnv(
     host: String,
     securityManager: SecurityManager) extends RpcEnv(conf) with Logging {
 
-  private val transportConf = SparkTransportConf.fromSparkConf(
+  private[netty] val transportConf = SparkTransportConf.fromSparkConf(
     conf.clone.set("spark.rpc.io.numConnectionsPerPeer", "1"),
     "rpc",
     conf.getInt("spark.rpc.io.threads", 0))
 
   private val dispatcher: Dispatcher = new Dispatcher(this)
 
+  private val streamManager = new NettyStreamManager(this)
+
+  private val _fileServer =
+    if (conf.getBoolean("spark.rpc.useNettyFileServer", false)) {
+      streamManager
+    } else {
+      new HttpBasedFileServer(conf, securityManager)
+    }
+
   private val transportContext = new TransportContext(transportConf,
-    new NettyRpcHandler(dispatcher, this))
+    new NettyRpcHandler(dispatcher, this, streamManager))
 
-  private val clientFactory = {
-    val bootstraps: java.util.List[TransportClientBootstrap] =
-      if (securityManager.isAuthenticationEnabled()) {
-        java.util.Arrays.asList(new SaslClientBootstrap(transportConf, "", 
securityManager,
-          securityManager.isSaslEncryptionEnabled()))
-      } else {
-        java.util.Collections.emptyList[TransportClientBootstrap]
-      }
-    transportContext.createClientFactory(bootstraps)
+  private def createClientBootstraps(): 
java.util.List[TransportClientBootstrap] = {
+    if (securityManager.isAuthenticationEnabled()) {
+      java.util.Arrays.asList(new SaslClientBootstrap(transportConf, "", 
securityManager,
+        securityManager.isSaslEncryptionEnabled()))
+    } else {
+      java.util.Collections.emptyList[TransportClientBootstrap]
+    }
   }
 
+  private val clientFactory = 
transportContext.createClientFactory(createClientBootstraps())
+
+  /**
+   * A separate client factory for file downloads. This avoids using the same 
RPC handler as
+   * the main RPC context, so that events caused by these clients are kept 
isolated from the
+   * main RPC traffic.
+   *
+   * It also allows for different configuration of certain properties, such as 
the number of
+   * connections per peer.
+   */
+  @volatile private var fileDownloadFactory: TransportClientFactory = _
+
   val timeoutScheduler = 
ThreadUtils.newDaemonSingleThreadScheduledExecutor("netty-rpc-env-timeout")
 
   // Because TransportClientFactory.createClient is blocking, we need to run 
it in this thread pool
@@ -292,6 +312,9 @@ private[netty] class NettyRpcEnv(
     if (clientConnectionExecutor != null) {
       clientConnectionExecutor.shutdownNow()
     }
+    if (fileDownloadFactory != null) {
+      fileDownloadFactory.close()
+    }
   }
 
   override def deserialize[T](deserializationAction: () => T): T = {
@@ -300,6 +323,96 @@ private[netty] class NettyRpcEnv(
     }
   }
 
+  override def fileServer: RpcEnvFileServer = _fileServer
+
+  override def openChannel(uri: String): ReadableByteChannel = {
+    val parsedUri = new URI(uri)
+    require(parsedUri.getHost() != null, "Host name must be defined.")
+    require(parsedUri.getPort() > 0, "Port must be defined.")
+    require(parsedUri.getPath() != null && parsedUri.getPath().nonEmpty, "Path 
must be defined.")
+
+    val pipe = Pipe.open()
+    val source = new FileDownloadChannel(pipe.source())
+    try {
+      val client = downloadClient(parsedUri.getHost(), parsedUri.getPort())
+      val callback = new FileDownloadCallback(pipe.sink(), source, client)
+      client.stream(parsedUri.getPath(), callback)
+    } catch {
+      case e: Exception =>
+        pipe.sink().close()
+        source.close()
+        throw e
+    }
+
+    source
+  }
+
+  private def downloadClient(host: String, port: Int): TransportClient = {
+    if (fileDownloadFactory == null) synchronized {
+      if (fileDownloadFactory == null) {
+        val module = "files"
+        val prefix = "spark.rpc.io."
+        val clone = conf.clone()
+
+        // Copy any RPC configuration that is not overridden in the 
spark.files namespace.
+        conf.getAll.foreach { case (key, value) =>
+          if (key.startsWith(prefix)) {
+            val opt = key.substring(prefix.length())
+            clone.setIfMissing(s"spark.$module.io.$opt", value)
+          }
+        }
+
+        val ioThreads = clone.getInt("spark.files.io.threads", 1)
+        val downloadConf = SparkTransportConf.fromSparkConf(clone, module, 
ioThreads)
+        val downloadContext = new TransportContext(downloadConf, new 
NoOpRpcHandler(), true)
+        fileDownloadFactory = 
downloadContext.createClientFactory(createClientBootstraps())
+      }
+    }
+    fileDownloadFactory.createClient(host, port)
+  }
+
+  private class FileDownloadChannel(source: ReadableByteChannel) extends 
ReadableByteChannel {
+
+    @volatile private var error: Throwable = _
+
+    def setError(e: Throwable): Unit = error = e
+
+    override def read(dst: ByteBuffer): Int = {
+      if (error != null) {
+        throw error
+      }
+      source.read(dst)
+    }
+
+    override def close(): Unit = source.close()
+
+    override def isOpen(): Boolean = source.isOpen()
+
+  }
+
+  private class FileDownloadCallback(
+      sink: WritableByteChannel,
+      source: FileDownloadChannel,
+      client: TransportClient) extends StreamCallback {
+
+    override def onData(streamId: String, buf: ByteBuffer): Unit = {
+      while (buf.remaining() > 0) {
+        sink.write(buf)
+      }
+    }
+
+    override def onComplete(streamId: String): Unit = {
+      sink.close()
+    }
+
+    override def onFailure(streamId: String, cause: Throwable): Unit = {
+      logError(s"Error downloading stream $streamId.", cause)
+      source.setError(cause)
+      sink.close()
+    }
+
+  }
+
 }
 
 private[netty] object NettyRpcEnv extends Logging {
@@ -420,7 +533,7 @@ private[netty] class NettyRpcEndpointRef(
 
   override def toString: String = s"NettyRpcEndpointRef(${_address})"
 
-  def toURI: URI = new URI(s"spark://${_address}")
+  def toURI: URI = new URI(_address.toString)
 
   final override def equals(that: Any): Boolean = that match {
     case other: NettyRpcEndpointRef => _address == other._address
@@ -471,7 +584,9 @@ private[netty] case class RpcFailure(e: Throwable)
  * with different `RpcAddress` information).
  */
 private[netty] class NettyRpcHandler(
-    dispatcher: Dispatcher, nettyEnv: NettyRpcEnv) extends RpcHandler with 
Logging {
+    dispatcher: Dispatcher,
+    nettyEnv: NettyRpcEnv,
+    streamManager: StreamManager) extends RpcHandler with Logging {
 
   // TODO: Can we add connection callback (channel registered) to the 
underlying framework?
   // A variable to track whether we should dispatch the RemoteProcessConnected 
message.
@@ -498,7 +613,7 @@ private[netty] class NettyRpcHandler(
     dispatcher.postRemoteMessage(messageToDispatch, callback)
   }
 
-  override def getStreamManager: StreamManager = new OneForOneStreamManager
+  override def getStreamManager: StreamManager = streamManager
 
   override def exceptionCaught(cause: Throwable, client: TransportClient): 
Unit = {
     val addr = 
client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]
@@ -516,8 +631,8 @@ private[netty] class NettyRpcHandler(
   override def connectionTerminated(client: TransportClient): Unit = {
     val addr = 
client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]
     if (addr != null) {
-      val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
       clients.remove(client)
+      val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
       nettyEnv.removeOutbox(clientAddr)
       dispatcher.postToAll(RemoteProcessDisconnected(clientAddr))
     } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/68bcb9b3/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala 
b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala
new file mode 100644
index 0000000..eb1d260
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.rpc.netty
+
+import java.io.File
+import java.util.concurrent.ConcurrentHashMap
+
+import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, 
ManagedBuffer}
+import org.apache.spark.network.server.StreamManager
+import org.apache.spark.rpc.RpcEnvFileServer
+
+/**
+ * StreamManager implementation for serving files from a NettyRpcEnv.
+ */
+private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv)
+  extends StreamManager with RpcEnvFileServer {
+
+  private val files = new ConcurrentHashMap[String, File]()
+  private val jars = new ConcurrentHashMap[String, File]()
+
+  override def getChunk(streamId: Long, chunkIndex: Int): ManagedBuffer = {
+    throw new UnsupportedOperationException()
+  }
+
+  override def openStream(streamId: String): ManagedBuffer = {
+    val Array(ftype, fname) = streamId.stripPrefix("/").split("/", 2)
+    val file = ftype match {
+      case "files" => files.get(fname)
+      case "jars" => jars.get(fname)
+      case _ => throw new IllegalArgumentException(s"Invalid file type: 
$ftype")
+    }
+
+    require(file != null, s"File not found: $streamId")
+    new FileSegmentManagedBuffer(rpcEnv.transportConf, file, 0, file.length())
+  }
+
+  override def addFile(file: File): String = {
+    require(files.putIfAbsent(file.getName(), file) == null,
+      s"File ${file.getName()} already registered.")
+    s"${rpcEnv.address.toSparkURL}/files/${file.getName()}"
+  }
+
+  override def addJar(file: File): String = {
+    require(jars.putIfAbsent(file.getName(), file) == null,
+      s"JAR ${file.getName()} already registered.")
+    s"${rpcEnv.address.toSparkURL}/jars/${file.getName()}"
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/68bcb9b3/core/src/main/scala/org/apache/spark/util/Utils.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala 
b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 1b3acb8..af63234 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -21,6 +21,7 @@ import java.io._
 import java.lang.management.ManagementFactory
 import java.net._
 import java.nio.ByteBuffer
+import java.nio.channels.Channels
 import java.util.concurrent._
 import java.util.{Locale, Properties, Random, UUID}
 import javax.net.ssl.HttpsURLConnection
@@ -535,6 +536,14 @@ private[spark] object Utils extends Logging {
     val uri = new URI(url)
     val fileOverwrite = conf.getBoolean("spark.files.overwrite", defaultValue 
= false)
     Option(uri.getScheme).getOrElse("file") match {
+      case "spark" =>
+        if (SparkEnv.get == null) {
+          throw new IllegalStateException(
+            "Cannot retrieve files with 'spark' scheme without an active 
SparkEnv.")
+        }
+        val source = SparkEnv.get.rpcEnv.openChannel(url)
+        val is = Channels.newInputStream(source)
+        downloadFile(url, is, targetFile, fileOverwrite)
       case "http" | "https" | "ftp" =>
         var uc: URLConnection = null
         if (securityMgr.isAuthenticationEnabled()) {

http://git-wip-us.apache.org/repos/asf/spark/blob/68bcb9b3/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala 
b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
index 2f55006..2b664c6 100644
--- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
@@ -17,7 +17,9 @@
 
 package org.apache.spark.rpc
 
-import java.io.NotSerializableException
+import java.io.{File, NotSerializableException}
+import java.util.UUID
+import java.nio.charset.StandardCharsets.UTF_8
 import java.util.concurrent.{TimeUnit, CountDownLatch, TimeoutException}
 
 import scala.collection.mutable
@@ -25,10 +27,14 @@ import scala.concurrent.Await
 import scala.concurrent.duration._
 import scala.language.postfixOps
 
+import com.google.common.io.Files
+import org.mockito.Mockito.{mock, when}
 import org.scalatest.BeforeAndAfterAll
 import org.scalatest.concurrent.Eventually._
 
-import org.apache.spark.{SparkConf, SparkException, SparkFunSuite}
+import org.apache.spark.{SecurityManager, SparkConf, SparkEnv, SparkException, 
SparkFunSuite}
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.util.Utils
 
 /**
  * Common tests for an RpcEnv implementation.
@@ -40,12 +46,17 @@ abstract class RpcEnvSuite extends SparkFunSuite with 
BeforeAndAfterAll {
   override def beforeAll(): Unit = {
     val conf = new SparkConf()
     env = createRpcEnv(conf, "local", 0)
+
+    val sparkEnv = mock(classOf[SparkEnv])
+    when(sparkEnv.rpcEnv).thenReturn(env)
+    SparkEnv.set(sparkEnv)
   }
 
   override def afterAll(): Unit = {
     if (env != null) {
       env.shutdown()
     }
+    SparkEnv.set(null)
   }
 
   def createRpcEnv(conf: SparkConf, name: String, port: Int, clientMode: 
Boolean = false): RpcEnv
@@ -713,6 +724,30 @@ abstract class RpcEnvSuite extends SparkFunSuite with 
BeforeAndAfterAll {
     assert(shortTimeout.timeoutProp.r.findAllIn(reply4).length === 1)
   }
 
+  test("file server") {
+    val conf = new SparkConf()
+    val tempDir = Utils.createTempDir()
+    val file = new File(tempDir, "file")
+    Files.write(UUID.randomUUID().toString(), file, UTF_8)
+    val jar = new File(tempDir, "jar")
+    Files.write(UUID.randomUUID().toString(), jar, UTF_8)
+
+    val fileUri = env.fileServer.addFile(file)
+    val jarUri = env.fileServer.addJar(jar)
+
+    val destDir = Utils.createTempDir()
+    val destFile = new File(destDir, file.getName())
+    val destJar = new File(destDir, jar.getName())
+
+    val sm = new SecurityManager(conf)
+    val hc = SparkHadoopUtil.get.conf
+    Utils.fetchFile(fileUri, destDir, conf, sm, hc, 0L, false)
+    Utils.fetchFile(jarUri, destDir, conf, sm, hc, 0L, false)
+
+    assert(Files.equal(file, destFile))
+    assert(Files.equal(jar, destJar))
+  }
+
 }
 
 class UnserializableClass

http://git-wip-us.apache.org/repos/asf/spark/blob/68bcb9b3/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala 
b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
index f9d8e80..ccca795 100644
--- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
@@ -25,17 +25,19 @@ import org.mockito.Matchers._
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.network.client.{TransportResponseHandler, 
TransportClient}
+import org.apache.spark.network.server.StreamManager
 import org.apache.spark.rpc._
 
 class NettyRpcHandlerSuite extends SparkFunSuite {
 
   val env = mock(classOf[NettyRpcEnv])
-  when(env.deserialize(any(classOf[TransportClient]), 
any(classOf[Array[Byte]]))(any())).
-    thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null, 
false))
+  val sm = mock(classOf[StreamManager])
+  when(env.deserialize(any(classOf[TransportClient]), 
any(classOf[Array[Byte]]))(any()))
+    .thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null, 
false))
 
   test("receive") {
     val dispatcher = mock(classOf[Dispatcher])
-    val nettyRpcHandler = new NettyRpcHandler(dispatcher, env)
+    val nettyRpcHandler = new NettyRpcHandler(dispatcher, env, sm)
 
     val channel = mock(classOf[Channel])
     val client = new TransportClient(channel, 
mock(classOf[TransportResponseHandler]))
@@ -47,7 +49,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite {
 
   test("connectionTerminated") {
     val dispatcher = mock(classOf[Dispatcher])
-    val nettyRpcHandler = new NettyRpcHandler(dispatcher, env)
+    val nettyRpcHandler = new NettyRpcHandler(dispatcher, env, sm)
 
     val channel = mock(classOf[Channel])
     val client = new TransportClient(channel, 
mock(classOf[TransportResponseHandler]))

http://git-wip-us.apache.org/repos/asf/spark/blob/68bcb9b3/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java
----------------------------------------------------------------------
diff --git 
a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java 
b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java
index 3ee6bd9..55fe156 100644
--- 
a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java
+++ 
b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java
@@ -148,7 +148,7 @@ abstract class AbstractCommandBuilder {
       String scala = getScalaVersion();
       List<String> projects = Arrays.asList("core", "repl", "mllib", "bagel", 
"graphx",
         "streaming", "tools", "sql/catalyst", "sql/core", "sql/hive", 
"sql/hive-thriftserver",
-        "yarn", "launcher");
+        "yarn", "launcher", "network/common", "network/shuffle", 
"network/yarn");
       if (prependClasses) {
         if (!isTesting) {
           System.err.println(

http://git-wip-us.apache.org/repos/asf/spark/blob/68bcb9b3/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
----------------------------------------------------------------------
diff --git 
a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
 
b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
index 659c471..61bafc8 100644
--- 
a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
+++ 
b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
@@ -170,8 +170,10 @@ public class TransportClientFactory implements Closeable {
   }
 
   /**
-   * Create a completely new {@link TransportClient} to the given remote host 
/ port
-   * But this connection is not pooled.
+   * Create a completely new {@link TransportClient} to the given remote host 
/ port.
+   * This connection is not pooled.
+   *
+   * As with {@link #createClient(String, int)}, this method is blocking.
    */
   public TransportClient createUnmanagedClient(String remoteHost, int 
remotePort)
       throws IOException {

http://git-wip-us.apache.org/repos/asf/spark/blob/68bcb9b3/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
----------------------------------------------------------------------
diff --git 
a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
 
b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
index 29d688a..3164e00 100644
--- 
a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
+++ 
b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
@@ -138,6 +138,7 @@ public class TransportChannelHandler extends 
SimpleChannelInboundHandler<Message
         }
       }
     }
+    ctx.fireUserEventTriggered(evt);
   }
 
   public TransportResponseHandler getResponseHandler() {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to