http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/server/src/main/scala/org/apache/livy/sessions/SessionManager.scala ---------------------------------------------------------------------- diff --git a/server/src/main/scala/org/apache/livy/sessions/SessionManager.scala b/server/src/main/scala/org/apache/livy/sessions/SessionManager.scala new file mode 100644 index 0000000..d482c33 --- /dev/null +++ b/server/src/main/scala/org/apache/livy/sessions/SessionManager.scala @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.sessions + +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicInteger + +import scala.collection.mutable +import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.duration.Duration +import scala.reflect.ClassTag +import scala.util.control.NonFatal + +import org.apache.livy.{LivyConf, Logging} +import org.apache.livy.server.batch.{BatchRecoveryMetadata, BatchSession} +import org.apache.livy.server.interactive.{InteractiveRecoveryMetadata, InteractiveSession, SessionHeartbeatWatchdog} +import org.apache.livy.server.recovery.SessionStore +import org.apache.livy.sessions.Session.RecoveryMetadata + +object SessionManager { + val SESSION_RECOVERY_MODE_OFF = "off" + val SESSION_RECOVERY_MODE_RECOVERY = "recovery" +} + +class BatchSessionManager( + livyConf: LivyConf, + sessionStore: SessionStore, + mockSessions: Option[Seq[BatchSession]] = None) + extends SessionManager[BatchSession, BatchRecoveryMetadata] ( + livyConf, BatchSession.recover(_, livyConf, sessionStore), sessionStore, "batch", mockSessions) + +class InteractiveSessionManager( + livyConf: LivyConf, + sessionStore: SessionStore, + mockSessions: Option[Seq[InteractiveSession]] = None) + extends SessionManager[InteractiveSession, InteractiveRecoveryMetadata] ( + livyConf, + InteractiveSession.recover(_, livyConf, sessionStore), + sessionStore, + "interactive", + mockSessions) + with SessionHeartbeatWatchdog[InteractiveSession, InteractiveRecoveryMetadata] + { + start() + } + +class SessionManager[S <: Session, R <: RecoveryMetadata : ClassTag]( + protected val livyConf: LivyConf, + sessionRecovery: R => S, + sessionStore: SessionStore, + sessionType: String, + mockSessions: Option[Seq[S]] = None) + extends Logging { + + import SessionManager._ + + protected implicit def executor: ExecutionContext = ExecutionContext.global + + protected[this] final val idCounter = new AtomicInteger(0) + protected[this] final val sessions = mutable.LinkedHashMap[Int, S]() + + private[this] final val sessionTimeoutCheck = livyConf.getBoolean(LivyConf.SESSION_TIMEOUT_CHECK) + private[this] final val sessionTimeout = + TimeUnit.MILLISECONDS.toNanos(livyConf.getTimeAsMs(LivyConf.SESSION_TIMEOUT)) + private[this] final val sessionStateRetainedInSec = + TimeUnit.MILLISECONDS.toNanos(livyConf.getTimeAsMs(LivyConf.SESSION_STATE_RETAIN_TIME)) + + mockSessions.getOrElse(recover()).foreach(register) + new GarbageCollector().start() + + def nextId(): Int = synchronized { + val id = idCounter.getAndIncrement() + sessionStore.saveNextSessionId(sessionType, idCounter.get()) + id + } + + def register(session: S): S = { + info(s"Registering new session ${session.id}") + synchronized { + sessions.put(session.id, session) + } + session + } + + def get(id: Int): Option[S] = sessions.get(id) + + def size(): Int = sessions.size + + def all(): Iterable[S] = sessions.values + + def delete(id: Int): Option[Future[Unit]] = { + get(id).map(delete) + } + + def delete(session: S): Future[Unit] = { + session.stop().map { case _ => + try { + sessionStore.remove(sessionType, session.id) + synchronized { + sessions.remove(session.id) + } + } catch { + case NonFatal(e) => + error("Exception was thrown during stop session:", e) + throw e + } + } + } + + def shutdown(): Unit = { + val recoveryEnabled = livyConf.get(LivyConf.RECOVERY_MODE) != SESSION_RECOVERY_MODE_OFF + if (!recoveryEnabled) { + sessions.values.map(_.stop).foreach { future => + Await.ready(future, Duration.Inf) + } + } + } + + def collectGarbage(): Future[Iterable[Unit]] = { + def expired(session: Session): Boolean = { + session.state match { + case s: FinishedSessionState => + val currentTime = System.nanoTime() + currentTime - s.time > sessionStateRetainedInSec + case _ => + if (!sessionTimeoutCheck) { + false + } else if (session.isInstanceOf[BatchSession]) { + false + } else { + val currentTime = System.nanoTime() + currentTime - session.lastActivity > sessionTimeout + } + } + } + + Future.sequence(all().filter(expired).map(delete)) + } + + private def recover(): Seq[S] = { + // Recover next session id from state store and create SessionManager. + idCounter.set(sessionStore.getNextSessionId(sessionType)) + + // Retrieve session recovery metadata from state store. + val sessionMetadata = sessionStore.getAllSessions[R](sessionType) + + // Recover session from session recovery metadata. + val recoveredSessions = sessionMetadata.flatMap(_.toOption).map(sessionRecovery) + + info(s"Recovered ${recoveredSessions.length} $sessionType sessions." + + s" Next session id: $idCounter") + + // Print recovery error. + val recoveryFailure = sessionMetadata.filter(_.isFailure).map(_.failed.get) + recoveryFailure.foreach(ex => error(ex.getMessage, ex.getCause)) + + recoveredSessions + } + + private class GarbageCollector extends Thread("session gc thread") { + + setDaemon(true) + + override def run(): Unit = { + while (true) { + collectGarbage() + Thread.sleep(60 * 1000) + } + } + + } + +}
http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/server/src/main/scala/org/apache/livy/utils/Clock.scala ---------------------------------------------------------------------- diff --git a/server/src/main/scala/org/apache/livy/utils/Clock.scala b/server/src/main/scala/org/apache/livy/utils/Clock.scala new file mode 100644 index 0000000..8b396c7 --- /dev/null +++ b/server/src/main/scala/org/apache/livy/utils/Clock.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.livy.utils + +/** + * A lot of Livy code relies on time related functions like Thread.sleep. + * To timing effects from unit test, this class is created to mock out time. + * + * Code in Livy should not call Thread.sleep() directly. It should call this class instead. + */ +object Clock { + private var _sleep: Long => Unit = Thread.sleep + + def withSleepMethod(mockSleep: Long => Unit)(f: => Unit): Unit = { + try { + _sleep = mockSleep + f + } finally { + _sleep = Thread.sleep + } + } + + def sleep: Long => Unit = _sleep +} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/server/src/main/scala/org/apache/livy/utils/LineBufferedProcess.scala ---------------------------------------------------------------------- diff --git a/server/src/main/scala/org/apache/livy/utils/LineBufferedProcess.scala b/server/src/main/scala/org/apache/livy/utils/LineBufferedProcess.scala new file mode 100644 index 0000000..863f9a6 --- /dev/null +++ b/server/src/main/scala/org/apache/livy/utils/LineBufferedProcess.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.utils + +import org.apache.livy.{Logging, Utils} + +class LineBufferedProcess(process: Process) extends Logging { + + private[this] val _inputStream = new LineBufferedStream(process.getInputStream) + private[this] val _errorStream = new LineBufferedStream(process.getErrorStream) + + def inputLines: IndexedSeq[String] = _inputStream.lines + def errorLines: IndexedSeq[String] = _errorStream.lines + + def inputIterator: Iterator[String] = _inputStream.iterator + def errorIterator: Iterator[String] = _errorStream.iterator + + def destroy(): Unit = { + process.destroy() + } + + /** Returns if the process is still actively running. */ + def isAlive: Boolean = Utils.isProcessAlive(process) + + def exitValue(): Int = { + process.exitValue() + } + + def waitFor(): Int = { + val returnCode = process.waitFor() + _inputStream.waitUntilClose() + _errorStream.waitUntilClose() + returnCode + } +} + http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/server/src/main/scala/org/apache/livy/utils/LineBufferedStream.scala ---------------------------------------------------------------------- diff --git a/server/src/main/scala/org/apache/livy/utils/LineBufferedStream.scala b/server/src/main/scala/org/apache/livy/utils/LineBufferedStream.scala new file mode 100644 index 0000000..fb076e1 --- /dev/null +++ b/server/src/main/scala/org/apache/livy/utils/LineBufferedStream.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.utils + +import java.io.InputStream +import java.util.concurrent.locks.ReentrantLock + +import scala.io.Source + +import org.apache.livy.Logging + +class LineBufferedStream(inputStream: InputStream) extends Logging { + + private[this] var _lines: IndexedSeq[String] = IndexedSeq() + + private[this] val _lock = new ReentrantLock() + private[this] val _condition = _lock.newCondition() + private[this] var _finished = false + + private val thread = new Thread { + override def run() = { + val lines = Source.fromInputStream(inputStream).getLines() + for (line <- lines) { + _lock.lock() + try { + _lines = _lines :+ line + _condition.signalAll() + } finally { + _lock.unlock() + } + } + + _lines.map { line => info("stdout: ", line) } + _lock.lock() + try { + _finished = true + _condition.signalAll() + } finally { + _lock.unlock() + } + } + } + thread.setDaemon(true) + thread.start() + + def lines: IndexedSeq[String] = _lines + + def iterator: Iterator[String] = { + new LinesIterator + } + + def waitUntilClose(): Unit = thread.join() + + private class LinesIterator extends Iterator[String] { + private[this] var index = 0 + + override def hasNext: Boolean = { + if (index < _lines.length) { + true + } else { + // Otherwise we might still have more data. + _lock.lock() + try { + if (_finished) { + false + } else { + _condition.await() + index < _lines.length + } + } finally { + _lock.unlock() + } + } + } + + override def next(): String = { + val line = _lines(index) + index += 1 + line + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/server/src/main/scala/org/apache/livy/utils/LivySparkUtils.scala ---------------------------------------------------------------------- diff --git a/server/src/main/scala/org/apache/livy/utils/LivySparkUtils.scala b/server/src/main/scala/org/apache/livy/utils/LivySparkUtils.scala new file mode 100644 index 0000000..df83ca6 --- /dev/null +++ b/server/src/main/scala/org/apache/livy/utils/LivySparkUtils.scala @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.utils + +import java.io.{File, IOException} + +import scala.collection.SortedMap +import scala.math.Ordering.Implicits._ + +import org.apache.livy.{LivyConf, Logging} +import org.apache.livy.LivyConf.LIVY_SPARK_SCALA_VERSION + +object LivySparkUtils extends Logging { + + // For each Spark version we supported, we need to add this mapping relation in case Scala + // version cannot be detected from "spark-submit --version". + private val _defaultSparkScalaVersion = SortedMap( + // Spark 2.1 + Scala 2.11 + (2, 1) -> "2.11", + // Spark 2.0 + Scala 2.11 + (2, 0) -> "2.11", + // Spark 1.6 + Scala 2.10 + (1, 6) -> "2.10" + ) + + // Supported Spark version + private val MIN_VERSION = (1, 6) + private val MAX_VERSION = (2, 2) + + private val sparkVersionRegex = """version (.*)""".r.unanchored + private val scalaVersionRegex = """Scala version (.*), Java""".r.unanchored + + /** + * Test that Spark home is configured and configured Spark home is a directory. + */ + def testSparkHome(livyConf: LivyConf): Unit = { + val sparkHome = livyConf.sparkHome().getOrElse { + throw new IllegalArgumentException("Livy requires the SPARK_HOME environment variable") + } + + require(new File(sparkHome).isDirectory(), "SPARK_HOME path does not exist") + } + + /** + * Test that the configured `spark-submit` executable exists. + * + * @param livyConf + */ + def testSparkSubmit(livyConf: LivyConf): Unit = { + try { + testSparkVersion(sparkSubmitVersion(livyConf)._1) + } catch { + case e: IOException => + throw new IOException("Failed to run spark-submit executable", e) + } + } + + /** + * Throw an exception if Spark version is not supported. + * @param version Spark version + */ + def testSparkVersion(version: String): Unit = { + val v = formatSparkVersion(version) + require(v >= MIN_VERSION, s"Unsupported Spark version $v") + if (v >= MAX_VERSION) { + warn(s"Current Spark $v is not verified in Livy, please use it carefully") + } + } + + /** + * Call `spark-submit --version` and parse its output for Spark and Scala version. + * + * @param livyConf + * @return Tuple with Spark and Scala version + */ + def sparkSubmitVersion(livyConf: LivyConf): (String, Option[String]) = { + val sparkSubmit = livyConf.sparkSubmit() + val pb = new ProcessBuilder(sparkSubmit, "--version") + pb.redirectErrorStream(true) + pb.redirectInput(ProcessBuilder.Redirect.PIPE) + + if (LivyConf.TEST_MODE) { + pb.environment().put("LIVY_TEST_CLASSPATH", sys.props("java.class.path")) + } + + val process = new LineBufferedProcess(pb.start()) + val exitCode = process.waitFor() + val output = process.inputIterator.mkString("\n") + + var sparkVersion = "" + output match { + case sparkVersionRegex(version) => sparkVersion = version + case _ => + throw new IOException(f"Unable to determine spark-submit version [$exitCode]:\n$output") + } + + val scalaVersion = output match { + case scalaVersionRegex(version) if version.nonEmpty => Some(formatScalaVersion(version)) + case _ => None + } + + (sparkVersion, scalaVersion) + } + + def sparkScalaVersion( + formattedSparkVersion: (Int, Int), + scalaVersionFromSparkSubmit: Option[String], + livyConf: LivyConf): String = { + val scalaVersionInLivyConf = Option(livyConf.get(LIVY_SPARK_SCALA_VERSION)) + .filter(_.nonEmpty) + .map(formatScalaVersion) + + for (vSparkSubmit <- scalaVersionFromSparkSubmit; vLivyConf <- scalaVersionInLivyConf) { + require(vSparkSubmit == vLivyConf, + s"Scala version detected from spark-submit ($vSparkSubmit) does not match " + + s"Scala version configured in livy.conf ($vLivyConf)") + } + + scalaVersionInLivyConf + .orElse(scalaVersionFromSparkSubmit) + .getOrElse(defaultSparkScalaVersion(formattedSparkVersion)) + } + + /** + * Return formatted Spark version. + * + * @param version Spark version + * @return Two element tuple, one is major version and the other is minor version + */ + def formatSparkVersion(version: String): (Int, Int) = { + val versionPattern = """^(\d+)\.(\d+)(\..*)?$""".r + versionPattern.findFirstMatchIn(version) match { + case Some(m) => + (m.group(1).toInt, m.group(2).toInt) + case None => + throw new IllegalArgumentException(s"Fail to parse Spark version from $version") + } + } + + /** + * Return Scala binary version. + * It strips the patch version if specified. + * Throws if it cannot parse the version. + * + * @param scalaVersion Scala binary version String + * @return Scala binary version + */ + def formatScalaVersion(scalaVersion: String): String = { + val versionPattern = """(\d)+\.(\d+)+.*""".r + scalaVersion match { + case versionPattern(major, minor) => s"$major.$minor" + case _ => throw new IllegalArgumentException(s"Unrecognized Scala version: $scalaVersion") + } + } + + /** + * Return the default Scala version of a Spark version. + * + * @param sparkVersion formatted Spark version. + * @return Scala binary version + */ + private[utils] def defaultSparkScalaVersion(sparkVersion: (Int, Int)): String = { + _defaultSparkScalaVersion.get(sparkVersion) + .orElse { + if (sparkVersion < _defaultSparkScalaVersion.head._1) { + throw new IllegalArgumentException(s"Spark version $sparkVersion is less than the " + + s"minimum version ${_defaultSparkScalaVersion.head._1} supported by Livy") + } else if (sparkVersion > _defaultSparkScalaVersion.last._1) { + val (spark, scala) = _defaultSparkScalaVersion.last + warn(s"Spark version $sparkVersion is greater then the maximum version " + + s"$spark supported by Livy, will choose Scala version $scala instead, " + + s"please specify manually if it is the expected Scala version you want") + Some(scala) + } else { + None + } + } + .getOrElse( + throw new IllegalArgumentException(s"Fail to get Scala version from Spark $sparkVersion")) + } +} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/server/src/main/scala/org/apache/livy/utils/SparkApp.scala ---------------------------------------------------------------------- diff --git a/server/src/main/scala/org/apache/livy/utils/SparkApp.scala b/server/src/main/scala/org/apache/livy/utils/SparkApp.scala new file mode 100644 index 0000000..9afe281 --- /dev/null +++ b/server/src/main/scala/org/apache/livy/utils/SparkApp.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.utils + +import scala.collection.JavaConverters._ + +import org.apache.livy.LivyConf + +object AppInfo { + val DRIVER_LOG_URL_NAME = "driverLogUrl" + val SPARK_UI_URL_NAME = "sparkUiUrl" +} + +case class AppInfo(var driverLogUrl: Option[String] = None, var sparkUiUrl: Option[String] = None) { + import AppInfo._ + def asJavaMap: java.util.Map[String, String] = + Map(DRIVER_LOG_URL_NAME -> driverLogUrl.orNull, SPARK_UI_URL_NAME -> sparkUiUrl.orNull).asJava +} + +trait SparkAppListener { + /** Fired when appId is known, even during recovery. */ + def appIdKnown(appId: String): Unit = {} + + /** Fired when the app state in the cluster changes. */ + def stateChanged(oldState: SparkApp.State, newState: SparkApp.State): Unit = {} + + /** Fired when the app info is changed. */ + def infoChanged(appInfo: AppInfo): Unit = {} +} + +/** + * Provide factory methods for SparkApp. + */ +object SparkApp { + private val SPARK_YARN_TAG_KEY = "spark.yarn.tags" + + object State extends Enumeration { + val STARTING, RUNNING, FINISHED, FAILED, KILLED = Value + } + type State = State.Value + + /** + * Return cluster manager dependent SparkConf. + * + * @param uniqueAppTag A tag that can uniquely identify the application. + * @param livyConf + * @param sparkConf + */ + def prepareSparkConf( + uniqueAppTag: String, + livyConf: LivyConf, + sparkConf: Map[String, String]): Map[String, String] = { + if (livyConf.isRunningOnYarn()) { + val userYarnTags = sparkConf.get(SPARK_YARN_TAG_KEY).map("," + _).getOrElse("") + val mergedYarnTags = uniqueAppTag + userYarnTags + sparkConf ++ Map( + SPARK_YARN_TAG_KEY -> mergedYarnTags, + "spark.yarn.submit.waitAppCompletion" -> "false") + } else { + sparkConf + } + } + + /** + * Return a SparkApp object to control the underlying Spark application via YARN or spark-submit. + * + * @param uniqueAppTag A tag that can uniquely identify the application. + */ + def create( + uniqueAppTag: String, + appId: Option[String], + process: Option[LineBufferedProcess], + livyConf: LivyConf, + listener: Option[SparkAppListener]): SparkApp = { + if (livyConf.isRunningOnYarn()) { + new SparkYarnApp(uniqueAppTag, appId, process, listener, livyConf) + } else { + require(process.isDefined, "process must not be None when Livy master is not YARN.") + new SparkProcApp(process.get, listener) + } + } +} + +/** + * Encapsulate a Spark application. + */ +abstract class SparkApp { + def kill(): Unit + def log(): IndexedSeq[String] +} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/server/src/main/scala/org/apache/livy/utils/SparkProcApp.scala ---------------------------------------------------------------------- diff --git a/server/src/main/scala/org/apache/livy/utils/SparkProcApp.scala b/server/src/main/scala/org/apache/livy/utils/SparkProcApp.scala new file mode 100644 index 0000000..5fb3e42 --- /dev/null +++ b/server/src/main/scala/org/apache/livy/utils/SparkProcApp.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.utils + +import org.apache.livy.{Logging, Utils} + +/** + * Provide a class to control a Spark application using spark-submit. + * + * @param process The spark-submit process launched the Spark application. + */ +class SparkProcApp ( + process: LineBufferedProcess, + listener: Option[SparkAppListener]) + extends SparkApp with Logging { + + private var state = SparkApp.State.STARTING + + override def kill(): Unit = { + if (process.isAlive) { + process.destroy() + waitThread.join() + } + } + + override def log(): IndexedSeq[String] = process.inputLines + + private def changeState(newState: SparkApp.State.Value) = { + if (state != newState) { + listener.foreach(_.stateChanged(state, newState)) + state = newState + } + } + + private val waitThread = Utils.startDaemonThread(s"SparProcApp_$this") { + changeState(SparkApp.State.RUNNING) + process.waitFor() match { + case 0 => changeState(SparkApp.State.FINISHED) + case exitCode => + changeState(SparkApp.State.FAILED) + error(s"spark-submit exited with code $exitCode") + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/server/src/main/scala/org/apache/livy/utils/SparkProcessBuilder.scala ---------------------------------------------------------------------- diff --git a/server/src/main/scala/org/apache/livy/utils/SparkProcessBuilder.scala b/server/src/main/scala/org/apache/livy/utils/SparkProcessBuilder.scala new file mode 100644 index 0000000..66452d1 --- /dev/null +++ b/server/src/main/scala/org/apache/livy/utils/SparkProcessBuilder.scala @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.utils + +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.livy.{LivyConf, Logging} + +class SparkProcessBuilder(livyConf: LivyConf) extends Logging { + + private[this] var _executable: String = livyConf.sparkSubmit() + private[this] var _master: Option[String] = None + private[this] var _deployMode: Option[String] = None + private[this] var _className: Option[String] = None + private[this] var _name: Option[String] = None + private[this] val _conf = mutable.HashMap[String, String]() + private[this] var _driverClassPath: ArrayBuffer[String] = ArrayBuffer() + private[this] var _proxyUser: Option[String] = None + private[this] var _queue: Option[String] = None + private[this] var _env: ArrayBuffer[(String, String)] = ArrayBuffer() + private[this] var _redirectOutput: Option[ProcessBuilder.Redirect] = None + private[this] var _redirectError: Option[ProcessBuilder.Redirect] = None + private[this] var _redirectErrorStream: Option[Boolean] = None + + def executable(executable: String): SparkProcessBuilder = { + _executable = executable + this + } + + def master(masterUrl: String): SparkProcessBuilder = { + _master = Some(masterUrl) + this + } + + def deployMode(deployMode: String): SparkProcessBuilder = { + _deployMode = Some(deployMode) + this + } + + def className(className: String): SparkProcessBuilder = { + _className = Some(className) + this + } + + def name(name: String): SparkProcessBuilder = { + _name = Some(name) + this + } + + def conf(key: String): Option[String] = { + _conf.get(key) + } + + def conf(key: String, value: String, admin: Boolean = false): SparkProcessBuilder = { + this._conf(key) = value + this + } + + def conf(conf: Traversable[(String, String)]): SparkProcessBuilder = { + conf.foreach { case (key, value) => this.conf(key, value) } + this + } + + def driverJavaOptions(driverJavaOptions: String): SparkProcessBuilder = { + conf("spark.driver.extraJavaOptions", driverJavaOptions) + } + + def driverClassPath(classPath: String): SparkProcessBuilder = { + _driverClassPath += classPath + this + } + + def driverClassPaths(classPaths: Traversable[String]): SparkProcessBuilder = { + _driverClassPath ++= classPaths + this + } + + def driverCores(driverCores: Int): SparkProcessBuilder = { + this.driverCores(driverCores.toString) + } + + def driverMemory(driverMemory: String): SparkProcessBuilder = { + conf("spark.driver.memory", driverMemory) + } + + def driverCores(driverCores: String): SparkProcessBuilder = { + conf("spark.driver.cores", driverCores) + } + + def executorCores(executorCores: Int): SparkProcessBuilder = { + this.executorCores(executorCores.toString) + } + + def executorCores(executorCores: String): SparkProcessBuilder = { + conf("spark.executor.cores", executorCores) + } + + def executorMemory(executorMemory: String): SparkProcessBuilder = { + conf("spark.executor.memory", executorMemory) + } + + def numExecutors(numExecutors: Int): SparkProcessBuilder = { + this.numExecutors(numExecutors.toString) + } + + def numExecutors(numExecutors: String): SparkProcessBuilder = { + this.conf("spark.executor.instances", numExecutors) + } + + def proxyUser(proxyUser: String): SparkProcessBuilder = { + _proxyUser = Some(proxyUser) + this + } + + def queue(queue: String): SparkProcessBuilder = { + _queue = Some(queue) + this + } + + def env(key: String, value: String): SparkProcessBuilder = { + _env += ((key, value)) + this + } + + def redirectOutput(redirect: ProcessBuilder.Redirect): SparkProcessBuilder = { + _redirectOutput = Some(redirect) + this + } + + def redirectError(redirect: ProcessBuilder.Redirect): SparkProcessBuilder = { + _redirectError = Some(redirect) + this + } + + def redirectErrorStream(redirect: Boolean): SparkProcessBuilder = { + _redirectErrorStream = Some(redirect) + this + } + + def start(file: Option[String], args: Traversable[String]): LineBufferedProcess = { + var arguments = ArrayBuffer(_executable) + + def addOpt(option: String, value: Option[String]): Unit = { + value.foreach { v => + arguments += option + arguments += v + } + } + + def addList(option: String, values: Traversable[String]): Unit = { + if (values.nonEmpty) { + arguments += option + arguments += values.mkString(",") + } + } + + addOpt("--master", _master) + addOpt("--deploy-mode", _deployMode) + addOpt("--name", _name) + addOpt("--class", _className) + _conf.foreach { case (key, value) => + if (key == "spark.submit.pyFiles") { + arguments += "--py-files" + arguments += f"$value" + } else { + arguments += "--conf" + arguments += f"$key=$value" + } + } + addList("--driver-class-path", _driverClassPath) + + if (livyConf.getBoolean(LivyConf.IMPERSONATION_ENABLED)) { + addOpt("--proxy-user", _proxyUser) + } + + addOpt("--queue", _queue) + + arguments += file.getOrElse("spark-internal") + arguments ++= args + + val argsString = arguments + .map("'" + _.replace("'", "\\'") + "'") + .mkString(" ") + + info(s"Running $argsString") + + val pb = new ProcessBuilder(arguments.asJava) + val env = pb.environment() + + for ((key, value) <- _env) { + env.put(key, value) + } + + _redirectOutput.foreach(pb.redirectOutput) + _redirectError.foreach(pb.redirectError) + _redirectErrorStream.foreach(pb.redirectErrorStream) + + new LineBufferedProcess(pb.start()) + } + +} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/server/src/main/scala/org/apache/livy/utils/SparkYarnApp.scala ---------------------------------------------------------------------- diff --git a/server/src/main/scala/org/apache/livy/utils/SparkYarnApp.scala b/server/src/main/scala/org/apache/livy/utils/SparkYarnApp.scala new file mode 100644 index 0000000..b2a828f --- /dev/null +++ b/server/src/main/scala/org/apache/livy/utils/SparkYarnApp.scala @@ -0,0 +1,311 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.livy.utils + +import java.util.concurrent.TimeoutException + +import scala.annotation.tailrec +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer +import scala.concurrent._ +import scala.concurrent.duration._ +import scala.language.postfixOps +import scala.util.Try + +import org.apache.hadoop.yarn.api.records.{ApplicationId, ApplicationReport, FinalApplicationStatus, YarnApplicationState} +import org.apache.hadoop.yarn.client.api.YarnClient +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.exceptions.ApplicationAttemptNotFoundException +import org.apache.hadoop.yarn.util.ConverterUtils + +import org.apache.livy.{LivyConf, Logging, Utils} + +object SparkYarnApp extends Logging { + + def init(livyConf: LivyConf): Unit = { + sessionLeakageCheckInterval = livyConf.getTimeAsMs(LivyConf.YARN_APP_LEAKAGE_CHECK_INTERVAL) + sessionLeakageCheckTimeout = livyConf.getTimeAsMs(LivyConf.YARN_APP_LEAKAGE_CHECK_TIMEOUT) + leakedAppsGCThread.setDaemon(true) + leakedAppsGCThread.setName("LeakedAppsGCThread") + leakedAppsGCThread.start() + } + + // YarnClient is thread safe. Create once, share it across threads. + lazy val yarnClient = { + val c = YarnClient.createYarnClient() + c.init(new YarnConfiguration()) + c.start() + c + } + + private def getYarnTagToAppIdTimeout(livyConf: LivyConf): FiniteDuration = + livyConf.getTimeAsMs(LivyConf.YARN_APP_LOOKUP_TIMEOUT) milliseconds + + private def getYarnPollInterval(livyConf: LivyConf): FiniteDuration = + livyConf.getTimeAsMs(LivyConf.YARN_POLL_INTERVAL) milliseconds + + private val appType = Set("SPARK").asJava + + private val leakedAppTags = new java.util.concurrent.ConcurrentHashMap[String, Long]() + + private var sessionLeakageCheckTimeout: Long = _ + + private var sessionLeakageCheckInterval: Long = _ + + private val leakedAppsGCThread = new Thread() { + override def run(): Unit = { + while (true) { + if (!leakedAppTags.isEmpty) { + // kill the app if found it and remove it if exceeding a threashold + val iter = leakedAppTags.entrySet().iterator() + var isRemoved = false + val now = System.currentTimeMillis() + val apps = yarnClient.getApplications(appType).asScala + while(iter.hasNext) { + val entry = iter.next() + apps.find(_.getApplicationTags.contains(entry.getKey)) + .foreach({ e => + info(s"Kill leaked app ${e.getApplicationId}") + yarnClient.killApplication(e.getApplicationId) + iter.remove() + isRemoved = true + }) + if (!isRemoved) { + if ((entry.getValue - now) > sessionLeakageCheckTimeout) { + iter.remove() + info(s"Remove leaked yarn app tag ${entry.getKey}") + } + } + } + } + Thread.sleep(sessionLeakageCheckInterval) + } + } + } + + +} + +/** + * Provide a class to control a Spark application using YARN API. + * + * @param appTag An app tag that can unique identify the YARN app. + * @param appIdOption The appId of the YARN app. If this's None, SparkYarnApp will find it + * using appTag. + * @param process The spark-submit process launched the YARN application. This is optional. + * If it's provided, SparkYarnApp.log() will include its log. + * @param listener Optional listener for notification of appId discovery and app state changes. + */ +class SparkYarnApp private[utils] ( + appTag: String, + appIdOption: Option[String], + process: Option[LineBufferedProcess], + listener: Option[SparkAppListener], + livyConf: LivyConf, + yarnClient: => YarnClient = SparkYarnApp.yarnClient) // For unit test. + extends SparkApp + with Logging { + import SparkYarnApp._ + + private val appIdPromise: Promise[ApplicationId] = Promise() + private[utils] var state: SparkApp.State = SparkApp.State.STARTING + private var yarnDiagnostics: IndexedSeq[String] = IndexedSeq.empty[String] + + override def log(): IndexedSeq[String] = + ("stdout: " +: process.map(_.inputLines).getOrElse(ArrayBuffer.empty[String])) ++ + ("\nstderr: " +: process.map(_.errorLines).getOrElse(ArrayBuffer.empty[String])) ++ + ("\nYARN Diagnostics: " +: yarnDiagnostics) + + override def kill(): Unit = synchronized { + if (isRunning) { + try { + val timeout = SparkYarnApp.getYarnTagToAppIdTimeout(livyConf) + yarnClient.killApplication(Await.result(appIdPromise.future, timeout)) + } catch { + // We cannot kill the YARN app without the app id. + // There's a chance the YARN app hasn't been submitted during a livy-server failure. + // We don't want a stuck session that can't be deleted. Emit a warning and move on. + case _: TimeoutException | _: InterruptedException => + warn("Deleting a session while its YARN application is not found.") + yarnAppMonitorThread.interrupt() + } finally { + process.foreach(_.destroy()) + } + } + } + + private def changeState(newState: SparkApp.State.Value): Unit = { + if (state != newState) { + listener.foreach(_.stateChanged(state, newState)) + state = newState + } + } + + /** + * Find the corresponding YARN application id from an application tag. + * + * @param appTag The application tag tagged on the target application. + * If the tag is not unique, it returns the first application it found. + * It will be converted to lower case to match YARN's behaviour. + * @return ApplicationId or the failure. + */ + @tailrec + private def getAppIdFromTag( + appTag: String, + pollInterval: Duration, + deadline: Deadline): ApplicationId = { + val appTagLowerCase = appTag.toLowerCase() + + // FIXME Should not loop thru all YARN applications but YarnClient doesn't offer an API. + // Consider calling rmClient in YarnClient directly. + yarnClient.getApplications(appType).asScala.find(_.getApplicationTags.contains(appTagLowerCase)) + match { + case Some(app) => app.getApplicationId + case None => + if (deadline.isOverdue) { + process.foreach(_.destroy()) + leakedAppTags.put(appTag, System.currentTimeMillis()) + throw new Exception(s"No YARN application is found with tag $appTagLowerCase in " + + livyConf.getTimeAsMs(LivyConf.YARN_APP_LOOKUP_TIMEOUT)/1000 + " seconds. " + + "Please check your cluster status, it is may be very busy.") + } else { + Clock.sleep(pollInterval.toMillis) + getAppIdFromTag(appTagLowerCase, pollInterval, deadline) + } + } + } + + private def getYarnDiagnostics(appReport: ApplicationReport): IndexedSeq[String] = { + Option(appReport.getDiagnostics) + .filter(_.nonEmpty) + .map[IndexedSeq[String]]("YARN Diagnostics:" +: _.split("\n")) + .getOrElse(IndexedSeq.empty) + } + + private def isRunning: Boolean = { + state != SparkApp.State.FAILED && state != SparkApp.State.FINISHED && + state != SparkApp.State.KILLED + } + + // Exposed for unit test. + private[utils] def mapYarnState( + appId: ApplicationId, + yarnAppState: YarnApplicationState, + finalAppStatus: FinalApplicationStatus): SparkApp.State.Value = { + yarnAppState match { + case (YarnApplicationState.NEW | + YarnApplicationState.NEW_SAVING | + YarnApplicationState.SUBMITTED | + YarnApplicationState.ACCEPTED) => SparkApp.State.STARTING + case YarnApplicationState.RUNNING => SparkApp.State.RUNNING + case YarnApplicationState.FINISHED => + finalAppStatus match { + case FinalApplicationStatus.SUCCEEDED => SparkApp.State.FINISHED + case FinalApplicationStatus.FAILED => SparkApp.State.FAILED + case FinalApplicationStatus.KILLED => SparkApp.State.KILLED + case s => + error(s"Unknown YARN final status $appId $s") + SparkApp.State.FAILED + } + case YarnApplicationState.FAILED => SparkApp.State.FAILED + case YarnApplicationState.KILLED => SparkApp.State.KILLED + } + } + + // Exposed for unit test. + // TODO Instead of spawning a thread for every session, create a centralized thread and + // batch YARN queries. + private[utils] val yarnAppMonitorThread = Utils.startDaemonThread(s"yarnAppMonitorThread-$this") { + try { + // Wait for spark-submit to finish submitting the app to YARN. + process.foreach { p => + val exitCode = p.waitFor() + if (exitCode != 0) { + throw new Exception(s"spark-submit exited with code $exitCode}.\n" + + s"${process.get.inputLines.mkString("\n")}") + } + } + + // If appId is not known, query YARN by appTag to get it. + val appId = try { + appIdOption.map(ConverterUtils.toApplicationId).getOrElse { + val pollInterval = getYarnPollInterval(livyConf) + val deadline = getYarnTagToAppIdTimeout(livyConf).fromNow + getAppIdFromTag(appTag, pollInterval, deadline) + } + } catch { + case e: Exception => + appIdPromise.failure(e) + throw e + } + appIdPromise.success(appId) + + Thread.currentThread().setName(s"yarnAppMonitorThread-$appId") + listener.foreach(_.appIdKnown(appId.toString)) + + val pollInterval = SparkYarnApp.getYarnPollInterval(livyConf) + var appInfo = AppInfo() + while (isRunning) { + try { + Clock.sleep(pollInterval.toMillis) + + // Refresh application state + val appReport = yarnClient.getApplicationReport(appId) + yarnDiagnostics = getYarnDiagnostics(appReport) + changeState(mapYarnState( + appReport.getApplicationId, + appReport.getYarnApplicationState, + appReport.getFinalApplicationStatus)) + + val latestAppInfo = { + val attempt = + yarnClient.getApplicationAttemptReport(appReport.getCurrentApplicationAttemptId) + val driverLogUrl = + Try(yarnClient.getContainerReport(attempt.getAMContainerId).getLogUrl) + .toOption + AppInfo(driverLogUrl, Option(appReport.getTrackingUrl)) + } + + if (appInfo != latestAppInfo) { + listener.foreach(_.infoChanged(latestAppInfo)) + appInfo = latestAppInfo + } + } catch { + // This exception might be thrown during app is starting up. It's transient. + case e: ApplicationAttemptNotFoundException => + // Workaround YARN-4411: No enum constant FINAL_SAVING from getApplicationAttemptReport() + case e: IllegalArgumentException => + if (e.getMessage.contains("FINAL_SAVING")) { + debug("Encountered YARN-4411.") + } else { + throw e + } + } + } + + debug(s"$appId $state ${yarnDiagnostics.mkString(" ")}") + } catch { + case e: InterruptedException => + yarnDiagnostics = ArrayBuffer("Session stopped by user.") + changeState(SparkApp.State.KILLED) + case e: Throwable => + error(s"Error whiling refreshing YARN state: $e") + yarnDiagnostics = ArrayBuffer(e.toString, e.getStackTrace().mkString(" ")) + changeState(SparkApp.State.FAILED) + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/server/src/test/scala/com/cloudera/livy/server/ApiVersioningSupportSpec.scala ---------------------------------------------------------------------- diff --git a/server/src/test/scala/com/cloudera/livy/server/ApiVersioningSupportSpec.scala b/server/src/test/scala/com/cloudera/livy/server/ApiVersioningSupportSpec.scala deleted file mode 100644 index 4c530f3..0000000 --- a/server/src/test/scala/com/cloudera/livy/server/ApiVersioningSupportSpec.scala +++ /dev/null @@ -1,124 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.cloudera.livy.server - -import javax.servlet.http.HttpServletResponse - -import org.scalatest.FunSpecLike -import org.scalatra.ScalatraServlet -import org.scalatra.test.scalatest.ScalatraSuite - -import com.cloudera.livy.LivyBaseUnitTestSuite - -class ApiVersioningSupportSpec extends ScalatraSuite with FunSpecLike with LivyBaseUnitTestSuite { - val LatestVersionOutput = "latest" - - object FakeApiVersions extends Enumeration { - type FakeApiVersions = Value - val v0_1 = Value("0.1") - val v0_2 = Value("0.2") - val v1_0 = Value("1.0") - } - - import FakeApiVersions._ - - class MockServlet extends ScalatraServlet with AbstractApiVersioningSupport { - override val apiVersions = FakeApiVersions - override type ApiVersionType = FakeApiVersions.Value - - get("/test") { - response.writer.write(LatestVersionOutput) - } - - get("/test", apiVersion <= v0_2) { - response.writer.write(v0_2.toString) - } - - get("/test", apiVersion <= v0_1) { - response.writer.write(v0_1.toString) - } - - get("/droppedApi", apiVersion <= v0_2) { - } - - get("/newApi", apiVersion >= v0_2) { - } - } - - var mockServlet: MockServlet = new MockServlet - addServlet(mockServlet, "/*") - - def generateHeader(acceptHeader: String): Map[String, String] = { - if (acceptHeader != null) Map("Accept" -> acceptHeader) else Map.empty - } - - def shouldReturn(url: String, acceptHeader: String, expectedVersion: String = null): Unit = { - get(url, headers = generateHeader(acceptHeader)) { - status should equal(200) - if (expectedVersion != null) { - body should equal(expectedVersion) - } - } - } - - def shouldFail(url: String, acceptHeader: String, expectedErrorCode: Int): Unit = { - get(url, headers = generateHeader(acceptHeader)) { - status should equal(expectedErrorCode) - } - } - - it("should pick the latest API version if Accept header is unspecified") { - shouldReturn("/test", null, LatestVersionOutput) - } - - it("should pick the latest API version if Accept header does not specify any version") { - shouldReturn("/test", "foo", LatestVersionOutput) - shouldReturn("/test", "application/vnd.random.v1.1", LatestVersionOutput) - shouldReturn("/test", "application/vnd.livy.+json", LatestVersionOutput) - } - - it("should pick the correct API version") { - shouldReturn("/test", "application/vnd.livy.v0.1", v0_1.toString) - shouldReturn("/test", "application/vnd.livy.v0.2+", v0_2.toString) - shouldReturn("/test", "application/vnd.livy.v0.1+bar", v0_1.toString) - shouldReturn("/test", "application/vnd.livy.v0.2+foo", v0_2.toString) - shouldReturn("/test", "application/vnd.livy.v0.1+vnd.livy.v0.2", v0_1.toString) - shouldReturn("/test", "application/vnd.livy.v0.2++++++++++++++++", v0_2.toString) - shouldReturn("/test", "application/vnd.livy.v1.0", LatestVersionOutput) - } - - it("should return error when the specified API version does not exist") { - shouldFail("/test", "application/vnd.livy.v", HttpServletResponse.SC_NOT_ACCEPTABLE) - shouldFail("/test", "application/vnd.livy.v+json", HttpServletResponse.SC_NOT_ACCEPTABLE) - shouldFail("/test", "application/vnd.livy.v666.666", HttpServletResponse.SC_NOT_ACCEPTABLE) - shouldFail("/test", "application/vnd.livy.v666.666+json", HttpServletResponse.SC_NOT_ACCEPTABLE) - shouldFail("/test", "application/vnd.livy.v1.1+json", HttpServletResponse.SC_NOT_ACCEPTABLE) - } - - it("should not see a dropped API") { - shouldReturn("/droppedApi", "application/vnd.livy.v0.1+json") - shouldReturn("/droppedApi", "application/vnd.livy.v0.2+json") - shouldFail("/droppedApi", "application/vnd.livy.v1.0+json", HttpServletResponse.SC_NOT_FOUND) - } - - it("should not see a new API at an older version") { - shouldFail("/newApi", "application/vnd.livy.v0.1+json", HttpServletResponse.SC_NOT_FOUND) - shouldReturn("/newApi", "application/vnd.livy.v0.2+json") - shouldReturn("/newApi", "application/vnd.livy.v1.0+json") - } -} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/server/src/test/scala/com/cloudera/livy/server/BaseJsonServletSpec.scala ---------------------------------------------------------------------- diff --git a/server/src/test/scala/com/cloudera/livy/server/BaseJsonServletSpec.scala b/server/src/test/scala/com/cloudera/livy/server/BaseJsonServletSpec.scala deleted file mode 100644 index 1d8d38a..0000000 --- a/server/src/test/scala/com/cloudera/livy/server/BaseJsonServletSpec.scala +++ /dev/null @@ -1,141 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.cloudera.livy.server - -import java.io.ByteArrayOutputStream -import javax.servlet.http.HttpServletResponse._ - -import scala.reflect.ClassTag - -import com.fasterxml.jackson.databind.ObjectMapper -import org.scalatest.FunSpecLike -import org.scalatra.test.scalatest.ScalatraSuite - -import com.cloudera.livy.LivyBaseUnitTestSuite - -/** - * Base class that enhances ScalatraSuite so that it's easier to test JsonServlet - * implementations. Variants of the test methods (get, post, etc) exist with the "j" - * prefix; these automatically serialize the body of the request to JSON, and - * deserialize the result from JSON. - * - * In case the response is not JSON, the expected type for the test function should be - * `Unit`, and the `response` object should be checked directly. - */ -abstract class BaseJsonServletSpec extends ScalatraSuite - with FunSpecLike with LivyBaseUnitTestSuite { - - protected val mapper = new ObjectMapper() - .registerModule(com.fasterxml.jackson.module.scala.DefaultScalaModule) - - protected val defaultHeaders: Map[String, String] = Map("Content-Type" -> "application/json") - - protected def jdelete[R: ClassTag]( - uri: String, - expectedStatus: Int = SC_OK, - headers: Map[String, String] = defaultHeaders) - (fn: R => Unit): Unit = { - delete(uri, headers = headers)(doTest(expectedStatus, fn)) - } - - protected def jget[R: ClassTag]( - uri: String, - expectedStatus: Int = SC_OK, - headers: Map[String, String] = defaultHeaders) - (fn: R => Unit): Unit = { - get(uri, headers = headers)(doTest(expectedStatus, fn)) - } - - protected def jpatch[R: ClassTag]( - uri: String, - body: AnyRef, - expectedStatus: Int = SC_OK, - headers: Map[String, String] = defaultHeaders) - (fn: R => Unit): Unit = { - patch(uri, body = toJson(body), headers = headers)(doTest(expectedStatus, fn)) - } - - protected def jpost[R: ClassTag]( - uri: String, - body: AnyRef, - expectedStatus: Int = SC_CREATED, - headers: Map[String, String] = defaultHeaders) - (fn: R => Unit): Unit = { - post(uri, body = toJson(body), headers = headers)(doTest(expectedStatus, fn)) - } - - /** A version of jpost specific for testing file upload. */ - protected def jupload[R: ClassTag]( - uri: String, - files: Iterable[(String, Any)], - headers: Map[String, String] = Map(), - expectedStatus: Int = SC_OK) - (fn: R => Unit): Unit = { - post(uri, Map.empty, files)(doTest(expectedStatus, fn)) - } - - protected def jput[R: ClassTag]( - uri: String, - body: AnyRef, - expectedStatus: Int = SC_OK, - headers: Map[String, String] = defaultHeaders) - (fn: R => Unit): Unit = { - put(uri, body = toJson(body), headers = headers)(doTest(expectedStatus, fn)) - } - - private def doTest[R: ClassTag](expectedStatus: Int, fn: R => Unit) - (implicit klass: ClassTag[R]): Unit = { - if (status != expectedStatus) { - // Yeah this is weird, but we don't want to evaluate "response.body" if there's no error. - assert(status === expectedStatus, - s"Unexpected response status: $status != $expectedStatus (${response.body})") - } - // Only try to parse the body if response is in the "OK" range (20x). - if ((status / 100) * 100 == SC_OK) { - val result = - if (header("Content-Type").startsWith("application/json")) { - // Sometimes there's an empty body with no "Content-Length" header. So read the whole - // body first, and only send it to Jackson if there's content. - val in = response.inputStream - val out = new ByteArrayOutputStream() - val buf = new Array[Byte](1024) - var read = 0 - while (read >= 0) { - read = in.read(buf) - if (read > 0) { - out.write(buf, 0, read) - } - } - - val data = out.toByteArray() - if (data.length > 0) { - mapper.readValue(data, klass.runtimeClass) - } else { - null - } - } else { - assert(klass.runtimeClass == classOf[Unit]) - () - } - fn(result.asInstanceOf[R]) - } - } - - private def toJson(obj: Any): Array[Byte] = mapper.writeValueAsBytes(obj) - -} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/server/src/test/scala/com/cloudera/livy/server/BaseSessionServletSpec.scala ---------------------------------------------------------------------- diff --git a/server/src/test/scala/com/cloudera/livy/server/BaseSessionServletSpec.scala b/server/src/test/scala/com/cloudera/livy/server/BaseSessionServletSpec.scala deleted file mode 100644 index 14eb2e6..0000000 --- a/server/src/test/scala/com/cloudera/livy/server/BaseSessionServletSpec.scala +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.cloudera.livy.server - -import javax.servlet.http.HttpServletRequest - -import org.scalatest.BeforeAndAfterAll - -import com.cloudera.livy.LivyConf -import com.cloudera.livy.sessions.Session -import com.cloudera.livy.sessions.Session.RecoveryMetadata - -object BaseSessionServletSpec { - - /** Header used to override the user remote user in tests. */ - val REMOTE_USER_HEADER = "X-Livy-SessionServlet-User" - -} - -abstract class BaseSessionServletSpec[S <: Session, R <: RecoveryMetadata] - extends BaseJsonServletSpec - with BeforeAndAfterAll { - - /** Config map containing option that is blacklisted. */ - protected val BLACKLISTED_CONFIG = Map("spark.do_not_set" -> "true") - - /** Name of the admin user. */ - protected val ADMIN = "__admin__" - - /** Create headers that identify a specific user in tests. */ - protected def makeUserHeaders(user: String): Map[String, String] = { - defaultHeaders ++ Map(BaseSessionServletSpec.REMOTE_USER_HEADER -> user) - } - - protected val adminHeaders = makeUserHeaders(ADMIN) - - /** Create a LivyConf with impersonation enabled and a superuser. */ - protected def createConf(): LivyConf = { - new LivyConf() - .set(LivyConf.IMPERSONATION_ENABLED, true) - .set(LivyConf.SUPERUSERS, ADMIN) - .set(LivyConf.LOCAL_FS_WHITELIST, sys.props("java.io.tmpdir")) - } - - override def afterAll(): Unit = { - super.afterAll() - servlet.shutdown() - } - - def createServlet(): SessionServlet[S, R] - - protected val servlet = createServlet() - - addServlet(servlet, "/*") - - protected def toJson(msg: AnyRef): Array[Byte] = mapper.writeValueAsBytes(msg) - -} - -trait RemoteUserOverride { - this: SessionServlet[_, _] => - - override protected def remoteUser(req: HttpServletRequest): String = { - req.getHeader(BaseSessionServletSpec.REMOTE_USER_HEADER) - } - -} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/server/src/test/scala/com/cloudera/livy/server/JsonServletSpec.scala ---------------------------------------------------------------------- diff --git a/server/src/test/scala/com/cloudera/livy/server/JsonServletSpec.scala b/server/src/test/scala/com/cloudera/livy/server/JsonServletSpec.scala deleted file mode 100644 index 3713443..0000000 --- a/server/src/test/scala/com/cloudera/livy/server/JsonServletSpec.scala +++ /dev/null @@ -1,149 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.cloudera.livy.server - -import java.nio.charset.StandardCharsets.UTF_8 -import javax.servlet.http.HttpServletResponse._ - -import org.scalatra._ - -class JsonServletSpec extends BaseJsonServletSpec { - - addServlet(new TestJsonServlet(), "/*") - - describe("The JSON servlet") { - - it("should serialize result of delete") { - jdelete[MethodReturn]("/delete") { result => - assert(result.value === "delete") - } - } - - it("should serialize result of get") { - jget[MethodReturn]("/get") { result => - assert(result.value === "get") - } - } - - it("should serialize an ActionResult's body") { - jpost[MethodReturn]("/post", MethodArg("post")) { result => - assert(result.value === "post") - } - } - - it("should wrap a raw result") { - jput[MethodReturn]("/put", MethodArg("put")) { result => - assert(result.value === "put") - } - } - - it("should bypass non-json results") { - jpatch[Unit]("/patch", MethodArg("patch"), expectedStatus = SC_NOT_FOUND) { _ => - assert(response.body === "patch") - } - } - - it("should translate JSON errors to BadRequest") { - post("/post", "abcde".getBytes(UTF_8), headers = defaultHeaders) { - assert(status === SC_BAD_REQUEST) - } - } - - it("should translate bad param name to BadRequest") { - post("/post", """{"value1":"1"}""".getBytes(UTF_8), headers = defaultHeaders) { - assert(status === SC_BAD_REQUEST) - } - } - - it("should translate type mismatch to BadRequest") { - post("/postlist", """{"listParam":"1"}""".getBytes(UTF_8), headers = defaultHeaders) { - assert(status === SC_BAD_REQUEST) - } - } - - it("should respect user-installed error handlers") { - post("/error", headers = defaultHeaders) { - assert(status === SC_SERVICE_UNAVAILABLE) - assert(response.body === "error") - } - } - - it("should handle empty return values") { - jget[MethodReturn]("/empty") { result => - assert(result == null) - } - } - - } - -} - -private case class MethodArg(value: String) - -private case class MethodReturn(value: String) - -private case class MethodReturnList(listParam: List[String] = List()) - -private class TestJsonServlet extends JsonServlet { - - before() { - contentType = "application/json" - } - - delete("/delete") { - Ok(MethodReturn("delete")) - } - - get("/get") { - Ok(MethodReturn("get")) - } - - jpost[MethodArg]("/post") { arg => - Created(MethodReturn(arg.value)) - } - - jpost[MethodReturnList]("/postlist") { arg => - Created() - } - - jput[MethodArg]("/put") { arg => - MethodReturn(arg.value) - } - - jpatch[MethodArg]("/patch") { arg => - contentType = "text/plain" - NotFound(arg.value) - } - - get("/empty") { - () - } - - post("/error") { - throw new IllegalStateException("error") - } - - // Install an error handler to make sure the parent's still work. - error { - case e: IllegalStateException => - contentType = "text/plain" - ServiceUnavailable(e.getMessage()) - } - -} - http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/server/src/test/scala/com/cloudera/livy/server/SessionServletSpec.scala ---------------------------------------------------------------------- diff --git a/server/src/test/scala/com/cloudera/livy/server/SessionServletSpec.scala b/server/src/test/scala/com/cloudera/livy/server/SessionServletSpec.scala deleted file mode 100644 index 1ae8a25..0000000 --- a/server/src/test/scala/com/cloudera/livy/server/SessionServletSpec.scala +++ /dev/null @@ -1,156 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.cloudera.livy.server - - -import javax.servlet.http.HttpServletRequest -import javax.servlet.http.HttpServletResponse._ - -import org.scalatest.mock.MockitoSugar.mock - -import com.cloudera.livy.LivyConf -import com.cloudera.livy.server.recovery.SessionStore -import com.cloudera.livy.sessions.{Session, SessionManager, SessionState} -import com.cloudera.livy.sessions.Session.RecoveryMetadata - -object SessionServletSpec { - - val PROXY_USER = "proxyUser" - - class MockSession(id: Int, owner: String, livyConf: LivyConf) - extends Session(id, owner, livyConf) { - - case class MockRecoveryMetadata(id: Int) extends RecoveryMetadata() - - override val proxyUser = None - - override def recoveryMetadata: RecoveryMetadata = MockRecoveryMetadata(0) - - override def state: SessionState = SessionState.Idle() - - override protected def stopSession(): Unit = () - - override def logLines(): IndexedSeq[String] = IndexedSeq("log") - - } - - case class MockSessionView(id: Int, owner: String, logs: Seq[String]) - -} - -class SessionServletSpec - extends BaseSessionServletSpec[Session, RecoveryMetadata] { - - import SessionServletSpec._ - - override def createServlet(): SessionServlet[Session, RecoveryMetadata] = { - val conf = createConf() - val sessionManager = new SessionManager[Session, RecoveryMetadata]( - conf, - { _ => assert(false).asInstanceOf[Session] }, - mock[SessionStore], - "test", - Some(Seq.empty)) - - new SessionServlet(sessionManager, conf) with RemoteUserOverride { - override protected def createSession(req: HttpServletRequest): Session = { - val params = bodyAs[Map[String, String]](req) - checkImpersonation(params.get(PROXY_USER), req) - new MockSession(sessionManager.nextId(), remoteUser(req), conf) - } - - override protected def clientSessionView( - session: Session, - req: HttpServletRequest): Any = { - val logs = if (hasAccess(session.owner, req)) session.logLines() else Nil - MockSessionView(session.id, session.owner, logs) - } - } - } - - private val aliceHeaders = makeUserHeaders("alice") - private val bobHeaders = makeUserHeaders("bob") - - private def delete(id: Int, headers: Map[String, String], expectedStatus: Int): Unit = { - jdelete[Map[String, Any]](s"/$id", headers = headers, expectedStatus = expectedStatus) { _ => - // Nothing to do. - } - } - - describe("SessionServlet") { - - it("should return correct Location in header") { - // mount to "/sessions/*" to test. If request URI is "/session", getPathInfo() will - // return null, since there's no extra path. - // mount to "/*" will always return "/", so that it cannot reflect the issue. - addServlet(servlet, "/sessions/*") - jpost[MockSessionView]("/sessions", Map(), headers = aliceHeaders) { res => - assert(header("Location") === "/sessions/0") - jdelete[Map[String, Any]]("/sessions/0", SC_OK, aliceHeaders) { _ => } - } - } - - it("should attach owner information to sessions") { - jpost[MockSessionView]("/", Map(), headers = aliceHeaders) { res => - assert(res.owner === "alice") - assert(res.logs === IndexedSeq("log")) - delete(res.id, aliceHeaders, SC_OK) - } - } - - it("should allow other users to see non-sensitive information") { - jpost[MockSessionView]("/", Map(), headers = aliceHeaders) { res => - jget[MockSessionView](s"/${res.id}", headers = bobHeaders) { res => - assert(res.owner === "alice") - assert(res.logs === Nil) - } - delete(res.id, aliceHeaders, SC_OK) - } - } - - it("should prevent non-owners from modifying sessions") { - jpost[MockSessionView]("/", Map(), headers = aliceHeaders) { res => - delete(res.id, bobHeaders, SC_FORBIDDEN) - } - } - - it("should allow admins to access all sessions") { - jpost[MockSessionView]("/", Map(), headers = aliceHeaders) { res => - jget[MockSessionView](s"/${res.id}", headers = adminHeaders) { res => - assert(res.owner === "alice") - assert(res.logs === IndexedSeq("log")) - } - delete(res.id, adminHeaders, SC_OK) - } - } - - it("should not allow regular users to impersonate others") { - jpost[MockSessionView]("/", Map(PROXY_USER -> "bob"), headers = aliceHeaders, - expectedStatus = SC_FORBIDDEN) { _ => } - } - - it("should allow admins to impersonate anyone") { - jpost[MockSessionView]("/", Map(PROXY_USER -> "bob"), headers = adminHeaders) { res => - delete(res.id, bobHeaders, SC_FORBIDDEN) - delete(res.id, adminHeaders, SC_OK) - } - } - - } - -} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/server/src/test/scala/com/cloudera/livy/server/batch/BatchServletSpec.scala ---------------------------------------------------------------------- diff --git a/server/src/test/scala/com/cloudera/livy/server/batch/BatchServletSpec.scala b/server/src/test/scala/com/cloudera/livy/server/batch/BatchServletSpec.scala deleted file mode 100644 index 8a79593..0000000 --- a/server/src/test/scala/com/cloudera/livy/server/batch/BatchServletSpec.scala +++ /dev/null @@ -1,149 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.cloudera.livy.server.batch - -import java.io.FileWriter -import java.nio.file.{Files, Path} -import java.util.concurrent.TimeUnit -import javax.servlet.http.HttpServletRequest -import javax.servlet.http.HttpServletResponse._ - -import scala.concurrent.duration.Duration - -import org.mockito.Mockito._ -import org.scalatest.mock.MockitoSugar.mock - -import com.cloudera.livy.Utils -import com.cloudera.livy.server.BaseSessionServletSpec -import com.cloudera.livy.server.recovery.SessionStore -import com.cloudera.livy.sessions.{BatchSessionManager, SessionState} -import com.cloudera.livy.utils.AppInfo - -class BatchServletSpec extends BaseSessionServletSpec[BatchSession, BatchRecoveryMetadata] { - - val script: Path = { - val script = Files.createTempFile("livy-test", ".py") - script.toFile.deleteOnExit() - val writer = new FileWriter(script.toFile) - try { - writer.write( - """ - |print "hello world" - """.stripMargin) - } finally { - writer.close() - } - script - } - - override def createServlet(): BatchSessionServlet = { - val livyConf = createConf() - val sessionStore = mock[SessionStore] - new BatchSessionServlet( - new BatchSessionManager(livyConf, sessionStore, Some(Seq.empty)), - sessionStore, - livyConf) - } - - describe("Batch Servlet") { - it("should create and tear down a batch") { - jget[Map[String, Any]]("/") { data => - data("sessions") should equal (Seq()) - } - - val createRequest = new CreateBatchRequest() - createRequest.file = script.toString - createRequest.conf = Map("spark.driver.extraClassPath" -> sys.props("java.class.path")) - - jpost[Map[String, Any]]("/", createRequest) { data => - header("Location") should equal("/0") - data("id") should equal (0) - - val batch = servlet.sessionManager.get(0) - batch should be (defined) - } - - // Wait for the process to finish. - { - val batch = servlet.sessionManager.get(0).get - Utils.waitUntil({ () => !batch.state.isActive }, Duration(10, TimeUnit.SECONDS)) - (batch.state match { - case SessionState.Success(_) => true - case _ => false - }) should be (true) - } - - jget[Map[String, Any]]("/0") { data => - data("id") should equal (0) - data("state") should equal ("success") - - val batch = servlet.sessionManager.get(0) - batch should be (defined) - } - - jget[Map[String, Any]]("/0/log?size=1000") { data => - data("id") should equal (0) - data("log").asInstanceOf[Seq[String]] should contain ("hello world") - - val batch = servlet.sessionManager.get(0) - batch should be (defined) - } - - jdelete[Map[String, Any]]("/0") { data => - data should equal (Map("msg" -> "deleted")) - - val batch = servlet.sessionManager.get(0) - batch should not be defined - } - } - - it("should respect config black list") { - val createRequest = new CreateBatchRequest() - createRequest.file = script.toString - createRequest.conf = BLACKLISTED_CONFIG - jpost[Map[String, Any]]("/", createRequest, expectedStatus = SC_BAD_REQUEST) { _ => } - } - - it("should show session properties") { - val id = 0 - val state = SessionState.Running() - val appId = "appid" - val appInfo = AppInfo(Some("DRIVER LOG URL"), Some("SPARK UI URL")) - val log = IndexedSeq[String]("log1", "log2") - - val session = mock[BatchSession] - when(session.id).thenReturn(id) - when(session.state).thenReturn(state) - when(session.appId).thenReturn(Some(appId)) - when(session.appInfo).thenReturn(appInfo) - when(session.logLines()).thenReturn(log) - - val req = mock[HttpServletRequest] - - val view = servlet.asInstanceOf[BatchSessionServlet].clientSessionView(session, req) - .asInstanceOf[BatchSessionView] - - view.id shouldEqual id - view.state shouldEqual state.toString - view.appId shouldEqual Some(appId) - view.appInfo shouldEqual appInfo - view.log shouldEqual log - } - } - -} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/server/src/test/scala/com/cloudera/livy/server/batch/BatchSessionSpec.scala ---------------------------------------------------------------------- diff --git a/server/src/test/scala/com/cloudera/livy/server/batch/BatchSessionSpec.scala b/server/src/test/scala/com/cloudera/livy/server/batch/BatchSessionSpec.scala deleted file mode 100644 index 0aa8d28..0000000 --- a/server/src/test/scala/com/cloudera/livy/server/batch/BatchSessionSpec.scala +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.cloudera.livy.server.batch - -import java.io.FileWriter -import java.nio.file.{Files, Path} -import java.util.concurrent.TimeUnit - -import scala.concurrent.duration.Duration - -import org.mockito.Matchers -import org.mockito.Matchers.anyObject -import org.mockito.Mockito._ -import org.scalatest.{BeforeAndAfter, FunSpec, ShouldMatchers} -import org.scalatest.mock.MockitoSugar.mock - -import com.cloudera.livy.{LivyBaseUnitTestSuite, LivyConf, Utils} -import com.cloudera.livy.server.recovery.SessionStore -import com.cloudera.livy.sessions.SessionState -import com.cloudera.livy.utils.{AppInfo, SparkApp} - -class BatchSessionSpec - extends FunSpec - with BeforeAndAfter - with ShouldMatchers - with LivyBaseUnitTestSuite { - - val script: Path = { - val script = Files.createTempFile("livy-test", ".py") - script.toFile.deleteOnExit() - val writer = new FileWriter(script.toFile) - try { - writer.write( - """ - |print "hello world" - """.stripMargin) - } finally { - writer.close() - } - script - } - - describe("A Batch process") { - var sessionStore: SessionStore = null - - before { - sessionStore = mock[SessionStore] - } - - it("should create a process") { - val req = new CreateBatchRequest() - req.file = script.toString - req.conf = Map("spark.driver.extraClassPath" -> sys.props("java.class.path")) - - val conf = new LivyConf().set(LivyConf.LOCAL_FS_WHITELIST, sys.props("java.io.tmpdir")) - val batch = BatchSession.create(0, req, conf, null, None, sessionStore) - - Utils.waitUntil({ () => !batch.state.isActive }, Duration(10, TimeUnit.SECONDS)) - (batch.state match { - case SessionState.Success(_) => true - case _ => false - }) should be (true) - - batch.logLines() should contain("hello world") - } - - it("should update appId and appInfo") { - val conf = new LivyConf() - val req = new CreateBatchRequest() - val mockApp = mock[SparkApp] - val batch = BatchSession.create(0, req, conf, null, None, sessionStore, Some(mockApp)) - - val expectedAppId = "APPID" - batch.appIdKnown(expectedAppId) - verify(sessionStore, atLeastOnce()).save( - Matchers.eq(BatchSession.RECOVERY_SESSION_TYPE), anyObject()) - batch.appId shouldEqual Some(expectedAppId) - - val expectedAppInfo = AppInfo(Some("DRIVER LOG URL"), Some("SPARK UI URL")) - batch.infoChanged(expectedAppInfo) - batch.appInfo shouldEqual expectedAppInfo - } - - it("should recover session") { - val conf = new LivyConf() - val req = new CreateBatchRequest() - val mockApp = mock[SparkApp] - val m = BatchRecoveryMetadata(99, None, "appTag", null, None) - val batch = BatchSession.recover(m, conf, sessionStore, Some(mockApp)) - - batch.state shouldBe a[SessionState.Recovering] - - batch.appIdKnown("appId") - verify(sessionStore, atLeastOnce()).save( - Matchers.eq(BatchSession.RECOVERY_SESSION_TYPE), anyObject()) - } - } -} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/server/src/test/scala/com/cloudera/livy/server/batch/CreateBatchRequestSpec.scala ---------------------------------------------------------------------- diff --git a/server/src/test/scala/com/cloudera/livy/server/batch/CreateBatchRequestSpec.scala b/server/src/test/scala/com/cloudera/livy/server/batch/CreateBatchRequestSpec.scala deleted file mode 100644 index 9119e79..0000000 --- a/server/src/test/scala/com/cloudera/livy/server/batch/CreateBatchRequestSpec.scala +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.cloudera.livy.server.batch - -import com.fasterxml.jackson.databind.{JsonMappingException, ObjectMapper} -import org.scalatest.FunSpec - -import com.cloudera.livy.LivyBaseUnitTestSuite - -class CreateBatchRequestSpec extends FunSpec with LivyBaseUnitTestSuite { - - private val mapper = new ObjectMapper() - .registerModule(com.fasterxml.jackson.module.scala.DefaultScalaModule) - - describe("CreateBatchRequest") { - - it("should have default values for fields after deserialization") { - val json = """{ "file" : "foo" }""" - val req = mapper.readValue(json, classOf[CreateBatchRequest]) - assert(req.file === "foo") - assert(req.proxyUser === None) - assert(req.args === List()) - assert(req.className === None) - assert(req.jars === List()) - assert(req.pyFiles === List()) - assert(req.files === List()) - assert(req.driverMemory === None) - assert(req.driverCores === None) - assert(req.executorMemory === None) - assert(req.executorCores === None) - assert(req.numExecutors === None) - assert(req.archives === List()) - assert(req.queue === None) - assert(req.name === None) - assert(req.conf === Map()) - } - - } - -} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/server/src/test/scala/com/cloudera/livy/server/interactive/BaseInteractiveServletSpec.scala ---------------------------------------------------------------------- diff --git a/server/src/test/scala/com/cloudera/livy/server/interactive/BaseInteractiveServletSpec.scala b/server/src/test/scala/com/cloudera/livy/server/interactive/BaseInteractiveServletSpec.scala deleted file mode 100644 index fc48643..0000000 --- a/server/src/test/scala/com/cloudera/livy/server/interactive/BaseInteractiveServletSpec.scala +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.cloudera.livy.server.interactive - -import java.io.File -import java.nio.file.Files - -import org.apache.commons.io.FileUtils -import org.apache.spark.launcher.SparkLauncher - -import com.cloudera.livy.LivyConf -import com.cloudera.livy.rsc.RSCConf -import com.cloudera.livy.server.BaseSessionServletSpec -import com.cloudera.livy.sessions.{Kind, SessionKindModule, Spark} - -abstract class BaseInteractiveServletSpec - extends BaseSessionServletSpec[InteractiveSession, InteractiveRecoveryMetadata] { - - mapper.registerModule(new SessionKindModule()) - - protected var tempDir: File = _ - - override def afterAll(): Unit = { - super.afterAll() - if (tempDir != null) { - scala.util.Try(FileUtils.deleteDirectory(tempDir)) - tempDir = null - } - } - - override protected def createConf(): LivyConf = synchronized { - if (tempDir == null) { - tempDir = Files.createTempDirectory("client-test").toFile() - } - super.createConf() - .set(LivyConf.SESSION_STAGING_DIR, tempDir.toURI().toString()) - .set(LivyConf.REPL_JARS, "dummy.jar") - .set(LivyConf.LIVY_SPARK_VERSION, "1.6.0") - .set(LivyConf.LIVY_SPARK_SCALA_VERSION, "2.10.5") - } - - protected def createRequest( - inProcess: Boolean = true, - extraConf: Map[String, String] = Map(), - kind: Kind = Spark()): CreateInteractiveRequest = { - val classpath = sys.props("java.class.path") - val request = new CreateInteractiveRequest() - request.kind = kind - request.conf = extraConf ++ Map( - RSCConf.Entry.LIVY_JARS.key() -> "", - RSCConf.Entry.CLIENT_IN_PROCESS.key() -> inProcess.toString, - SparkLauncher.SPARK_MASTER -> "local", - SparkLauncher.DRIVER_EXTRA_CLASSPATH -> classpath, - SparkLauncher.EXECUTOR_EXTRA_CLASSPATH -> classpath - ) - request - } - -}