Repository: spark Updated Branches: refs/heads/master 4943ea598 -> 576c43fb4
http://git-wip-us.apache.org/repos/asf/spark/blob/576c43fb/repl/src/main/scala/org/apache/spark/repl/Main.scala ---------------------------------------------------------------------- diff --git a/repl/src/main/scala/org/apache/spark/repl/Main.scala b/repl/src/main/scala/org/apache/spark/repl/Main.scala new file mode 100644 index 0000000..cc76a70 --- /dev/null +++ b/repl/src/main/scala/org/apache/spark/repl/Main.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.repl + +import java.io.File +import java.net.URI +import java.util.Locale + +import scala.tools.nsc.GenericRunnerSettings + +import org.apache.spark._ +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION +import org.apache.spark.util.Utils + +object Main extends Logging { + + initializeLogIfNecessary(true) + Signaling.cancelOnInterrupt() + + val conf = new SparkConf() + val rootDir = conf.getOption("spark.repl.classdir").getOrElse(Utils.getLocalDir(conf)) + val outputDir = Utils.createTempDir(root = rootDir, namePrefix = "repl") + + var sparkContext: SparkContext = _ + var sparkSession: SparkSession = _ + // this is a public var because tests reset it. + var interp: SparkILoop = _ + + private var hasErrors = false + + private def scalaOptionError(msg: String): Unit = { + hasErrors = true + // scalastyle:off println + Console.err.println(msg) + // scalastyle:on println + } + + def main(args: Array[String]) { + doMain(args, new SparkILoop) + } + + // Visible for testing + private[repl] def doMain(args: Array[String], _interp: SparkILoop): Unit = { + interp = _interp + val jars = Utils.getLocalUserJarsForShell(conf) + // Remove file:///, file:// or file:/ scheme if exists for each jar + .map { x => if (x.startsWith("file:")) new File(new URI(x)).getPath else x } + .mkString(File.pathSeparator) + val interpArguments = List( + "-Yrepl-class-based", + "-Yrepl-outdir", s"${outputDir.getAbsolutePath}", + "-classpath", jars + ) ++ args.toList + + val settings = new GenericRunnerSettings(scalaOptionError) + settings.processArguments(interpArguments, true) + + if (!hasErrors) { + interp.process(settings) // Repl starts and goes in loop of R.E.P.L + Option(sparkContext).foreach(_.stop) + } + } + + def createSparkSession(): SparkSession = { + val execUri = System.getenv("SPARK_EXECUTOR_URI") + conf.setIfMissing("spark.app.name", "Spark shell") + // SparkContext will detect this configuration and register it with the RpcEnv's + // file server, setting spark.repl.class.uri to the actual URI for executors to + // use. This is sort of ugly but since executors are started as part of SparkContext + // initialization in certain cases, there's an initialization order issue that prevents + // this from being set after SparkContext is instantiated. + conf.set("spark.repl.class.outputDir", outputDir.getAbsolutePath()) + if (execUri != null) { + conf.set("spark.executor.uri", execUri) + } + if (System.getenv("SPARK_HOME") != null) { + conf.setSparkHome(System.getenv("SPARK_HOME")) + } + + val builder = SparkSession.builder.config(conf) + if (conf.get(CATALOG_IMPLEMENTATION.key, "hive").toLowerCase(Locale.ROOT) == "hive") { + if (SparkSession.hiveClassesArePresent) { + // In the case that the property is not set at all, builder's config + // does not have this value set to 'hive' yet. The original default + // behavior is that when there are hive classes, we use hive catalog. + sparkSession = builder.enableHiveSupport().getOrCreate() + logInfo("Created Spark session with Hive support") + } else { + // Need to change it back to 'in-memory' if no hive classes are found + // in the case that the property is set to hive in spark-defaults.conf + builder.config(CATALOG_IMPLEMENTATION.key, "in-memory") + sparkSession = builder.getOrCreate() + logInfo("Created Spark session") + } + } else { + // In the case that the property is set but not to 'hive', the internal + // default is 'in-memory'. So the sparkSession will use in-memory catalog. + sparkSession = builder.getOrCreate() + logInfo("Created Spark session") + } + sparkContext = sparkSession.sparkContext + sparkSession + } + +} http://git-wip-us.apache.org/repos/asf/spark/blob/576c43fb/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala ---------------------------------------------------------------------- diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala new file mode 100644 index 0000000..c7ae194 --- /dev/null +++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.repl + +import java.io._ +import java.net.URLClassLoader + +import scala.collection.mutable.ArrayBuffer + +import org.apache.log4j.{Level, LogManager} + +import org.apache.spark.{SparkContext, SparkFunSuite} +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION + +class ReplSuite extends SparkFunSuite { + + def runInterpreter(master: String, input: String): String = { + val CONF_EXECUTOR_CLASSPATH = "spark.executor.extraClassPath" + + val in = new BufferedReader(new StringReader(input + "\n")) + val out = new StringWriter() + val cl = getClass.getClassLoader + var paths = new ArrayBuffer[String] + if (cl.isInstanceOf[URLClassLoader]) { + val urlLoader = cl.asInstanceOf[URLClassLoader] + for (url <- urlLoader.getURLs) { + if (url.getProtocol == "file") { + paths += url.getFile + } + } + } + val classpath = paths.map(new File(_).getAbsolutePath).mkString(File.pathSeparator) + + val oldExecutorClasspath = System.getProperty(CONF_EXECUTOR_CLASSPATH) + System.setProperty(CONF_EXECUTOR_CLASSPATH, classpath) + Main.sparkContext = null + Main.sparkSession = null // causes recreation of SparkContext for each test. + Main.conf.set("spark.master", master) + Main.doMain(Array("-classpath", classpath), new SparkILoop(in, new PrintWriter(out))) + + if (oldExecutorClasspath != null) { + System.setProperty(CONF_EXECUTOR_CLASSPATH, oldExecutorClasspath) + } else { + System.clearProperty(CONF_EXECUTOR_CLASSPATH) + } + return out.toString + } + + // Simulate the paste mode in Scala REPL. + def runInterpreterInPasteMode(master: String, input: String): String = + runInterpreter(master, ":paste\n" + input + 4.toChar) // 4 is the ascii code of CTRL + D + + def assertContains(message: String, output: String) { + val isContain = output.contains(message) + assert(isContain, + "Interpreter output did not contain '" + message + "':\n" + output) + } + + def assertDoesNotContain(message: String, output: String) { + val isContain = output.contains(message) + assert(!isContain, + "Interpreter output contained '" + message + "':\n" + output) + } + + test("propagation of local properties") { + // A mock ILoop that doesn't install the SIGINT handler. + class ILoop(out: PrintWriter) extends SparkILoop(None, out) { + settings = new scala.tools.nsc.Settings + settings.usejavacp.value = true + org.apache.spark.repl.Main.interp = this + } + + val out = new StringWriter() + Main.interp = new ILoop(new PrintWriter(out)) + Main.sparkContext = new SparkContext("local", "repl-test") + Main.interp.createInterpreter() + + Main.sparkContext.setLocalProperty("someKey", "someValue") + + // Make sure the value we set in the caller to interpret is propagated in the thread that + // interprets the command. + Main.interp.interpret("org.apache.spark.repl.Main.sparkContext.getLocalProperty(\"someKey\")") + assert(out.toString.contains("someValue")) + + Main.sparkContext.stop() + System.clearProperty("spark.driver.port") + } + + test("SPARK-15236: use Hive catalog") { + // turn on the INFO log so that it is possible the code will dump INFO + // entry for using "HiveMetastore" + val rootLogger = LogManager.getRootLogger() + val logLevel = rootLogger.getLevel + rootLogger.setLevel(Level.INFO) + try { + Main.conf.set(CATALOG_IMPLEMENTATION.key, "hive") + val output = runInterpreter("local", + """ + |spark.sql("drop table if exists t_15236") + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + // only when the config is set to hive and + // hive classes are built, we will use hive catalog. + // Then log INFO entry will show things using HiveMetastore + if (SparkSession.hiveClassesArePresent) { + assertContains("HiveMetaStore", output) + } else { + // If hive classes are not built, in-memory catalog will be used + assertDoesNotContain("HiveMetaStore", output) + } + } finally { + rootLogger.setLevel(logLevel) + } + } + + test("SPARK-15236: use in-memory catalog") { + val rootLogger = LogManager.getRootLogger() + val logLevel = rootLogger.getLevel + rootLogger.setLevel(Level.INFO) + try { + Main.conf.set(CATALOG_IMPLEMENTATION.key, "in-memory") + val output = runInterpreter("local", + """ + |spark.sql("drop table if exists t_16236") + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertDoesNotContain("HiveMetaStore", output) + } finally { + rootLogger.setLevel(logLevel) + } + } + + test("broadcast vars") { + // Test that the value that a broadcast var had when it was created is used, + // even if that variable is then modified in the driver program + // TODO: This doesn't actually work for arrays when we run in local mode! + val output = runInterpreter("local", + """ + |var array = new Array[Int](5) + |val broadcastArray = sc.broadcast(array) + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect() + |array(0) = 5 + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect() + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res0: Array[Int] = Array(0, 0, 0, 0, 0)", output) + assertContains("res2: Array[Int] = Array(5, 0, 0, 0, 0)", output) + } + + if (System.getenv("MESOS_NATIVE_JAVA_LIBRARY") != null) { + test("running on Mesos") { + val output = runInterpreter("localquiet", + """ + |var v = 7 + |def getV() = v + |sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_) + |v = 10 + |sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_) + |var array = new Array[Int](5) + |val broadcastArray = sc.broadcast(array) + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect() + |array(0) = 5 + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect() + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res0: Int = 70", output) + assertContains("res1: Int = 100", output) + assertContains("res2: Array[Int] = Array(0, 0, 0, 0, 0)", output) + assertContains("res4: Array[Int] = Array(0, 0, 0, 0, 0)", output) + } + } + + test("line wrapper only initialized once when used as encoder outer scope") { + val output = runInterpreter("local", + """ + |val fileName = "repl-test-" + System.currentTimeMillis + |val tmpDir = System.getProperty("java.io.tmpdir") + |val file = new java.io.File(tmpDir, fileName) + |def createFile(): Unit = file.createNewFile() + | + |createFile();case class TestCaseClass(value: Int) + |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).collect() + | + |file.delete() + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } + + test("define case class and create Dataset together with paste mode") { + val output = runInterpreterInPasteMode("local-cluster[1,1,1024]", + """ + |import spark.implicits._ + |case class TestClass(value: Int) + |Seq(TestClass(1)).toDS() + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/576c43fb/repl/src/test/scala/org/apache/spark/repl/SingletonReplSuite.scala ---------------------------------------------------------------------- diff --git a/repl/src/test/scala/org/apache/spark/repl/SingletonReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/SingletonReplSuite.scala new file mode 100644 index 0000000..ec3d790 --- /dev/null +++ b/repl/src/test/scala/org/apache/spark/repl/SingletonReplSuite.scala @@ -0,0 +1,408 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.repl + +import java.io._ +import java.net.URLClassLoader + +import scala.collection.mutable.ArrayBuffer + +import org.apache.commons.lang3.StringEscapeUtils + +import org.apache.spark.SparkFunSuite +import org.apache.spark.util.Utils + +/** + * A special test suite for REPL that all test cases share one REPL instance. + */ +class SingletonReplSuite extends SparkFunSuite { + + private val out = new StringWriter() + private val in = new PipedOutputStream() + private var thread: Thread = _ + + private val CONF_EXECUTOR_CLASSPATH = "spark.executor.extraClassPath" + private val oldExecutorClasspath = System.getProperty(CONF_EXECUTOR_CLASSPATH) + + override def beforeAll(): Unit = { + super.beforeAll() + + val cl = getClass.getClassLoader + var paths = new ArrayBuffer[String] + if (cl.isInstanceOf[URLClassLoader]) { + val urlLoader = cl.asInstanceOf[URLClassLoader] + for (url <- urlLoader.getURLs) { + if (url.getProtocol == "file") { + paths += url.getFile + } + } + } + val classpath = paths.map(new File(_).getAbsolutePath).mkString(File.pathSeparator) + + System.setProperty(CONF_EXECUTOR_CLASSPATH, classpath) + Main.conf.set("spark.master", "local-cluster[2,1,1024]") + val interp = new SparkILoop( + new BufferedReader(new InputStreamReader(new PipedInputStream(in))), + new PrintWriter(out)) + + // Forces to create new SparkContext + Main.sparkContext = null + Main.sparkSession = null + + // Starts a new thread to run the REPL interpreter, so that we won't block. + thread = new Thread(new Runnable { + override def run(): Unit = Main.doMain(Array("-classpath", classpath), interp) + }) + thread.setDaemon(true) + thread.start() + + waitUntil(() => out.toString.contains("Type :help for more information")) + } + + override def afterAll(): Unit = { + in.close() + thread.join() + if (oldExecutorClasspath != null) { + System.setProperty(CONF_EXECUTOR_CLASSPATH, oldExecutorClasspath) + } else { + System.clearProperty(CONF_EXECUTOR_CLASSPATH) + } + super.afterAll() + } + + private def waitUntil(cond: () => Boolean): Unit = { + import scala.concurrent.duration._ + import org.scalatest.concurrent.Eventually._ + + eventually(timeout(50.seconds), interval(500.millis)) { + assert(cond(), "current output: " + out.toString) + } + } + + /** + * Run the given commands string in a globally shared interpreter instance. Note that the given + * commands should not crash the interpreter, to not affect other test cases. + */ + def runInterpreter(input: String): String = { + val currentOffset = out.getBuffer.length() + // append a special statement to the end of the given code, so that we can know what's + // the final output of this code snippet and rely on it to wait until the output is ready. + val timestamp = System.currentTimeMillis() + in.write((input + s"\nval _result_$timestamp = 1\n").getBytes) + in.flush() + val stopMessage = s"_result_$timestamp: Int = 1" + waitUntil(() => out.getBuffer.substring(currentOffset).contains(stopMessage)) + out.getBuffer.substring(currentOffset) + } + + def assertContains(message: String, output: String) { + val isContain = output.contains(message) + assert(isContain, + "Interpreter output did not contain '" + message + "':\n" + output) + } + + def assertDoesNotContain(message: String, output: String) { + val isContain = output.contains(message) + assert(!isContain, + "Interpreter output contained '" + message + "':\n" + output) + } + + test("simple foreach with accumulator") { + val output = runInterpreter( + """ + |val accum = sc.longAccumulator + |sc.parallelize(1 to 10).foreach(x => accum.add(x)) + |val res = accum.value + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res: Long = 55", output) + } + + test("external vars") { + val output = runInterpreter( + """ + |var v = 7 + |val res1 = sc.parallelize(1 to 10).map(x => v).collect().reduceLeft(_+_) + |v = 10 + |val res2 = sc.parallelize(1 to 10).map(x => v).collect().reduceLeft(_+_) + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res1: Int = 70", output) + assertContains("res2: Int = 100", output) + } + + test("external classes") { + val output = runInterpreter( + """ + |class C { + |def foo = 5 + |} + |val res = sc.parallelize(1 to 10).map(x => (new C).foo).collect().reduceLeft(_+_) + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res: Int = 50", output) + } + + test("external functions") { + val output = runInterpreter( + """ + |def double(x: Int) = x + x + |val res = sc.parallelize(1 to 10).map(x => double(x)).collect().reduceLeft(_+_) + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res: Int = 110", output) + } + + test("external functions that access vars") { + val output = runInterpreter( + """ + |var v = 7 + |def getV() = v + |val res1 = sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_) + |v = 10 + |val res2 = sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_) + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res1: Int = 70", output) + assertContains("res2: Int = 100", output) + } + + test("broadcast vars") { + // Test that the value that a broadcast var had when it was created is used, + // even if that variable is then modified in the driver program + val output = runInterpreter( + """ + |var array = new Array[Int](5) + |val broadcastArray = sc.broadcast(array) + |val res1 = sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect() + |array(0) = 5 + |val res2 = sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect() + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res1: Array[Int] = Array(0, 0, 0, 0, 0)", output) + assertContains("res2: Array[Int] = Array(0, 0, 0, 0, 0)", output) + } + + test("interacting with files") { + val tempDir = Utils.createTempDir() + val out = new FileWriter(tempDir + "/input") + out.write("Hello world!\n") + out.write("What's up?\n") + out.write("Goodbye\n") + out.close() + val output = runInterpreter( + """ + |var file = sc.textFile("%s").cache() + |val res1 = file.count() + |val res2 = file.count() + |val res3 = file.count() + """.stripMargin.format(StringEscapeUtils.escapeJava( + tempDir.getAbsolutePath + File.separator + "input"))) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res1: Long = 3", output) + assertContains("res2: Long = 3", output) + assertContains("res3: Long = 3", output) + Utils.deleteRecursively(tempDir) + } + + test("local-cluster mode") { + val output = runInterpreter( + """ + |var v = 7 + |def getV() = v + |val res1 = sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_) + |v = 10 + |val res2 = sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_) + |var array = new Array[Int](5) + |val broadcastArray = sc.broadcast(array) + |val res3 = sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect() + |array(0) = 5 + |val res4 = sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect() + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res1: Int = 70", output) + assertContains("res2: Int = 100", output) + assertContains("res3: Array[Int] = Array(0, 0, 0, 0, 0)", output) + assertContains("res4: Array[Int] = Array(0, 0, 0, 0, 0)", output) + } + + test("SPARK-1199 two instances of same class don't type check.") { + val output = runInterpreter( + """ + |case class Sum(exp: String, exp2: String) + |val a = Sum("A", "B") + |def b(a: Sum): String = a match { case Sum(_, _) => "Found Sum" } + |b(a) + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } + + test("SPARK-2452 compound statements.") { + val output = runInterpreter( + """ + |val x = 4 ; def f() = x + |f() + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } + + test("SPARK-2576 importing implicits") { + // We need to use local-cluster to test this case. + val output = runInterpreter( + """ + |import spark.implicits._ + |case class TestCaseClass(value: Int) + |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toDF().collect() + | + |// Test Dataset Serialization in the REPL + |Seq(TestCaseClass(1)).toDS().collect() + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } + + test("Datasets and encoders") { + val output = runInterpreter( + """ + |import org.apache.spark.sql.functions._ + |import org.apache.spark.sql.{Encoder, Encoders} + |import org.apache.spark.sql.expressions.Aggregator + |import org.apache.spark.sql.TypedColumn + |val simpleSum = new Aggregator[Int, Int, Int] { + | def zero: Int = 0 // The initial value. + | def reduce(b: Int, a: Int) = b + a // Add an element to the running total + | def merge(b1: Int, b2: Int) = b1 + b2 // Merge intermediate values. + | def finish(b: Int) = b // Return the final result. + | def bufferEncoder: Encoder[Int] = Encoders.scalaInt + | def outputEncoder: Encoder[Int] = Encoders.scalaInt + |}.toColumn + | + |val ds = Seq(1, 2, 3, 4).toDS() + |ds.select(simpleSum).collect + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } + + test("SPARK-2632 importing a method from non serializable class and not using it.") { + val output = runInterpreter( + """ + |class TestClass() { def testMethod = 3 } + |val t = new TestClass + |import t.testMethod + |case class TestCaseClass(value: Int) + |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).collect() + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } + + test("collecting objects of class defined in repl") { + val output = runInterpreter( + """ + |case class Foo(i: Int) + |val res = sc.parallelize((1 to 100).map(Foo), 10).collect() + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res: Array[Foo] = Array(Foo(1),", output) + } + + test("collecting objects of class defined in repl - shuffling") { + val output = runInterpreter( + """ + |case class Foo(i: Int) + |val list = List((1, Foo(1)), (1, Foo(2))) + |val res = sc.parallelize(list).groupByKey().collect() + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res: Array[(Int, Iterable[Foo])] = Array((1,", output) + } + + test("replicating blocks of object with class defined in repl") { + val output = runInterpreter( + """ + |val timeout = 60000 // 60 seconds + |val start = System.currentTimeMillis + |while(sc.getExecutorStorageStatus.size != 3 && + | (System.currentTimeMillis - start) < timeout) { + | Thread.sleep(10) + |} + |if (System.currentTimeMillis - start >= timeout) { + | throw new java.util.concurrent.TimeoutException("Executors were not up in 60 seconds") + |} + |import org.apache.spark.storage.StorageLevel._ + |case class Foo(i: Int) + |val ret = sc.parallelize((1 to 100).map(Foo), 10).persist(MEMORY_AND_DISK_2) + |ret.count() + |val res = sc.getExecutorStorageStatus.map(s => s.rddBlocksById(ret.id).size).sum + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res: Int = 20", output) + } + + test("should clone and clean line object in ClosureCleaner") { + val output = runInterpreter( + """ + |import org.apache.spark.rdd.RDD + | + |val lines = sc.textFile("pom.xml") + |case class Data(s: String) + |val dataRDD = lines.map(line => Data(line.take(3))) + |dataRDD.cache.count + |val repartitioned = dataRDD.repartition(dataRDD.partitions.size) + |repartitioned.cache.count + | + |def getCacheSize(rdd: RDD[_]) = { + | sc.getRDDStorageInfo.filter(_.id == rdd.id).map(_.memSize).sum + |} + |val cacheSize1 = getCacheSize(dataRDD) + |val cacheSize2 = getCacheSize(repartitioned) + | + |// The cache size of dataRDD and the repartitioned one should be similar. + |val deviation = math.abs(cacheSize2 - cacheSize1).toDouble / cacheSize1 + |assert(deviation < 0.2, + | s"deviation too large: $deviation, first size: $cacheSize1, second size: $cacheSize2") + """.stripMargin) + assertDoesNotContain("AssertionError", output) + assertDoesNotContain("Exception", output) + } + + test("newProductSeqEncoder with REPL defined class") { + val output = runInterpreter( + """ + |case class Click(id: Int) + |spark.implicits.newProductSeqEncoder[Click] + """.stripMargin) + + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/576c43fb/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala index 8761ae4..4894036 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala @@ -179,7 +179,7 @@ case class Percentile( val sortedCounts = buffer.toSeq.sortBy(_._1)( child.dataType.asInstanceOf[NumericType].ordering.asInstanceOf[Ordering[AnyRef]]) - val accumlatedCounts = sortedCounts.scanLeft(sortedCounts.head._1, 0L) { + val accumlatedCounts = sortedCounts.scanLeft((sortedCounts.head._1, 0L)) { case ((key1, count1), (key2, count2)) => (key2, count1 + count2) }.tail val maxPosition = accumlatedCounts.last._2 - 1 http://git-wip-us.apache.org/repos/asf/spark/blob/576c43fb/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala index 3aa4bf6..352fb54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala @@ -177,7 +177,7 @@ object Metadata { private def toJsonValue(obj: Any): JValue = { obj match { case map: Map[_, _] => - val fields = map.toList.map { case (k: String, v) => (k, toJsonValue(v)) } + val fields = map.toList.map { case (k, v) => (k.toString, toJsonValue(v)) } JObject(fields) case arr: Array[_] => val values = arr.toList.map(toJsonValue) http://git-wip-us.apache.org/repos/asf/spark/blob/576c43fb/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index c35e563..65ca374 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -96,7 +96,7 @@ case class GenerateExec( } else { outputRows.map(joinedRow.withRight) } - } ++ LazyIterator(boundGenerator.terminate).map { row => + } ++ LazyIterator(() => boundGenerator.terminate()).map { row => // we leave the left side as the last element of its child output // keep it the same as Hive does joinedRow.withRight(row) @@ -109,7 +109,7 @@ case class GenerateExec( } else { outputRows } - } ++ LazyIterator(boundGenerator.terminate) + } ++ LazyIterator(() => boundGenerator.terminate()) } // Convert the rows to unsafe rows. http://git-wip-us.apache.org/repos/asf/spark/blob/576c43fb/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala index d74aae3..203d449 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala @@ -119,6 +119,7 @@ class InMemoryFileIndex( case None => pathsToFetch += path } + Unit // for some reasons scalac 2.12 needs this; return type doesn't matter } val filter = FileInputFormat.getInputPathFilter(new JobConf(hadoopConf, this.getClass)) val discovered = InMemoryFileIndex.bulkListLeafFiles( http://git-wip-us.apache.org/repos/asf/spark/blob/576c43fb/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 0e41f3c..7d6d7e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -205,7 +205,7 @@ class UnivocityParser( } throw BadRecordException( () => getCurrentInput, - getPartialResult, + () => getPartialResult(), new RuntimeException("Malformed CSV record")) } else { try { http://git-wip-us.apache.org/repos/asf/spark/blob/576c43fb/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 1b6a28c..f8058b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -216,7 +216,7 @@ private[joins] class UnsafeHashedRelation( } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { - read(in.readInt, in.readLong, in.readFully) + read(() => in.readInt(), () => in.readLong(), in.readFully) } private def read( @@ -277,7 +277,7 @@ private[joins] class UnsafeHashedRelation( } override def read(kryo: Kryo, in: Input): Unit = Utils.tryOrIOException { - read(in.readInt, in.readLong, in.readBytes) + read(() => in.readInt(), () => in.readLong(), in.readBytes) } override def getAverageProbesPerLookup: Double = binaryMap.getAverageProbesPerLookup @@ -766,11 +766,11 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap } override def readExternal(in: ObjectInput): Unit = { - read(in.readBoolean, in.readLong, in.readFully) + read(() => in.readBoolean(), () => in.readLong(), in.readFully) } override def read(kryo: Kryo, in: Input): Unit = { - read(in.readBoolean, in.readLong, in.readBytes) + read(() => in.readBoolean(), () => in.readLong(), in.readBytes) } /** http://git-wip-us.apache.org/repos/asf/spark/blob/576c43fb/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 13b006f..c132cab 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -1334,6 +1334,10 @@ public class JavaDatasetSuite implements Serializable { return "BeanWithEnum(" + enumField + ", " + regularField + ")"; } + public int hashCode() { + return Objects.hashCode(enumField, regularField); + } + public boolean equals(Object other) { if (other instanceof BeanWithEnum) { BeanWithEnum beanWithEnum = (BeanWithEnum) other; http://git-wip-us.apache.org/repos/asf/spark/blob/576c43fb/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 09502d0..247c30e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -230,11 +230,9 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { val resNaN1 = dfNaN.stat.approxQuantile("input1", Array(q1, q2), epsilon) assert(resNaN1.count(_.isNaN) === 0) - assert(resNaN1.count(_ == null) === 0) val resNaN2 = dfNaN.stat.approxQuantile("input2", Array(q1, q2), epsilon) assert(resNaN2.count(_.isNaN) === 0) - assert(resNaN2.count(_ == null) === 0) val resNaN3 = dfNaN.stat.approxQuantile("input3", Array(q1, q2), epsilon) assert(resNaN3.isEmpty) @@ -242,7 +240,6 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { val resNaNAll = dfNaN.stat.approxQuantile(Array("input1", "input2", "input3"), Array(q1, q2), epsilon) assert(resNaNAll.flatten.count(_.isNaN) === 0) - assert(resNaNAll.flatten.count(_ == null) === 0) assert(resNaN1(0) === resNaNAll(0)(0)) assert(resNaN1(1) === resNaNAll(0)(1)) http://git-wip-us.apache.org/repos/asf/spark/blob/576c43fb/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala index 2986b7f..46eec73 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala @@ -289,7 +289,7 @@ class StreamingQueryManagerSuite extends StreamTest with BeforeAndAfter { } } - AwaitTerminationTester.test(expectedBehavior, awaitTermFunc, testBehaviorFor) + AwaitTerminationTester.test(expectedBehavior, () => awaitTermFunc(), testBehaviorFor) } /** Stop a random active query either with `stop()` or with an error */ http://git-wip-us.apache.org/repos/asf/spark/blob/576c43fb/sql/hive-thriftserver/pom.xml ---------------------------------------------------------------------- diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index a5a8e26..3135a8a 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -63,6 +63,16 @@ <groupId>${hive.group}</groupId> <artifactId>hive-beeline</artifactId> </dependency> + <dependency> + <groupId>org.eclipse.jetty</groupId> + <artifactId>jetty-server</artifactId> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>org.eclipse.jetty</groupId> + <artifactId>jetty-servlet</artifactId> + <scope>provided</scope> + </dependency> <!-- Added for selenium: --> <dependency> <groupId>org.seleniumhq.selenium</groupId> http://git-wip-us.apache.org/repos/asf/spark/blob/576c43fb/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala ---------------------------------------------------------------------- diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 8c7418e..0274038 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -596,7 +596,7 @@ class StreamingContext private[streaming] ( } logDebug("Adding shutdown hook") // force eager creation of logger shutdownHookRef = ShutdownHookManager.addShutdownHook( - StreamingContext.SHUTDOWN_HOOK_PRIORITY)(stopOnShutdown) + StreamingContext.SHUTDOWN_HOOK_PRIORITY)(() => stopOnShutdown()) // Registering Streaming Metrics at the start of the StreamingContext assert(env.metricsSystem != null) env.metricsSystem.registerSource(streamingSource) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org