Repository: spark
Updated Branches:
  refs/heads/branch-2.3 eab10f994 -> 16cd9ac52


[SPARKR] Match pyspark features in SparkR communication protocol.

(cherry picked from commit 628c7b517969c4a7ccb26ea67ab3dd61266073ca)
Signed-off-by: Marcelo Vanzin <van...@cloudera.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/16cd9ac5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/16cd9ac5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/16cd9ac5

Branch: refs/heads/branch-2.3
Commit: 16cd9ac5264831e061c033b26fe1173ebc88e5d1
Parents: 323dc3a
Author: Marcelo Vanzin <van...@cloudera.com>
Authored: Tue Apr 17 13:29:43 2018 -0700
Committer: Marcelo Vanzin <van...@cloudera.com>
Committed: Thu May 10 10:47:37 2018 -0700

----------------------------------------------------------------------
 R/pkg/R/client.R                                |  4 +-
 R/pkg/R/deserialize.R                           | 10 ++--
 R/pkg/R/sparkR.R                                | 39 ++++++++++++--
 R/pkg/inst/worker/daemon.R                      |  4 +-
 R/pkg/inst/worker/worker.R                      |  5 +-
 .../org/apache/spark/api/r/RAuthHelper.scala    | 38 ++++++++++++++
 .../scala/org/apache/spark/api/r/RBackend.scala | 43 ++++++++++++---
 .../spark/api/r/RBackendAuthHandler.scala       | 55 ++++++++++++++++++++
 .../scala/org/apache/spark/api/r/RRunner.scala  | 35 +++++++++----
 .../scala/org/apache/spark/deploy/RRunner.scala |  6 ++-
 10 files changed, 210 insertions(+), 29 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/16cd9ac5/R/pkg/R/client.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R
index 9d82814..7244cc9 100644
--- a/R/pkg/R/client.R
+++ b/R/pkg/R/client.R
@@ -19,7 +19,7 @@
 
 # Creates a SparkR client connection object
 # if one doesn't already exist
-connectBackend <- function(hostname, port, timeout) {
+connectBackend <- function(hostname, port, timeout, authSecret) {
   if (exists(".sparkRcon", envir = .sparkREnv)) {
     if (isOpen(.sparkREnv[[".sparkRCon"]])) {
       cat("SparkRBackend client connection already exists\n")
@@ -29,7 +29,7 @@ connectBackend <- function(hostname, port, timeout) {
 
   con <- socketConnection(host = hostname, port = port, server = FALSE,
                           blocking = TRUE, open = "wb", timeout = timeout)
-
+  doServerAuth(con, authSecret)
   assign(".sparkRCon", con, envir = .sparkREnv)
   con
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/16cd9ac5/R/pkg/R/deserialize.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R
index a90f7d3..cb03f16 100644
--- a/R/pkg/R/deserialize.R
+++ b/R/pkg/R/deserialize.R
@@ -60,14 +60,18 @@ readTypedObject <- function(con, type) {
     stop(paste("Unsupported type for deserialization", type)))
 }
 
-readString <- function(con) {
-  stringLen <- readInt(con)
-  raw <- readBin(con, raw(), stringLen, endian = "big")
+readStringData <- function(con, len) {
+  raw <- readBin(con, raw(), len, endian = "big")
   string <- rawToChar(raw)
   Encoding(string) <- "UTF-8"
   string
 }
 
+readString <- function(con) {
+  stringLen <- readInt(con)
+  readStringData(con, stringLen)
+}
+
 readInt <- function(con) {
   readBin(con, integer(), n = 1, endian = "big")
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/16cd9ac5/R/pkg/R/sparkR.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R
index 965471f..7430d84 100644
--- a/R/pkg/R/sparkR.R
+++ b/R/pkg/R/sparkR.R
@@ -161,6 +161,10 @@ sparkR.sparkContext <- function(
                     " please use the --packages commandline instead", sep = 
","))
     }
     backendPort <- existingPort
+    authSecret <- Sys.getenv("SPARKR_BACKEND_AUTH_SECRET")
+    if (nchar(authSecret) == 0) {
+      stop("Auth secret not provided in environment.")
+    }
   } else {
     path <- tempfile(pattern = "backend_port")
     submitOps <- getClientModeSparkSubmitOpts(
@@ -189,16 +193,27 @@ sparkR.sparkContext <- function(
     monitorPort <- readInt(f)
     rLibPath <- readString(f)
     connectionTimeout <- readInt(f)
+
+    # Don't use readString() so that we can provide a useful
+    # error message if the R and Java versions are mismatched.
+    authSecretLen = readInt(f)
+    if (length(authSecretLen) == 0 || authSecretLen == 0) {
+      stop("Unexpected EOF in JVM connection data. Mismatched versions?")
+    }
+    authSecret <- readStringData(f, authSecretLen)
     close(f)
     file.remove(path)
     if (length(backendPort) == 0 || backendPort == 0 ||
         length(monitorPort) == 0 || monitorPort == 0 ||
-        length(rLibPath) != 1) {
+        length(rLibPath) != 1 || length(authSecret) == 0) {
       stop("JVM failed to launch")
     }
-    assign(".monitorConn",
-           socketConnection(port = monitorPort, timeout = connectionTimeout),
-           envir = .sparkREnv)
+
+    monitorConn <- socketConnection(port = monitorPort, blocking = TRUE,
+                                    timeout = connectionTimeout, open = "wb")
+    doServerAuth(monitorConn, authSecret)
+
+    assign(".monitorConn", monitorConn, envir = .sparkREnv)
     assign(".backendLaunched", 1, envir = .sparkREnv)
     if (rLibPath != "") {
       assign(".libPath", rLibPath, envir = .sparkREnv)
@@ -208,7 +223,7 @@ sparkR.sparkContext <- function(
 
   .sparkREnv$backendPort <- backendPort
   tryCatch({
-    connectBackend("localhost", backendPort, timeout = connectionTimeout)
+    connectBackend("localhost", backendPort, timeout = connectionTimeout, 
authSecret = authSecret)
   },
   error = function(err) {
     stop("Failed to connect JVM\n")
@@ -694,3 +709,17 @@ sparkCheckInstall <- function(sparkHome, master, 
deployMode) {
     NULL
   }
 }
+
+# Utility function for sending auth data over a socket and checking the 
server's reply.
+doServerAuth <- function(con, authSecret) {
+  if (nchar(authSecret) == 0) {
+    stop("Auth secret not provided.")
+  }
+  writeString(con, authSecret)
+  flush(con)
+  reply <- readString(con)
+  if (reply != "ok") {
+    close(con)
+    stop("Unexpected reply from server.")
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/16cd9ac5/R/pkg/inst/worker/daemon.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/worker/daemon.R b/R/pkg/inst/worker/daemon.R
index 2e31dc5..fb9db63 100644
--- a/R/pkg/inst/worker/daemon.R
+++ b/R/pkg/inst/worker/daemon.R
@@ -28,7 +28,9 @@ suppressPackageStartupMessages(library(SparkR))
 
 port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT"))
 inputCon <- socketConnection(
-    port = port, open = "rb", blocking = TRUE, timeout = connectionTimeout)
+    port = port, open = "wb", blocking = TRUE, timeout = connectionTimeout)
+
+SparkR:::doServerAuth(inputCon, Sys.getenv("SPARKR_WORKER_SECRET"))
 
 # Waits indefinitely for a socket connecion by default.
 selectTimeout <- NULL

http://git-wip-us.apache.org/repos/asf/spark/blob/16cd9ac5/R/pkg/inst/worker/worker.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R
index 00789d8..ba458d2 100644
--- a/R/pkg/inst/worker/worker.R
+++ b/R/pkg/inst/worker/worker.R
@@ -100,9 +100,12 @@ suppressPackageStartupMessages(library(SparkR))
 
 port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT"))
 inputCon <- socketConnection(
-    port = port, blocking = TRUE, open = "rb", timeout = connectionTimeout)
+    port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout)
+SparkR:::doServerAuth(inputCon, Sys.getenv("SPARKR_WORKER_SECRET"))
+
 outputCon <- socketConnection(
     port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout)
+SparkR:::doServerAuth(outputCon, Sys.getenv("SPARKR_WORKER_SECRET"))
 
 # read the index of the current partition inside the RDD
 partition <- SparkR:::readInt(inputCon)

http://git-wip-us.apache.org/repos/asf/spark/blob/16cd9ac5/core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala 
b/core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala
new file mode 100644
index 0000000..ac6826a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.r
+
+import java.io.{DataInputStream, DataOutputStream}
+import java.net.Socket
+
+import org.apache.spark.SparkConf
+import org.apache.spark.security.SocketAuthHelper
+
+private[spark] class RAuthHelper(conf: SparkConf) extends 
SocketAuthHelper(conf) {
+
+  override protected def readUtf8(s: Socket): String = {
+    SerDe.readString(new DataInputStream(s.getInputStream()))
+  }
+
+  override protected def writeUtf8(str: String, s: Socket): Unit = {
+    val out = s.getOutputStream()
+    SerDe.writeString(new DataOutputStream(out), str)
+    out.flush()
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/16cd9ac5/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala 
b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
index 2d1152a..3b2e809 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
@@ -17,8 +17,8 @@
 
 package org.apache.spark.api.r
 
-import java.io.{DataOutputStream, File, FileOutputStream, IOException}
-import java.net.{InetAddress, InetSocketAddress, ServerSocket}
+import java.io.{DataInputStream, DataOutputStream, File, FileOutputStream, 
IOException}
+import java.net.{InetAddress, InetSocketAddress, ServerSocket, Socket}
 import java.util.concurrent.TimeUnit
 
 import io.netty.bootstrap.ServerBootstrap
@@ -32,6 +32,8 @@ import io.netty.handler.timeout.ReadTimeoutHandler
 
 import org.apache.spark.SparkConf
 import org.apache.spark.internal.Logging
+import org.apache.spark.network.util.JavaUtils
+import org.apache.spark.util.Utils
 
 /**
  * Netty-based backend server that is used to communicate between R and Java.
@@ -45,7 +47,7 @@ private[spark] class RBackend {
   /** Tracks JVM objects returned to R for this RBackend instance. */
   private[r] val jvmObjectTracker = new JVMObjectTracker
 
-  def init(): Int = {
+  def init(): (Int, RAuthHelper) = {
     val conf = new SparkConf()
     val backendConnectionTimeout = conf.getInt(
       "spark.r.backendConnectionTimeout", 
SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT)
@@ -53,6 +55,7 @@ private[spark] class RBackend {
       conf.getInt("spark.r.numRBackendThreads", 
SparkRDefaults.DEFAULT_NUM_RBACKEND_THREADS))
     val workerGroup = bossGroup
     val handler = new RBackendHandler(this)
+    val authHelper = new RAuthHelper(conf)
 
     bootstrap = new ServerBootstrap()
       .group(bossGroup, workerGroup)
@@ -71,13 +74,16 @@ private[spark] class RBackend {
             new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4))
           .addLast("decoder", new ByteArrayDecoder())
           .addLast("readTimeoutHandler", new 
ReadTimeoutHandler(backendConnectionTimeout))
+          .addLast(new RBackendAuthHandler(authHelper.secret))
           .addLast("handler", handler)
       }
     })
 
     channelFuture = bootstrap.bind(new InetSocketAddress("localhost", 0))
     channelFuture.syncUninterruptibly()
-    
channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort()
+
+    val port = 
channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort()
+    (port, authHelper)
   }
 
   def run(): Unit = {
@@ -116,7 +122,7 @@ private[spark] object RBackend extends Logging {
     val sparkRBackend = new RBackend()
     try {
       // bind to random port
-      val boundPort = sparkRBackend.init()
+      val (boundPort, authHelper) = sparkRBackend.init()
       val serverSocket = new ServerSocket(0, 1, 
InetAddress.getByName("localhost"))
       val listenPort = serverSocket.getLocalPort()
       // Connection timeout is set by socket client. To make it configurable 
we will pass the
@@ -133,6 +139,7 @@ private[spark] object RBackend extends Logging {
       dos.writeInt(listenPort)
       SerDe.writeString(dos, RUtils.rPackages.getOrElse(""))
       dos.writeInt(backendConnectionTimeout)
+      SerDe.writeString(dos, authHelper.secret)
       dos.close()
       f.renameTo(new File(path))
 
@@ -144,12 +151,35 @@ private[spark] object RBackend extends Logging {
           val buf = new Array[Byte](1024)
           // shutdown JVM if R does not connect back in 10 seconds
           serverSocket.setSoTimeout(10000)
+
+          // Wait for the R process to connect back, ignoring any failed auth 
attempts. Allow
+          // a max number of connection attempts to avoid looping forever.
           try {
-            val inSocket = serverSocket.accept()
+            var remainingAttempts = 10
+            var inSocket: Socket = null
+            while (inSocket == null) {
+              inSocket = serverSocket.accept()
+              try {
+                authHelper.authClient(inSocket)
+              } catch {
+                case e: Exception =>
+                  remainingAttempts -= 1
+                  if (remainingAttempts == 0) {
+                    val msg = "Too many failed authentication attempts."
+                    logError(msg)
+                    throw new IllegalStateException(msg)
+                  }
+                  logInfo("Client connection failed authentication.")
+                  inSocket = null
+              }
+            }
+
             serverSocket.close()
+
             // wait for the end of socket, closed if R process die
             inSocket.getInputStream().read(buf)
           } finally {
+            serverSocket.close()
             sparkRBackend.close()
             System.exit(0)
           }
@@ -165,4 +195,5 @@ private[spark] object RBackend extends Logging {
     }
     System.exit(0)
   }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/16cd9ac5/core/src/main/scala/org/apache/spark/api/r/RBackendAuthHandler.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/api/r/RBackendAuthHandler.scala 
b/core/src/main/scala/org/apache/spark/api/r/RBackendAuthHandler.scala
new file mode 100644
index 0000000..4162e4a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/r/RBackendAuthHandler.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.r
+
+import java.io.{ByteArrayOutputStream, DataOutputStream}
+import java.nio.charset.StandardCharsets.UTF_8
+
+import io.netty.channel.{Channel, ChannelHandlerContext, 
SimpleChannelInboundHandler}
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.util.Utils
+
+/**
+ * Authentication handler for connections from the R process.
+ */
+private class RBackendAuthHandler(secret: String)
+  extends SimpleChannelInboundHandler[Array[Byte]] with Logging {
+
+  override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]): 
Unit = {
+    // The R code adds a null terminator to serialized strings, so ignore it 
here.
+    val clientSecret = new String(msg, 0, msg.length - 1, UTF_8)
+    try {
+      require(secret == clientSecret, "Auth secret mismatch.")
+      ctx.pipeline().remove(this)
+      writeReply("ok", ctx.channel())
+    } catch {
+      case e: Exception =>
+        logInfo("Authentication failure.", e)
+        writeReply("err", ctx.channel())
+        ctx.close()
+    }
+  }
+
+  private def writeReply(reply: String, chan: Channel): Unit = {
+    val out = new ByteArrayOutputStream()
+    SerDe.writeString(new DataOutputStream(out), reply)
+    chan.writeAndFlush(out.toByteArray())
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/16cd9ac5/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala 
b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
index 8811839..e7fdc39 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
@@ -74,14 +74,19 @@ private[spark] class RRunner[U](
 
     // the socket used to send out the input of task
     serverSocket.setSoTimeout(10000)
-    val inSocket = serverSocket.accept()
-    startStdinThread(inSocket.getOutputStream(), inputIterator, partitionIndex)
-
-    // the socket used to receive the output of task
-    val outSocket = serverSocket.accept()
-    val inputStream = new BufferedInputStream(outSocket.getInputStream)
-    dataStream = new DataInputStream(inputStream)
-    serverSocket.close()
+    dataStream = try {
+      val inSocket = serverSocket.accept()
+      RRunner.authHelper.authClient(inSocket)
+      startStdinThread(inSocket.getOutputStream(), inputIterator, 
partitionIndex)
+
+      // the socket used to receive the output of task
+      val outSocket = serverSocket.accept()
+      RRunner.authHelper.authClient(outSocket)
+      val inputStream = new BufferedInputStream(outSocket.getInputStream)
+      new DataInputStream(inputStream)
+    } finally {
+      serverSocket.close()
+    }
 
     try {
       return new Iterator[U] {
@@ -315,6 +320,11 @@ private[r] object RRunner {
   private[this] var errThread: BufferedStreamThread = _
   private[this] var daemonChannel: DataOutputStream = _
 
+  private lazy val authHelper = {
+    val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
+    new RAuthHelper(conf)
+  }
+
   /**
    * Start a thread to print the process's stderr to ours
    */
@@ -349,6 +359,7 @@ private[r] object RRunner {
     pb.environment().put("SPARKR_BACKEND_CONNECTION_TIMEOUT", 
rConnectionTimeout.toString)
     pb.environment().put("SPARKR_SPARKFILES_ROOT_DIR", 
SparkFiles.getRootDirectory())
     pb.environment().put("SPARKR_IS_RUNNING_ON_WORKER", "TRUE")
+    pb.environment().put("SPARKR_WORKER_SECRET", authHelper.secret)
     pb.redirectErrorStream(true)  // redirect stderr into stdout
     val proc = pb.start()
     val errThread = startStdoutThread(proc)
@@ -370,8 +381,12 @@ private[r] object RRunner {
           // the socket used to send out the input of task
           serverSocket.setSoTimeout(10000)
           val sock = serverSocket.accept()
-          daemonChannel = new DataOutputStream(new 
BufferedOutputStream(sock.getOutputStream))
-          serverSocket.close()
+          try {
+            authHelper.authClient(sock)
+            daemonChannel = new DataOutputStream(new 
BufferedOutputStream(sock.getOutputStream))
+          } finally {
+            serverSocket.close()
+          }
         }
         try {
           daemonChannel.writeInt(port)

http://git-wip-us.apache.org/repos/asf/spark/blob/16cd9ac5/core/src/main/scala/org/apache/spark/deploy/RRunner.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala 
b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala
index 6eb53a8..e86b362 100644
--- a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala
@@ -68,10 +68,13 @@ object RRunner {
     // Java system properties etc.
     val sparkRBackend = new RBackend()
     @volatile var sparkRBackendPort = 0
+    @volatile var sparkRBackendSecret: String = null
     val initialized = new Semaphore(0)
     val sparkRBackendThread = new Thread("SparkR backend") {
       override def run() {
-        sparkRBackendPort = sparkRBackend.init()
+        val (port, authHelper) = sparkRBackend.init()
+        sparkRBackendPort = port
+        sparkRBackendSecret = authHelper.secret
         initialized.release()
         sparkRBackend.run()
       }
@@ -91,6 +94,7 @@ object RRunner {
         env.put("SPARKR_PACKAGE_DIR", rPackageDir.mkString(","))
         env.put("R_PROFILE_USER",
           Seq(rPackageDir(0), "SparkR", "profile", 
"general.R").mkString(File.separator))
+        env.put("SPARKR_BACKEND_AUTH_SECRET", sparkRBackendSecret)
         builder.redirectErrorStream(true) // Ugly but needed for stdout and 
stderr to synchronize
         val process = builder.start()
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to