http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 1ff288c..e401abe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -57,7 +57,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("show functions") { def getFunctions(pattern: String): Seq[Row] = { - StringUtils.filterPattern(sqlContext.sessionState.functionRegistry.listFunction(), pattern) + StringUtils.filterPattern(spark.sessionState.functionRegistry.listFunction(), pattern) .map(Row(_)) } checkAnswer(sql("SHOW functions"), getFunctions("*")) @@ -88,7 +88,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-14415: All functions should have own descriptions") { - for (f <- sqlContext.sessionState.functionRegistry.listFunction()) { + for (f <- spark.sessionState.functionRegistry.listFunction()) { if (!Seq("cube", "grouping", "grouping_id", "rollup", "window").contains(f)) { checkKeywordsNotExist(sql(s"describe function `$f`"), "N/A.") } @@ -102,7 +102,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { (43, 81, 24) ).toDF("a", "b", "c").registerTempTable("cachedData") - sqlContext.cacheTable("cachedData") + spark.catalog.cacheTable("cachedData") checkAnswer( sql("SELECT t1.b FROM cachedData, cachedData t1 GROUP BY t1.b"), Row(0) :: Row(81) :: Nil) @@ -193,7 +193,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("grouping on nested fields") { - sqlContext.read.json(sparkContext.parallelize( + spark.read.json(sparkContext.parallelize( """{"nested": {"attribute": 1}, "value": 2}""" :: Nil)) .registerTempTable("rows") @@ -211,7 +211,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-6201 IN type conversion") { - sqlContext.read.json( + spark.read.json( sparkContext.parallelize( Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}"))) .registerTempTable("d") @@ -222,7 +222,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-11226 Skip empty line in json file") { - sqlContext.read.json( + spark.read.json( sparkContext.parallelize( Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}", ""))) .registerTempTable("d") @@ -258,9 +258,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("aggregation with codegen") { // Prepare a table that we can group some rows. - sqlContext.table("testData") - .union(sqlContext.table("testData")) - .union(sqlContext.table("testData")) + spark.table("testData") + .union(spark.table("testData")) + .union(spark.table("testData")) .registerTempTable("testData3x") try { @@ -333,7 +333,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { "SELECT sum('a'), avg('a'), count(null) FROM testData", Row(null, null, 0) :: Nil) } finally { - sqlContext.dropTempTable("testData3x") + spark.catalog.dropTempTable("testData3x") } } @@ -1041,7 +1041,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SET commands semantics using sql()") { - sqlContext.conf.clear() + spark.wrapped.conf.clear() val testKey = "test.key.0" val testVal = "test.val.0" val nonexistentKey = "nonexistent" @@ -1082,17 +1082,17 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { sql(s"SET $nonexistentKey"), Row(nonexistentKey, "<undefined>") ) - sqlContext.conf.clear() + spark.wrapped.conf.clear() } test("SET commands with illegal or inappropriate argument") { - sqlContext.conf.clear() + spark.wrapped.conf.clear() // Set negative mapred.reduce.tasks for automatically determining // the number of reducers is not supported intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-1")) intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-01")) intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-2")) - sqlContext.conf.clear() + spark.wrapped.conf.clear() } test("apply schema") { @@ -1110,7 +1110,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(values(0).toInt, values(1), values(2).toBoolean, v4) } - val df1 = sqlContext.createDataFrame(rowRDD1, schema1) + val df1 = spark.createDataFrame(rowRDD1, schema1) df1.registerTempTable("applySchema1") checkAnswer( sql("SELECT * FROM applySchema1"), @@ -1140,7 +1140,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) } - val df2 = sqlContext.createDataFrame(rowRDD2, schema2) + val df2 = spark.createDataFrame(rowRDD2, schema2) df2.registerTempTable("applySchema2") checkAnswer( sql("SELECT * FROM applySchema2"), @@ -1165,7 +1165,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(Row(values(0).toInt, values(2).toBoolean), scala.collection.mutable.Map(values(1) -> v4)) } - val df3 = sqlContext.createDataFrame(rowRDD3, schema2) + val df3 = spark.createDataFrame(rowRDD3, schema2) df3.registerTempTable("applySchema3") checkAnswer( @@ -1210,7 +1210,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { .build() val schemaWithMeta = new StructType(Array( schema("id"), schema("name").copy(metadata = metadata), schema("age"))) - val personWithMeta = sqlContext.createDataFrame(person.rdd, schemaWithMeta) + val personWithMeta = spark.createDataFrame(person.rdd, schemaWithMeta) def validateMetadata(rdd: DataFrame): Unit = { assert(rdd.schema("name").metadata.getString(docKey) == docValue) } @@ -1226,7 +1226,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-3371 Renaming a function expression with group by gives error") { - sqlContext.udf.register("len", (s: String) => s.length) + spark.udf.register("len", (s: String) => s.length) checkAnswer( sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), Row(1)) @@ -1409,7 +1409,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-3483 Special chars in column names") { val data = sparkContext.parallelize( Seq("""{"key?number1": "value1", "key.number2": "value2"}""")) - sqlContext.read.json(data).registerTempTable("records") + spark.read.json(data).registerTempTable("records") sql("SELECT `key?number1`, `key.number2` FROM records") } @@ -1450,15 +1450,15 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-4322 Grouping field with struct field as sub expression") { - sqlContext.read.json(sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)) + spark.read.json(sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)) .registerTempTable("data") checkAnswer(sql("SELECT a.b[0].c FROM data GROUP BY a.b[0].c"), Row(1)) - sqlContext.dropTempTable("data") + spark.catalog.dropTempTable("data") - sqlContext.read.json( + spark.read.json( sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data") checkAnswer(sql("SELECT a.b + 1 FROM data GROUP BY a.b + 1"), Row(2)) - sqlContext.dropTempTable("data") + spark.catalog.dropTempTable("data") } test("SPARK-4432 Fix attribute reference resolution error when using ORDER BY") { @@ -1504,7 +1504,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-6145: ORDER BY test for nested fields") { - sqlContext.read.json(sparkContext.makeRDD( + spark.read.json(sparkContext.makeRDD( """{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)) .registerTempTable("nestedOrder") @@ -1517,14 +1517,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-6145: special cases") { - sqlContext.read.json(sparkContext.makeRDD( + spark.read.json(sparkContext.makeRDD( """{"a": {"b": [1]}, "b": [{"a": 1}], "_c0": {"a": 1}}""" :: Nil)).registerTempTable("t") checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY _c0.a"), Row(1)) checkAnswer(sql("SELECT b[0].a FROM t ORDER BY _c0.a"), Row(1)) } test("SPARK-6898: complete support for special chars in column names") { - sqlContext.read.json(sparkContext.makeRDD( + spark.read.json(sparkContext.makeRDD( """{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil)) .registerTempTable("t") @@ -1628,7 +1628,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-7067: order by queries for complex ExtractValue chain") { withTempTable("t") { - sqlContext.read.json(sparkContext.makeRDD( + spark.read.json(sparkContext.makeRDD( """{"a": {"b": [{"c": 1}]}, "b": [{"d": 1}]}""" :: Nil)).registerTempTable("t") checkAnswer(sql("SELECT a.b FROM t ORDER BY b[0].d"), Row(Seq(Row(1)))) } @@ -1776,7 +1776,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { // We don't support creating a temporary table while specifying a database intercept[AnalysisException] { - sqlContext.sql( + spark.sql( s""" |CREATE TEMPORARY TABLE db.t |USING parquet @@ -1787,7 +1787,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { }.getMessage // If you use backticks to quote the name then it's OK. - sqlContext.sql( + spark.sql( s""" |CREATE TEMPORARY TABLE `db.t` |USING parquet @@ -1795,7 +1795,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { | path '$path' |) """.stripMargin) - checkAnswer(sqlContext.table("`db.t`"), df) + checkAnswer(spark.table("`db.t`"), df) } } @@ -1818,7 +1818,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("run sql directly on files") { - val df = sqlContext.range(100).toDF() + val df = spark.range(100).toDF() withTempPath(f => { df.write.json(f.getCanonicalPath) checkAnswer(sql(s"select id from json.`${f.getCanonicalPath}`"), @@ -1880,7 +1880,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-11303: filter should not be pushed down into sample") { - val df = sqlContext.range(100) + val df = spark.range(100) List(true, false).foreach { withReplacement => val sampled = df.sample(withReplacement, 0.1, 1) val sampledOdd = sampled.filter("id % 2 != 0") @@ -2059,7 +2059,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { // Identity udf that tracks the number of times it is called. val countAcc = sparkContext.accumulator(0, "CallCount") - sqlContext.udf.register("testUdf", (x: Int) => { + spark.udf.register("testUdf", (x: Int) => { countAcc.++=(1) x }) @@ -2093,9 +2093,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 1) // Try disabling it via configuration. - sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false") + spark.conf.set("spark.sql.subexpressionElimination.enabled", "false") verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2) - sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true") + spark.conf.set("spark.sql.subexpressionElimination.enabled", "true") verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) } }
http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala index ddab918..b489b74 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.test.SharedSQLContext class SerializationSuite extends SparkFunSuite with SharedSQLContext { test("[SPARK-5235] SQLContext should be serializable") { - val _sqlContext = new SQLContext(sparkContext) - new JavaSerializer(new SparkConf()).newInstance().serialize(_sqlContext) + val spark = SparkSession.builder.getOrCreate() + new JavaSerializer(new SparkConf()).newInstance().serialize(spark.wrapped) } } http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala index 6fb1aca..1ab562f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala @@ -290,7 +290,7 @@ trait StreamTest extends QueryTest with Timeouts { verify(currentStream == null, "stream already running") lastStream = currentStream currentStream = - sqlContext + spark .streams .startQuery( StreamExecution.nextName, http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 6809f26..c7b95c2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -281,7 +281,7 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { } test("number format function") { - val df = sqlContext.range(1) + val df = spark.range(1) checkAnswer( df.select(format_number(lit(5L), 4)), http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index ec95033..427f24a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -26,7 +26,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("built-in fixed arity expressions") { - val df = sqlContext.emptyDataFrame + val df = spark.emptyDataFrame df.selectExpr("rand()", "randn()", "rand(5)", "randn(50)") } @@ -55,23 +55,23 @@ class UDFSuite extends QueryTest with SharedSQLContext { val df = Seq((1, "Tearing down the walls that divide us")).toDF("id", "saying") df.registerTempTable("tmp_table") checkAnswer(sql("select spark_partition_id() from tmp_table").toDF(), Row(0)) - sqlContext.dropTempTable("tmp_table") + spark.catalog.dropTempTable("tmp_table") } test("SPARK-8005 input_file_name") { withTempPath { dir => val data = sparkContext.parallelize(0 to 10, 2).toDF("id") data.write.parquet(dir.getCanonicalPath) - sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("test_table") + spark.read.parquet(dir.getCanonicalPath).registerTempTable("test_table") val answer = sql("select input_file_name() from test_table").head().getString(0) assert(answer.contains(dir.getCanonicalPath)) assert(sql("select input_file_name() from test_table").distinct().collect().length >= 2) - sqlContext.dropTempTable("test_table") + spark.catalog.dropTempTable("test_table") } } test("error reporting for incorrect number of arguments") { - val df = sqlContext.emptyDataFrame + val df = spark.emptyDataFrame val e = intercept[AnalysisException] { df.selectExpr("substr('abcd', 2, 3, 4)") } @@ -79,7 +79,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("error reporting for undefined functions") { - val df = sqlContext.emptyDataFrame + val df = spark.emptyDataFrame val e = intercept[AnalysisException] { df.selectExpr("a_function_that_does_not_exist()") } @@ -88,22 +88,22 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("Simple UDF") { - sqlContext.udf.register("strLenScala", (_: String).length) + spark.udf.register("strLenScala", (_: String).length) assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4) } test("ZeroArgument UDF") { - sqlContext.udf.register("random0", () => { Math.random()}) + spark.udf.register("random0", () => { Math.random()}) assert(sql("SELECT random0()").head().getDouble(0) >= 0.0) } test("TwoArgument UDF") { - sqlContext.udf.register("strLenScala", (_: String).length + (_: Int)) + spark.udf.register("strLenScala", (_: String).length + (_: Int)) assert(sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) } test("UDF in a WHERE") { - sqlContext.udf.register("oneArgFilter", (n: Int) => { n > 80 }) + spark.udf.register("oneArgFilter", (n: Int) => { n > 80 }) val df = sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))).toDF() @@ -115,7 +115,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("UDF in a HAVING") { - sqlContext.udf.register("havingFilter", (n: Long) => { n > 5 }) + spark.udf.register("havingFilter", (n: Long) => { n > 5 }) val df = Seq(("red", 1), ("red", 2), ("blue", 10), ("green", 100), ("green", 200)).toDF("g", "v") @@ -134,7 +134,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("UDF in a GROUP BY") { - sqlContext.udf.register("groupFunction", (n: Int) => { n > 10 }) + spark.udf.register("groupFunction", (n: Int) => { n > 10 }) val df = Seq(("red", 1), ("red", 2), ("blue", 10), ("green", 100), ("green", 200)).toDF("g", "v") @@ -151,10 +151,10 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("UDFs everywhere") { - sqlContext.udf.register("groupFunction", (n: Int) => { n > 10 }) - sqlContext.udf.register("havingFilter", (n: Long) => { n > 2000 }) - sqlContext.udf.register("whereFilter", (n: Int) => { n < 150 }) - sqlContext.udf.register("timesHundred", (n: Long) => { n * 100 }) + spark.udf.register("groupFunction", (n: Int) => { n > 10 }) + spark.udf.register("havingFilter", (n: Long) => { n > 2000 }) + spark.udf.register("whereFilter", (n: Int) => { n < 150 }) + spark.udf.register("timesHundred", (n: Long) => { n * 100 }) val df = Seq(("red", 1), ("red", 2), ("blue", 10), ("green", 100), ("green", 200)).toDF("g", "v") @@ -173,7 +173,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("struct UDF") { - sqlContext.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) + spark.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) val result = sql("SELECT returnStruct('test', 'test2') as ret") @@ -182,27 +182,27 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("udf that is transformed") { - sqlContext.udf.register("makeStruct", (x: Int, y: Int) => (x, y)) + spark.udf.register("makeStruct", (x: Int, y: Int) => (x, y)) // 1 + 1 is constant folded causing a transformation. assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2)) } test("type coercion for udf inputs") { - sqlContext.udf.register("intExpected", (x: Int) => x) + spark.udf.register("intExpected", (x: Int) => x) // pass a decimal to intExpected. assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1) } test("udf in different types") { - sqlContext.udf.register("testDataFunc", (n: Int, s: String) => { (n, s) }) - sqlContext.udf.register("decimalDataFunc", + spark.udf.register("testDataFunc", (n: Int, s: String) => { (n, s) }) + spark.udf.register("decimalDataFunc", (a: java.math.BigDecimal, b: java.math.BigDecimal) => { (a, b) }) - sqlContext.udf.register("binaryDataFunc", (a: Array[Byte], b: Int) => { (a, b) }) - sqlContext.udf.register("arrayDataFunc", + spark.udf.register("binaryDataFunc", (a: Array[Byte], b: Int) => { (a, b) }) + spark.udf.register("arrayDataFunc", (data: Seq[Int], nestedData: Seq[Seq[Int]]) => { (data, nestedData) }) - sqlContext.udf.register("mapDataFunc", + spark.udf.register("mapDataFunc", (data: scala.collection.Map[Int, String]) => { data }) - sqlContext.udf.register("complexDataFunc", + spark.udf.register("complexDataFunc", (m: Map[String, Int], a: Seq[Int], b: Boolean) => { (m, a, b) } ) checkAnswer( @@ -235,7 +235,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("SPARK-11716 UDFRegistration does not include the input data type in returned UDF") { - val myUDF = sqlContext.udf.register("testDataFunc", (n: Int, s: String) => { (n, s.toInt) }) + val myUDF = spark.udf.register("testDataFunc", (n: Int, s: String) => { (n, s.toInt) }) // Without the fix, this will fail because we fail to cast data type of b to string // because myUDF does not know its input data type. With the fix, this query should not http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index a49aaa8..3057e01 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -94,7 +94,7 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT } test("UDTs and UDFs") { - sqlContext.udf.register("testType", (d: UDT.MyDenseVector) => d.isInstanceOf[UDT.MyDenseVector]) + spark.udf.register("testType", (d: UDT.MyDenseVector) => d.isInstanceOf[UDT.MyDenseVector]) pointsRDD.registerTempTable("points") checkAnswer( sql("SELECT testType(features) from points"), @@ -106,7 +106,7 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT val path = dir.getCanonicalPath pointsRDD.write.parquet(path) checkAnswer( - sqlContext.read.parquet(path), + spark.read.parquet(path), Seq( Row(1.0, new UDT.MyDenseVector(Array(0.1, 1.0))), Row(0.0, new UDT.MyDenseVector(Array(0.2, 2.0))))) @@ -118,7 +118,7 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT val path = dir.getCanonicalPath pointsRDD.repartition(1).write.parquet(path) checkAnswer( - sqlContext.read.parquet(path), + spark.read.parquet(path), Seq( Row(1.0, new UDT.MyDenseVector(Array(0.1, 1.0))), Row(0.0, new UDT.MyDenseVector(Array(0.2, 2.0))))) @@ -146,7 +146,7 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT )) val stringRDD = sparkContext.parallelize(data) - val jsonRDD = sqlContext.read.schema(schema).json(stringRDD) + val jsonRDD = spark.read.schema(schema).json(stringRDD) checkAnswer( jsonRDD, Row(1, new UDT.MyDenseVector(Array(1.1, 2.2, 3.3, 4.4))) :: @@ -167,7 +167,7 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT )) val stringRDD = sparkContext.parallelize(data) - val jsonDataset = sqlContext.read.schema(schema).json(stringRDD) + val jsonDataset = spark.read.schema(schema).json(stringRDD) .as[(Int, UDT.MyDenseVector)] checkDataset( jsonDataset, http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index 01d485c..70a00a4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -19,12 +19,11 @@ package org.apache.spark.sql.execution import org.scalatest.BeforeAndAfterAll -import org.apache.spark.{MapOutputStatistics, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.{MapOutputStatistics, SparkConf, SparkFunSuite} import org.apache.spark.sql._ import org.apache.spark.sql.execution.exchange.{ExchangeCoordinator, ShuffleExchange} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.TestSQLContext class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { @@ -251,7 +250,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } def withSQLContext( - f: SQLContext => Unit, + f: SparkSession => Unit, targetNumPostShufflePartitions: Int, minNumPostShufflePartitions: Option[Int]): Unit = { val sparkConf = @@ -272,9 +271,11 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { case None => sparkConf.set(SQLConf.SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS.key, "-1") } - val sparkContext = new SparkContext(sparkConf) - val sqlContext = new TestSQLContext(sparkContext) - try f(sqlContext) finally sparkContext.stop() + + val spark = SparkSession.builder + .config(sparkConf) + .getOrCreate() + try f(spark) finally spark.stop() } Seq(Some(3), None).foreach { minNumPostShufflePartitions => @@ -284,9 +285,9 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } test(s"determining the number of reducers: aggregate operator$testNameNote") { - val test = { sqlContext: SQLContext => + val test = { spark: SparkSession => val df = - sqlContext + spark .range(0, 1000, 1, numInputPartitions) .selectExpr("id % 20 as key", "id as value") val agg = df.groupBy("key").count @@ -294,7 +295,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Check the answer first. checkAnswer( agg, - sqlContext.range(0, 20).selectExpr("id", "50 as cnt").collect()) + spark.range(0, 20).selectExpr("id", "50 as cnt").collect()) // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. @@ -325,13 +326,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } test(s"determining the number of reducers: join operator$testNameNote") { - val test = { sqlContext: SQLContext => + val test = { spark: SparkSession => val df1 = - sqlContext + spark .range(0, 1000, 1, numInputPartitions) .selectExpr("id % 500 as key1", "id as value1") val df2 = - sqlContext + spark .range(0, 1000, 1, numInputPartitions) .selectExpr("id % 500 as key2", "id as value2") @@ -339,10 +340,10 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Check the answer first. val expectedAnswer = - sqlContext + spark .range(0, 1000) .selectExpr("id % 500 as key", "id as value") - .union(sqlContext.range(0, 1000).selectExpr("id % 500 as key", "id as value")) + .union(spark.range(0, 1000).selectExpr("id % 500 as key", "id as value")) checkAnswer( join, expectedAnswer.collect()) @@ -376,16 +377,16 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } test(s"determining the number of reducers: complex query 1$testNameNote") { - val test = { sqlContext: SQLContext => + val test = { spark: SparkSession => val df1 = - sqlContext + spark .range(0, 1000, 1, numInputPartitions) .selectExpr("id % 500 as key1", "id as value1") .groupBy("key1") .count .toDF("key1", "cnt1") val df2 = - sqlContext + spark .range(0, 1000, 1, numInputPartitions) .selectExpr("id % 500 as key2", "id as value2") .groupBy("key2") @@ -396,7 +397,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Check the answer first. val expectedAnswer = - sqlContext + spark .range(0, 500) .selectExpr("id", "2 as cnt") checkAnswer( @@ -428,16 +429,16 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } test(s"determining the number of reducers: complex query 2$testNameNote") { - val test = { sqlContext: SQLContext => + val test = { spark: SparkSession => val df1 = - sqlContext + spark .range(0, 1000, 1, numInputPartitions) .selectExpr("id % 500 as key1", "id as value1") .groupBy("key1") .count .toDF("key1", "cnt1") val df2 = - sqlContext + spark .range(0, 1000, 1, numInputPartitions) .selectExpr("id % 500 as key2", "id as value2") @@ -448,7 +449,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Check the answer first. val expectedAnswer = - sqlContext + spark .range(0, 1000) .selectExpr("id % 500 as key", "2 as cnt", "id as value") checkAnswer( http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index ba16810..36cde32 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -50,7 +50,7 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { } test("BroadcastExchange same result") { - val df = sqlContext.range(10) + val df = spark.range(10) val plan = df.queryExecution.executedPlan val output = plan.output assert(plan sameResult plan) @@ -75,7 +75,7 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { } test("ShuffleExchange same result") { - val df = sqlContext.range(10) + val df = spark.range(10) val plan = df.queryExecution.executedPlan val output = plan.output assert(plan sameResult plan) http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 3b2911d..d2e1ea1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -38,7 +38,7 @@ class PlannerSuite extends SharedSQLContext { setupTestData() private def testPartialAggregationPlan(query: LogicalPlan): Unit = { - val planner = sqlContext.sessionState.planner + val planner = spark.sessionState.planner import planner._ val plannedOption = Aggregation(query).headOption val planned = @@ -78,7 +78,7 @@ class PlannerSuite extends SharedSQLContext { val schema = StructType(fields) val row = Row.fromSeq(Seq.fill(fields.size)(null)) val rowRDD = sparkContext.parallelize(row :: Nil) - sqlContext.createDataFrame(rowRDD, schema).registerTempTable("testLimit") + spark.createDataFrame(rowRDD, schema).registerTempTable("testLimit") val planned = sql( """ @@ -136,7 +136,7 @@ class PlannerSuite extends SharedSQLContext { sql("CACHE TABLE tiny") val a = testData.as("a") - val b = sqlContext.table("tiny").as("b") + val b = spark.table("tiny").as("b") val planned = a.join(b, $"a.key" === $"b.key").queryExecution.sparkPlan val broadcastHashJoins = planned.collect { case join: BroadcastHashJoinExec => join } @@ -145,7 +145,7 @@ class PlannerSuite extends SharedSQLContext { assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") assert(sortMergeJoins.isEmpty, "Should not use shuffled hash join") - sqlContext.clearCache() + spark.catalog.clearCache() } } } @@ -154,8 +154,8 @@ class PlannerSuite extends SharedSQLContext { withTempPath { file => val path = file.getCanonicalPath testData.write.parquet(path) - val df = sqlContext.read.parquet(path) - sqlContext.registerDataFrameAsTable(df, "testPushed") + val df = spark.read.parquet(path) + spark.wrapped.registerDataFrameAsTable(df, "testPushed") withTempTable("testPushed") { val exp = sql("select * from testPushed where key = 15").queryExecution.sparkPlan @@ -295,7 +295,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(Seq.empty, Seq.empty) ) - val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.isEmpty) { fail(s"Exchange should have been added:\n$outputPlan") @@ -315,7 +315,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(Seq.empty, Seq.empty) ) - val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) } @@ -333,7 +333,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(Seq.empty, Seq.empty) ) - val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.isEmpty) { fail(s"Exchange should have been added:\n$outputPlan") @@ -353,7 +353,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(Seq.empty, Seq.empty) ) - val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.nonEmpty) { fail(s"Exchange should not have been added:\n$outputPlan") @@ -376,7 +376,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(outputOrdering, outputOrdering) ) - val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.nonEmpty) { fail(s"No Exchanges should have been added:\n$outputPlan") @@ -392,7 +392,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildOrdering = Seq(Seq(orderingB)), requiredChildDistribution = Seq(UnspecifiedDistribution) ) - val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case s: SortExec => true }.isEmpty) { fail(s"Sort should have been added:\n$outputPlan") @@ -408,7 +408,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildOrdering = Seq(Seq(orderingA)), requiredChildDistribution = Seq(UnspecifiedDistribution) ) - val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case s: SortExec => true }.nonEmpty) { fail(s"No sorts should have been added:\n$outputPlan") @@ -425,7 +425,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildOrdering = Seq(Seq(orderingA, orderingB)), requiredChildDistribution = Seq(UnspecifiedDistribution) ) - val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case s: SortExec => true }.isEmpty) { fail(s"Sort should have been added:\n$outputPlan") @@ -444,7 +444,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildOrdering = Seq(Seq.empty)), None) - val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.size == 2) { fail(s"Topmost Exchange should have been eliminated:\n$outputPlan") @@ -464,7 +464,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildOrdering = Seq(Seq.empty)), None) - val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.size == 1) { fail(s"Topmost Exchange should not have been eliminated:\n$outputPlan") @@ -493,7 +493,7 @@ class PlannerSuite extends SharedSQLContext { shuffle, shuffle) - val outputPlan = ReuseExchange(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = ReuseExchange(spark.sessionState.conf).apply(inputPlan) if (outputPlan.collect { case e: ReusedExchangeExec => true }.size != 1) { fail(s"Should re-use the shuffle:\n$outputPlan") } @@ -510,7 +510,7 @@ class PlannerSuite extends SharedSQLContext { ShuffleExchange(finalPartitioning, inputPlan), ShuffleExchange(finalPartitioning, inputPlan)) - val outputPlan2 = ReuseExchange(sqlContext.sessionState.conf).apply(inputPlan2) + val outputPlan2 = ReuseExchange(spark.sessionState.conf).apply(inputPlan2) if (outputPlan2.collect { case e: ReusedExchangeExec => true }.size != 2) { fail(s"Should re-use the two shuffles:\n$outputPlan2") } http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala index c9f517c..ad41111 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import java.util.Properties import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession class SQLExecutionSuite extends SparkFunSuite { @@ -50,16 +50,19 @@ class SQLExecutionSuite extends SparkFunSuite { } test("concurrent query execution with fork-join pool (SPARK-13747)") { - val sc = new SparkContext("local[*]", "test") - val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ + val spark = SparkSession.builder + .master("local[*]") + .appName("test") + .getOrCreate() + + import spark.implicits._ try { // Should not throw IllegalArgumentException (1 to 100).par.foreach { _ => - sc.parallelize(1 to 5).map { i => (i, i) }.toDF("a", "b").count() + spark.sparkContext.parallelize(1 to 5).map { i => (i, i) }.toDF("a", "b").count() } } finally { - sc.stop() + spark.sparkContext.stop() } } @@ -67,8 +70,8 @@ class SQLExecutionSuite extends SparkFunSuite { * Trigger SPARK-10548 by mocking a parent and its child thread executing queries concurrently. */ private def testConcurrentQueryExecution(sc: SparkContext): Unit = { - val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ + val spark = SparkSession.builder.getOrCreate() + import spark.implicits._ // Initialize local properties. This is necessary for the test to pass. sc.getLocalProperties http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index 073e0b3..d7eae21 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -21,7 +21,7 @@ import scala.language.implicitConversions import scala.util.control.NonFatal import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Row, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.test.SQLTestUtils @@ -30,7 +30,7 @@ import org.apache.spark.sql.test.SQLTestUtils * class's test helper methods can be used, see [[SortSuite]]. */ private[sql] abstract class SparkPlanTest extends SparkFunSuite { - protected def sqlContext: SQLContext + protected def spark: SparkSession /** * Runs the plan and makes sure the answer matches the expected result. @@ -90,9 +90,10 @@ private[sql] abstract class SparkPlanTest extends SparkFunSuite { planFunction: Seq[SparkPlan] => SparkPlan, expectedAnswer: Seq[Row], sortAnswers: Boolean = true): Unit = { - SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, sqlContext) match { - case Some(errorMessage) => fail(errorMessage) - case None => + SparkPlanTest + .checkAnswer(input, planFunction, expectedAnswer, sortAnswers, spark.wrapped) match { + case Some(errorMessage) => fail(errorMessage) + case None => } } @@ -114,7 +115,7 @@ private[sql] abstract class SparkPlanTest extends SparkFunSuite { expectedPlanFunction: SparkPlan => SparkPlan, sortAnswers: Boolean = true): Unit = { SparkPlanTest.checkAnswer( - input, planFunction, expectedPlanFunction, sortAnswers, sqlContext) match { + input, planFunction, expectedPlanFunction, sortAnswers, spark.wrapped) match { case Some(errorMessage) => fail(errorMessage) case None => } @@ -141,13 +142,13 @@ object SparkPlanTest { planFunction: SparkPlan => SparkPlan, expectedPlanFunction: SparkPlan => SparkPlan, sortAnswers: Boolean, - sqlContext: SQLContext): Option[String] = { + spark: SQLContext): Option[String] = { val outputPlan = planFunction(input.queryExecution.sparkPlan) val expectedOutputPlan = expectedPlanFunction(input.queryExecution.sparkPlan) val expectedAnswer: Seq[Row] = try { - executePlan(expectedOutputPlan, sqlContext) + executePlan(expectedOutputPlan, spark) } catch { case NonFatal(e) => val errorMessage = @@ -162,7 +163,7 @@ object SparkPlanTest { } val actualAnswer: Seq[Row] = try { - executePlan(outputPlan, sqlContext) + executePlan(outputPlan, spark) } catch { case NonFatal(e) => val errorMessage = @@ -202,12 +203,12 @@ object SparkPlanTest { planFunction: Seq[SparkPlan] => SparkPlan, expectedAnswer: Seq[Row], sortAnswers: Boolean, - sqlContext: SQLContext): Option[String] = { + spark: SQLContext): Option[String] = { val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan)) val sparkAnswer: Seq[Row] = try { - executePlan(outputPlan, sqlContext) + executePlan(outputPlan, spark) } catch { case NonFatal(e) => val errorMessage = @@ -230,8 +231,8 @@ object SparkPlanTest { } } - private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = { - val execution = new QueryExecution(sqlContext.sparkSession, null) { + private def executePlan(outputPlan: SparkPlan, spark: SQLContext): Seq[Row] = { + val execution = new QueryExecution(spark.sparkSession, null) { override lazy val sparkPlan: SparkPlan = outputPlan transform { case plan: SparkPlan => val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 233104a..ada60f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -28,14 +28,14 @@ import org.apache.spark.sql.types.{IntegerType, StringType, StructType} class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { test("range/filter should be combined") { - val df = sqlContext.range(10).filter("id = 1").selectExpr("id + 1") + val df = spark.range(10).filter("id = 1").selectExpr("id + 1") val plan = df.queryExecution.executedPlan assert(plan.find(_.isInstanceOf[WholeStageCodegenExec]).isDefined) assert(df.collect() === Array(Row(2))) } test("Aggregate should be included in WholeStageCodegen") { - val df = sqlContext.range(10).groupBy().agg(max(col("id")), avg(col("id"))) + val df = spark.range(10).groupBy().agg(max(col("id")), avg(col("id"))) val plan = df.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegenExec] && @@ -44,7 +44,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { } test("Aggregate with grouping keys should be included in WholeStageCodegen") { - val df = sqlContext.range(3).groupBy("id").count().orderBy("id") + val df = spark.range(3).groupBy("id").count().orderBy("id") val plan = df.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegenExec] && @@ -53,10 +53,10 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { } test("BroadcastHashJoin should be included in WholeStageCodegen") { - val rdd = sqlContext.sparkContext.makeRDD(Seq(Row(1, "1"), Row(1, "1"), Row(2, "2"))) + val rdd = spark.sparkContext.makeRDD(Seq(Row(1, "1"), Row(1, "1"), Row(2, "2"))) val schema = new StructType().add("k", IntegerType).add("v", StringType) - val smallDF = sqlContext.createDataFrame(rdd, schema) - val df = sqlContext.range(10).join(broadcast(smallDF), col("k") === col("id")) + val smallDF = spark.createDataFrame(rdd, schema) + val df = spark.range(10).join(broadcast(smallDF), col("k") === col("id")) assert(df.queryExecution.executedPlan.find(p => p.isInstanceOf[WholeStageCodegenExec] && p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[BroadcastHashJoinExec]).isDefined) @@ -64,7 +64,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { } test("Sort should be included in WholeStageCodegen") { - val df = sqlContext.range(3, 0, -1).toDF().sort(col("id")) + val df = spark.range(3, 0, -1).toDF().sort(col("id")) val plan = df.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegenExec] && @@ -75,7 +75,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { test("MapElements should be included in WholeStageCodegen") { import testImplicits._ - val ds = sqlContext.range(10).map(_.toString) + val ds = spark.range(10).map(_.toString) val plan = ds.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegenExec] && @@ -84,7 +84,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { } test("typed filter should be included in WholeStageCodegen") { - val ds = sqlContext.range(10).filter(_ % 2 == 0) + val ds = spark.range(10).filter(_ % 2 == 0) val plan = ds.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegenExec] && @@ -93,7 +93,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { } test("back-to-back typed filter should be included in WholeStageCodegen") { - val ds = sqlContext.range(10).filter(_ % 2 == 0).filter(_ % 3 == 0) + val ds = spark.range(10).filter(_ % 2 == 0).filter(_ % 3 == 0) val plan = ds.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegenExec] && http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 50c8745..88269a6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -21,6 +21,7 @@ import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.types._ @@ -32,7 +33,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { setupTestData() test("simple columnar query") { - val plan = sqlContext.executePlan(testData.logicalPlan).sparkPlan + val plan = spark.executePlan(testData.logicalPlan).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().toSeq) @@ -42,14 +43,14 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { // TODO: Improve this test when we have better statistics sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)) .toDF().registerTempTable("sizeTst") - sqlContext.cacheTable("sizeTst") + spark.catalog.cacheTable("sizeTst") assert( - sqlContext.table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes > - sqlContext.conf.autoBroadcastJoinThreshold) + spark.table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes > + spark.conf.get(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) } test("projection") { - val plan = sqlContext.executePlan(testData.select('value, 'key).logicalPlan).sparkPlan + val plan = spark.executePlan(testData.select('value, 'key).logicalPlan).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().map { @@ -58,7 +59,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { - val plan = sqlContext.executePlan(testData.logicalPlan).sparkPlan + val plan = spark.executePlan(testData.logicalPlan).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().toSeq) @@ -70,7 +71,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { sql("SELECT * FROM repeatedData"), repeatedData.collect().toSeq.map(Row.fromTuple)) - sqlContext.cacheTable("repeatedData") + spark.catalog.cacheTable("repeatedData") checkAnswer( sql("SELECT * FROM repeatedData"), @@ -82,7 +83,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { sql("SELECT * FROM nullableRepeatedData"), nullableRepeatedData.collect().toSeq.map(Row.fromTuple)) - sqlContext.cacheTable("nullableRepeatedData") + spark.catalog.cacheTable("nullableRepeatedData") checkAnswer( sql("SELECT * FROM nullableRepeatedData"), @@ -97,7 +98,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { sql("SELECT time FROM timestamps"), timestamps.collect().toSeq) - sqlContext.cacheTable("timestamps") + spark.catalog.cacheTable("timestamps") checkAnswer( sql("SELECT time FROM timestamps"), @@ -109,7 +110,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { sql("SELECT * FROM withEmptyParts"), withEmptyParts.collect().toSeq.map(Row.fromTuple)) - sqlContext.cacheTable("withEmptyParts") + spark.catalog.cacheTable("withEmptyParts") checkAnswer( sql("SELECT * FROM withEmptyParts"), @@ -178,35 +179,35 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { (i to i + 10).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap, Row((i - 0.25).toFloat, Seq(true, false, null))) } - sqlContext.createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types") + spark.createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types") // Cache the table. sql("cache table InMemoryCache_different_data_types") // Make sure the table is indeed cached. - sqlContext.table("InMemoryCache_different_data_types").queryExecution.executedPlan + spark.table("InMemoryCache_different_data_types").queryExecution.executedPlan assert( - sqlContext.isCached("InMemoryCache_different_data_types"), + spark.catalog.isCached("InMemoryCache_different_data_types"), "InMemoryCache_different_data_types should be cached.") // Issue a query and check the results. checkAnswer( sql(s"SELECT DISTINCT ${allColumns} FROM InMemoryCache_different_data_types"), - sqlContext.table("InMemoryCache_different_data_types").collect()) - sqlContext.dropTempTable("InMemoryCache_different_data_types") + spark.table("InMemoryCache_different_data_types").collect()) + spark.catalog.dropTempTable("InMemoryCache_different_data_types") } test("SPARK-10422: String column in InMemoryColumnarCache needs to override clone method") { - val df = sqlContext.range(1, 100).selectExpr("id % 10 as id") + val df = spark.range(1, 100).selectExpr("id % 10 as id") .rdd.map(id => Tuple1(s"str_$id")).toDF("i") val cached = df.cache() // count triggers the caching action. It should not throw. cached.count() // Make sure, the DataFrame is indeed cached. - assert(sqlContext.cacheManager.lookupCachedData(cached).nonEmpty) + assert(spark.cacheManager.lookupCachedData(cached).nonEmpty) // Check result. checkAnswer( cached, - sqlContext.range(1, 100).selectExpr("id % 10 as id") + spark.range(1, 100).selectExpr("id % 10 as id") .rdd.map(id => Tuple1(s"str_$id")).toDF("i") ) @@ -215,7 +216,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-10859: Predicates pushed to InMemoryColumnarTableScan are not evaluated correctly") { - val data = sqlContext.range(10).selectExpr("id", "cast(id as string) as s") + val data = spark.range(10).selectExpr("id", "cast(id as string) as s") data.cache() assert(data.count() === 10) assert(data.filter($"s" === "3").count() === 1) http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala index 9164074..48c7989 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala @@ -32,23 +32,24 @@ class PartitionBatchPruningSuite import testImplicits._ - private lazy val originalColumnBatchSize = sqlContext.conf.columnBatchSize - private lazy val originalInMemoryPartitionPruning = sqlContext.conf.inMemoryPartitionPruning + private lazy val originalColumnBatchSize = spark.conf.get(SQLConf.COLUMN_BATCH_SIZE) + private lazy val originalInMemoryPartitionPruning = + spark.conf.get(SQLConf.IN_MEMORY_PARTITION_PRUNING) override protected def beforeAll(): Unit = { super.beforeAll() // Make a table with 5 partitions, 2 batches per partition, 10 elements per batch - sqlContext.setConf(SQLConf.COLUMN_BATCH_SIZE, 10) + spark.conf.set(SQLConf.COLUMN_BATCH_SIZE.key, 10) // Enable in-memory partition pruning - sqlContext.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) + spark.conf.set(SQLConf.IN_MEMORY_PARTITION_PRUNING.key, true) // Enable in-memory table scan accumulators - sqlContext.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true") + spark.conf.set("spark.sql.inMemoryTableScanStatistics.enable", "true") } override protected def afterAll(): Unit = { try { - sqlContext.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) - sqlContext.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) + spark.conf.set(SQLConf.COLUMN_BATCH_SIZE.key, originalColumnBatchSize) + spark.conf.set(SQLConf.IN_MEMORY_PARTITION_PRUNING.key, originalInMemoryPartitionPruning) } finally { super.afterAll() } @@ -63,12 +64,12 @@ class PartitionBatchPruningSuite TestData(key, string) }, 5).toDF() pruningData.registerTempTable("pruningData") - sqlContext.cacheTable("pruningData") + spark.catalog.cacheTable("pruningData") } override protected def afterEach(): Unit = { try { - sqlContext.uncacheTable("pruningData") + spark.catalog.uncacheTable("pruningData") } finally { super.afterEach() } http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 3586ddf..5fbab23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -37,7 +37,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { override def afterEach(): Unit = { try { // drop all databases, tables and functions after each test - sqlContext.sessionState.catalog.reset() + spark.sessionState.catalog.reset() } finally { super.afterEach() } @@ -66,7 +66,8 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { private def createDatabase(catalog: SessionCatalog, name: String): Unit = { catalog.createDatabase( - CatalogDatabase(name, "", sqlContext.conf.warehousePath, Map()), ignoreIfExists = false) + CatalogDatabase(name, "", spark.sessionState.conf.warehousePath, Map()), + ignoreIfExists = false) } private def generateTable(catalog: SessionCatalog, name: TableIdentifier): CatalogTable = { @@ -111,7 +112,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("the qualified path of a database is stored in the catalog") { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog withTempDir { tmpDir => val path = tmpDir.toString @@ -274,7 +275,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { databaseNames.foreach { dbName => val dbNameWithoutBackTicks = cleanIdentifier(dbName) - assert(!sqlContext.sessionState.catalog.databaseExists(dbNameWithoutBackTicks)) + assert(!spark.sessionState.catalog.databaseExists(dbNameWithoutBackTicks)) var message = intercept[AnalysisException] { sql(s"DROP DATABASE $dbName") @@ -334,7 +335,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("create table in default db") { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent1 = TableIdentifier("tab1", None) createTable(catalog, tableIdent1) val expectedTableIdent = tableIdent1.copy(database = Some("default")) @@ -343,7 +344,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("create table in a specific db") { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog createDatabase(catalog, "dbx") val tableIdent1 = TableIdentifier("tab1", Some("dbx")) createTable(catalog, tableIdent1) @@ -352,7 +353,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("alter table: rename") { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent1 = TableIdentifier("tab1", Some("dbx")) val tableIdent2 = TableIdentifier("tab2", Some("dbx")) createDatabase(catalog, "dbx") @@ -444,7 +445,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("alter table: set properties") { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") createTable(catalog, tableIdent) @@ -471,7 +472,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("alter table: unset properties") { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") createTable(catalog, tableIdent) @@ -512,7 +513,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("alter table: bucketing is not supported") { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") createTable(catalog, tableIdent) @@ -523,7 +524,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("alter table: skew is not supported") { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") createTable(catalog, tableIdent) @@ -560,7 +561,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("alter table: rename partition") { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val part1 = Map("a" -> "1") val part2 = Map("b" -> "2") @@ -661,7 +662,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("drop table - temporary table") { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog sql( """ |CREATE TEMPORARY TABLE tab1 @@ -686,7 +687,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } private def testDropTable(isDatasourceTable: Boolean): Unit = { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") createTable(catalog, tableIdent) @@ -705,7 +706,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { // SQLContext does not support create view. Log an error message, if tab1 does not exists sql("DROP VIEW tab1") - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") createTable(catalog, tableIdent) @@ -726,7 +727,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } private def testSetLocation(isDatasourceTable: Boolean): Unit = { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val partSpec = Map("a" -> "1") createDatabase(catalog, "dbx") @@ -784,7 +785,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } private def testSetSerde(isDatasourceTable: Boolean): Unit = { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") createTable(catalog, tableIdent) @@ -830,7 +831,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } private def testAddPartitions(isDatasourceTable: Boolean): Unit = { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val part1 = Map("a" -> "1") val part2 = Map("b" -> "2") @@ -880,7 +881,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } private def testDropPartitions(isDatasourceTable: Boolean): Unit = { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val part1 = Map("a" -> "1") val part2 = Map("b" -> "2") http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index ac2af77..52dda8c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -281,7 +281,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi )) val fakeRDD = new FileScanRDD( - sqlContext.sparkSession, + spark, (file: PartitionedFile) => Iterator.empty, Seq(partition) ) @@ -399,7 +399,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi util.stringToFile(file, "*" * size) } - val df = sqlContext.read + val df = spark.read .format(classOf[TestFileFormat].getName) .load(tempDir.getCanonicalPath) @@ -409,7 +409,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi l.copy(relation = r.copy(bucketSpec = Some(BucketSpec(numBuckets = buckets, "c1" :: Nil, Nil)))) } - Dataset.ofRows(sqlContext.sparkSession, bucketed) + Dataset.ofRows(spark, bucketed) } else { df } http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala index 297731c..89d5765 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala @@ -27,7 +27,7 @@ class HadoopFsRelationSuite extends QueryTest with SharedSQLContext { test("sizeInBytes should be the total size of all files") { withTempDir{ dir => dir.delete() - sqlContext.range(1000).write.parquet(dir.toString) + spark.range(1000).write.parquet(dir.toString) // ignore hidden files val allFiles = dir.listFiles(new FilenameFilter { override def accept(dir: File, name: String): Boolean = { @@ -35,7 +35,7 @@ class HadoopFsRelationSuite extends QueryTest with SharedSQLContext { } }) val totalSize = allFiles.map(_.length()).sum - val df = sqlContext.read.parquet(dir.toString) + val df = spark.read.parquet(dir.toString) assert(df.queryExecution.logical.statistics.sizeInBytes === BigInt(totalSize)) } } http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 28e5905..b6cdc8c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -91,7 +91,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("simple csv test") { - val cars = sqlContext + val cars = spark .read .format("csv") .option("header", "false") @@ -101,7 +101,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("simple csv test with calling another function to load") { - val cars = sqlContext + val cars = spark .read .option("header", "false") .csv(testFile(carsFile)) @@ -110,7 +110,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("simple csv test with type inference") { - val cars = sqlContext + val cars = spark .read .format("csv") .option("header", "true") @@ -121,7 +121,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("test inferring booleans") { - val result = sqlContext.read + val result = spark.read .format("csv") .option("header", "true") .option("inferSchema", "true") @@ -133,7 +133,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("test with alternative delimiter and quote") { - val cars = sqlContext.read + val cars = spark.read .format("csv") .options(Map("quote" -> "\'", "delimiter" -> "|", "header" -> "true")) .load(testFile(carsAltFile)) @@ -142,7 +142,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("parse unescaped quotes with maxCharsPerColumn") { - val rows = sqlContext.read + val rows = spark.read .format("csv") .option("maxCharsPerColumn", "4") .load(testFile(unescapedQuotesFile)) @@ -154,7 +154,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("bad encoding name") { val exception = intercept[UnsupportedCharsetException] { - sqlContext + spark .read .format("csv") .option("charset", "1-9588-osi") @@ -166,7 +166,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("test different encoding") { // scalastyle:off - sqlContext.sql( + spark.sql( s""" |CREATE TEMPORARY TABLE carsTable USING csv |OPTIONS (path "${testFile(carsFile8859)}", header "true", @@ -174,12 +174,12 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { """.stripMargin.replaceAll("\n", " ")) // scalastyle:on - verifyCars(sqlContext.table("carsTable"), withHeader = true) + verifyCars(spark.table("carsTable"), withHeader = true) } test("test aliases sep and encoding for delimiter and charset") { // scalastyle:off - val cars = sqlContext + val cars = spark .read .format("csv") .option("header", "true") @@ -192,17 +192,17 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("DDL test with tab separated file") { - sqlContext.sql( + spark.sql( s""" |CREATE TEMPORARY TABLE carsTable USING csv |OPTIONS (path "${testFile(carsTsvFile)}", header "true", delimiter "\t") """.stripMargin.replaceAll("\n", " ")) - verifyCars(sqlContext.table("carsTable"), numFields = 6, withHeader = true, checkHeader = false) + verifyCars(spark.table("carsTable"), numFields = 6, withHeader = true, checkHeader = false) } test("DDL test parsing decimal type") { - sqlContext.sql( + spark.sql( s""" |CREATE TEMPORARY TABLE carsTable |(yearMade double, makeName string, modelName string, priceTag decimal, @@ -212,11 +212,11 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { """.stripMargin.replaceAll("\n", " ")) assert( - sqlContext.sql("SELECT makeName FROM carsTable where priceTag > 60000").collect().size === 1) + spark.sql("SELECT makeName FROM carsTable where priceTag > 60000").collect().size === 1) } test("test for DROPMALFORMED parsing mode") { - val cars = sqlContext.read + val cars = spark.read .format("csv") .options(Map("header" -> "true", "mode" -> "dropmalformed")) .load(testFile(carsFile)) @@ -226,7 +226,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("test for FAILFAST parsing mode") { val exception = intercept[SparkException]{ - sqlContext.read + spark.read .format("csv") .options(Map("header" -> "true", "mode" -> "failfast")) .load(testFile(carsFile)).collect() @@ -236,7 +236,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("test for tokens more than the fields in the schema") { - val cars = sqlContext + val cars = spark .read .format("csv") .option("header", "false") @@ -247,7 +247,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("test with null quote character") { - val cars = sqlContext.read + val cars = spark.read .format("csv") .option("header", "true") .option("quote", "") @@ -258,7 +258,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("test with empty file and known schema") { - val result = sqlContext.read + val result = spark.read .format("csv") .schema(StructType(List(StructField("column", StringType, false)))) .load(testFile(emptyFile)) @@ -268,25 +268,25 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("DDL test with empty file") { - sqlContext.sql(s""" + spark.sql(s""" |CREATE TEMPORARY TABLE carsTable |(yearMade double, makeName string, modelName string, comments string, grp string) |USING csv |OPTIONS (path "${testFile(emptyFile)}", header "false") """.stripMargin.replaceAll("\n", " ")) - assert(sqlContext.sql("SELECT count(*) FROM carsTable").collect().head(0) === 0) + assert(spark.sql("SELECT count(*) FROM carsTable").collect().head(0) === 0) } test("DDL test with schema") { - sqlContext.sql(s""" + spark.sql(s""" |CREATE TEMPORARY TABLE carsTable |(yearMade double, makeName string, modelName string, comments string, blank string) |USING csv |OPTIONS (path "${testFile(carsFile)}", header "true") """.stripMargin.replaceAll("\n", " ")) - val cars = sqlContext.table("carsTable") + val cars = spark.table("carsTable") verifyCars(cars, withHeader = true, checkHeader = false, checkValues = false) assert( cars.schema.fieldNames === Array("yearMade", "makeName", "modelName", "comments", "blank")) @@ -295,7 +295,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("save csv") { withTempDir { dir => val csvDir = new File(dir, "csv").getCanonicalPath - val cars = sqlContext.read + val cars = spark.read .format("csv") .option("header", "true") .load(testFile(carsFile)) @@ -304,7 +304,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .option("header", "true") .csv(csvDir) - val carsCopy = sqlContext.read + val carsCopy = spark.read .format("csv") .option("header", "true") .load(csvDir) @@ -316,7 +316,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("save csv with quote") { withTempDir { dir => val csvDir = new File(dir, "csv").getCanonicalPath - val cars = sqlContext.read + val cars = spark.read .format("csv") .option("header", "true") .load(testFile(carsFile)) @@ -327,7 +327,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .option("quote", "\"") .save(csvDir) - val carsCopy = sqlContext.read + val carsCopy = spark.read .format("csv") .option("header", "true") .option("quote", "\"") @@ -338,7 +338,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("commented lines in CSV data") { - val results = sqlContext.read + val results = spark.read .format("csv") .options(Map("comment" -> "~", "header" -> "false")) .load(testFile(commentsFile)) @@ -353,7 +353,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("inferring schema with commented lines in CSV data") { - val results = sqlContext.read + val results = spark.read .format("csv") .options(Map("comment" -> "~", "header" -> "false", "inferSchema" -> "true")) .load(testFile(commentsFile)) @@ -372,7 +372,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { "header" -> "true", "inferSchema" -> "true", "dateFormat" -> "dd/MM/yyyy hh:mm") - val results = sqlContext.read + val results = spark.read .format("csv") .options(options) .load(testFile(datesFile)) @@ -393,7 +393,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { "header" -> "true", "inferSchema" -> "false", "dateFormat" -> "dd/MM/yyyy hh:mm") - val results = sqlContext.read + val results = spark.read .format("csv") .options(options) .schema(customSchema) @@ -416,7 +416,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("setting comment to null disables comment support") { - val results = sqlContext.read + val results = spark.read .format("csv") .options(Map("comment" -> "", "header" -> "false")) .load(testFile(disableCommentsFile)) @@ -439,7 +439,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { StructField("model", StringType, nullable = false), StructField("comment", StringType, nullable = true), StructField("blank", StringType, nullable = true))) - val cars = sqlContext.read + val cars = spark.read .format("csv") .schema(dataSchema) .options(Map("header" -> "true", "nullValue" -> "null")) @@ -454,7 +454,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("save csv with compression codec option") { withTempDir { dir => val csvDir = new File(dir, "csv").getCanonicalPath - val cars = sqlContext.read + val cars = spark.read .format("csv") .option("header", "true") .load(testFile(carsFile)) @@ -468,7 +468,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val compressedFiles = new File(csvDir).listFiles() assert(compressedFiles.exists(_.getName.endsWith(".csv.gz"))) - val carsCopy = sqlContext.read + val carsCopy = spark.read .format("csv") .option("header", "true") .load(csvDir) @@ -486,7 +486,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { ) withTempDir { dir => val csvDir = new File(dir, "csv").getCanonicalPath - val cars = sqlContext.read + val cars = spark.read .format("csv") .option("header", "true") .options(extraOptions) @@ -502,7 +502,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val compressedFiles = new File(csvDir).listFiles() assert(compressedFiles.exists(!_.getName.endsWith(".csv.gz"))) - val carsCopy = sqlContext.read + val carsCopy = spark.read .format("csv") .option("header", "true") .options(extraOptions) @@ -513,7 +513,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("Schema inference correctly identifies the datatype when data is sparse.") { - val df = sqlContext.read + val df = spark.read .format("csv") .option("header", "true") .option("inferSchema", "true") @@ -525,7 +525,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("old csv data source name works") { - val cars = sqlContext + val cars = spark .read .format("com.databricks.spark.csv") .option("header", "false") @@ -535,7 +535,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("nulls, NaNs and Infinity values can be parsed") { - val numbers = sqlContext + val numbers = spark .read .format("csv") .schema(StructType(List( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org