LIVY-303. Add statement retention mechanism. (#279) To avoid OOM for long running sessions, introduce statement retention mechanism to remove old statements.
Also refactor the statement state code to make it more clear. Project: http://git-wip-us.apache.org/repos/asf/incubator-livy/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-livy/commit/932d397b Tree: http://git-wip-us.apache.org/repos/asf/incubator-livy/tree/932d397b Diff: http://git-wip-us.apache.org/repos/asf/incubator-livy/diff/932d397b Branch: refs/heads/master Commit: 932d397bf6d6fe7ed9df455f8230227cb0e3761b Parents: 2aa910c Author: Saisai Shao <sai.sai.s...@gmail.com> Authored: Sat Feb 25 02:24:05 2017 +0800 Committer: Alex Man <alex-the-...@users.noreply.github.com> Committed: Fri Feb 24 10:24:05 2017 -0800 ---------------------------------------------------------------------- .../scala/com/cloudera/livy/repl/Session.scala | 91 +++++++++++--------- .../com/cloudera/livy/repl/SessionSpec.scala | 31 +++++++ .../java/com/cloudera/livy/rsc/RSCConf.java | 4 +- .../com/cloudera/livy/rsc/driver/Statement.java | 8 ++ .../livy/rsc/driver/StatementState.java | 46 ++++++++++ 5 files changed, 140 insertions(+), 40 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/932d397b/repl/src/main/scala/com/cloudera/livy/repl/Session.scala ---------------------------------------------------------------------- diff --git a/repl/src/main/scala/com/cloudera/livy/repl/Session.scala b/repl/src/main/scala/com/cloudera/livy/repl/Session.scala index f927e73..bf1f3b4 100644 --- a/repl/src/main/scala/com/cloudera/livy/repl/Session.scala +++ b/repl/src/main/scala/com/cloudera/livy/repl/Session.scala @@ -18,10 +18,12 @@ package com.cloudera.livy.repl +import java.util.{LinkedHashMap => JLinkedHashMap} +import java.util.Map.Entry import java.util.concurrent.Executors import java.util.concurrent.atomic.AtomicInteger -import scala.collection.concurrent.TrieMap +import scala.collection.JavaConverters._ import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration._ @@ -53,7 +55,7 @@ class Session( extends Logging { import Session._ - private implicit val executor = ExecutionContext.fromExecutorService( + private val interpreterExecutor = ExecutionContext.fromExecutorService( Executors.newSingleThreadExecutor()) private val cancelExecutor = ExecutionContext.fromExecutorService( @@ -64,7 +66,15 @@ class Session( @volatile private[repl] var _sc: Option[SparkContext] = None private var _state: SessionState = SessionState.NotStarted() - private val _statements = TrieMap[Int, Statement]() + + // Number of statements kept in driver's memory + private val numRetainedStatements = livyConf.getInt(RSCConf.Entry.RETAINED_STATEMENT_NUMBER) + + private val _statements = new JLinkedHashMap[Int, Statement] { + protected override def removeEldestEntry(eldest: Entry[Int, Statement]): Boolean = { + size() > numRetainedStatements + } + }.asScala private val newStatementId = new AtomicInteger(0) @@ -77,9 +87,9 @@ class Session( _sc = Option(sc) changeState(SessionState.Idle()) sc - } + }(interpreterExecutor) - future.onFailure { case _ => changeState(SessionState.Error()) } + future.onFailure { case _ => changeState(SessionState.Error()) }(interpreterExecutor) future } @@ -87,72 +97,75 @@ class Session( def state: SessionState = _state - def statements: collection.Map[Int, Statement] = _statements.readOnlySnapshot() + def statements: collection.Map[Int, Statement] = _statements.synchronized { + _statements.toMap + } def execute(code: String): Int = { val statementId = newStatementId.getAndIncrement() - _statements(statementId) = new Statement(statementId, StatementState.Waiting, null) + val statement = new Statement(statementId, StatementState.Waiting, null) + _statements.synchronized { _statements(statementId) = statement } Future { setJobGroup(statementId) - _statements(statementId).state.compareAndSet(StatementState.Waiting, StatementState.Running) + statement.compareAndTransit(StatementState.Waiting, StatementState.Running) - val executeResult = if (_statements(statementId).state.get() == StatementState.Running) { - executeCode(statementId, code) - } else { - null + if (statement.state.get() == StatementState.Running) { + statement.output = executeCode(statementId, code) } - _statements(statementId).output = executeResult - _statements(statementId).state.compareAndSet(StatementState.Running, StatementState.Available) - _statements(statementId).state.compareAndSet( - StatementState.Cancelling, StatementState.Cancelled) - } + statement.compareAndTransit(StatementState.Running, StatementState.Available) + statement.compareAndTransit(StatementState.Cancelling, StatementState.Cancelled) + }(interpreterExecutor) statementId } def cancel(statementId: Int): Unit = { - if (!_statements.contains(statementId)) { + val statementOpt = _statements.synchronized { _statements.get(statementId) } + if (statementOpt.isEmpty) { return } - if (_statements(statementId).state.get() == StatementState.Available || - _statements(statementId).state.get() == StatementState.Cancelled || - _statements(statementId).state.get() == StatementState.Cancelling) { + val statement = statementOpt.get + if (statement.state.get().isOneOf( + StatementState.Available, StatementState.Cancelled, StatementState.Cancelling)) { return } else { // statement 1 is running and statement 2 is waiting. User cancels // statement 2 then cancels statement 1. The 2nd cancel call will loop and block the 1st // cancel call since cancelExecutor is single threaded. To avoid this, set the statement // state to cancelled when cancelling a waiting statement. - _statements(statementId).state.compareAndSet(StatementState.Waiting, StatementState.Cancelled) - _statements(statementId).state.compareAndSet( - StatementState.Running, StatementState.Cancelling) + statement.compareAndTransit(StatementState.Waiting, StatementState.Cancelled) + statement.compareAndTransit(StatementState.Running, StatementState.Cancelling) } info(s"Cancelling statement $statementId...") - Future { - val deadline = livyConf.getTimeAsMs(RSCConf.Entry.JOB_CANCEL_TIMEOUT).millis.fromNow - while (_statements(statementId).state.get() == StatementState.Cancelling) { - if (deadline.isOverdue()) { - info(s"Failed to cancel statement $statementId.") - _statements(statementId).state.compareAndSet( - StatementState.Cancelling, StatementState.Cancelled) - } else { - _sc.foreach(_.cancelJobGroup(statementId.toString)) + Future { + val deadline = livyConf.getTimeAsMs(RSCConf.Entry.JOB_CANCEL_TIMEOUT).millis.fromNow + + while (statement.state.get() == StatementState.Cancelling) { + if (deadline.isOverdue()) { + info(s"Failed to cancel statement $statementId.") + statement.compareAndTransit(StatementState.Cancelling, StatementState.Cancelled) + } else { + _sc.foreach(_.cancelJobGroup(statementId.toString)) + if (statement.state.get() == StatementState.Cancelling) { + Thread.sleep(livyConf.getTimeAsMs(RSCConf.Entry.JOB_CANCEL_TRIGGER_INTERVAL)) } - Thread.sleep(livyConf.getTimeAsMs(RSCConf.Entry.JOB_CANCEL_TRIGGER_INTERVAL)) - } - if (_statements(statementId).state.get() == StatementState.Cancelled) { - info(s"Statement $statementId cancelled.") } - }(cancelExecutor) + } + + if (statement.state.get() == StatementState.Cancelled) { + info(s"Statement $statementId cancelled.") + } + }(cancelExecutor) } def close(): Unit = { - executor.shutdown() + interpreterExecutor.shutdown() + cancelExecutor.shutdown() interpreter.close() } http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/932d397b/repl/src/test/scala/com/cloudera/livy/repl/SessionSpec.scala ---------------------------------------------------------------------- diff --git a/repl/src/test/scala/com/cloudera/livy/repl/SessionSpec.scala b/repl/src/test/scala/com/cloudera/livy/repl/SessionSpec.scala index 6329365..203d15e 100644 --- a/repl/src/test/scala/com/cloudera/livy/repl/SessionSpec.scala +++ b/repl/src/test/scala/com/cloudera/livy/repl/SessionSpec.scala @@ -92,5 +92,36 @@ class SessionSpec extends FunSpec with Eventually with LivyBaseUnitTestSuite { actualStateTransitions.toArray shouldBe expectedStateTransitions } } + + it("should remove old statements when reaching threshold") { + val interpreter = mock[Interpreter] + when(interpreter.kind).thenAnswer(new Answer[String] { + override def answer(invocationOnMock: InvocationOnMock): String = "spark" + }) + + rscConf.set(RSCConf.Entry.RETAINED_STATEMENT_NUMBER, 2) + val session = new Session(rscConf, interpreter) + session.start() + + session.statements.size should be (0) + session.execute("") + session.statements.size should be (1) + session.statements.map(_._1).toSet should be (Set(0)) + session.execute("") + session.statements.size should be (2) + session.statements.map(_._1).toSet should be (Set(0, 1)) + session.execute("") + eventually { + session.statements.size should be (2) + session.statements.map(_._1).toSet should be (Set(1, 2)) + } + + // Continue submitting statements, total statements in memory should be 2. + session.execute("") + eventually { + session.statements.size should be (2) + session.statements.map(_._1).toSet should be (Set(2, 3)) + } + } } } http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/932d397b/rsc/src/main/java/com/cloudera/livy/rsc/RSCConf.java ---------------------------------------------------------------------- diff --git a/rsc/src/main/java/com/cloudera/livy/rsc/RSCConf.java b/rsc/src/main/java/com/cloudera/livy/rsc/RSCConf.java index 1520990..11444e2 100644 --- a/rsc/src/main/java/com/cloudera/livy/rsc/RSCConf.java +++ b/rsc/src/main/java/com/cloudera/livy/rsc/RSCConf.java @@ -75,7 +75,9 @@ public class RSCConf extends ClientConf<RSCConf> { TEST_STUCK_START_DRIVER("test.do_not_use.stuck_start_driver", false), JOB_CANCEL_TRIGGER_INTERVAL("job_cancel.trigger_interval", "100ms"), - JOB_CANCEL_TIMEOUT("job_cancel.timeout", "30s"); + JOB_CANCEL_TIMEOUT("job_cancel.timeout", "30s"), + + RETAINED_STATEMENT_NUMBER("retained_statements", 100); private final String key; private final Object dflt; http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/932d397b/rsc/src/main/java/com/cloudera/livy/rsc/driver/Statement.java ---------------------------------------------------------------------- diff --git a/rsc/src/main/java/com/cloudera/livy/rsc/driver/Statement.java b/rsc/src/main/java/com/cloudera/livy/rsc/driver/Statement.java index 512b238..c88514e 100644 --- a/rsc/src/main/java/com/cloudera/livy/rsc/driver/Statement.java +++ b/rsc/src/main/java/com/cloudera/livy/rsc/driver/Statement.java @@ -36,4 +36,12 @@ public class Statement { public Statement() { this(null, null, null); } + + public boolean compareAndTransit(final StatementState from, final StatementState to) { + if (state.compareAndSet(from, to)) { + StatementState.validate(from, to); + return true; + } + return false; + } } http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/932d397b/rsc/src/main/java/com/cloudera/livy/rsc/driver/StatementState.java ---------------------------------------------------------------------- diff --git a/rsc/src/main/java/com/cloudera/livy/rsc/driver/StatementState.java b/rsc/src/main/java/com/cloudera/livy/rsc/driver/StatementState.java index b5414ab..5e084bc 100644 --- a/rsc/src/main/java/com/cloudera/livy/rsc/driver/StatementState.java +++ b/rsc/src/main/java/com/cloudera/livy/rsc/driver/StatementState.java @@ -17,7 +17,12 @@ package com.cloudera.livy.rsc.driver; +import java.util.*; + import com.fasterxml.jackson.annotation.JsonValue; +import com.google.common.base.Preconditions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public enum StatementState { Waiting("waiting"), @@ -26,6 +31,8 @@ public enum StatementState { Cancelling("cancelling"), Cancelled("cancelled"); + private static final Logger LOG = LoggerFactory.getLogger(StatementState.class); + private final String state; StatementState(final String text) { @@ -37,4 +44,43 @@ public enum StatementState { public String toString() { return state; } + + public boolean isOneOf(StatementState... states) { + for (StatementState s : states) { + if (s == this) { + return true; + } + } + return false; + } + + private static final Map<StatementState, List<StatementState>> PREDECESSORS; + + static void put(StatementState key, + Map<StatementState, List<StatementState>> map, + StatementState... values) { + map.put(key, Collections.unmodifiableList(Arrays.asList(values))); + } + + static { + final Map<StatementState, List<StatementState>> predecessors = + new EnumMap<>(StatementState.class); + put(Waiting, predecessors); + put(Running, predecessors, Waiting); + put(Available, predecessors, Running); + put(Cancelling, predecessors, Running); + put(Cancelled, predecessors, Waiting, Cancelling); + + PREDECESSORS = Collections.unmodifiableMap(predecessors); + } + + static boolean isValid(StatementState from, StatementState to) { + return PREDECESSORS.get(to).contains(from); + } + + static void validate(StatementState from, StatementState to) { + LOG.debug("{} -> {}", from, to); + + Preconditions.checkState(isValid(from, to), "Illegal Transition: %s -> %s", from, to); + } }