http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/server/src/test/scala/com/cloudera/livy/utils/SparkYarnAppSpec.scala ---------------------------------------------------------------------- diff --git a/server/src/test/scala/com/cloudera/livy/utils/SparkYarnAppSpec.scala b/server/src/test/scala/com/cloudera/livy/utils/SparkYarnAppSpec.scala deleted file mode 100644 index 37f001a..0000000 --- a/server/src/test/scala/com/cloudera/livy/utils/SparkYarnAppSpec.scala +++ /dev/null @@ -1,353 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.cloudera.livy.utils - -import java.util.concurrent.{CountDownLatch, TimeUnit} - -import scala.concurrent.duration._ -import scala.language.postfixOps - -import org.apache.hadoop.yarn.api.records._ -import org.apache.hadoop.yarn.api.records.FinalApplicationStatus.UNDEFINED -import org.apache.hadoop.yarn.api.records.YarnApplicationState._ -import org.apache.hadoop.yarn.client.api.YarnClient -import org.apache.hadoop.yarn.exceptions.ApplicationAttemptNotFoundException -import org.apache.hadoop.yarn.util.ConverterUtils -import org.mockito.Mockito._ -import org.mockito.invocation.InvocationOnMock -import org.mockito.stubbing.Answer -import org.scalatest.FunSpec -import org.scalatest.mock.MockitoSugar.mock - -import com.cloudera.livy.{LivyBaseUnitTestSuite, LivyConf} -import com.cloudera.livy.util.LineBufferedProcess -import com.cloudera.livy.utils.SparkApp._ - -class SparkYarnAppSpec extends FunSpec with LivyBaseUnitTestSuite { - private def cleanupThread(t: Thread)(f: => Unit) = { - try { f } finally { t.interrupt() } - } - - private def mockSleep(ms: Long) = { - Thread.`yield`() - } - - describe("SparkYarnApp") { - val TEST_TIMEOUT = 30 seconds - val appId = ConverterUtils.toApplicationId("application_1467912463905_0021") - val appIdOption = Some(appId.toString) - val appTag = "fakeTag" - val livyConf = new LivyConf() - livyConf.set(LivyConf.YARN_APP_LOOKUP_TIMEOUT, "30s") - - it("should poll YARN state and terminate") { - Clock.withSleepMethod(mockSleep) { - val mockYarnClient = mock[YarnClient] - val mockAppListener = mock[SparkAppListener] - - val mockAppReport = mock[ApplicationReport] - when(mockAppReport.getApplicationId).thenReturn(appId) - when(mockAppReport.getFinalApplicationStatus).thenReturn(FinalApplicationStatus.SUCCEEDED) - // Simulate YARN app state progression. - when(mockAppReport.getYarnApplicationState).thenAnswer(new Answer[YarnApplicationState]() { - private var stateSeq = List(ACCEPTED, RUNNING, FINISHED) - - override def answer(invocation: InvocationOnMock): YarnApplicationState = { - val currentState = stateSeq.head - if (stateSeq.tail.nonEmpty) { - stateSeq = stateSeq.tail - } - currentState - } - }) - when(mockYarnClient.getApplicationReport(appId)).thenReturn(mockAppReport) - - val app = new SparkYarnApp( - appTag, - appIdOption, - None, - Some(mockAppListener), - livyConf, - mockYarnClient) - cleanupThread(app.yarnAppMonitorThread) { - app.yarnAppMonitorThread.join(TEST_TIMEOUT.toMillis) - assert(!app.yarnAppMonitorThread.isAlive, - "YarnAppMonitorThread should terminate after YARN app is finished.") - verify(mockYarnClient, atLeast(1)).getApplicationReport(appId) - verify(mockAppListener).stateChanged(State.STARTING, State.RUNNING) - verify(mockAppListener).stateChanged(State.RUNNING, State.FINISHED) - } - } - } - - it("should kill yarn app") { - Clock.withSleepMethod(mockSleep) { - val diag = "DIAG" - val mockYarnClient = mock[YarnClient] - - val mockAppReport = mock[ApplicationReport] - when(mockAppReport.getApplicationId).thenReturn(appId) - when(mockAppReport.getDiagnostics).thenReturn(diag) - when(mockAppReport.getFinalApplicationStatus).thenReturn(FinalApplicationStatus.SUCCEEDED) - - var appKilled = false - when(mockAppReport.getYarnApplicationState).thenAnswer(new Answer[YarnApplicationState]() { - override def answer(invocation: InvocationOnMock): YarnApplicationState = { - if (!appKilled) { - RUNNING - } else { - KILLED - } - } - }) - when(mockYarnClient.getApplicationReport(appId)).thenReturn(mockAppReport) - - val app = new SparkYarnApp(appTag, appIdOption, None, None, livyConf, mockYarnClient) - cleanupThread(app.yarnAppMonitorThread) { - app.kill() - appKilled = true - - app.yarnAppMonitorThread.join(TEST_TIMEOUT.toMillis) - assert(!app.yarnAppMonitorThread.isAlive, - "YarnAppMonitorThread should terminate after YARN app is finished.") - verify(mockYarnClient, atLeast(1)).getApplicationReport(appId) - verify(mockYarnClient).killApplication(appId) - assert(app.log().mkString.contains(diag)) - } - } - } - - it("should return spark-submit log") { - Clock.withSleepMethod(mockSleep) { - val mockYarnClient = mock[YarnClient] - val mockSparkSubmit = mock[LineBufferedProcess] - val sparkSubmitInfoLog = IndexedSeq("SPARK-SUBMIT", "LOG") - val sparkSubmitErrorLog = IndexedSeq("SPARK-SUBMIT", "error log") - val sparkSubmitLog = ("stdout: " +: sparkSubmitInfoLog) ++ - ("\nstderr: " +: sparkSubmitErrorLog) :+ "\nYARN Diagnostics: " - when(mockSparkSubmit.inputLines).thenReturn(sparkSubmitInfoLog) - when(mockSparkSubmit.errorLines).thenReturn(sparkSubmitErrorLog) - val waitForCalledLatch = new CountDownLatch(1) - when(mockSparkSubmit.waitFor()).thenAnswer(new Answer[Int]() { - override def answer(invocation: InvocationOnMock): Int = { - waitForCalledLatch.countDown() - 0 - } - }) - - val mockAppReport = mock[ApplicationReport] - when(mockAppReport.getApplicationId).thenReturn(appId) - when(mockAppReport.getYarnApplicationState).thenReturn(YarnApplicationState.FINISHED) - when(mockAppReport.getDiagnostics).thenReturn(null) - when(mockYarnClient.getApplicationReport(appId)).thenReturn(mockAppReport) - - val app = new SparkYarnApp( - appTag, - appIdOption, - Some(mockSparkSubmit), - None, - livyConf, - mockYarnClient) - cleanupThread(app.yarnAppMonitorThread) { - waitForCalledLatch.await(TEST_TIMEOUT.toMillis, TimeUnit.MILLISECONDS) - assert(app.log() == sparkSubmitLog, "Expect spark-submit log") - } - } - } - - it("can kill spark-submit while it's running") { - Clock.withSleepMethod(mockSleep) { - val livyConf = new LivyConf() - livyConf.set(LivyConf.YARN_APP_LOOKUP_TIMEOUT, "0") - - val mockYarnClient = mock[YarnClient] - val mockSparkSubmit = mock[LineBufferedProcess] - - val sparkSubmitRunningLatch = new CountDownLatch(1) - // Simulate a running spark-submit - when(mockSparkSubmit.waitFor()).thenAnswer(new Answer[Int]() { - override def answer(invocation: InvocationOnMock): Int = { - sparkSubmitRunningLatch.await() - 0 - } - }) - - val app = new SparkYarnApp( - appTag, - appIdOption, - Some(mockSparkSubmit), - None, - livyConf, - mockYarnClient) - cleanupThread(app.yarnAppMonitorThread) { - app.kill() - verify(mockSparkSubmit, times(1)).destroy() - sparkSubmitRunningLatch.countDown() - } - } - } - - it("should map YARN state to SparkApp.State correctly") { - val app = new SparkYarnApp(appTag, appIdOption, None, None, livyConf) - cleanupThread(app.yarnAppMonitorThread) { - assert(app.mapYarnState(appId, NEW, UNDEFINED) == State.STARTING) - assert(app.mapYarnState(appId, NEW_SAVING, UNDEFINED) == State.STARTING) - assert(app.mapYarnState(appId, SUBMITTED, UNDEFINED) == State.STARTING) - assert(app.mapYarnState(appId, ACCEPTED, UNDEFINED) == State.STARTING) - assert(app.mapYarnState(appId, RUNNING, UNDEFINED) == State.RUNNING) - assert( - app.mapYarnState(appId, FINISHED, FinalApplicationStatus.SUCCEEDED) == State.FINISHED) - assert(app.mapYarnState(appId, FINISHED, FinalApplicationStatus.FAILED) == State.FAILED) - assert(app.mapYarnState(appId, FINISHED, FinalApplicationStatus.KILLED) == State.KILLED) - assert(app.mapYarnState(appId, FINISHED, UNDEFINED) == State.FAILED) - assert(app.mapYarnState(appId, FAILED, UNDEFINED) == State.FAILED) - assert(app.mapYarnState(appId, KILLED, UNDEFINED) == State.KILLED) - } - } - - it("should expose driver log url and Spark UI url") { - Clock.withSleepMethod(mockSleep) { - val mockYarnClient = mock[YarnClient] - val driverLogUrl = "DRIVER LOG URL" - val sparkUiUrl = "SPARK UI URL" - - val mockApplicationAttemptId = mock[ApplicationAttemptId] - val mockAppReport = mock[ApplicationReport] - when(mockAppReport.getApplicationId).thenReturn(appId) - when(mockAppReport.getFinalApplicationStatus).thenReturn(FinalApplicationStatus.SUCCEEDED) - when(mockAppReport.getTrackingUrl).thenReturn(sparkUiUrl) - when(mockAppReport.getCurrentApplicationAttemptId).thenReturn(mockApplicationAttemptId) - var done = false - when(mockAppReport.getYarnApplicationState).thenAnswer(new Answer[YarnApplicationState]() { - override def answer(invocation: InvocationOnMock): YarnApplicationState = { - if (!done) { - RUNNING - } else { - FINISHED - } - } - }) - when(mockYarnClient.getApplicationReport(appId)).thenReturn(mockAppReport) - - val mockAttemptReport = mock[ApplicationAttemptReport] - val mockContainerId = mock[ContainerId] - when(mockAttemptReport.getAMContainerId).thenReturn(mockContainerId) - when(mockYarnClient.getApplicationAttemptReport(mockApplicationAttemptId)) - .thenReturn(mockAttemptReport) - - val mockContainerReport = mock[ContainerReport] - when(mockYarnClient.getContainerReport(mockContainerId)).thenReturn(mockContainerReport) - - // Block test until getLogUrl is called 10 times. - val getLogUrlCountDown = new CountDownLatch(10) - when(mockContainerReport.getLogUrl).thenAnswer(new Answer[String] { - override def answer(invocation: InvocationOnMock): String = { - getLogUrlCountDown.countDown() - driverLogUrl - } - }) - - val mockListener = mock[SparkAppListener] - - val app = new SparkYarnApp( - appTag, appIdOption, None, Some(mockListener), livyConf, mockYarnClient) - cleanupThread(app.yarnAppMonitorThread) { - getLogUrlCountDown.await(TEST_TIMEOUT.length, TEST_TIMEOUT.unit) - done = true - - app.yarnAppMonitorThread.join(TEST_TIMEOUT.toMillis) - assert(!app.yarnAppMonitorThread.isAlive, - "YarnAppMonitorThread should terminate after YARN app is finished.") - - verify(mockYarnClient, atLeast(1)).getApplicationReport(appId) - verify(mockAppReport, atLeast(1)).getTrackingUrl() - verify(mockContainerReport, atLeast(1)).getLogUrl() - verify(mockListener).appIdKnown(appId.toString) - verify(mockListener).infoChanged(AppInfo(Some(driverLogUrl), Some(sparkUiUrl))) - } - } - } - - it("should not die on YARN-4411") { - Clock.withSleepMethod(mockSleep) { - val mockYarnClient = mock[YarnClient] - - // Block test until getApplicationReport is called 10 times. - val pollCountDown = new CountDownLatch(10) - when(mockYarnClient.getApplicationReport(appId)).thenAnswer(new Answer[ApplicationReport] { - override def answer(invocation: InvocationOnMock): ApplicationReport = { - pollCountDown.countDown() - throw new IllegalArgumentException("No enum constant " + - "org.apache.hadoop.yarn.api.records.YarnApplicationAttemptState.FINAL_SAVING") - } - }) - - val app = new SparkYarnApp(appTag, appIdOption, None, None, livyConf, mockYarnClient) - cleanupThread(app.yarnAppMonitorThread) { - pollCountDown.await(TEST_TIMEOUT.length, TEST_TIMEOUT.unit) - assert(app.state == SparkApp.State.STARTING) - - app.state = SparkApp.State.FINISHED - app.yarnAppMonitorThread.join(TEST_TIMEOUT.toMillis) - } - } - } - - it("should not die on ApplicationAttemptNotFoundException") { - Clock.withSleepMethod(mockSleep) { - val mockYarnClient = mock[YarnClient] - val mockAppReport = mock[ApplicationReport] - val mockApplicationAttemptId = mock[ApplicationAttemptId] - var done = false - - when(mockAppReport.getApplicationId).thenReturn(appId) - when(mockAppReport.getYarnApplicationState).thenAnswer( - new Answer[YarnApplicationState]() { - override def answer(invocation: InvocationOnMock): YarnApplicationState = { - if (done) { - FINISHED - } else { - RUNNING - } - } - }) - when(mockAppReport.getFinalApplicationStatus).thenReturn(FinalApplicationStatus.SUCCEEDED) - when(mockAppReport.getCurrentApplicationAttemptId).thenReturn(mockApplicationAttemptId) - when(mockYarnClient.getApplicationReport(appId)).thenReturn(mockAppReport) - - // Block test until getApplicationReport is called 10 times. - val pollCountDown = new CountDownLatch(10) - when(mockYarnClient.getApplicationAttemptReport(mockApplicationAttemptId)).thenAnswer( - new Answer[ApplicationReport] { - override def answer(invocation: InvocationOnMock): ApplicationReport = { - pollCountDown.countDown() - throw new ApplicationAttemptNotFoundException("unit test") - } - }) - - val app = new SparkYarnApp(appTag, appIdOption, None, None, livyConf, mockYarnClient) - cleanupThread(app.yarnAppMonitorThread) { - pollCountDown.await(TEST_TIMEOUT.length, TEST_TIMEOUT.unit) - assert(app.state == SparkApp.State.RUNNING) - done = true - - app.yarnAppMonitorThread.join(TEST_TIMEOUT.toMillis) - } - } - } - } -}
http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/server/src/test/scala/org/apache/livy/server/ApiVersioningSupportSpec.scala ---------------------------------------------------------------------- diff --git a/server/src/test/scala/org/apache/livy/server/ApiVersioningSupportSpec.scala b/server/src/test/scala/org/apache/livy/server/ApiVersioningSupportSpec.scala new file mode 100644 index 0000000..0f50ced --- /dev/null +++ b/server/src/test/scala/org/apache/livy/server/ApiVersioningSupportSpec.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.server + +import javax.servlet.http.HttpServletResponse + +import org.scalatest.FunSpecLike +import org.scalatra.ScalatraServlet +import org.scalatra.test.scalatest.ScalatraSuite + +import org.apache.livy.LivyBaseUnitTestSuite + +class ApiVersioningSupportSpec extends ScalatraSuite with FunSpecLike with LivyBaseUnitTestSuite { + val LatestVersionOutput = "latest" + + object FakeApiVersions extends Enumeration { + type FakeApiVersions = Value + val v0_1 = Value("0.1") + val v0_2 = Value("0.2") + val v1_0 = Value("1.0") + } + + import FakeApiVersions._ + + class MockServlet extends ScalatraServlet with AbstractApiVersioningSupport { + override val apiVersions = FakeApiVersions + override type ApiVersionType = FakeApiVersions.Value + + get("/test") { + response.writer.write(LatestVersionOutput) + } + + get("/test", apiVersion <= v0_2) { + response.writer.write(v0_2.toString) + } + + get("/test", apiVersion <= v0_1) { + response.writer.write(v0_1.toString) + } + + get("/droppedApi", apiVersion <= v0_2) { + } + + get("/newApi", apiVersion >= v0_2) { + } + } + + var mockServlet: MockServlet = new MockServlet + addServlet(mockServlet, "/*") + + def generateHeader(acceptHeader: String): Map[String, String] = { + if (acceptHeader != null) Map("Accept" -> acceptHeader) else Map.empty + } + + def shouldReturn(url: String, acceptHeader: String, expectedVersion: String = null): Unit = { + get(url, headers = generateHeader(acceptHeader)) { + status should equal(200) + if (expectedVersion != null) { + body should equal(expectedVersion) + } + } + } + + def shouldFail(url: String, acceptHeader: String, expectedErrorCode: Int): Unit = { + get(url, headers = generateHeader(acceptHeader)) { + status should equal(expectedErrorCode) + } + } + + it("should pick the latest API version if Accept header is unspecified") { + shouldReturn("/test", null, LatestVersionOutput) + } + + it("should pick the latest API version if Accept header does not specify any version") { + shouldReturn("/test", "foo", LatestVersionOutput) + shouldReturn("/test", "application/vnd.random.v1.1", LatestVersionOutput) + shouldReturn("/test", "application/vnd.livy.+json", LatestVersionOutput) + } + + it("should pick the correct API version") { + shouldReturn("/test", "application/vnd.livy.v0.1", v0_1.toString) + shouldReturn("/test", "application/vnd.livy.v0.2+", v0_2.toString) + shouldReturn("/test", "application/vnd.livy.v0.1+bar", v0_1.toString) + shouldReturn("/test", "application/vnd.livy.v0.2+foo", v0_2.toString) + shouldReturn("/test", "application/vnd.livy.v0.1+vnd.livy.v0.2", v0_1.toString) + shouldReturn("/test", "application/vnd.livy.v0.2++++++++++++++++", v0_2.toString) + shouldReturn("/test", "application/vnd.livy.v1.0", LatestVersionOutput) + } + + it("should return error when the specified API version does not exist") { + shouldFail("/test", "application/vnd.livy.v", HttpServletResponse.SC_NOT_ACCEPTABLE) + shouldFail("/test", "application/vnd.livy.v+json", HttpServletResponse.SC_NOT_ACCEPTABLE) + shouldFail("/test", "application/vnd.livy.v666.666", HttpServletResponse.SC_NOT_ACCEPTABLE) + shouldFail("/test", "application/vnd.livy.v666.666+json", HttpServletResponse.SC_NOT_ACCEPTABLE) + shouldFail("/test", "application/vnd.livy.v1.1+json", HttpServletResponse.SC_NOT_ACCEPTABLE) + } + + it("should not see a dropped API") { + shouldReturn("/droppedApi", "application/vnd.livy.v0.1+json") + shouldReturn("/droppedApi", "application/vnd.livy.v0.2+json") + shouldFail("/droppedApi", "application/vnd.livy.v1.0+json", HttpServletResponse.SC_NOT_FOUND) + } + + it("should not see a new API at an older version") { + shouldFail("/newApi", "application/vnd.livy.v0.1+json", HttpServletResponse.SC_NOT_FOUND) + shouldReturn("/newApi", "application/vnd.livy.v0.2+json") + shouldReturn("/newApi", "application/vnd.livy.v1.0+json") + } +} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/server/src/test/scala/org/apache/livy/server/BaseJsonServletSpec.scala ---------------------------------------------------------------------- diff --git a/server/src/test/scala/org/apache/livy/server/BaseJsonServletSpec.scala b/server/src/test/scala/org/apache/livy/server/BaseJsonServletSpec.scala new file mode 100644 index 0000000..959707a --- /dev/null +++ b/server/src/test/scala/org/apache/livy/server/BaseJsonServletSpec.scala @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.server + +import java.io.ByteArrayOutputStream +import javax.servlet.http.HttpServletResponse._ + +import scala.reflect.ClassTag + +import com.fasterxml.jackson.databind.ObjectMapper +import org.scalatest.FunSpecLike +import org.scalatra.test.scalatest.ScalatraSuite + +import org.apache.livy.LivyBaseUnitTestSuite + +/** + * Base class that enhances ScalatraSuite so that it's easier to test JsonServlet + * implementations. Variants of the test methods (get, post, etc) exist with the "j" + * prefix; these automatically serialize the body of the request to JSON, and + * deserialize the result from JSON. + * + * In case the response is not JSON, the expected type for the test function should be + * `Unit`, and the `response` object should be checked directly. + */ +abstract class BaseJsonServletSpec extends ScalatraSuite + with FunSpecLike with LivyBaseUnitTestSuite { + + protected val mapper = new ObjectMapper() + .registerModule(com.fasterxml.jackson.module.scala.DefaultScalaModule) + + protected val defaultHeaders: Map[String, String] = Map("Content-Type" -> "application/json") + + protected def jdelete[R: ClassTag]( + uri: String, + expectedStatus: Int = SC_OK, + headers: Map[String, String] = defaultHeaders) + (fn: R => Unit): Unit = { + delete(uri, headers = headers)(doTest(expectedStatus, fn)) + } + + protected def jget[R: ClassTag]( + uri: String, + expectedStatus: Int = SC_OK, + headers: Map[String, String] = defaultHeaders) + (fn: R => Unit): Unit = { + get(uri, headers = headers)(doTest(expectedStatus, fn)) + } + + protected def jpatch[R: ClassTag]( + uri: String, + body: AnyRef, + expectedStatus: Int = SC_OK, + headers: Map[String, String] = defaultHeaders) + (fn: R => Unit): Unit = { + patch(uri, body = toJson(body), headers = headers)(doTest(expectedStatus, fn)) + } + + protected def jpost[R: ClassTag]( + uri: String, + body: AnyRef, + expectedStatus: Int = SC_CREATED, + headers: Map[String, String] = defaultHeaders) + (fn: R => Unit): Unit = { + post(uri, body = toJson(body), headers = headers)(doTest(expectedStatus, fn)) + } + + /** A version of jpost specific for testing file upload. */ + protected def jupload[R: ClassTag]( + uri: String, + files: Iterable[(String, Any)], + headers: Map[String, String] = Map(), + expectedStatus: Int = SC_OK) + (fn: R => Unit): Unit = { + post(uri, Map.empty, files)(doTest(expectedStatus, fn)) + } + + protected def jput[R: ClassTag]( + uri: String, + body: AnyRef, + expectedStatus: Int = SC_OK, + headers: Map[String, String] = defaultHeaders) + (fn: R => Unit): Unit = { + put(uri, body = toJson(body), headers = headers)(doTest(expectedStatus, fn)) + } + + private def doTest[R: ClassTag](expectedStatus: Int, fn: R => Unit) + (implicit klass: ClassTag[R]): Unit = { + if (status != expectedStatus) { + // Yeah this is weird, but we don't want to evaluate "response.body" if there's no error. + assert(status === expectedStatus, + s"Unexpected response status: $status != $expectedStatus (${response.body})") + } + // Only try to parse the body if response is in the "OK" range (20x). + if ((status / 100) * 100 == SC_OK) { + val result = + if (header("Content-Type").startsWith("application/json")) { + // Sometimes there's an empty body with no "Content-Length" header. So read the whole + // body first, and only send it to Jackson if there's content. + val in = response.inputStream + val out = new ByteArrayOutputStream() + val buf = new Array[Byte](1024) + var read = 0 + while (read >= 0) { + read = in.read(buf) + if (read > 0) { + out.write(buf, 0, read) + } + } + + val data = out.toByteArray() + if (data.length > 0) { + mapper.readValue(data, klass.runtimeClass) + } else { + null + } + } else { + assert(klass.runtimeClass == classOf[Unit]) + () + } + fn(result.asInstanceOf[R]) + } + } + + private def toJson(obj: Any): Array[Byte] = mapper.writeValueAsBytes(obj) + +} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/server/src/test/scala/org/apache/livy/server/BaseSessionServletSpec.scala ---------------------------------------------------------------------- diff --git a/server/src/test/scala/org/apache/livy/server/BaseSessionServletSpec.scala b/server/src/test/scala/org/apache/livy/server/BaseSessionServletSpec.scala new file mode 100644 index 0000000..203f1f7 --- /dev/null +++ b/server/src/test/scala/org/apache/livy/server/BaseSessionServletSpec.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.server + +import javax.servlet.http.HttpServletRequest + +import org.scalatest.BeforeAndAfterAll + +import org.apache.livy.LivyConf +import org.apache.livy.sessions.Session +import org.apache.livy.sessions.Session.RecoveryMetadata + +object BaseSessionServletSpec { + + /** Header used to override the user remote user in tests. */ + val REMOTE_USER_HEADER = "X-Livy-SessionServlet-User" + +} + +abstract class BaseSessionServletSpec[S <: Session, R <: RecoveryMetadata] + extends BaseJsonServletSpec + with BeforeAndAfterAll { + + /** Config map containing option that is blacklisted. */ + protected val BLACKLISTED_CONFIG = Map("spark.do_not_set" -> "true") + + /** Name of the admin user. */ + protected val ADMIN = "__admin__" + + /** Create headers that identify a specific user in tests. */ + protected def makeUserHeaders(user: String): Map[String, String] = { + defaultHeaders ++ Map(BaseSessionServletSpec.REMOTE_USER_HEADER -> user) + } + + protected val adminHeaders = makeUserHeaders(ADMIN) + + /** Create a LivyConf with impersonation enabled and a superuser. */ + protected def createConf(): LivyConf = { + new LivyConf() + .set(LivyConf.IMPERSONATION_ENABLED, true) + .set(LivyConf.SUPERUSERS, ADMIN) + .set(LivyConf.LOCAL_FS_WHITELIST, sys.props("java.io.tmpdir")) + } + + override def afterAll(): Unit = { + super.afterAll() + servlet.shutdown() + } + + def createServlet(): SessionServlet[S, R] + + protected val servlet = createServlet() + + addServlet(servlet, "/*") + + protected def toJson(msg: AnyRef): Array[Byte] = mapper.writeValueAsBytes(msg) + +} + +trait RemoteUserOverride { + this: SessionServlet[_, _] => + + override protected def remoteUser(req: HttpServletRequest): String = { + req.getHeader(BaseSessionServletSpec.REMOTE_USER_HEADER) + } + +} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/server/src/test/scala/org/apache/livy/server/JsonServletSpec.scala ---------------------------------------------------------------------- diff --git a/server/src/test/scala/org/apache/livy/server/JsonServletSpec.scala b/server/src/test/scala/org/apache/livy/server/JsonServletSpec.scala new file mode 100644 index 0000000..5ca3997 --- /dev/null +++ b/server/src/test/scala/org/apache/livy/server/JsonServletSpec.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.server + +import java.nio.charset.StandardCharsets.UTF_8 +import javax.servlet.http.HttpServletResponse._ + +import org.scalatra._ + +class JsonServletSpec extends BaseJsonServletSpec { + + addServlet(new TestJsonServlet(), "/*") + + describe("The JSON servlet") { + + it("should serialize result of delete") { + jdelete[MethodReturn]("/delete") { result => + assert(result.value === "delete") + } + } + + it("should serialize result of get") { + jget[MethodReturn]("/get") { result => + assert(result.value === "get") + } + } + + it("should serialize an ActionResult's body") { + jpost[MethodReturn]("/post", MethodArg("post")) { result => + assert(result.value === "post") + } + } + + it("should wrap a raw result") { + jput[MethodReturn]("/put", MethodArg("put")) { result => + assert(result.value === "put") + } + } + + it("should bypass non-json results") { + jpatch[Unit]("/patch", MethodArg("patch"), expectedStatus = SC_NOT_FOUND) { _ => + assert(response.body === "patch") + } + } + + it("should translate JSON errors to BadRequest") { + post("/post", "abcde".getBytes(UTF_8), headers = defaultHeaders) { + assert(status === SC_BAD_REQUEST) + } + } + + it("should translate bad param name to BadRequest") { + post("/post", """{"value1":"1"}""".getBytes(UTF_8), headers = defaultHeaders) { + assert(status === SC_BAD_REQUEST) + } + } + + it("should translate type mismatch to BadRequest") { + post("/postlist", """{"listParam":"1"}""".getBytes(UTF_8), headers = defaultHeaders) { + assert(status === SC_BAD_REQUEST) + } + } + + it("should respect user-installed error handlers") { + post("/error", headers = defaultHeaders) { + assert(status === SC_SERVICE_UNAVAILABLE) + assert(response.body === "error") + } + } + + it("should handle empty return values") { + jget[MethodReturn]("/empty") { result => + assert(result == null) + } + } + + } + +} + +private case class MethodArg(value: String) + +private case class MethodReturn(value: String) + +private case class MethodReturnList(listParam: List[String] = List()) + +private class TestJsonServlet extends JsonServlet { + + before() { + contentType = "application/json" + } + + delete("/delete") { + Ok(MethodReturn("delete")) + } + + get("/get") { + Ok(MethodReturn("get")) + } + + jpost[MethodArg]("/post") { arg => + Created(MethodReturn(arg.value)) + } + + jpost[MethodReturnList]("/postlist") { arg => + Created() + } + + jput[MethodArg]("/put") { arg => + MethodReturn(arg.value) + } + + jpatch[MethodArg]("/patch") { arg => + contentType = "text/plain" + NotFound(arg.value) + } + + get("/empty") { + () + } + + post("/error") { + throw new IllegalStateException("error") + } + + // Install an error handler to make sure the parent's still work. + error { + case e: IllegalStateException => + contentType = "text/plain" + ServiceUnavailable(e.getMessage()) + } + +} + http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/server/src/test/scala/org/apache/livy/server/SessionServletSpec.scala ---------------------------------------------------------------------- diff --git a/server/src/test/scala/org/apache/livy/server/SessionServletSpec.scala b/server/src/test/scala/org/apache/livy/server/SessionServletSpec.scala new file mode 100644 index 0000000..292a9cd --- /dev/null +++ b/server/src/test/scala/org/apache/livy/server/SessionServletSpec.scala @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.server + +import javax.servlet.http.HttpServletRequest +import javax.servlet.http.HttpServletResponse._ + +import org.scalatest.mock.MockitoSugar.mock + +import org.apache.livy.LivyConf +import org.apache.livy.server.recovery.SessionStore +import org.apache.livy.sessions.{Session, SessionManager, SessionState} +import org.apache.livy.sessions.Session.RecoveryMetadata + +object SessionServletSpec { + + val PROXY_USER = "proxyUser" + + class MockSession(id: Int, owner: String, livyConf: LivyConf) + extends Session(id, owner, livyConf) { + + case class MockRecoveryMetadata(id: Int) extends RecoveryMetadata() + + override val proxyUser = None + + override def recoveryMetadata: RecoveryMetadata = MockRecoveryMetadata(0) + + override def state: SessionState = SessionState.Idle() + + override protected def stopSession(): Unit = () + + override def logLines(): IndexedSeq[String] = IndexedSeq("log") + + } + + case class MockSessionView(id: Int, owner: String, logs: Seq[String]) + +} + +class SessionServletSpec + extends BaseSessionServletSpec[Session, RecoveryMetadata] { + + import SessionServletSpec._ + + override def createServlet(): SessionServlet[Session, RecoveryMetadata] = { + val conf = createConf() + val sessionManager = new SessionManager[Session, RecoveryMetadata]( + conf, + { _ => assert(false).asInstanceOf[Session] }, + mock[SessionStore], + "test", + Some(Seq.empty)) + + new SessionServlet(sessionManager, conf) with RemoteUserOverride { + override protected def createSession(req: HttpServletRequest): Session = { + val params = bodyAs[Map[String, String]](req) + checkImpersonation(params.get(PROXY_USER), req) + new MockSession(sessionManager.nextId(), remoteUser(req), conf) + } + + override protected def clientSessionView( + session: Session, + req: HttpServletRequest): Any = { + val logs = if (hasAccess(session.owner, req)) session.logLines() else Nil + MockSessionView(session.id, session.owner, logs) + } + } + } + + private val aliceHeaders = makeUserHeaders("alice") + private val bobHeaders = makeUserHeaders("bob") + + private def delete(id: Int, headers: Map[String, String], expectedStatus: Int): Unit = { + jdelete[Map[String, Any]](s"/$id", headers = headers, expectedStatus = expectedStatus) { _ => + // Nothing to do. + } + } + + describe("SessionServlet") { + + it("should return correct Location in header") { + // mount to "/sessions/*" to test. If request URI is "/session", getPathInfo() will + // return null, since there's no extra path. + // mount to "/*" will always return "/", so that it cannot reflect the issue. + addServlet(servlet, "/sessions/*") + jpost[MockSessionView]("/sessions", Map(), headers = aliceHeaders) { res => + assert(header("Location") === "/sessions/0") + jdelete[Map[String, Any]]("/sessions/0", SC_OK, aliceHeaders) { _ => } + } + } + + it("should attach owner information to sessions") { + jpost[MockSessionView]("/", Map(), headers = aliceHeaders) { res => + assert(res.owner === "alice") + assert(res.logs === IndexedSeq("log")) + delete(res.id, aliceHeaders, SC_OK) + } + } + + it("should allow other users to see non-sensitive information") { + jpost[MockSessionView]("/", Map(), headers = aliceHeaders) { res => + jget[MockSessionView](s"/${res.id}", headers = bobHeaders) { res => + assert(res.owner === "alice") + assert(res.logs === Nil) + } + delete(res.id, aliceHeaders, SC_OK) + } + } + + it("should prevent non-owners from modifying sessions") { + jpost[MockSessionView]("/", Map(), headers = aliceHeaders) { res => + delete(res.id, bobHeaders, SC_FORBIDDEN) + } + } + + it("should allow admins to access all sessions") { + jpost[MockSessionView]("/", Map(), headers = aliceHeaders) { res => + jget[MockSessionView](s"/${res.id}", headers = adminHeaders) { res => + assert(res.owner === "alice") + assert(res.logs === IndexedSeq("log")) + } + delete(res.id, adminHeaders, SC_OK) + } + } + + it("should not allow regular users to impersonate others") { + jpost[MockSessionView]("/", Map(PROXY_USER -> "bob"), headers = aliceHeaders, + expectedStatus = SC_FORBIDDEN) { _ => } + } + + it("should allow admins to impersonate anyone") { + jpost[MockSessionView]("/", Map(PROXY_USER -> "bob"), headers = adminHeaders) { res => + delete(res.id, bobHeaders, SC_FORBIDDEN) + delete(res.id, adminHeaders, SC_OK) + } + } + + } + +} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/server/src/test/scala/org/apache/livy/server/batch/BatchServletSpec.scala ---------------------------------------------------------------------- diff --git a/server/src/test/scala/org/apache/livy/server/batch/BatchServletSpec.scala b/server/src/test/scala/org/apache/livy/server/batch/BatchServletSpec.scala new file mode 100644 index 0000000..7f5e33d --- /dev/null +++ b/server/src/test/scala/org/apache/livy/server/batch/BatchServletSpec.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.server.batch + +import java.io.FileWriter +import java.nio.file.{Files, Path} +import java.util.concurrent.TimeUnit +import javax.servlet.http.HttpServletRequest +import javax.servlet.http.HttpServletResponse._ + +import scala.concurrent.duration.Duration + +import org.mockito.Mockito._ +import org.scalatest.mock.MockitoSugar.mock + +import org.apache.livy.Utils +import org.apache.livy.server.BaseSessionServletSpec +import org.apache.livy.server.recovery.SessionStore +import org.apache.livy.sessions.{BatchSessionManager, SessionState} +import org.apache.livy.utils.AppInfo + +class BatchServletSpec extends BaseSessionServletSpec[BatchSession, BatchRecoveryMetadata] { + + val script: Path = { + val script = Files.createTempFile("livy-test", ".py") + script.toFile.deleteOnExit() + val writer = new FileWriter(script.toFile) + try { + writer.write( + """ + |print "hello world" + """.stripMargin) + } finally { + writer.close() + } + script + } + + override def createServlet(): BatchSessionServlet = { + val livyConf = createConf() + val sessionStore = mock[SessionStore] + new BatchSessionServlet( + new BatchSessionManager(livyConf, sessionStore, Some(Seq.empty)), + sessionStore, + livyConf) + } + + describe("Batch Servlet") { + it("should create and tear down a batch") { + jget[Map[String, Any]]("/") { data => + data("sessions") should equal (Seq()) + } + + val createRequest = new CreateBatchRequest() + createRequest.file = script.toString + createRequest.conf = Map("spark.driver.extraClassPath" -> sys.props("java.class.path")) + + jpost[Map[String, Any]]("/", createRequest) { data => + header("Location") should equal("/0") + data("id") should equal (0) + + val batch = servlet.sessionManager.get(0) + batch should be (defined) + } + + // Wait for the process to finish. + { + val batch = servlet.sessionManager.get(0).get + Utils.waitUntil({ () => !batch.state.isActive }, Duration(10, TimeUnit.SECONDS)) + (batch.state match { + case SessionState.Success(_) => true + case _ => false + }) should be (true) + } + + jget[Map[String, Any]]("/0") { data => + data("id") should equal (0) + data("state") should equal ("success") + + val batch = servlet.sessionManager.get(0) + batch should be (defined) + } + + jget[Map[String, Any]]("/0/log?size=1000") { data => + data("id") should equal (0) + data("log").asInstanceOf[Seq[String]] should contain ("hello world") + + val batch = servlet.sessionManager.get(0) + batch should be (defined) + } + + jdelete[Map[String, Any]]("/0") { data => + data should equal (Map("msg" -> "deleted")) + + val batch = servlet.sessionManager.get(0) + batch should not be defined + } + } + + it("should respect config black list") { + val createRequest = new CreateBatchRequest() + createRequest.file = script.toString + createRequest.conf = BLACKLISTED_CONFIG + jpost[Map[String, Any]]("/", createRequest, expectedStatus = SC_BAD_REQUEST) { _ => } + } + + it("should show session properties") { + val id = 0 + val state = SessionState.Running() + val appId = "appid" + val appInfo = AppInfo(Some("DRIVER LOG URL"), Some("SPARK UI URL")) + val log = IndexedSeq[String]("log1", "log2") + + val session = mock[BatchSession] + when(session.id).thenReturn(id) + when(session.state).thenReturn(state) + when(session.appId).thenReturn(Some(appId)) + when(session.appInfo).thenReturn(appInfo) + when(session.logLines()).thenReturn(log) + + val req = mock[HttpServletRequest] + + val view = servlet.asInstanceOf[BatchSessionServlet].clientSessionView(session, req) + .asInstanceOf[BatchSessionView] + + view.id shouldEqual id + view.state shouldEqual state.toString + view.appId shouldEqual Some(appId) + view.appInfo shouldEqual appInfo + view.log shouldEqual log + } + } + +} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/server/src/test/scala/org/apache/livy/server/batch/BatchSessionSpec.scala ---------------------------------------------------------------------- diff --git a/server/src/test/scala/org/apache/livy/server/batch/BatchSessionSpec.scala b/server/src/test/scala/org/apache/livy/server/batch/BatchSessionSpec.scala new file mode 100644 index 0000000..eb80bef --- /dev/null +++ b/server/src/test/scala/org/apache/livy/server/batch/BatchSessionSpec.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.server.batch + +import java.io.FileWriter +import java.nio.file.{Files, Path} +import java.util.concurrent.TimeUnit + +import scala.concurrent.duration.Duration + +import org.mockito.Matchers +import org.mockito.Matchers.anyObject +import org.mockito.Mockito._ +import org.scalatest.{BeforeAndAfter, FunSpec, ShouldMatchers} +import org.scalatest.mock.MockitoSugar.mock + +import org.apache.livy.{LivyBaseUnitTestSuite, LivyConf, Utils} +import org.apache.livy.server.recovery.SessionStore +import org.apache.livy.sessions.SessionState +import org.apache.livy.utils.{AppInfo, SparkApp} + +class BatchSessionSpec + extends FunSpec + with BeforeAndAfter + with ShouldMatchers + with LivyBaseUnitTestSuite { + + val script: Path = { + val script = Files.createTempFile("livy-test", ".py") + script.toFile.deleteOnExit() + val writer = new FileWriter(script.toFile) + try { + writer.write( + """ + |print "hello world" + """.stripMargin) + } finally { + writer.close() + } + script + } + + describe("A Batch process") { + var sessionStore: SessionStore = null + + before { + sessionStore = mock[SessionStore] + } + + it("should create a process") { + val req = new CreateBatchRequest() + req.file = script.toString + req.conf = Map("spark.driver.extraClassPath" -> sys.props("java.class.path")) + + val conf = new LivyConf().set(LivyConf.LOCAL_FS_WHITELIST, sys.props("java.io.tmpdir")) + val batch = BatchSession.create(0, req, conf, null, None, sessionStore) + + Utils.waitUntil({ () => !batch.state.isActive }, Duration(10, TimeUnit.SECONDS)) + (batch.state match { + case SessionState.Success(_) => true + case _ => false + }) should be (true) + + batch.logLines() should contain("hello world") + } + + it("should update appId and appInfo") { + val conf = new LivyConf() + val req = new CreateBatchRequest() + val mockApp = mock[SparkApp] + val batch = BatchSession.create(0, req, conf, null, None, sessionStore, Some(mockApp)) + + val expectedAppId = "APPID" + batch.appIdKnown(expectedAppId) + verify(sessionStore, atLeastOnce()).save( + Matchers.eq(BatchSession.RECOVERY_SESSION_TYPE), anyObject()) + batch.appId shouldEqual Some(expectedAppId) + + val expectedAppInfo = AppInfo(Some("DRIVER LOG URL"), Some("SPARK UI URL")) + batch.infoChanged(expectedAppInfo) + batch.appInfo shouldEqual expectedAppInfo + } + + it("should recover session") { + val conf = new LivyConf() + val req = new CreateBatchRequest() + val mockApp = mock[SparkApp] + val m = BatchRecoveryMetadata(99, None, "appTag", null, None) + val batch = BatchSession.recover(m, conf, sessionStore, Some(mockApp)) + + batch.state shouldBe a[SessionState.Recovering] + + batch.appIdKnown("appId") + verify(sessionStore, atLeastOnce()).save( + Matchers.eq(BatchSession.RECOVERY_SESSION_TYPE), anyObject()) + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/server/src/test/scala/org/apache/livy/server/batch/CreateBatchRequestSpec.scala ---------------------------------------------------------------------- diff --git a/server/src/test/scala/org/apache/livy/server/batch/CreateBatchRequestSpec.scala b/server/src/test/scala/org/apache/livy/server/batch/CreateBatchRequestSpec.scala new file mode 100644 index 0000000..7fef3c2 --- /dev/null +++ b/server/src/test/scala/org/apache/livy/server/batch/CreateBatchRequestSpec.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.server.batch + +import com.fasterxml.jackson.databind.{JsonMappingException, ObjectMapper} +import org.scalatest.FunSpec + +import org.apache.livy.LivyBaseUnitTestSuite + +class CreateBatchRequestSpec extends FunSpec with LivyBaseUnitTestSuite { + + private val mapper = new ObjectMapper() + .registerModule(com.fasterxml.jackson.module.scala.DefaultScalaModule) + + describe("CreateBatchRequest") { + + it("should have default values for fields after deserialization") { + val json = """{ "file" : "foo" }""" + val req = mapper.readValue(json, classOf[CreateBatchRequest]) + assert(req.file === "foo") + assert(req.proxyUser === None) + assert(req.args === List()) + assert(req.className === None) + assert(req.jars === List()) + assert(req.pyFiles === List()) + assert(req.files === List()) + assert(req.driverMemory === None) + assert(req.driverCores === None) + assert(req.executorMemory === None) + assert(req.executorCores === None) + assert(req.numExecutors === None) + assert(req.archives === List()) + assert(req.queue === None) + assert(req.name === None) + assert(req.conf === Map()) + } + + } + +} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/server/src/test/scala/org/apache/livy/server/interactive/BaseInteractiveServletSpec.scala ---------------------------------------------------------------------- diff --git a/server/src/test/scala/org/apache/livy/server/interactive/BaseInteractiveServletSpec.scala b/server/src/test/scala/org/apache/livy/server/interactive/BaseInteractiveServletSpec.scala new file mode 100644 index 0000000..b401a92 --- /dev/null +++ b/server/src/test/scala/org/apache/livy/server/interactive/BaseInteractiveServletSpec.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.server.interactive + +import java.io.File +import java.nio.file.Files + +import org.apache.commons.io.FileUtils +import org.apache.spark.launcher.SparkLauncher + +import org.apache.livy.LivyConf +import org.apache.livy.rsc.RSCConf +import org.apache.livy.server.BaseSessionServletSpec +import org.apache.livy.sessions.{Kind, SessionKindModule, Spark} + +abstract class BaseInteractiveServletSpec + extends BaseSessionServletSpec[InteractiveSession, InteractiveRecoveryMetadata] { + + mapper.registerModule(new SessionKindModule()) + + protected var tempDir: File = _ + + override def afterAll(): Unit = { + super.afterAll() + if (tempDir != null) { + scala.util.Try(FileUtils.deleteDirectory(tempDir)) + tempDir = null + } + } + + override protected def createConf(): LivyConf = synchronized { + if (tempDir == null) { + tempDir = Files.createTempDirectory("client-test").toFile() + } + super.createConf() + .set(LivyConf.SESSION_STAGING_DIR, tempDir.toURI().toString()) + .set(LivyConf.REPL_JARS, "dummy.jar") + .set(LivyConf.LIVY_SPARK_VERSION, "1.6.0") + .set(LivyConf.LIVY_SPARK_SCALA_VERSION, "2.10.5") + } + + protected def createRequest( + inProcess: Boolean = true, + extraConf: Map[String, String] = Map(), + kind: Kind = Spark()): CreateInteractiveRequest = { + val classpath = sys.props("java.class.path") + val request = new CreateInteractiveRequest() + request.kind = kind + request.conf = extraConf ++ Map( + RSCConf.Entry.LIVY_JARS.key() -> "", + RSCConf.Entry.CLIENT_IN_PROCESS.key() -> inProcess.toString, + SparkLauncher.SPARK_MASTER -> "local", + SparkLauncher.DRIVER_EXTRA_CLASSPATH -> classpath, + SparkLauncher.EXECUTOR_EXTRA_CLASSPATH -> classpath + ) + request + } + +} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/server/src/test/scala/org/apache/livy/server/interactive/CreateInteractiveRequestSpec.scala ---------------------------------------------------------------------- diff --git a/server/src/test/scala/org/apache/livy/server/interactive/CreateInteractiveRequestSpec.scala b/server/src/test/scala/org/apache/livy/server/interactive/CreateInteractiveRequestSpec.scala new file mode 100644 index 0000000..a67c725 --- /dev/null +++ b/server/src/test/scala/org/apache/livy/server/interactive/CreateInteractiveRequestSpec.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.server.interactive + +import com.fasterxml.jackson.databind.ObjectMapper +import org.scalatest.FunSpec + +import org.apache.livy.LivyBaseUnitTestSuite +import org.apache.livy.sessions.{PySpark, SessionKindModule} + +class CreateInteractiveRequestSpec extends FunSpec with LivyBaseUnitTestSuite { + + private val mapper = new ObjectMapper() + .registerModule(com.fasterxml.jackson.module.scala.DefaultScalaModule) + .registerModule(new SessionKindModule()) + + describe("CreateInteractiveRequest") { + + it("should have default values for fields after deserialization") { + val json = """{ "kind" : "pyspark" }""" + val req = mapper.readValue(json, classOf[CreateInteractiveRequest]) + assert(req.kind === PySpark()) + assert(req.proxyUser === None) + assert(req.jars === List()) + assert(req.pyFiles === List()) + assert(req.files === List()) + assert(req.driverMemory === None) + assert(req.driverCores === None) + assert(req.executorMemory === None) + assert(req.executorCores === None) + assert(req.numExecutors === None) + assert(req.archives === List()) + assert(req.queue === None) + assert(req.name === None) + assert(req.conf === Map()) + } + + } + +} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/server/src/test/scala/org/apache/livy/server/interactive/InteractiveSessionServletSpec.scala ---------------------------------------------------------------------- diff --git a/server/src/test/scala/org/apache/livy/server/interactive/InteractiveSessionServletSpec.scala b/server/src/test/scala/org/apache/livy/server/interactive/InteractiveSessionServletSpec.scala new file mode 100644 index 0000000..372fe76 --- /dev/null +++ b/server/src/test/scala/org/apache/livy/server/interactive/InteractiveSessionServletSpec.scala @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.server.interactive + +import java.util.concurrent.atomic.AtomicInteger +import javax.servlet.http.{HttpServletRequest, HttpServletResponse} + +import scala.collection.JavaConverters._ +import scala.concurrent.Future + +import org.json4s.jackson.Json4sScalaModule +import org.mockito.Matchers._ +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.Entry +import org.scalatest.mock.MockitoSugar.mock + +import org.apache.livy.{ExecuteRequest, LivyConf} +import org.apache.livy.client.common.HttpMessages.SessionInfo +import org.apache.livy.rsc.driver.{Statement, StatementState} +import org.apache.livy.server.recovery.SessionStore +import org.apache.livy.sessions._ +import org.apache.livy.utils.AppInfo + +class InteractiveSessionServletSpec extends BaseInteractiveServletSpec { + + mapper.registerModule(new Json4sScalaModule()) + + class MockInteractiveSessionServlet( + sessionManager: InteractiveSessionManager, + conf: LivyConf) + extends InteractiveSessionServlet(sessionManager, mock[SessionStore], conf) { + + private var statements = IndexedSeq[Statement]() + + override protected def createSession(req: HttpServletRequest): InteractiveSession = { + val statementCounter = new AtomicInteger() + + val session = mock[InteractiveSession] + when(session.kind).thenReturn(Spark()) + when(session.appId).thenReturn(None) + when(session.appInfo).thenReturn(AppInfo()) + when(session.logLines()).thenReturn(IndexedSeq()) + when(session.state).thenReturn(SessionState.Idle()) + when(session.stop()).thenReturn(Future.successful(())) + when(session.proxyUser).thenReturn(None) + when(session.statements).thenAnswer( + new Answer[IndexedSeq[Statement]]() { + override def answer(args: InvocationOnMock): IndexedSeq[Statement] = statements + }) + when(session.executeStatement(any(classOf[ExecuteRequest]))).thenAnswer( + new Answer[Statement]() { + override def answer(args: InvocationOnMock): Statement = { + val id = statementCounter.getAndIncrement + val statement = new Statement(id, "1+1", StatementState.Available, "1") + + statements :+= statement + statement + } + }) + when(session.cancelStatement(anyInt())).thenAnswer( + new Answer[Unit] { + override def answer(args: InvocationOnMock): Unit = { + statements = IndexedSeq( + new Statement(statementCounter.get(), null, StatementState.Cancelled, null)) + } + } + ) + + session + } + + } + + override def createServlet(): InteractiveSessionServlet = { + val conf = createConf() + val sessionManager = new InteractiveSessionManager(conf, mock[SessionStore], Some(Seq.empty)) + new MockInteractiveSessionServlet(sessionManager, conf) + } + + it("should setup and tear down an interactive session") { + jget[Map[String, Any]]("/") { data => + data("sessions") should equal(Seq()) + } + + jpost[Map[String, Any]]("/", createRequest()) { data => + header("Location") should equal("/0") + data("id") should equal (0) + + val session = servlet.sessionManager.get(0) + session should be (defined) + } + + jget[Map[String, Any]]("/0") { data => + data("id") should equal (0) + data("state") should equal ("idle") + + val batch = servlet.sessionManager.get(0) + batch should be (defined) + } + + jpost[Map[String, Any]]("/0/statements", ExecuteRequest("foo")) { data => + data("id") should be (0) + data("code") shouldBe "1+1" + data("progress") should be (0.0) + data("output") shouldBe 1 + } + + jget[Map[String, Any]]("/0/statements") { data => + data("total_statements") should be (1) + data("statements").asInstanceOf[Seq[Map[String, Any]]](0)("id") should be (0) + } + + jpost[Map[String, Any]]("/0/statements/0/cancel", null, HttpServletResponse.SC_OK) { data => + data should equal(Map("msg" -> "canceled")) + } + + jget[Map[String, Any]]("/0/statements") { data => + data("total_statements") should be (1) + data("statements").asInstanceOf[Seq[Map[String, Any]]](0)("state") should be ("cancelled") + } + + jdelete[Map[String, Any]]("/0") { data => + data should equal (Map("msg" -> "deleted")) + + val session = servlet.sessionManager.get(0) + session should not be defined + } + } + + it("should show session properties") { + val id = 0 + val appId = "appid" + val owner = "owner" + val proxyUser = "proxyUser" + val state = SessionState.Running() + val kind = Spark() + val appInfo = AppInfo(Some("DRIVER LOG URL"), Some("SPARK UI URL")) + val log = IndexedSeq[String]("log1", "log2") + + val session = mock[InteractiveSession] + when(session.id).thenReturn(id) + when(session.appId).thenReturn(Some(appId)) + when(session.owner).thenReturn(owner) + when(session.proxyUser).thenReturn(Some(proxyUser)) + when(session.state).thenReturn(state) + when(session.kind).thenReturn(kind) + when(session.appInfo).thenReturn(appInfo) + when(session.logLines()).thenReturn(log) + + val req = mock[HttpServletRequest] + + val view = servlet.asInstanceOf[InteractiveSessionServlet].clientSessionView(session, req) + .asInstanceOf[SessionInfo] + + view.id shouldEqual id + view.appId shouldEqual appId + view.owner shouldEqual owner + view.proxyUser shouldEqual proxyUser + view.state shouldEqual state.toString + view.kind shouldEqual kind.toString + view.appInfo should contain (Entry(AppInfo.DRIVER_LOG_URL_NAME, appInfo.driverLogUrl.get)) + view.appInfo should contain (Entry(AppInfo.SPARK_UI_URL_NAME, appInfo.sparkUiUrl.get)) + view.log shouldEqual log.asJava + } + +} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/server/src/test/scala/org/apache/livy/server/interactive/InteractiveSessionSpec.scala ---------------------------------------------------------------------- diff --git a/server/src/test/scala/org/apache/livy/server/interactive/InteractiveSessionSpec.scala b/server/src/test/scala/org/apache/livy/server/interactive/InteractiveSessionSpec.scala new file mode 100644 index 0000000..d2ae9ae --- /dev/null +++ b/server/src/test/scala/org/apache/livy/server/interactive/InteractiveSessionSpec.scala @@ -0,0 +1,264 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.server.interactive + +import java.net.URI + +import scala.concurrent.Await +import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.apache.spark.launcher.SparkLauncher +import org.json4s.{DefaultFormats, Extraction, JValue} +import org.json4s.jackson.JsonMethods.parse +import org.mockito.{Matchers => MockitoMatchers} +import org.mockito.Matchers._ +import org.mockito.Mockito.{atLeastOnce, verify, when} +import org.scalatest.{BeforeAndAfterAll, FunSpec, Matchers} +import org.scalatest.concurrent.Eventually._ +import org.scalatest.mock.MockitoSugar.mock + +import org.apache.livy.{ExecuteRequest, JobHandle, LivyBaseUnitTestSuite, LivyConf} +import org.apache.livy.rsc.{PingJob, RSCClient, RSCConf} +import org.apache.livy.rsc.driver.StatementState +import org.apache.livy.server.recovery.SessionStore +import org.apache.livy.sessions.{PySpark, SessionState, Spark} +import org.apache.livy.utils.{AppInfo, SparkApp} + +class InteractiveSessionSpec extends FunSpec + with Matchers with BeforeAndAfterAll with LivyBaseUnitTestSuite { + + private val livyConf = new LivyConf() + livyConf.set(LivyConf.REPL_JARS, "dummy.jar") + .set(LivyConf.LIVY_SPARK_VERSION, "1.6.0") + .set(LivyConf.LIVY_SPARK_SCALA_VERSION, "2.10.5") + + implicit val formats = DefaultFormats + + private var session: InteractiveSession = null + + private def createSession( + sessionStore: SessionStore = mock[SessionStore], + mockApp: Option[SparkApp] = None): InteractiveSession = { + assume(sys.env.get("SPARK_HOME").isDefined, "SPARK_HOME is not set.") + + val req = new CreateInteractiveRequest() + req.kind = PySpark() + req.driverMemory = Some("512m") + req.driverCores = Some(1) + req.executorMemory = Some("512m") + req.executorCores = Some(1) + req.name = Some("InteractiveSessionSpec") + req.conf = Map( + SparkLauncher.DRIVER_EXTRA_CLASSPATH -> sys.props("java.class.path"), + RSCConf.Entry.LIVY_JARS.key() -> "" + ) + InteractiveSession.create(0, null, None, livyConf, req, sessionStore, mockApp) + } + + private def executeStatement(code: String): JValue = { + val id = session.executeStatement(ExecuteRequest(code)).id + eventually(timeout(30 seconds), interval(100 millis)) { + val s = session.getStatement(id).get + s.state.get() shouldBe StatementState.Available + parse(s.output) + } + } + + override def afterAll(): Unit = { + if (session != null) { + Await.ready(session.stop(), 30 seconds) + session = null + } + super.afterAll() + } + + private def withSession(desc: String)(fn: (InteractiveSession) => Unit): Unit = { + it(desc) { + assume(session != null, "No active session.") + eventually(timeout(30 seconds), interval(100 millis)) { + session.state shouldBe a[SessionState.Idle] + } + fn(session) + } + } + + describe("A spark session") { + + it("should get scala version matched jars with livy.repl.jars") { + val testedJars = Seq( + "test_2.10-0.1.jar", + "local://dummy-path/test/test1_2.10-1.0.jar", + "file:///dummy-path/test/test2_2.11-1.0-SNAPSHOT.jar", + "hdfs:///dummy-path/test/test3.jar", + "non-jar", + "dummy.jar" + ) + val livyConf = new LivyConf(false) + .set(LivyConf.REPL_JARS, testedJars.mkString(",")) + .set(LivyConf.LIVY_SPARK_VERSION, "1.6.2") + .set(LivyConf.LIVY_SPARK_SCALA_VERSION, "2.10") + val properties = InteractiveSession.prepareBuilderProp(Map.empty, Spark(), livyConf) + assert(properties(LivyConf.SPARK_JARS).split(",").toSet === Set("test_2.10-0.1.jar", + "local://dummy-path/test/test1_2.10-1.0.jar", + "hdfs:///dummy-path/test/test3.jar", + "dummy.jar")) + + livyConf.set(LivyConf.LIVY_SPARK_SCALA_VERSION, "2.11") + val properties1 = InteractiveSession.prepareBuilderProp(Map.empty, Spark(), livyConf) + assert(properties1(LivyConf.SPARK_JARS).split(",").toSet === Set( + "file:///dummy-path/test/test2_2.11-1.0-SNAPSHOT.jar", + "hdfs:///dummy-path/test/test3.jar", + "dummy.jar")) + } + + + it("should set rsc jars through livy conf") { + val rscJars = Set( + "dummy.jar", + "local:///dummy-path/dummy1.jar", + "file:///dummy-path/dummy2.jar", + "hdfs:///dummy-path/dummy3.jar") + val livyConf = new LivyConf(false) + .set(LivyConf.REPL_JARS, "dummy.jar") + .set(LivyConf.RSC_JARS, rscJars.mkString(",")) + .set(LivyConf.LIVY_SPARK_VERSION, "1.6.2") + .set(LivyConf.LIVY_SPARK_SCALA_VERSION, "2.10") + val properties = InteractiveSession.prepareBuilderProp(Map.empty, Spark(), livyConf) + // if livy.rsc.jars is configured in LivyConf, it should be passed to RSCConf. + properties(RSCConf.Entry.LIVY_JARS.key()).split(",").toSet === rscJars + + val rscJars1 = Set( + "foo.jar", + "local:///dummy-path/foo1.jar", + "file:///dummy-path/foo2.jar", + "hdfs:///dummy-path/foo3.jar") + val properties1 = InteractiveSession.prepareBuilderProp( + Map(RSCConf.Entry.LIVY_JARS.key() -> rscJars1.mkString(",")), Spark(), livyConf) + // if rsc jars are configured both in LivyConf and RSCConf, RSCConf should take precedence. + properties1(RSCConf.Entry.LIVY_JARS.key()).split(",").toSet === rscJars1 + } + + it("should start in the idle state") { + session = createSession() + session.state should (be(a[SessionState.Starting]) or be(a[SessionState.Idle])) + } + + it("should update appId and appInfo and session store") { + val mockApp = mock[SparkApp] + val sessionStore = mock[SessionStore] + val session = createSession(sessionStore, Some(mockApp)) + + val expectedAppId = "APPID" + session.appIdKnown(expectedAppId) + session.appId shouldEqual Some(expectedAppId) + + val expectedAppInfo = AppInfo(Some("DRIVER LOG URL"), Some("SPARK UI URL")) + session.infoChanged(expectedAppInfo) + session.appInfo shouldEqual expectedAppInfo + + verify(sessionStore, atLeastOnce()).save( + MockitoMatchers.eq(InteractiveSession.RECOVERY_SESSION_TYPE), anyObject()) + } + + withSession("should execute `1 + 2` == 3") { session => + val result = executeStatement("1 + 2") + val expectedResult = Extraction.decompose(Map( + "status" -> "ok", + "execution_count" -> 0, + "data" -> Map( + "text/plain" -> "3" + ) + )) + + result should equal (expectedResult) + } + + withSession("should report an error if accessing an unknown variable") { session => + val result = executeStatement("x") + val expectedResult = Extraction.decompose(Map( + "status" -> "error", + "execution_count" -> 1, + "ename" -> "NameError", + "evalue" -> "name 'x' is not defined", + "traceback" -> List( + "Traceback (most recent call last):\n", + "NameError: name 'x' is not defined\n" + ) + )) + + result should equal (expectedResult) + eventually(timeout(10 seconds), interval(30 millis)) { + session.state shouldBe a[SessionState.Idle] + } + } + + withSession("should get statement progress along with statement result") { session => + val code = + """ + |from time import sleep + |sleep(3) + """.stripMargin + val statement = session.executeStatement(ExecuteRequest(code)) + statement.progress should be (0.0) + + eventually(timeout(10 seconds), interval(100 millis)) { + val s = session.getStatement(statement.id).get + s.state.get() shouldBe StatementState.Available + s.progress should be (1.0) + } + } + + withSession("should error out the session if the interpreter dies") { session => + session.executeStatement(ExecuteRequest("import os; os._exit(666)")) + eventually(timeout(30 seconds), interval(100 millis)) { + session.state shouldBe a[SessionState.Error] + } + } + } + + describe("recovery") { + it("should recover session") { + val conf = new LivyConf() + val sessionStore = mock[SessionStore] + val mockClient = mock[RSCClient] + when(mockClient.submit(any(classOf[PingJob]))).thenReturn(mock[JobHandle[Void]]) + val m = + InteractiveRecoveryMetadata( + 78, None, "appTag", Spark(), 0, null, None, Some(URI.create(""))) + val s = InteractiveSession.recover(m, conf, sessionStore, None, Some(mockClient)) + + s.state shouldBe a[SessionState.Recovering] + + s.appIdKnown("appId") + verify(sessionStore, atLeastOnce()).save( + MockitoMatchers.eq(InteractiveSession.RECOVERY_SESSION_TYPE), anyObject()) + } + + it("should recover session to dead state if rscDriverUri is unknown") { + val conf = new LivyConf() + val sessionStore = mock[SessionStore] + val m = InteractiveRecoveryMetadata( + 78, Some("appId"), "appTag", Spark(), 0, null, None, None) + val s = InteractiveSession.recover(m, conf, sessionStore, None) + + s.state shouldBe a[SessionState.Dead] + s.logLines().mkString should include("RSCDriver URI is unknown") + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/server/src/test/scala/org/apache/livy/server/interactive/JobApiSpec.scala ---------------------------------------------------------------------- diff --git a/server/src/test/scala/org/apache/livy/server/interactive/JobApiSpec.scala b/server/src/test/scala/org/apache/livy/server/interactive/JobApiSpec.scala new file mode 100644 index 0000000..697a953 --- /dev/null +++ b/server/src/test/scala/org/apache/livy/server/interactive/JobApiSpec.scala @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.server.interactive + +import java.io.File +import java.net.URI +import java.nio.ByteBuffer +import java.nio.file.{Files, Paths} +import javax.servlet.http.HttpServletResponse._ + +import scala.concurrent.duration._ +import scala.io.Source +import scala.language.postfixOps + +import org.scalatest.concurrent.Eventually._ +import org.scalatest.mock.MockitoSugar.mock + +import org.apache.livy.{Job, JobHandle} +import org.apache.livy.client.common.{BufferUtils, Serializer} +import org.apache.livy.client.common.HttpMessages._ +import org.apache.livy.server.RemoteUserOverride +import org.apache.livy.server.recovery.SessionStore +import org.apache.livy.sessions.{InteractiveSessionManager, SessionState} +import org.apache.livy.test.jobs.{Echo, GetCurrentUser} + +class JobApiSpec extends BaseInteractiveServletSpec { + + private val PROXY = "__proxy__" + + private var sessionId: Int = -1 + + override def createServlet(): InteractiveSessionServlet = { + val conf = createConf() + val sessionStore = mock[SessionStore] + val sessionManager = new InteractiveSessionManager(conf, sessionStore, Some(Seq.empty)) + new InteractiveSessionServlet(sessionManager, sessionStore, conf) with RemoteUserOverride + } + + def withSessionId(desc: String)(fn: (Int) => Unit): Unit = { + it(desc) { + assume(sessionId != -1, "No active session.") + fn(sessionId) + } + } + + describe("Interactive Servlet") { + + it("should create sessions") { + jpost[SessionInfo]("/", createRequest()) { data => + waitForIdle(data.id) + header("Location") should equal("/0") + data.id should equal (0) + sessionId = data.id + } + } + + withSessionId("should handle asynchronous jobs") { testJobSubmission(_, false) } + + withSessionId("should handle synchronous jobs") { testJobSubmission(_, true) } + + // Test that the file does get copied over to the live home dir on HDFS - does not test end + // to end that the RSCClient class copies it over to the app. + withSessionId("should support file uploads") { id => + testResourceUpload("file", id) + } + + withSessionId("should support jar uploads") { id => + testResourceUpload("jar", id) + } + + withSessionId("should monitor async Spark jobs") { sid => + val ser = new Serializer() + val job = BufferUtils.toByteArray(ser.serialize(new Echo("hello"))) + var jobId: Long = -1L + jpost[JobStatus](s"/$sid/submit-job", new SerializedJob(job)) { status => + jobId = status.id + } + + eventually(timeout(1 minute), interval(100 millis)) { + jget[JobStatus](s"/$sid/jobs/$jobId") { status => + status.state should be (JobHandle.State.SUCCEEDED) + } + } + } + + withSessionId("should update last activity on connect") { sid => + val currentActivity = servlet.sessionManager.get(sid).get.lastActivity + jpost[SessionInfo](s"/$sid/connect", null, expectedStatus = SC_OK) { info => + val newActivity = servlet.sessionManager.get(sid).get.lastActivity + assert(newActivity > currentActivity) + } + } + + withSessionId("should tear down sessions") { id => + jdelete[Map[String, Any]](s"/$id") { data => + data should equal (Map("msg" -> "deleted")) + } + jget[Map[String, Any]]("/") { data => + data("sessions") match { + case contents: Seq[_] => contents.size should equal (0) + case _ => fail("Response is not an array.") + } + } + + // Make sure the session's staging directory was cleaned up. + assert(tempDir.listFiles().length === 0) + } + + it("should support user impersonation") { + val headers = makeUserHeaders(PROXY) + jpost[SessionInfo]("/", createRequest(inProcess = false), headers = headers) { data => + try { + waitForIdle(data.id) + data.owner should be (PROXY) + data.proxyUser should be (PROXY) + val user = runJob(data.id, new GetCurrentUser(), headers = headers) + user should be (PROXY) + } finally { + deleteSession(data.id) + } + } + } + + it("should honor impersonation requests") { + val request = createRequest(inProcess = false) + request.proxyUser = Some(PROXY) + jpost[SessionInfo]("/", request, headers = adminHeaders) { data => + try { + waitForIdle(data.id) + data.owner should be (ADMIN) + data.proxyUser should be (PROXY) + val user = runJob(data.id, new GetCurrentUser(), headers = adminHeaders) + user should be (PROXY) + + // Test that files are uploaded to a new session directory. + assert(tempDir.listFiles().length === 0) + testResourceUpload("file", data.id) + } finally { + deleteSession(data.id) + assert(tempDir.listFiles().length === 0) + } + } + } + + it("should respect config black list") { + jpost[SessionInfo]("/", createRequest(extraConf = BLACKLISTED_CONFIG), + expectedStatus = SC_BAD_REQUEST) { _ => } + } + + } + + private def waitForIdle(id: Int): Unit = { + eventually(timeout(1 minute), interval(100 millis)) { + jget[SessionInfo](s"/$id") { status => + status.state should be (SessionState.Idle().toString()) + } + } + } + + private def deleteSession(id: Int): Unit = { + jdelete[Map[String, Any]](s"/$id", headers = adminHeaders) { _ => } + } + + private def testResourceUpload(cmd: String, sessionId: Int): Unit = { + val f = File.createTempFile("uploadTestFile", cmd) + val conf = createConf() + + Files.write(Paths.get(f.getAbsolutePath), "Test data".getBytes()) + + jupload[Unit](s"/$sessionId/upload-$cmd", Map(cmd -> f), expectedStatus = SC_OK) { _ => + // There should be a single directory under the staging dir. + val subdirs = tempDir.listFiles() + assert(subdirs.length === 1) + val stagingDir = subdirs(0).toURI().toString() + + val resultFile = new File(new URI(s"$stagingDir/${f.getName}")) + resultFile.deleteOnExit() + resultFile.exists() should be(true) + Source.fromFile(resultFile).mkString should be("Test data") + } + } + + private def testJobSubmission(sid: Int, sync: Boolean): Unit = { + val result = runJob(sid, new Echo(42), sync = sync) + result should be (42) + } + + private def runJob[T]( + sid: Int, + job: Job[T], + sync: Boolean = false, + headers: Map[String, String] = defaultHeaders): T = { + val ser = new Serializer() + val jobData = BufferUtils.toByteArray(ser.serialize(job)) + val route = if (sync) s"/$sid/submit-job" else s"/$sid/run-job" + var jobId: Long = -1L + jpost[JobStatus](route, new SerializedJob(jobData), headers = headers) { data => + jobId = data.id + } + + var result: Option[T] = None + eventually(timeout(1 minute), interval(100 millis)) { + jget[JobStatus](s"/$sid/jobs/$jobId") { status => + status.id should be (jobId) + status.state should be (JobHandle.State.SUCCEEDED) + result = Some(ser.deserialize(ByteBuffer.wrap(status.result)).asInstanceOf[T]) + } + } + result.getOrElse(throw new IllegalStateException()) + } + +} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/server/src/test/scala/org/apache/livy/server/interactive/SessionHeartbeatSpec.scala ---------------------------------------------------------------------- diff --git a/server/src/test/scala/org/apache/livy/server/interactive/SessionHeartbeatSpec.scala b/server/src/test/scala/org/apache/livy/server/interactive/SessionHeartbeatSpec.scala new file mode 100644 index 0000000..12c8bbb --- /dev/null +++ b/server/src/test/scala/org/apache/livy/server/interactive/SessionHeartbeatSpec.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.server.interactive + +import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.mockito.Mockito.{never, verify, when} +import org.scalatest.{FunSpec, Matchers} +import org.scalatest.concurrent.Eventually._ +import org.scalatest.mock.MockitoSugar.mock + +import org.apache.livy.LivyConf +import org.apache.livy.server.recovery.SessionStore +import org.apache.livy.sessions.{Session, SessionManager} +import org.apache.livy.sessions.Session.RecoveryMetadata + +class SessionHeartbeatSpec extends FunSpec with Matchers { + describe("SessionHeartbeat") { + class TestHeartbeat(override val heartbeatTimeout: FiniteDuration) extends SessionHeartbeat {} + + it("should not expire if heartbeat was never called.") { + val t = new TestHeartbeat(Duration.Zero) + t.heartbeatExpired shouldBe false + } + + it("should expire if time has elapsed.") { + val t = new TestHeartbeat(Duration.fromNanos(1)) + t.heartbeat() + eventually(timeout(2 nano), interval(1 nano)) { + t.heartbeatExpired shouldBe true + } + } + + it("should not expire if time hasn't elapsed.") { + val t = new TestHeartbeat(Duration.create(1, DAYS)) + t.heartbeat() + t.heartbeatExpired shouldBe false + } + } + + describe("SessionHeartbeatWatchdog") { + abstract class TestSession extends Session(0, null, null) with SessionHeartbeat {} + class TestWatchdog(conf: LivyConf) + extends SessionManager[TestSession, RecoveryMetadata]( + conf, + { _ => assert(false).asInstanceOf[TestSession] }, + mock[SessionStore], + "test", + Some(Seq.empty)) + with SessionHeartbeatWatchdog[TestSession, RecoveryMetadata] {} + + it("should delete only expired sessions") { + val expiredSession: TestSession = mock[TestSession] + when(expiredSession.id).thenReturn(0) + when(expiredSession.heartbeatExpired).thenReturn(true) + + val nonExpiredSession: TestSession = mock[TestSession] + when(nonExpiredSession.id).thenReturn(1) + when(nonExpiredSession.heartbeatExpired).thenReturn(false) + + val n = new TestWatchdog(new LivyConf()) + + n.register(expiredSession) + n.register(nonExpiredSession) + n.deleteExpiredSessions() + + verify(expiredSession).stop() + verify(nonExpiredSession, never).stop() + } + } +}