LIVY-272. Support Statement progress for interactive session. (#260)
Project: http://git-wip-us.apache.org/repos/asf/incubator-livy/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-livy/commit/70f23b90 Tree: http://git-wip-us.apache.org/repos/asf/incubator-livy/tree/70f23b90 Diff: http://git-wip-us.apache.org/repos/asf/incubator-livy/diff/70f23b90 Branch: refs/heads/master Commit: 70f23b90f9cc7fbe98663551a4cdff86a7c069dd Parents: 126b57e Author: Saisai Shao <sai.sai.s...@gmail.com> Authored: Sun Mar 12 10:16:52 2017 +0800 Committer: Alex Man <alex-the-...@users.noreply.github.com> Committed: Sat Mar 11 18:16:52 2017 -0800 ---------------------------------------------------------------------- .../cloudera/livy/repl/SparkInterpreter.scala | 4 +- .../livy/repl/SparkInterpreterSpec.scala | 2 +- .../cloudera/livy/repl/SparkInterpreter.scala | 4 +- .../livy/repl/SparkInterpreterSpec.scala | 2 +- .../livy/repl/AbstractSparkInterpreter.scala | 11 +- .../com/cloudera/livy/repl/Interpreter.scala | 16 +- .../cloudera/livy/repl/ProcessInterpreter.scala | 13 +- .../cloudera/livy/repl/PythonInterpreter.scala | 12 +- .../com/cloudera/livy/repl/ReplDriver.scala | 15 +- .../scala/com/cloudera/livy/repl/Session.scala | 3 +- .../cloudera/livy/repl/SparkRInterpreter.scala | 13 +- .../livy/repl/StatementProgressListener.scala | 162 +++++++++++++ .../livy/repl/PythonInterpreterSpec.scala | 7 +- .../cloudera/livy/repl/PythonSessionSpec.scala | 7 +- .../livy/repl/ScalaInterpreterSpec.scala | 5 +- .../livy/repl/SparkRInterpreterSpec.scala | 5 +- .../cloudera/livy/repl/SparkRSessionSpec.scala | 5 +- .../cloudera/livy/repl/SparkSessionSpec.scala | 6 +- .../repl/StatementProgressListenerSpec.scala | 227 +++++++++++++++++++ .../java/com/cloudera/livy/rsc/RSCConf.java | 4 + .../com/cloudera/livy/rsc/driver/Statement.java | 10 + .../InteractiveSessionServletSpec.scala | 1 + .../interactive/InteractiveSessionSpec.scala | 16 ++ 23 files changed, 509 insertions(+), 41 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/70f23b90/repl/scala-2.10/src/main/scala/com/cloudera/livy/repl/SparkInterpreter.scala ---------------------------------------------------------------------- diff --git a/repl/scala-2.10/src/main/scala/com/cloudera/livy/repl/SparkInterpreter.scala b/repl/scala-2.10/src/main/scala/com/cloudera/livy/repl/SparkInterpreter.scala index 3322de1..ec12929 100644 --- a/repl/scala-2.10/src/main/scala/com/cloudera/livy/repl/SparkInterpreter.scala +++ b/repl/scala-2.10/src/main/scala/com/cloudera/livy/repl/SparkInterpreter.scala @@ -33,7 +33,8 @@ import org.apache.spark.repl.SparkIMain /** * This represents a Spark interpreter. It is not thread safe. */ -class SparkInterpreter(conf: SparkConf) +class SparkInterpreter(conf: SparkConf, + override val statementProgressListener: StatementProgressListener) extends AbstractSparkInterpreter with SparkContextInitializer { private var sparkIMain: SparkIMain = _ @@ -103,6 +104,7 @@ class SparkInterpreter(conf: SparkConf) createSparkContext(conf) } + sparkContext.addSparkListener(statementProgressListener) sparkContext } http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/70f23b90/repl/scala-2.10/src/test/scala/com/cloudera/livy/repl/SparkInterpreterSpec.scala ---------------------------------------------------------------------- diff --git a/repl/scala-2.10/src/test/scala/com/cloudera/livy/repl/SparkInterpreterSpec.scala b/repl/scala-2.10/src/test/scala/com/cloudera/livy/repl/SparkInterpreterSpec.scala index e2b783a..3df35b5 100644 --- a/repl/scala-2.10/src/test/scala/com/cloudera/livy/repl/SparkInterpreterSpec.scala +++ b/repl/scala-2.10/src/test/scala/com/cloudera/livy/repl/SparkInterpreterSpec.scala @@ -24,7 +24,7 @@ import com.cloudera.livy.LivyBaseUnitTestSuite class SparkInterpreterSpec extends FunSpec with Matchers with LivyBaseUnitTestSuite { describe("SparkInterpreter") { - val interpreter = new SparkInterpreter(null) + val interpreter = new SparkInterpreter(null, null) it("should parse Scala compile error.") { // Regression test for LIVY-260. http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/70f23b90/repl/scala-2.11/src/main/scala/com/cloudera/livy/repl/SparkInterpreter.scala ---------------------------------------------------------------------- diff --git a/repl/scala-2.11/src/main/scala/com/cloudera/livy/repl/SparkInterpreter.scala b/repl/scala-2.11/src/main/scala/com/cloudera/livy/repl/SparkInterpreter.scala index 2bf6347..bf2f680 100644 --- a/repl/scala-2.11/src/main/scala/com/cloudera/livy/repl/SparkInterpreter.scala +++ b/repl/scala-2.11/src/main/scala/com/cloudera/livy/repl/SparkInterpreter.scala @@ -33,7 +33,8 @@ import org.apache.spark.repl.SparkILoop /** * Scala 2.11 version of SparkInterpreter */ -class SparkInterpreter(conf: SparkConf) +class SparkInterpreter(conf: SparkConf, + override val statementProgressListener: StatementProgressListener) extends AbstractSparkInterpreter with SparkContextInitializer { protected var sparkContext: SparkContext = _ @@ -89,6 +90,7 @@ class SparkInterpreter(conf: SparkConf) createSparkContext(conf) } + sparkContext.addSparkListener(statementProgressListener) sparkContext } http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/70f23b90/repl/scala-2.11/src/test/scala/com/cloudera/livy/repl/SparkInterpreterSpec.scala ---------------------------------------------------------------------- diff --git a/repl/scala-2.11/src/test/scala/com/cloudera/livy/repl/SparkInterpreterSpec.scala b/repl/scala-2.11/src/test/scala/com/cloudera/livy/repl/SparkInterpreterSpec.scala index 5cb88e3..56656d7 100644 --- a/repl/scala-2.11/src/test/scala/com/cloudera/livy/repl/SparkInterpreterSpec.scala +++ b/repl/scala-2.11/src/test/scala/com/cloudera/livy/repl/SparkInterpreterSpec.scala @@ -24,7 +24,7 @@ import com.cloudera.livy.LivyBaseUnitTestSuite class SparkInterpreterSpec extends FunSpec with Matchers with LivyBaseUnitTestSuite { describe("SparkInterpreter") { - val interpreter = new SparkInterpreter(null) + val interpreter = new SparkInterpreter(null, null) it("should parse Scala compile error.") { // Regression test for LIVY-. http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/70f23b90/repl/src/main/scala/com/cloudera/livy/repl/AbstractSparkInterpreter.scala ---------------------------------------------------------------------- diff --git a/repl/src/main/scala/com/cloudera/livy/repl/AbstractSparkInterpreter.scala b/repl/src/main/scala/com/cloudera/livy/repl/AbstractSparkInterpreter.scala index d30bb3b..d117da7 100644 --- a/repl/src/main/scala/com/cloudera/livy/repl/AbstractSparkInterpreter.scala +++ b/repl/src/main/scala/com/cloudera/livy/repl/AbstractSparkInterpreter.scala @@ -50,12 +50,13 @@ abstract class AbstractSparkInterpreter extends Interpreter with Logging { protected def valueOfTerm(name: String): Option[Any] - override def execute(code: String): Interpreter.ExecuteResponse = restoreContextClassLoader { - require(isStarted()) + override protected[repl] def execute(code: String): Interpreter.ExecuteResponse = + restoreContextClassLoader { + require(isStarted()) - executeLines(code.trim.split("\n").toList, Interpreter.ExecuteSuccess(JObject( - (TEXT_PLAIN, JString("")) - ))) + executeLines(code.trim.split("\n").toList, Interpreter.ExecuteSuccess(JObject( + (TEXT_PLAIN, JString("")) + ))) } private def executeMagic(magic: String, rest: String): Interpreter.ExecuteResponse = { http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/70f23b90/repl/src/main/scala/com/cloudera/livy/repl/Interpreter.scala ---------------------------------------------------------------------- diff --git a/repl/src/main/scala/com/cloudera/livy/repl/Interpreter.scala b/repl/src/main/scala/com/cloudera/livy/repl/Interpreter.scala index 069953e..59ad878 100644 --- a/repl/src/main/scala/com/cloudera/livy/repl/Interpreter.scala +++ b/repl/src/main/scala/com/cloudera/livy/repl/Interpreter.scala @@ -37,18 +37,28 @@ trait Interpreter { def kind: String + def statementProgressListener: StatementProgressListener + /** * Start the Interpreter. * - * @return A SparkContext, which may be null. + * @return A SparkContext */ def start(): SparkContext /** - * Execute the code and return the result as a Future as it may + * Execute the code and return the result. + */ + def execute(statementId: Int, code: String): ExecuteResponse = { + statementProgressListener.setCurrentStatementId(statementId) + execute(code) + } + + /** + * Execute the code and return the result, it may * take some time to execute. */ - def execute(code: String): ExecuteResponse + protected[repl] def execute(code: String): ExecuteResponse /** Shut down the interpreter. */ def close(): Unit http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/70f23b90/repl/src/main/scala/com/cloudera/livy/repl/ProcessInterpreter.scala ---------------------------------------------------------------------- diff --git a/repl/src/main/scala/com/cloudera/livy/repl/ProcessInterpreter.scala b/repl/src/main/scala/com/cloudera/livy/repl/ProcessInterpreter.scala index c4fb8ca..0414bbb 100644 --- a/repl/src/main/scala/com/cloudera/livy/repl/ProcessInterpreter.scala +++ b/repl/src/main/scala/com/cloudera/livy/repl/ProcessInterpreter.scala @@ -41,10 +41,9 @@ private case class ShutdownRequest(promise: Promise[Unit]) extends Request * * @param process */ -abstract class ProcessInterpreter(process: Process) - extends Interpreter - with Logging -{ +abstract class ProcessInterpreter(process: Process, + override val statementProgressListener: StatementProgressListener) + extends Interpreter with Logging { protected[this] val stdin = new PrintWriter(process.getOutputStream) protected[this] val stdout = new BufferedReader(new InputStreamReader(process.getInputStream), 1) @@ -54,11 +53,13 @@ abstract class ProcessInterpreter(process: Process) if (ClientConf.TEST_MODE) { null.asInstanceOf[SparkContext] } else { - SparkContext.getOrCreate() + val sc = SparkContext.getOrCreate() + sc.addSparkListener(statementProgressListener) + sc } } - override def execute(code: String): Interpreter.ExecuteResponse = { + override protected[repl] def execute(code: String): Interpreter.ExecuteResponse = { try { sendExecuteRequest(code) } catch { http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/70f23b90/repl/src/main/scala/com/cloudera/livy/repl/PythonInterpreter.scala ---------------------------------------------------------------------- diff --git a/repl/src/main/scala/com/cloudera/livy/repl/PythonInterpreter.scala b/repl/src/main/scala/com/cloudera/livy/repl/PythonInterpreter.scala index 2195d0e..6e80c09 100644 --- a/repl/src/main/scala/com/cloudera/livy/repl/PythonInterpreter.scala +++ b/repl/src/main/scala/com/cloudera/livy/repl/PythonInterpreter.scala @@ -45,7 +45,7 @@ import com.cloudera.livy.sessions._ // scalastyle:off println object PythonInterpreter extends Logging { - def apply(conf: SparkConf, kind: Kind): Interpreter = { + def apply(conf: SparkConf, kind: Kind, listener: StatementProgressListener): Interpreter = { val pythonExec = kind match { case PySpark() => sys.env.getOrElse("PYSPARK_PYTHON", "python") case PySpark3() => sys.env.getOrElse("PYSPARK3_PYTHON", "python3") @@ -72,7 +72,7 @@ object PythonInterpreter extends Logging { env.put("LIVY_SPARK_MAJOR_VERSION", conf.get("spark.livy.spark_major_version", "1")) builder.redirectError(Redirect.PIPE) val process = builder.start() - new PythonInterpreter(process, gatewayServer, kind.toString) + new PythonInterpreter(process, gatewayServer, kind.toString, listener) } private def findPySparkArchives(): Seq[String] = { @@ -187,8 +187,12 @@ object PythonInterpreter extends Logging { } } -private class PythonInterpreter(process: Process, gatewayServer: GatewayServer, pyKind: String) - extends ProcessInterpreter(process) +private class PythonInterpreter( + process: Process, + gatewayServer: GatewayServer, + pyKind: String, + listener: StatementProgressListener) + extends ProcessInterpreter(process, listener) with Logging { implicit val formats = DefaultFormats http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/70f23b90/repl/src/main/scala/com/cloudera/livy/repl/ReplDriver.scala ---------------------------------------------------------------------- diff --git a/repl/src/main/scala/com/cloudera/livy/repl/ReplDriver.scala b/repl/src/main/scala/com/cloudera/livy/repl/ReplDriver.scala index c176412..d368c6a 100644 --- a/repl/src/main/scala/com/cloudera/livy/repl/ReplDriver.scala +++ b/repl/src/main/scala/com/cloudera/livy/repl/ReplDriver.scala @@ -44,10 +44,11 @@ class ReplDriver(conf: SparkConf, livyConf: RSCConf) override protected def initializeContext(): JavaSparkContext = { interpreter = kind match { - case PySpark() => PythonInterpreter(conf, PySpark()) - case PySpark3() => PythonInterpreter(conf, PySpark3()) - case Spark() => new SparkInterpreter(conf) - case SparkR() => SparkRInterpreter(conf) + case PySpark() => PythonInterpreter(conf, PySpark(), new StatementProgressListener(livyConf)) + case PySpark3() => + PythonInterpreter(conf, PySpark3(), new StatementProgressListener(livyConf)) + case Spark() => new SparkInterpreter(conf, new StatementProgressListener(livyConf)) + case SparkR() => SparkRInterpreter(conf, new StatementProgressListener(livyConf)) } session = new Session(livyConf, interpreter, { s => broadcast(new ReplState(s.toString)) }) @@ -90,6 +91,12 @@ class ReplDriver(conf: SparkConf, livyConf: RSCConf) session.statements.filterKeys(id => id >= msg.from && id < until).values.toArray } } + + // Update progress of statements when queried + statements.foreach { s => + s.updateProgress(interpreter.statementProgressListener.progressOfStatement(s.id)) + } + new ReplJobResults(statements.sortBy(_.id)) } http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/70f23b90/repl/src/main/scala/com/cloudera/livy/repl/Session.scala ---------------------------------------------------------------------- diff --git a/repl/src/main/scala/com/cloudera/livy/repl/Session.scala b/repl/src/main/scala/com/cloudera/livy/repl/Session.scala index bf1f3b4..54056a3 100644 --- a/repl/src/main/scala/com/cloudera/livy/repl/Session.scala +++ b/repl/src/main/scala/com/cloudera/livy/repl/Session.scala @@ -116,6 +116,7 @@ class Session( statement.compareAndTransit(StatementState.Running, StatementState.Available) statement.compareAndTransit(StatementState.Cancelling, StatementState.Cancelled) + statement.updateProgress(1.0) }(interpreterExecutor) statementId @@ -187,7 +188,7 @@ class Session( } val resultInJson = try { - interpreter.execute(code) match { + interpreter.execute(executionCount, code) match { case Interpreter.ExecuteSuccess(data) => transitToIdle() http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/70f23b90/repl/src/main/scala/com/cloudera/livy/repl/SparkRInterpreter.scala ---------------------------------------------------------------------- diff --git a/repl/src/main/scala/com/cloudera/livy/repl/SparkRInterpreter.scala b/repl/src/main/scala/com/cloudera/livy/repl/SparkRInterpreter.scala index cc57d72..8e5f3c0 100644 --- a/repl/src/main/scala/com/cloudera/livy/repl/SparkRInterpreter.scala +++ b/repl/src/main/scala/com/cloudera/livy/repl/SparkRInterpreter.scala @@ -35,6 +35,7 @@ import org.json4s._ import org.json4s.JsonDSL._ import com.cloudera.livy.client.common.ClientConf +import com.cloudera.livy.rsc.RSCConf // scalastyle:off println object SparkRInterpreter { @@ -64,7 +65,7 @@ object SparkRInterpreter { ")" ).r.unanchored - def apply(conf: SparkConf): SparkRInterpreter = { + def apply(conf: SparkConf, listener: StatementProgressListener): SparkRInterpreter = { val backendTimeout = sys.env.getOrElse("SPARKR_BACKEND_TIMEOUT", "120").toInt val mirror = universe.runtimeMirror(getClass.getClassLoader) val sparkRBackendClass = mirror.classLoader.loadClass("org.apache.spark.api.r.RBackend") @@ -117,7 +118,8 @@ object SparkRInterpreter { val process = builder.start() new SparkRInterpreter(process, backendInstance, backendThread, conf.get("spark.livy.spark_major_version", "1"), - conf.getBoolean("spark.repl.enableHiveContext", false)) + conf.getBoolean("spark.repl.enableHiveContext", false), + listener) } catch { case e: Exception => if (backendThread != null) { @@ -132,15 +134,16 @@ class SparkRInterpreter(process: Process, backendInstance: Any, backendThread: Thread, val sparkMajorVersion: String, - hiveEnabled: Boolean) - extends ProcessInterpreter(process) { + hiveEnabled: Boolean, + statementProgressListener: StatementProgressListener) + extends ProcessInterpreter(process, statementProgressListener) { import SparkRInterpreter._ implicit val formats = DefaultFormats private[this] var executionCount = 0 override def kind: String = "sparkr" - private[this] val isStarted = new CountDownLatch(1); + private[this] val isStarted = new CountDownLatch(1) final override protected def waitUntilReady(): Unit = { // Set the option to catch and ignore errors instead of halting. http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/70f23b90/repl/src/main/scala/com/cloudera/livy/repl/StatementProgressListener.scala ---------------------------------------------------------------------- diff --git a/repl/src/main/scala/com/cloudera/livy/repl/StatementProgressListener.scala b/repl/src/main/scala/com/cloudera/livy/repl/StatementProgressListener.scala new file mode 100644 index 0000000..ae2147b --- /dev/null +++ b/repl/src/main/scala/com/cloudera/livy/repl/StatementProgressListener.scala @@ -0,0 +1,162 @@ +/* + * Licensed to Cloudera, Inc. under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Cloudera, Inc. licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.cloudera.livy.repl + +import scala.collection.mutable + +import com.google.common.annotations.VisibleForTesting +import org.apache.spark.Success +import org.apache.spark.scheduler._ + +import com.cloudera.livy.rsc.RSCConf + +/** + * [[StatementProgressListener]] is an implementation of SparkListener, used to track the progress + * of submitted statement, this class builds a mapping relation between statement, jobs, stages + * and tasks, and uses the finished task number to calculate the statement progress. + * + * By default 100 latest statement progresses will be kept, users could also configure + * livy.rsc.retained_statements to change the cached number. + * + * This statement progress can only reflect the statement in which has Spark jobs, if + * the statement submitted doesn't generate any Spark job, the progress will always return 0.0 + * until completed. + * + * Also if the statement includes several Spark jobs, the progress will be flipped because we + * don't know the actual number of Spark jobs/tasks generated before the statement executed. + */ +class StatementProgressListener(conf: RSCConf) extends SparkListener { + + case class TaskCount(var currFinishedTasks: Int, var totalTasks: Int) + case class JobState(jobId: Int, var isCompleted: Boolean) + + private val retainedStatements = conf.getInt(RSCConf.Entry.RETAINED_STATEMENT_NUMBER) + + /** Statement id to list of jobs map */ + @VisibleForTesting + private[repl] val statementToJobs = new mutable.LinkedHashMap[Int, Seq[JobState]]() + @VisibleForTesting + private[repl] val jobIdToStatement = new mutable.HashMap[Int, Int]() + /** Job id to list of stage ids map */ + @VisibleForTesting + private[repl] val jobIdToStages = new mutable.HashMap[Int, Seq[Int]]() + /** Stage id to number of finished/total tasks map */ + @VisibleForTesting + private[repl] val stageIdToTaskCount = new mutable.HashMap[Int, TaskCount]() + + @transient private var currentStatementId: Int = _ + + /** + * Set current statement id, onJobStart() will use current statement id to build the mapping + * relations. + */ + def setCurrentStatementId(stmtId: Int): Unit = { + currentStatementId = stmtId + } + + /** + * Get the current progress of given statement id. + */ + def progressOfStatement(stmtId: Int): Double = synchronized { + var finishedTasks = 0 + var totalTasks = 0 + + for { + job <- statementToJobs.getOrElse(stmtId, Seq.empty) + stageId <- jobIdToStages.getOrElse(job.jobId, Seq.empty) + taskCount <- stageIdToTaskCount.get(stageId) + } yield { + finishedTasks += taskCount.currFinishedTasks + totalTasks += taskCount.totalTasks + } + + if (totalTasks == 0) { + 0.0 + } else { + finishedTasks.toDouble / totalTasks + } + } + + /** + * Get the active job ids of the given statement id. + */ + def activeJobsOfStatement(stmtId: Int): Seq[Int] = synchronized { + statementToJobs.getOrElse(stmtId, Seq.empty).filter(!_.isCompleted).map(_.jobId) + } + + override def onJobStart(jobStart: SparkListenerJobStart): Unit = synchronized { + val jobs = statementToJobs.getOrElseUpdate(currentStatementId, Seq.empty) :+ + JobState(jobStart.jobId, isCompleted = false) + statementToJobs.put(currentStatementId, jobs) + jobIdToStatement(jobStart.jobId) = currentStatementId + + jobIdToStages(jobStart.jobId) = jobStart.stageInfos.map(_.stageId) + jobStart.stageInfos.foreach { s => stageIdToTaskCount(s.stageId) = TaskCount(0, s.numTasks) } + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { + taskEnd.reason match { + case Success => + stageIdToTaskCount.get(taskEnd.stageId).foreach { t => t.currFinishedTasks += 1 } + case _ => + // If task is failed, it will run again, so don't count it. + } + } + + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = synchronized { + // If stage is resubmitted, we should reset the task count of this stage. + stageIdToTaskCount.get(stageSubmitted.stageInfo.stageId).foreach { t => + t.currFinishedTasks = 0 + t.totalTasks = stageSubmitted.stageInfo.numTasks + } + } + + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = synchronized { + stageIdToTaskCount.get(stageCompleted.stageInfo.stageId).foreach { t => + t.currFinishedTasks = t.totalTasks + } + } + + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = synchronized { + jobIdToStatement.get(jobEnd.jobId).foreach { stmtId => + statementToJobs.get(stmtId).foreach { jobs => + jobs.filter(_.jobId == jobEnd.jobId).foreach(_.isCompleted = true) + } + } + + // Try to clean the old data when job is finished. This will trigger data cleaning in LRU + // policy. + cleanOldMetadata() + } + + private def cleanOldMetadata(): Unit = { + if (statementToJobs.size > retainedStatements) { + val toRemove = statementToJobs.size - retainedStatements + statementToJobs.take(toRemove).foreach { case (_, jobs) => + jobs.foreach { job => + jobIdToStatement.remove(job.jobId) + jobIdToStages.remove(job.jobId).foreach { stages => + stages.foreach(s => stageIdToTaskCount.remove(s)) + } + } + } + (0 until toRemove).foreach(_ => statementToJobs.remove(statementToJobs.head._1)) + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/70f23b90/repl/src/test/scala/com/cloudera/livy/repl/PythonInterpreterSpec.scala ---------------------------------------------------------------------- diff --git a/repl/src/test/scala/com/cloudera/livy/repl/PythonInterpreterSpec.scala b/repl/src/test/scala/com/cloudera/livy/repl/PythonInterpreterSpec.scala index 00bbc68..c67d580 100644 --- a/repl/src/test/scala/com/cloudera/livy/repl/PythonInterpreterSpec.scala +++ b/repl/src/test/scala/com/cloudera/livy/repl/PythonInterpreterSpec.scala @@ -23,6 +23,7 @@ import org.json4s.{DefaultFormats, JNull, JValue} import org.json4s.JsonDSL._ import org.scalatest._ +import com.cloudera.livy.rsc.RSCConf import com.cloudera.livy.sessions._ abstract class PythonBaseInterpreterSpec extends BaseInterpreterSpec { @@ -244,7 +245,8 @@ class Python2InterpreterSpec extends PythonBaseInterpreterSpec { implicit val formats = DefaultFormats - override def createInterpreter(): Interpreter = PythonInterpreter(new SparkConf(), PySpark()) + override def createInterpreter(): Interpreter = + PythonInterpreter(new SparkConf(), PySpark(), new StatementProgressListener(new RSCConf())) // Scalastyle is treating unicode escape as non ascii characters. Turn off the check. // scalastyle:off non.ascii.character.disallowed @@ -271,7 +273,8 @@ class Python3InterpreterSpec extends PythonBaseInterpreterSpec { test() } - override def createInterpreter(): Interpreter = PythonInterpreter(new SparkConf(), PySpark3()) + override def createInterpreter(): Interpreter = + PythonInterpreter(new SparkConf(), PySpark3(), new StatementProgressListener(new RSCConf())) it should "check python version is 3.x" in withInterpreter { interpreter => val response = interpreter.execute("""import sys http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/70f23b90/repl/src/test/scala/com/cloudera/livy/repl/PythonSessionSpec.scala ---------------------------------------------------------------------- diff --git a/repl/src/test/scala/com/cloudera/livy/repl/PythonSessionSpec.scala b/repl/src/test/scala/com/cloudera/livy/repl/PythonSessionSpec.scala index 4582acd..28f457f 100644 --- a/repl/src/test/scala/com/cloudera/livy/repl/PythonSessionSpec.scala +++ b/repl/src/test/scala/com/cloudera/livy/repl/PythonSessionSpec.scala @@ -23,6 +23,7 @@ import org.json4s.Extraction import org.json4s.jackson.JsonMethods.parse import org.scalatest._ +import com.cloudera.livy.rsc.RSCConf import com.cloudera.livy.sessions._ abstract class PythonSessionSpec extends BaseSessionSpec { @@ -173,7 +174,8 @@ abstract class PythonSessionSpec extends BaseSessionSpec { } class Python2SessionSpec extends PythonSessionSpec { - override def createInterpreter(): Interpreter = PythonInterpreter(new SparkConf(), PySpark()) + override def createInterpreter(): Interpreter = + PythonInterpreter(new SparkConf(), PySpark(), new StatementProgressListener(new RSCConf())) } class Python3SessionSpec extends PythonSessionSpec { @@ -183,7 +185,8 @@ class Python3SessionSpec extends PythonSessionSpec { test() } - override def createInterpreter(): Interpreter = PythonInterpreter(new SparkConf(), PySpark3()) + override def createInterpreter(): Interpreter = + PythonInterpreter(new SparkConf(), PySpark3(), new StatementProgressListener(new RSCConf())) it should "check python version is 3.x" in withSession { session => val statement = execute(session)( http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/70f23b90/repl/src/test/scala/com/cloudera/livy/repl/ScalaInterpreterSpec.scala ---------------------------------------------------------------------- diff --git a/repl/src/test/scala/com/cloudera/livy/repl/ScalaInterpreterSpec.scala b/repl/src/test/scala/com/cloudera/livy/repl/ScalaInterpreterSpec.scala index 63076e7..a9e1e8b 100644 --- a/repl/src/test/scala/com/cloudera/livy/repl/ScalaInterpreterSpec.scala +++ b/repl/src/test/scala/com/cloudera/livy/repl/ScalaInterpreterSpec.scala @@ -22,11 +22,14 @@ import org.apache.spark.SparkConf import org.json4s.{DefaultFormats, JValue} import org.json4s.JsonDSL._ +import com.cloudera.livy.rsc.RSCConf + class ScalaInterpreterSpec extends BaseInterpreterSpec { implicit val formats = DefaultFormats - override def createInterpreter(): Interpreter = new SparkInterpreter(new SparkConf()) + override def createInterpreter(): Interpreter = + new SparkInterpreter(new SparkConf(), new StatementProgressListener(new RSCConf())) it should "execute `1 + 2` == 3" in withInterpreter { interpreter => val response = interpreter.execute("1 + 2") http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/70f23b90/repl/src/test/scala/com/cloudera/livy/repl/SparkRInterpreterSpec.scala ---------------------------------------------------------------------- diff --git a/repl/src/test/scala/com/cloudera/livy/repl/SparkRInterpreterSpec.scala b/repl/src/test/scala/com/cloudera/livy/repl/SparkRInterpreterSpec.scala index af03581..e9db106 100644 --- a/repl/src/test/scala/com/cloudera/livy/repl/SparkRInterpreterSpec.scala +++ b/repl/src/test/scala/com/cloudera/livy/repl/SparkRInterpreterSpec.scala @@ -23,6 +23,8 @@ import org.json4s.{DefaultFormats, JValue} import org.json4s.JsonDSL._ import org.scalatest._ +import com.cloudera.livy.rsc.RSCConf + class SparkRInterpreterSpec extends BaseInterpreterSpec { implicit val formats = DefaultFormats @@ -32,7 +34,8 @@ class SparkRInterpreterSpec extends BaseInterpreterSpec { super.withFixture(test) } - override def createInterpreter(): Interpreter = SparkRInterpreter(new SparkConf()) + override def createInterpreter(): Interpreter = + SparkRInterpreter(new SparkConf(), new StatementProgressListener(new RSCConf())) it should "execute `1 + 2` == 3" in withInterpreter { interpreter => val response = interpreter.execute("1 + 2") http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/70f23b90/repl/src/test/scala/com/cloudera/livy/repl/SparkRSessionSpec.scala ---------------------------------------------------------------------- diff --git a/repl/src/test/scala/com/cloudera/livy/repl/SparkRSessionSpec.scala b/repl/src/test/scala/com/cloudera/livy/repl/SparkRSessionSpec.scala index cfa2ba5..5592977 100644 --- a/repl/src/test/scala/com/cloudera/livy/repl/SparkRSessionSpec.scala +++ b/repl/src/test/scala/com/cloudera/livy/repl/SparkRSessionSpec.scala @@ -22,6 +22,8 @@ import org.apache.spark.SparkConf import org.json4s.Extraction import org.json4s.jackson.JsonMethods.parse +import com.cloudera.livy.rsc.RSCConf + class SparkRSessionSpec extends BaseSessionSpec { override protected def withFixture(test: NoArgTest) = { @@ -29,7 +31,8 @@ class SparkRSessionSpec extends BaseSessionSpec { super.withFixture(test) } - override def createInterpreter(): Interpreter = SparkRInterpreter(new SparkConf()) + override def createInterpreter(): Interpreter = + SparkRInterpreter(new SparkConf(), new StatementProgressListener(new RSCConf())) it should "execute `1 + 2` == 3" in withSession { session => val statement = execute(session)("1 + 2") http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/70f23b90/repl/src/test/scala/com/cloudera/livy/repl/SparkSessionSpec.scala ---------------------------------------------------------------------- diff --git a/repl/src/test/scala/com/cloudera/livy/repl/SparkSessionSpec.scala b/repl/src/test/scala/com/cloudera/livy/repl/SparkSessionSpec.scala index 2ef7241..a051513 100644 --- a/repl/src/test/scala/com/cloudera/livy/repl/SparkSessionSpec.scala +++ b/repl/src/test/scala/com/cloudera/livy/repl/SparkSessionSpec.scala @@ -23,15 +23,17 @@ import scala.language.postfixOps import org.apache.spark.SparkConf import org.json4s.Extraction -import org.json4s.jackson.JsonMethods.parse import org.json4s.JsonAST.JValue +import org.json4s.jackson.JsonMethods.parse import org.scalatest.concurrent.Eventually._ +import com.cloudera.livy.rsc.RSCConf import com.cloudera.livy.rsc.driver.StatementState class SparkSessionSpec extends BaseSessionSpec { - override def createInterpreter(): Interpreter = new SparkInterpreter(new SparkConf()) + override def createInterpreter(): Interpreter = + new SparkInterpreter(new SparkConf(), new StatementProgressListener(new RSCConf())) it should "execute `1 + 2` == 3" in withSession { session => val statement = execute(session)("1 + 2") http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/70f23b90/repl/src/test/scala/com/cloudera/livy/repl/StatementProgressListenerSpec.scala ---------------------------------------------------------------------- diff --git a/repl/src/test/scala/com/cloudera/livy/repl/StatementProgressListenerSpec.scala b/repl/src/test/scala/com/cloudera/livy/repl/StatementProgressListenerSpec.scala new file mode 100644 index 0000000..2acee4c --- /dev/null +++ b/repl/src/test/scala/com/cloudera/livy/repl/StatementProgressListenerSpec.scala @@ -0,0 +1,227 @@ +/* + * Licensed to Cloudera, Inc. under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Cloudera, Inc. licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.cloudera.livy.repl + +import java.util.concurrent.atomic.AtomicInteger + +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.duration._ +import scala.language.{postfixOps, reflectiveCalls} + +import org.apache.spark.SparkConf +import org.apache.spark.scheduler._ +import org.scalatest._ +import org.scalatest.concurrent.Eventually._ + +import com.cloudera.livy.LivyBaseUnitTestSuite +import com.cloudera.livy.rsc.RSCConf + +class StatementProgressListenerSpec extends FlatSpec + with Matchers + with BeforeAndAfterAll + with BeforeAndAfter + with LivyBaseUnitTestSuite { + private val rscConf = new RSCConf() + .set(RSCConf.Entry.RETAINED_STATEMENT_NUMBER, 2) + + private val testListener = new StatementProgressListener(rscConf) { + var onJobStartedCallback: Option[() => Unit] = None + var onJobEndCallback: Option[() => Unit] = None + var onStageEndCallback: Option[() => Unit] = None + var onTaskEndCallback: Option[() => Unit] = None + + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + super.onJobStart(jobStart) + onJobStartedCallback.foreach(f => f()) + } + + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + super.onJobEnd(jobEnd) + onJobEndCallback.foreach(f => f()) + } + + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { + super.onStageCompleted(stageCompleted) + onStageEndCallback.foreach(f => f()) + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + super.onTaskEnd(taskEnd) + onTaskEndCallback.foreach(f => f()) + } + } + + private val statementId = new AtomicInteger(0) + + private def getStatementId = statementId.getAndIncrement() + + private var sparkInterpreter: SparkInterpreter = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sparkInterpreter = new SparkInterpreter(new SparkConf(), testListener) + sparkInterpreter.start() + } + + override def afterAll(): Unit = { + sparkInterpreter.close() + super.afterAll() + } + + after { + testListener.onJobStartedCallback = None + testListener.onJobEndCallback = None + testListener.onStageEndCallback = None + testListener.onTaskEndCallback = None + } + + it should "correctly calculate progress" in { + val executeCode = + """ + |sc.parallelize(1 to 2, 2).map(i => (i, 1)).collect() + """.stripMargin + val stmtId = getStatementId + + def verifyJobs(): Unit = { + testListener.statementToJobs.get(stmtId) should not be (None) + + // One job will be submitted + testListener.statementToJobs(stmtId).size should be (1) + val jobId = testListener.statementToJobs(stmtId).head.jobId + testListener.jobIdToStatement(jobId) should be (stmtId) + + // 1 stage will be generated + testListener.jobIdToStages(jobId).size should be (1) + val stageIds = testListener.jobIdToStages(jobId) + + // 2 tasks per stage will be generated + stageIds.foreach { id => + testListener.stageIdToTaskCount(id).currFinishedTasks should be (0) + testListener.stageIdToTaskCount(id).totalTasks should be (2) + } + } + + var taskEndCalls = 0 + def verifyTasks(): Unit = { + taskEndCalls += 1 + testListener.progressOfStatement(stmtId) should be (taskEndCalls.toDouble / 2) + } + + var stageEndCalls = 0 + def verifyStages(): Unit = { + stageEndCalls += 1 + testListener.progressOfStatement(stmtId) should be (stageEndCalls.toDouble / 1) + } + + testListener.onJobStartedCallback = Some(verifyJobs) + testListener.onTaskEndCallback = Some(verifyTasks) + testListener.onStageEndCallback = Some(verifyStages) + sparkInterpreter.execute(stmtId, executeCode) + + eventually(timeout(30 seconds), interval(100 millis)) { + testListener.progressOfStatement(stmtId) should be(1.0) + } + } + + it should "not generate Spark jobs for plain Scala code" in { + val executeCode = """1 + 1""" + val stmtId = getStatementId + + def verifyJobs(): Unit = { + fail("No job will be submitted") + } + + testListener.onJobStartedCallback = Some(verifyJobs) + testListener.progressOfStatement(stmtId) should be (0.0) + sparkInterpreter.execute(stmtId, executeCode) + testListener.progressOfStatement(stmtId) should be (0.0) + } + + it should "handle multiple jobs in one statement" in { + val executeCode = + """ + |sc.parallelize(1 to 2, 2).map(i => (i, 1)).collect() + |sc.parallelize(1 to 2, 2).map(i => (i, 1)).collect() + """.stripMargin + val stmtId = getStatementId + + var jobs = 0 + def verifyJobs(): Unit = { + jobs += 1 + + testListener.statementToJobs.get(stmtId) should not be (None) + // One job will be submitted + testListener.statementToJobs(stmtId).size should be (jobs) + val jobId = testListener.statementToJobs(stmtId)(jobs - 1).jobId + testListener.jobIdToStatement(jobId) should be (stmtId) + + // 1 stages will be generated + testListener.jobIdToStages(jobId).size should be (1) + val stageIds = testListener.jobIdToStages(jobId) + + // 2 tasks per stage will be generated + stageIds.foreach { id => + testListener.stageIdToTaskCount(id).currFinishedTasks should be (0) + testListener.stageIdToTaskCount(id).totalTasks should be (2) + } + } + + val taskProgress = ArrayBuffer[Double]() + def verifyTasks(): Unit = { + taskProgress += testListener.progressOfStatement(stmtId) + } + + val stageProgress = ArrayBuffer[Double]() + def verifyStages(): Unit = { + stageProgress += testListener.progressOfStatement(stmtId) + } + + testListener.onJobStartedCallback = Some(verifyJobs) + testListener.onTaskEndCallback = Some(verifyTasks) + testListener.onStageEndCallback = Some(verifyStages) + sparkInterpreter.execute(stmtId, executeCode) + + taskProgress.toArray should be (Array(0.5, 1.0, 0.75, 1.0)) + stageProgress.toArray should be (Array(1.0, 1.0)) + + eventually(timeout(30 seconds), interval(100 millis)) { + testListener.progressOfStatement(stmtId) should be(1.0) + } + } + + it should "remove old statement progress" in { + val executeCode = + """ + |sc.parallelize(1 to 2, 2).map(i => (i, 1)).collect() + """.stripMargin + val stmtId = getStatementId + + def onJobEnd(): Unit = { + testListener.statementToJobs(stmtId).size should be (1) + testListener.statementToJobs(stmtId).head.isCompleted should be (true) + + testListener.statementToJobs.size should be (2) + testListener.statementToJobs.get(0) should be (None) + testListener.jobIdToStatement.filter(_._2 == 0) should be (Map.empty) + } + + testListener.onJobEndCallback = Some(onJobEnd) + sparkInterpreter.execute(stmtId, executeCode) + } +} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/70f23b90/rsc/src/main/java/com/cloudera/livy/rsc/RSCConf.java ---------------------------------------------------------------------- diff --git a/rsc/src/main/java/com/cloudera/livy/rsc/RSCConf.java b/rsc/src/main/java/com/cloudera/livy/rsc/RSCConf.java index 11444e2..0d7b1c1 100644 --- a/rsc/src/main/java/com/cloudera/livy/rsc/RSCConf.java +++ b/rsc/src/main/java/com/cloudera/livy/rsc/RSCConf.java @@ -94,6 +94,10 @@ public class RSCConf extends ClientConf<RSCConf> { public Object dflt() { return dflt; } } + public RSCConf() { + this(new Properties()); + } + public RSCConf(Properties config) { super(config); } http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/70f23b90/rsc/src/main/java/com/cloudera/livy/rsc/driver/Statement.java ---------------------------------------------------------------------- diff --git a/rsc/src/main/java/com/cloudera/livy/rsc/driver/Statement.java b/rsc/src/main/java/com/cloudera/livy/rsc/driver/Statement.java index c88514e..c1717a9 100644 --- a/rsc/src/main/java/com/cloudera/livy/rsc/driver/Statement.java +++ b/rsc/src/main/java/com/cloudera/livy/rsc/driver/Statement.java @@ -26,11 +26,13 @@ public class Statement { public final AtomicReference<StatementState> state; @JsonRawValue public volatile String output; + public double progress; public Statement(Integer id, StatementState state, String output) { this.id = id; this.state = new AtomicReference<>(state); this.output = output; + this.progress = 0.0; } public Statement() { @@ -44,4 +46,12 @@ public class Statement { } return false; } + + public void updateProgress(double p) { + if (this.state.get().isOneOf(StatementState.Cancelled, StatementState.Available)) { + this.progress = 1.0; + } else { + this.progress = p; + } + } } http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/70f23b90/server/src/test/scala/com/cloudera/livy/server/interactive/InteractiveSessionServletSpec.scala ---------------------------------------------------------------------- diff --git a/server/src/test/scala/com/cloudera/livy/server/interactive/InteractiveSessionServletSpec.scala b/server/src/test/scala/com/cloudera/livy/server/interactive/InteractiveSessionServletSpec.scala index 0a31194..63d605d 100644 --- a/server/src/test/scala/com/cloudera/livy/server/interactive/InteractiveSessionServletSpec.scala +++ b/server/src/test/scala/com/cloudera/livy/server/interactive/InteractiveSessionServletSpec.scala @@ -121,6 +121,7 @@ class InteractiveSessionServletSpec extends BaseInteractiveServletSpec { jpost[Map[String, Any]]("/0/statements", ExecuteRequest("foo")) { data => data("id") should be (0) + data("progress") should be (0.0) data("output") shouldBe 1 } http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/70f23b90/server/src/test/scala/com/cloudera/livy/server/interactive/InteractiveSessionSpec.scala ---------------------------------------------------------------------- diff --git a/server/src/test/scala/com/cloudera/livy/server/interactive/InteractiveSessionSpec.scala b/server/src/test/scala/com/cloudera/livy/server/interactive/InteractiveSessionSpec.scala index 146df9b..28d7157 100644 --- a/server/src/test/scala/com/cloudera/livy/server/interactive/InteractiveSessionSpec.scala +++ b/server/src/test/scala/com/cloudera/livy/server/interactive/InteractiveSessionSpec.scala @@ -209,6 +209,22 @@ class InteractiveSessionSpec extends FunSpec } } + withSession("should get statement progress along with statement result") { session => + val code = + """ + |from time import sleep + |sleep(3) + """.stripMargin + val statement = session.executeStatement(ExecuteRequest(code)) + statement.progress should be (0.0) + + eventually(timeout(10 seconds), interval(100 millis)) { + val s = session.getStatement(statement.id).get + s.state.get() shouldBe StatementState.Available + s.progress should be (1.0) + } + } + withSession("should error out the session if the interpreter dies") { session => session.executeStatement(ExecuteRequest("import os; os._exit(666)")) eventually(timeout(30 seconds), interval(100 millis)) {