Repository: spark
Updated Branches:
  refs/heads/master 34e71c6d8 -> a88c66ca8


[SPARK-11098][CORE] Add Outbox to cache the sending messages to resolve the 
message disorder issue

The current NettyRpc has a message order issue because it uses a thread pool to 
send messages. E.g., running the following two lines in the same thread,

```
ref.send("A")
ref.send("B")
```

The remote endpoint may see "B" before "A" because sending "A" and "B" are in 
parallel.
To resolve this issue, this PR added an outbox for each connection, and if we 
are connecting to the remote node when sending messages, just cache the sending 
messages in the outbox and send them one by one when the connection is 
established.

Author: zsxwing <zsxw...@gmail.com>

Closes #9197 from zsxwing/rpc-outbox.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a88c66ca
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a88c66ca
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a88c66ca

Branch: refs/heads/master
Commit: a88c66ca8780c7228dc909f904d31cd9464ee0e3
Parents: 34e71c6
Author: zsxwing <zsxw...@gmail.com>
Authored: Thu Oct 22 21:01:01 2015 -0700
Committer: Reynold Xin <r...@databricks.com>
Committed: Thu Oct 22 21:01:01 2015 -0700

----------------------------------------------------------------------
 .../apache/spark/rpc/netty/NettyRpcEnv.scala    | 145 +++++++-----
 .../org/apache/spark/rpc/netty/Outbox.scala     | 222 +++++++++++++++++++
 2 files changed, 310 insertions(+), 57 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a88c66ca/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala 
b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
index e01cf1a..284284e 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
@@ -20,6 +20,7 @@ import java.io._
 import java.net.{InetSocketAddress, URI}
 import java.nio.ByteBuffer
 import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicBoolean
 import javax.annotation.concurrent.GuardedBy
 
 import scala.collection.mutable
@@ -70,12 +71,30 @@ private[netty] class NettyRpcEnv(
   // Because TransportClientFactory.createClient is blocking, we need to run 
it in this thread pool
   // to implement non-blocking send/ask.
   // TODO: a non-blocking TransportClientFactory.createClient in future
-  private val clientConnectionExecutor = ThreadUtils.newDaemonCachedThreadPool(
+  private[netty] val clientConnectionExecutor = 
ThreadUtils.newDaemonCachedThreadPool(
     "netty-rpc-connection",
     conf.getInt("spark.rpc.connect.threads", 64))
 
   @volatile private var server: TransportServer = _
 
+  private val stopped = new AtomicBoolean(false)
+
+  /**
+   * A map for [[RpcAddress]] and [[Outbox]]. When we are connecting to a 
remote [[RpcAddress]],
+   * we just put messages to its [[Outbox]] to implement a non-blocking `send` 
method.
+   */
+  private val outboxes = new ConcurrentHashMap[RpcAddress, Outbox]()
+
+  /**
+   * Remove the address's Outbox and stop it.
+   */
+  private[netty] def removeOutbox(address: RpcAddress): Unit = {
+    val outbox = outboxes.remove(address)
+    if (outbox != null) {
+      outbox.stop()
+    }
+  }
+
   def start(port: Int): Unit = {
     val bootstraps: java.util.List[TransportServerBootstrap] =
       if (securityManager.isAuthenticationEnabled()) {
@@ -116,6 +135,30 @@ private[netty] class NettyRpcEnv(
     dispatcher.stop(endpointRef)
   }
 
+  private def postToOutbox(address: RpcAddress, message: OutboxMessage): Unit 
= {
+    val targetOutbox = {
+      val outbox = outboxes.get(address)
+      if (outbox == null) {
+        val newOutbox = new Outbox(this, address)
+        val oldOutbox = outboxes.putIfAbsent(address, newOutbox)
+        if (oldOutbox == null) {
+          newOutbox
+        } else {
+          oldOutbox
+        }
+      } else {
+        outbox
+      }
+    }
+    if (stopped.get) {
+      // It's possible that we put `targetOutbox` after stopping. So we need 
to clean it.
+      outboxes.remove(address)
+      targetOutbox.stop()
+    } else {
+      targetOutbox.send(message)
+    }
+  }
+
   private[netty] def send(message: RequestMessage): Unit = {
     val remoteAddr = message.receiver.address
     if (remoteAddr == address) {
@@ -127,37 +170,28 @@ private[netty] class NettyRpcEnv(
           val ack = response.asInstanceOf[Ack]
           logTrace(s"Received ack from ${ack.sender}")
         case Failure(e) =>
-          logError(s"Exception when sending $message", e)
+          logWarning(s"Exception when sending $message", e)
       }(ThreadUtils.sameThread)
     } else {
       // Message to a remote RPC endpoint.
-      try {
-        // `createClient` will block if it cannot find a known connection, so 
we should run it in
-        // clientConnectionExecutor
-        clientConnectionExecutor.execute(new Runnable {
-          override def run(): Unit = Utils.tryLogNonFatalError {
-            val client = clientFactory.createClient(remoteAddr.host, 
remoteAddr.port)
-            client.sendRpc(serialize(message), new RpcResponseCallback {
-
-              override def onFailure(e: Throwable): Unit = {
-                logError(s"Exception when sending $message", e)
-              }
-
-              override def onSuccess(response: Array[Byte]): Unit = {
-                val ack = deserialize[Ack](response)
-                logDebug(s"Receive ack from ${ack.sender}")
-              }
-            })
-          }
-        })
-      } catch {
-        case e: RejectedExecutionException =>
-          // `send` after shutting clientConnectionExecutor down, ignore it
-          logWarning(s"Cannot send $message because RpcEnv is stopped")
-      }
+      postToOutbox(remoteAddr, OutboxMessage(serialize(message), new 
RpcResponseCallback {
+
+        override def onFailure(e: Throwable): Unit = {
+          logWarning(s"Exception when sending $message", e)
+        }
+
+        override def onSuccess(response: Array[Byte]): Unit = {
+          val ack = deserialize[Ack](response)
+          logDebug(s"Receive ack from ${ack.sender}")
+        }
+      }))
     }
   }
 
+  private[netty] def createClient(address: RpcAddress): TransportClient = {
+    clientFactory.createClient(address.host, address.port)
+  }
+
   private[netty] def ask(message: RequestMessage): Future[Any] = {
     val promise = Promise[Any]()
     val remoteAddr = message.receiver.address
@@ -180,39 +214,25 @@ private[netty] class NettyRpcEnv(
           }
       }(ThreadUtils.sameThread)
     } else {
-      try {
-        // `createClient` will block if it cannot find a known connection, so 
we should run it in
-        // clientConnectionExecutor
-        clientConnectionExecutor.execute(new Runnable {
-          override def run(): Unit = {
-            val client = clientFactory.createClient(remoteAddr.host, 
remoteAddr.port)
-            client.sendRpc(serialize(message), new RpcResponseCallback {
-
-              override def onFailure(e: Throwable): Unit = {
-                if (!promise.tryFailure(e)) {
-                  logWarning("Ignore Exception", e)
-                }
-              }
-
-              override def onSuccess(response: Array[Byte]): Unit = {
-                val reply = deserialize[AskResponse](response)
-                if (reply.reply.isInstanceOf[RpcFailure]) {
-                  if 
(!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) {
-                    logWarning(s"Ignore failure: ${reply.reply}")
-                  }
-                } else if (!promise.trySuccess(reply.reply)) {
-                  logWarning(s"Ignore message: ${reply}")
-                }
-              }
-            })
-          }
-        })
-      } catch {
-        case e: RejectedExecutionException =>
+      postToOutbox(remoteAddr, OutboxMessage(serialize(message), new 
RpcResponseCallback {
+
+        override def onFailure(e: Throwable): Unit = {
           if (!promise.tryFailure(e)) {
-            logWarning(s"Ignore failure", e)
+            logWarning("Ignore Exception", e)
           }
-      }
+        }
+
+        override def onSuccess(response: Array[Byte]): Unit = {
+          val reply = deserialize[AskResponse](response)
+          if (reply.reply.isInstanceOf[RpcFailure]) {
+            if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) {
+              logWarning(s"Ignore failure: ${reply.reply}")
+            }
+          } else if (!promise.trySuccess(reply.reply)) {
+            logWarning(s"Ignore message: ${reply}")
+          }
+        }
+      }))
     }
     promise.future
   }
@@ -245,6 +265,16 @@ private[netty] class NettyRpcEnv(
   }
 
   private def cleanup(): Unit = {
+    if (!stopped.compareAndSet(false, true)) {
+      return
+    }
+
+    val iter = outboxes.values().iterator()
+    while (iter.hasNext()) {
+      val outbox = iter.next()
+      outboxes.remove(outbox.address)
+      outbox.stop()
+    }
     if (timeoutScheduler != null) {
       timeoutScheduler.shutdownNow()
     }
@@ -463,6 +493,7 @@ private[netty] class NettyRpcHandler(
     val addr = 
client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]
     if (addr != null) {
       val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
+      nettyEnv.removeOutbox(clientAddr)
       val messageOpt: Option[RemoteProcessDisconnected] =
       synchronized {
         remoteAddresses.get(clientAddr).flatMap { remoteEnvAddress =>

http://git-wip-us.apache.org/repos/asf/spark/blob/a88c66ca/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala 
b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala
new file mode 100644
index 0000000..7d9d593
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rpc.netty
+
+import java.util.concurrent.Callable
+import javax.annotation.concurrent.GuardedBy
+
+import scala.util.control.NonFatal
+
+import org.apache.spark.SparkException
+import org.apache.spark.network.client.{RpcResponseCallback, TransportClient}
+import org.apache.spark.rpc.RpcAddress
+
+private[netty] case class OutboxMessage(content: Array[Byte], callback: 
RpcResponseCallback)
+
+private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) {
+
+  outbox => // Give this an alias so we can use it more clearly in closures.
+
+  @GuardedBy("this")
+  private val messages = new java.util.LinkedList[OutboxMessage]
+
+  @GuardedBy("this")
+  private var client: TransportClient = null
+
+  /**
+   * connectFuture points to the connect task. If there is no connect task, 
connectFuture will be
+   * null.
+   */
+  @GuardedBy("this")
+  private var connectFuture: java.util.concurrent.Future[Unit] = null
+
+  @GuardedBy("this")
+  private var stopped = false
+
+  /**
+   * If there is any thread draining the message queue
+   */
+  @GuardedBy("this")
+  private var draining = false
+
+  /**
+   * Send a message. If there is no active connection, cache it and launch a 
new connection. If
+   * [[Outbox]] is stopped, the sender will be notified with a 
[[SparkException]].
+   */
+  def send(message: OutboxMessage): Unit = {
+    val dropped = synchronized {
+      if (stopped) {
+        true
+      } else {
+        messages.add(message)
+        false
+      }
+    }
+    if (dropped) {
+      message.callback.onFailure(new SparkException("Message is dropped 
because Outbox is stopped"))
+    } else {
+      drainOutbox()
+    }
+  }
+
+  /**
+   * Drain the message queue. If there is other draining thread, just exit. If 
the connection has
+   * not been established, launch a task in the 
`nettyEnv.clientConnectionExecutor` to setup the
+   * connection.
+   */
+  private def drainOutbox(): Unit = {
+    var message: OutboxMessage = null
+    synchronized {
+      if (stopped) {
+        return
+      }
+      if (connectFuture != null) {
+        // We are connecting to the remote address, so just exit
+        return
+      }
+      if (client == null) {
+        // There is no connect task but client is null, so we need to launch 
the connect task.
+        launchConnectTask()
+        return
+      }
+      if (draining) {
+        // There is some thread draining, so just exit
+        return
+      }
+      message = messages.poll()
+      if (message == null) {
+        return
+      }
+      draining = true
+    }
+    while (true) {
+      try {
+        val _client = synchronized { client }
+        if (_client != null) {
+          _client.sendRpc(message.content, message.callback)
+        } else {
+          assert(stopped == true)
+        }
+      } catch {
+        case NonFatal(e) =>
+          handleNetworkFailure(e)
+          return
+      }
+      synchronized {
+        if (stopped) {
+          return
+        }
+        message = messages.poll()
+        if (message == null) {
+          draining = false
+          return
+        }
+      }
+    }
+  }
+
+  private def launchConnectTask(): Unit = {
+    connectFuture = nettyEnv.clientConnectionExecutor.submit(new 
Callable[Unit] {
+
+      override def call(): Unit = {
+        try {
+          val _client = nettyEnv.createClient(address)
+          outbox.synchronized {
+            client = _client
+            if (stopped) {
+              closeClient()
+            }
+          }
+        } catch {
+          case ie: InterruptedException =>
+            // exit
+            return
+          case NonFatal(e) =>
+            outbox.synchronized { connectFuture = null }
+            handleNetworkFailure(e)
+            return
+        }
+        outbox.synchronized { connectFuture = null }
+        // It's possible that no thread is draining now. If we don't drain 
here, we cannot send the
+        // messages until the next message arrives.
+        drainOutbox()
+      }
+    })
+  }
+
+  /**
+   * Stop [[Inbox]] and notify the waiting messages with the cause.
+   */
+  private def handleNetworkFailure(e: Throwable): Unit = {
+    synchronized {
+      assert(connectFuture == null)
+      if (stopped) {
+        return
+      }
+      stopped = true
+      closeClient()
+    }
+    // Remove this Outbox from nettyEnv so that the further messages will 
create a new Outbox along
+    // with a new connection
+    nettyEnv.removeOutbox(address)
+
+    // Notify the connection failure for the remaining messages
+    //
+    // We always check `stopped` before updating messages, so here we can make 
sure no thread will
+    // update messages and it's safe to just drain the queue.
+    var message = messages.poll()
+    while (message != null) {
+      message.callback.onFailure(e)
+      message = messages.poll()
+    }
+    assert(messages.isEmpty)
+  }
+
+  private def closeClient(): Unit = synchronized {
+    // Not sure if `client.close` is idempotent. Just for safety.
+    if (client != null) {
+      client.close()
+    }
+    client = null
+  }
+
+  /**
+   * Stop [[Outbox]]. The remaining messages in the [[Outbox]] will be 
notified with a
+   * [[SparkException]].
+   */
+  def stop(): Unit = {
+    synchronized {
+      if (stopped) {
+        return
+      }
+      stopped = true
+      if (connectFuture != null) {
+        connectFuture.cancel(true)
+      }
+      closeClient()
+    }
+
+    // We always check `stopped` before updating messages, so here we can make 
sure no thread will
+    // update messages and it's safe to just drain the queue.
+    var message = messages.poll()
+    while (message != null) {
+      message.callback.onFailure(new SparkException("Message is dropped 
because Outbox is stopped"))
+      message = messages.poll()
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to