http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java index 7863177..059c2d9 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java @@ -26,36 +26,30 @@ import scala.Tuple2; import org.junit.After; import org.junit.Before; -import org.apache.spark.SparkContext; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.MapFunction; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Encoder; import org.apache.spark.sql.Encoders; import org.apache.spark.sql.KeyValueGroupedDataset; -import org.apache.spark.sql.test.TestSQLContext; +import org.apache.spark.sql.test.TestSparkSession; /** * Common test base shared across this and Java8DatasetAggregatorSuite. */ public class JavaDatasetAggregatorSuiteBase implements Serializable { - protected transient JavaSparkContext jsc; - protected transient TestSQLContext context; + private transient TestSparkSession spark; @Before public void setUp() { // Trigger static initializer of TestData - SparkContext sc = new SparkContext("local[*]", "testing"); - jsc = new JavaSparkContext(sc); - context = new TestSQLContext(sc); - context.loadTestData(); + spark = new TestSparkSession(); + spark.loadTestData(); } @After public void tearDown() { - context.sparkContext().stop(); - context = null; - jsc = null; + spark.stop(); + spark = null; } protected <T1, T2> Tuple2<T1, T2> tuple2(T1 t1, T2 t2) { @@ -66,7 +60,7 @@ public class JavaDatasetAggregatorSuiteBase implements Serializable { Encoder<Tuple2<String, Integer>> encoder = Encoders.tuple(Encoders.STRING(), Encoders.INT()); List<Tuple2<String, Integer>> data = Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3)); - Dataset<Tuple2<String, Integer>> ds = context.createDataset(data, encoder); + Dataset<Tuple2<String, Integer>> ds = spark.createDataset(data, encoder); return ds.groupByKey( new MapFunction<Tuple2<String, Integer>, String>() {
http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java index 9e65158..d0435e4 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java @@ -19,14 +19,16 @@ package test.org.apache.spark.sql.sources; import java.io.File; import java.io.IOException; -import java.util.*; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.*; @@ -37,8 +39,8 @@ import org.apache.spark.util.Utils; public class JavaSaveLoadSuite { - private transient JavaSparkContext sc; - private transient SQLContext sqlContext; + private transient SparkSession spark; + private transient JavaSparkContext jsc; File path; Dataset<Row> df; @@ -52,9 +54,11 @@ public class JavaSaveLoadSuite { @Before public void setUp() throws IOException { - SparkContext _sc = new SparkContext("local[*]", "testing"); - sqlContext = new SQLContext(_sc); - sc = new JavaSparkContext(_sc); + spark = SparkSession.builder() + .master("local[*]") + .appName("testing") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); path = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource").getCanonicalFile(); @@ -66,16 +70,15 @@ public class JavaSaveLoadSuite { for (int i = 0; i < 10; i++) { jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}"); } - JavaRDD<String> rdd = sc.parallelize(jsonObjects); - df = sqlContext.read().json(rdd); + JavaRDD<String> rdd = jsc.parallelize(jsonObjects); + df = spark.read().json(rdd); df.registerTempTable("jsonTable"); } @After public void tearDown() { - sqlContext.sparkContext().stop(); - sqlContext = null; - sc = null; + spark.stop(); + spark = null; } @Test @@ -83,7 +86,7 @@ public class JavaSaveLoadSuite { Map<String, String> options = new HashMap<>(); options.put("path", path.toString()); df.write().mode(SaveMode.ErrorIfExists).format("json").options(options).save(); - Dataset<Row> loadedDF = sqlContext.read().format("json").options(options).load(); + Dataset<Row> loadedDF = spark.read().format("json").options(options).load(); checkAnswer(loadedDF, df.collectAsList()); } @@ -96,8 +99,8 @@ public class JavaSaveLoadSuite { List<StructField> fields = new ArrayList<>(); fields.add(DataTypes.createStructField("b", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields); - Dataset<Row> loadedDF = sqlContext.read().format("json").schema(schema).options(options).load(); + Dataset<Row> loadedDF = spark.read().format("json").schema(schema).options(options).load(); - checkAnswer(loadedDF, sqlContext.sql("SELECT b FROM jsonTable").collectAsList()); + checkAnswer(loadedDF, spark.sql("SELECT b FROM jsonTable").collectAsList()); } } http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 5ef2026..800316c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -36,7 +36,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext import testImplicits._ def rddIdOf(tableName: String): Int = { - val plan = sqlContext.table(tableName).queryExecution.sparkPlan + val plan = spark.table(tableName).queryExecution.sparkPlan plan.collect { case InMemoryTableScanExec(_, _, relation) => relation.cachedColumnBuffers.id @@ -73,41 +73,41 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext test("cache temp table") { testData.select('key).registerTempTable("tempTable") assertCached(sql("SELECT COUNT(*) FROM tempTable"), 0) - sqlContext.cacheTable("tempTable") + spark.catalog.cacheTable("tempTable") assertCached(sql("SELECT COUNT(*) FROM tempTable")) - sqlContext.uncacheTable("tempTable") + spark.catalog.uncacheTable("tempTable") } test("unpersist an uncached table will not raise exception") { - assert(None == sqlContext.cacheManager.lookupCachedData(testData)) + assert(None == spark.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = true) - assert(None == sqlContext.cacheManager.lookupCachedData(testData)) + assert(None == spark.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = false) - assert(None == sqlContext.cacheManager.lookupCachedData(testData)) + assert(None == spark.cacheManager.lookupCachedData(testData)) testData.persist() - assert(None != sqlContext.cacheManager.lookupCachedData(testData)) + assert(None != spark.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = true) - assert(None == sqlContext.cacheManager.lookupCachedData(testData)) + assert(None == spark.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = false) - assert(None == sqlContext.cacheManager.lookupCachedData(testData)) + assert(None == spark.cacheManager.lookupCachedData(testData)) } test("cache table as select") { sql("CACHE TABLE tempTable AS SELECT key FROM testData") assertCached(sql("SELECT COUNT(*) FROM tempTable")) - sqlContext.uncacheTable("tempTable") + spark.catalog.uncacheTable("tempTable") } test("uncaching temp table") { testData.select('key).registerTempTable("tempTable1") testData.select('key).registerTempTable("tempTable2") - sqlContext.cacheTable("tempTable1") + spark.catalog.cacheTable("tempTable1") assertCached(sql("SELECT COUNT(*) FROM tempTable1")) assertCached(sql("SELECT COUNT(*) FROM tempTable2")) // Is this valid? - sqlContext.uncacheTable("tempTable2") + spark.catalog.uncacheTable("tempTable2") // Should this be cached? assertCached(sql("SELECT COUNT(*) FROM tempTable1"), 0) @@ -117,101 +117,101 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext val data = "*" * 1000 sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF() .registerTempTable("bigData") - sqlContext.table("bigData").persist(StorageLevel.MEMORY_AND_DISK) - assert(sqlContext.table("bigData").count() === 200000L) - sqlContext.table("bigData").unpersist(blocking = true) + spark.table("bigData").persist(StorageLevel.MEMORY_AND_DISK) + assert(spark.table("bigData").count() === 200000L) + spark.table("bigData").unpersist(blocking = true) } test("calling .cache() should use in-memory columnar caching") { - sqlContext.table("testData").cache() - assertCached(sqlContext.table("testData")) - sqlContext.table("testData").unpersist(blocking = true) + spark.table("testData").cache() + assertCached(spark.table("testData")) + spark.table("testData").unpersist(blocking = true) } test("calling .unpersist() should drop in-memory columnar cache") { - sqlContext.table("testData").cache() - sqlContext.table("testData").count() - sqlContext.table("testData").unpersist(blocking = true) - assertCached(sqlContext.table("testData"), 0) + spark.table("testData").cache() + spark.table("testData").count() + spark.table("testData").unpersist(blocking = true) + assertCached(spark.table("testData"), 0) } test("isCached") { - sqlContext.cacheTable("testData") + spark.catalog.cacheTable("testData") - assertCached(sqlContext.table("testData")) - assert(sqlContext.table("testData").queryExecution.withCachedData match { + assertCached(spark.table("testData")) + assert(spark.table("testData").queryExecution.withCachedData match { case _: InMemoryRelation => true case _ => false }) - sqlContext.uncacheTable("testData") - assert(!sqlContext.isCached("testData")) - assert(sqlContext.table("testData").queryExecution.withCachedData match { + spark.catalog.uncacheTable("testData") + assert(!spark.catalog.isCached("testData")) + assert(spark.table("testData").queryExecution.withCachedData match { case _: InMemoryRelation => false case _ => true }) } test("SPARK-1669: cacheTable should be idempotent") { - assume(!sqlContext.table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) + assume(!spark.table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) - sqlContext.cacheTable("testData") - assertCached(sqlContext.table("testData")) + spark.catalog.cacheTable("testData") + assertCached(spark.table("testData")) assertResult(1, "InMemoryRelation not found, testData should have been cached") { - sqlContext.table("testData").queryExecution.withCachedData.collect { + spark.table("testData").queryExecution.withCachedData.collect { case r: InMemoryRelation => r }.size } - sqlContext.cacheTable("testData") + spark.catalog.cacheTable("testData") assertResult(0, "Double InMemoryRelations found, cacheTable() is not idempotent") { - sqlContext.table("testData").queryExecution.withCachedData.collect { + spark.table("testData").queryExecution.withCachedData.collect { case r @ InMemoryRelation(_, _, _, _, _: InMemoryTableScanExec, _) => r }.size } - sqlContext.uncacheTable("testData") + spark.catalog.uncacheTable("testData") } test("read from cached table and uncache") { - sqlContext.cacheTable("testData") - checkAnswer(sqlContext.table("testData"), testData.collect().toSeq) - assertCached(sqlContext.table("testData")) + spark.catalog.cacheTable("testData") + checkAnswer(spark.table("testData"), testData.collect().toSeq) + assertCached(spark.table("testData")) - sqlContext.uncacheTable("testData") - checkAnswer(sqlContext.table("testData"), testData.collect().toSeq) - assertCached(sqlContext.table("testData"), 0) + spark.catalog.uncacheTable("testData") + checkAnswer(spark.table("testData"), testData.collect().toSeq) + assertCached(spark.table("testData"), 0) } test("correct error on uncache of non-cached table") { intercept[IllegalArgumentException] { - sqlContext.uncacheTable("testData") + spark.catalog.uncacheTable("testData") } } test("SELECT star from cached table") { sql("SELECT * FROM testData").registerTempTable("selectStar") - sqlContext.cacheTable("selectStar") + spark.catalog.cacheTable("selectStar") checkAnswer( sql("SELECT * FROM selectStar WHERE key = 1"), Seq(Row(1, "1"))) - sqlContext.uncacheTable("selectStar") + spark.catalog.uncacheTable("selectStar") } test("Self-join cached") { val unCachedAnswer = sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key").collect() - sqlContext.cacheTable("testData") + spark.catalog.cacheTable("testData") checkAnswer( sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key"), unCachedAnswer.toSeq) - sqlContext.uncacheTable("testData") + spark.catalog.uncacheTable("testData") } test("'CACHE TABLE' and 'UNCACHE TABLE' SQL statement") { sql("CACHE TABLE testData") - assertCached(sqlContext.table("testData")) + assertCached(spark.table("testData")) val rddId = rddIdOf("testData") assert( @@ -219,7 +219,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext "Eagerly cached in-memory table should have already been materialized") sql("UNCACHE TABLE testData") - assert(!sqlContext.isCached("testData"), "Table 'testData' should not be cached") + assert(!spark.catalog.isCached("testData"), "Table 'testData' should not be cached") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") @@ -228,14 +228,14 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext test("CACHE TABLE tableName AS SELECT * FROM anotherTable") { sql("CACHE TABLE testCacheTable AS SELECT * FROM testData") - assertCached(sqlContext.table("testCacheTable")) + assertCached(spark.table("testCacheTable")) val rddId = rddIdOf("testCacheTable") assert( isMaterialized(rddId), "Eagerly cached in-memory table should have already been materialized") - sqlContext.uncacheTable("testCacheTable") + spark.catalog.uncacheTable("testCacheTable") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } @@ -243,14 +243,14 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext test("CACHE TABLE tableName AS SELECT ...") { sql("CACHE TABLE testCacheTable AS SELECT key FROM testData LIMIT 10") - assertCached(sqlContext.table("testCacheTable")) + assertCached(spark.table("testCacheTable")) val rddId = rddIdOf("testCacheTable") assert( isMaterialized(rddId), "Eagerly cached in-memory table should have already been materialized") - sqlContext.uncacheTable("testCacheTable") + spark.catalog.uncacheTable("testCacheTable") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } @@ -258,7 +258,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext test("CACHE LAZY TABLE tableName") { sql("CACHE LAZY TABLE testData") - assertCached(sqlContext.table("testData")) + assertCached(spark.table("testData")) val rddId = rddIdOf("testData") assert( @@ -270,7 +270,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext isMaterialized(rddId), "Lazily cached in-memory table should have been materialized") - sqlContext.uncacheTable("testData") + spark.catalog.uncacheTable("testData") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } @@ -278,7 +278,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext test("InMemoryRelation statistics") { sql("CACHE TABLE testData") - sqlContext.table("testData").queryExecution.withCachedData.collect { + spark.table("testData").queryExecution.withCachedData.collect { case cached: InMemoryRelation => val actualSizeInBytes = (1 to 100).map(i => 4 + i.toString.length + 4).sum assert(cached.statistics.sizeInBytes === actualSizeInBytes) @@ -287,62 +287,62 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext test("Drops temporary table") { testData.select('key).registerTempTable("t1") - sqlContext.table("t1") - sqlContext.dropTempTable("t1") - intercept[AnalysisException](sqlContext.table("t1")) + spark.table("t1") + spark.catalog.dropTempTable("t1") + intercept[AnalysisException](spark.table("t1")) } test("Drops cached temporary table") { testData.select('key).registerTempTable("t1") testData.select('key).registerTempTable("t2") - sqlContext.cacheTable("t1") + spark.catalog.cacheTable("t1") - assert(sqlContext.isCached("t1")) - assert(sqlContext.isCached("t2")) + assert(spark.catalog.isCached("t1")) + assert(spark.catalog.isCached("t2")) - sqlContext.dropTempTable("t1") - intercept[AnalysisException](sqlContext.table("t1")) - assert(!sqlContext.isCached("t2")) + spark.catalog.dropTempTable("t1") + intercept[AnalysisException](spark.table("t1")) + assert(!spark.catalog.isCached("t2")) } test("Clear all cache") { sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") - sqlContext.clearCache() - assert(sqlContext.cacheManager.isEmpty) + spark.catalog.cacheTable("t1") + spark.catalog.cacheTable("t2") + spark.catalog.clearCache() + assert(spark.cacheManager.isEmpty) sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") + spark.catalog.cacheTable("t1") + spark.catalog.cacheTable("t2") sql("Clear CACHE") - assert(sqlContext.cacheManager.isEmpty) + assert(spark.cacheManager.isEmpty) } test("Clear accumulators when uncacheTable to prevent memory leaking") { sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") + spark.catalog.cacheTable("t1") + spark.catalog.cacheTable("t2") sql("SELECT * FROM t1").count() sql("SELECT * FROM t2").count() sql("SELECT * FROM t1").count() sql("SELECT * FROM t2").count() - val accId1 = sqlContext.table("t1").queryExecution.withCachedData.collect { + val accId1 = spark.table("t1").queryExecution.withCachedData.collect { case i: InMemoryRelation => i.batchStats.id }.head - val accId2 = sqlContext.table("t1").queryExecution.withCachedData.collect { + val accId2 = spark.table("t1").queryExecution.withCachedData.collect { case i: InMemoryRelation => i.batchStats.id }.head - sqlContext.uncacheTable("t1") - sqlContext.uncacheTable("t2") + spark.catalog.uncacheTable("t1") + spark.catalog.uncacheTable("t2") assert(AccumulatorContext.get(accId1).isEmpty) assert(AccumulatorContext.get(accId2).isEmpty) @@ -351,7 +351,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext test("SPARK-10327 Cache Table is not working while subquery has alias in its project list") { sparkContext.parallelize((1, 1) :: (2, 2) :: Nil) .toDF("key", "value").selectExpr("key", "value", "key+1").registerTempTable("abc") - sqlContext.cacheTable("abc") + spark.catalog.cacheTable("abc") val sparkPlan = sql( """select a.key, b.key, c.key from @@ -374,15 +374,15 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext table3x.registerTempTable("testData3x") sql("SELECT key, value FROM testData3x ORDER BY key").registerTempTable("orderedTable") - sqlContext.cacheTable("orderedTable") - assertCached(sqlContext.table("orderedTable")) + spark.catalog.cacheTable("orderedTable") + assertCached(spark.table("orderedTable")) // Should not have an exchange as the query is already sorted on the group by key. verifyNumExchanges(sql("SELECT key, count(*) FROM orderedTable GROUP BY key"), 0) checkAnswer( sql("SELECT key, count(*) FROM orderedTable GROUP BY key ORDER BY key"), sql("SELECT key, count(*) FROM testData3x GROUP BY key ORDER BY key").collect()) - sqlContext.uncacheTable("orderedTable") - sqlContext.dropTempTable("orderedTable") + spark.catalog.uncacheTable("orderedTable") + spark.catalog.dropTempTable("orderedTable") // Set up two tables distributed in the same way. Try this with the data distributed into // different number of partitions. @@ -390,8 +390,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext withTempTable("t1", "t2") { testData.repartition(numPartitions, $"key").registerTempTable("t1") testData2.repartition(numPartitions, $"a").registerTempTable("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") + spark.catalog.cacheTable("t1") + spark.catalog.cacheTable("t2") // Joining them should result in no exchanges. verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 0) @@ -403,8 +403,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext checkAnswer(sql("SELECT count(*) FROM t1 GROUP BY key"), sql("SELECT count(*) FROM testData GROUP BY key")) - sqlContext.uncacheTable("t1") - sqlContext.uncacheTable("t2") + spark.catalog.uncacheTable("t1") + spark.catalog.uncacheTable("t2") } } @@ -412,8 +412,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext withTempTable("t1", "t2") { testData.repartition(6, $"key").registerTempTable("t1") testData2.repartition(3, $"a").registerTempTable("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") + spark.catalog.cacheTable("t1") + spark.catalog.cacheTable("t2") val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a") verifyNumExchanges(query, 1) @@ -421,16 +421,16 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext checkAnswer( query, testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b")) - sqlContext.uncacheTable("t1") - sqlContext.uncacheTable("t2") + spark.catalog.uncacheTable("t1") + spark.catalog.uncacheTable("t2") } // One side of join is not partitioned in the desired way. Need to shuffle one side. withTempTable("t1", "t2") { testData.repartition(6, $"value").registerTempTable("t1") testData2.repartition(6, $"a").registerTempTable("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") + spark.catalog.cacheTable("t1") + spark.catalog.cacheTable("t2") val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a") verifyNumExchanges(query, 1) @@ -438,15 +438,15 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext checkAnswer( query, testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b")) - sqlContext.uncacheTable("t1") - sqlContext.uncacheTable("t2") + spark.catalog.uncacheTable("t1") + spark.catalog.uncacheTable("t2") } withTempTable("t1", "t2") { testData.repartition(6, $"value").registerTempTable("t1") testData2.repartition(12, $"a").registerTempTable("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") + spark.catalog.cacheTable("t1") + spark.catalog.cacheTable("t2") val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a") verifyNumExchanges(query, 1) @@ -454,8 +454,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext checkAnswer( query, testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b")) - sqlContext.uncacheTable("t1") - sqlContext.uncacheTable("t2") + spark.catalog.uncacheTable("t1") + spark.catalog.uncacheTable("t2") } // One side of join is not partitioned in the desired way. Since the number of partitions of @@ -464,30 +464,30 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext withTempTable("t1", "t2") { testData.repartition(6, $"value").registerTempTable("t1") testData2.repartition(3, $"a").registerTempTable("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") + spark.catalog.cacheTable("t1") + spark.catalog.cacheTable("t2") val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a") verifyNumExchanges(query, 2) checkAnswer( query, testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b")) - sqlContext.uncacheTable("t1") - sqlContext.uncacheTable("t2") + spark.catalog.uncacheTable("t1") + spark.catalog.uncacheTable("t2") } // repartition's column ordering is different from group by column ordering. // But they use the same set of columns. withTempTable("t1") { testData.repartition(6, $"value", $"key").registerTempTable("t1") - sqlContext.cacheTable("t1") + spark.catalog.cacheTable("t1") val query = sql("SELECT value, key from t1 group by key, value") verifyNumExchanges(query, 0) checkAnswer( query, testData.distinct().select($"value", $"key")) - sqlContext.uncacheTable("t1") + spark.catalog.uncacheTable("t1") } // repartition's column ordering is different from join condition's column ordering. @@ -499,8 +499,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext df1.repartition(6, $"value", $"key").registerTempTable("t1") val df2 = testData2.select($"a", $"b".cast("string")) df2.repartition(6, $"a", $"b").registerTempTable("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") + spark.catalog.cacheTable("t1") + spark.catalog.cacheTable("t2") val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a and t1.value = t2.b") @@ -509,8 +509,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext checkAnswer( query, df1.join(df2, $"key" === $"a" && $"value" === $"b").select($"key", $"value", $"a", $"b")) - sqlContext.uncacheTable("t1") - sqlContext.uncacheTable("t2") + spark.catalog.uncacheTable("t1") + spark.catalog.uncacheTable("t2") } } } http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 19fe29a..a5aecca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -29,7 +29,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { import testImplicits._ private lazy val booleanData = { - sqlContext.createDataFrame(sparkContext.parallelize( + spark.createDataFrame(sparkContext.parallelize( Row(false, false) :: Row(false, true) :: Row(true, false) :: @@ -287,7 +287,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { } test("isNaN") { - val testData = sqlContext.createDataFrame(sparkContext.parallelize( + val testData = spark.createDataFrame(sparkContext.parallelize( Row(Double.NaN, Float.NaN) :: Row(math.log(-1), math.log(-3).toFloat) :: Row(null, null) :: @@ -308,7 +308,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { } test("nanvl") { - val testData = sqlContext.createDataFrame(sparkContext.parallelize( + val testData = spark.createDataFrame(sparkContext.parallelize( Row(null, 3.0, Double.NaN, Double.PositiveInfinity, 1.0f, 4) :: Nil), StructType(Seq(StructField("a", DoubleType), StructField("b", DoubleType), StructField("c", DoubleType), StructField("d", DoubleType), @@ -351,7 +351,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { } test("=!=") { - val nullData = sqlContext.createDataFrame(sparkContext.parallelize( + val nullData = spark.createDataFrame(sparkContext.parallelize( Row(1, 1) :: Row(1, 2) :: Row(1, null) :: @@ -370,7 +370,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { nullData.filter($"a" <=> $"b"), Row(1, 1) :: Row(null, null) :: Nil) - val nullData2 = sqlContext.createDataFrame(sparkContext.parallelize( + val nullData2 = spark.createDataFrame(sparkContext.parallelize( Row("abc") :: Row(null) :: Row("xyz") :: Nil), @@ -596,7 +596,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { withTempPath { dir => val data = sparkContext.parallelize(0 to 10).toDF("id") data.write.parquet(dir.getCanonicalPath) - val answer = sqlContext.read.parquet(dir.getCanonicalPath).select(input_file_name()) + val answer = spark.read.parquet(dir.getCanonicalPath).select(input_file_name()) .head.getString(0) assert(answer.contains(dir.getCanonicalPath)) http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 63f4b75..8a99866 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -70,7 +70,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(new java.math.BigDecimal(3.0), new java.math.BigDecimal(3.0))) ) - val decimalDataWithNulls = sqlContext.sparkContext.parallelize( + val decimalDataWithNulls = spark.sparkContext.parallelize( DecimalData(1, 1) :: DecimalData(1, null) :: DecimalData(2, 1) :: @@ -114,7 +114,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(null, null, 113000.0) :: Nil ) - val df0 = sqlContext.sparkContext.parallelize(Seq( + val df0 = spark.sparkContext.parallelize(Seq( Fact(20151123, 18, 35, "room1", 18.6), Fact(20151123, 18, 35, "room2", 22.4), Fact(20151123, 18, 36, "room1", 17.4), @@ -207,12 +207,12 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Seq(Row(1, 3), Row(2, 3), Row(3, 3)) ) - sqlContext.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, false) + spark.conf.set(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS.key, false) checkAnswer( testData2.groupBy("a").agg(sum($"b")), Seq(Row(3), Row(3), Row(3)) ) - sqlContext.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, true) + spark.conf.set(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS.key, true) } test("agg without groups") { @@ -433,10 +433,10 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { test("SPARK-14664: Decimal sum/avg over window should work.") { checkAnswer( - sqlContext.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"), + spark.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"), Row(6.0) :: Row(6.0) :: Row(6.0) :: Nil) checkAnswer( - sqlContext.sql("select avg(a) over () from values 1.0, 2.0, 3.0 T(a)"), + spark.sql("select avg(a) over () from values 1.0, 2.0, 3.0 T(a)"), Row(2.0) :: Row(2.0) :: Row(2.0) :: Nil) } } http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 0414fa1..031e66b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -154,7 +154,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { // SPARK-12275: no physical plan for BroadcastHint in some condition withTempPath { path => df1.write.parquet(path.getCanonicalPath) - val pf1 = sqlContext.read.parquet(path.getCanonicalPath) + val pf1 = spark.read.parquet(path.getCanonicalPath) assert(df1.join(broadcast(pf1)).count() === 4) } } http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index c6d6751..fa8fa06 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -81,11 +81,11 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ } test("pivot max values enforced") { - sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, 1) + spark.conf.set(SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key, 1) intercept[AnalysisException]( courseSales.groupBy("year").pivot("course") ) - sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, + spark.conf.set(SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key, SQLConf.DATAFRAME_PIVOT_MAX_VALUES.defaultValue.get) } @@ -104,7 +104,7 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ // pivot with extra columns to trigger optimization .pivot("course", Seq("dotNET", "Java") ++ (1 to 10).map(_.toString)) .agg(sum($"earnings")) - val queryExecution = sqlContext.executePlan(df.queryExecution.logical) + val queryExecution = spark.executePlan(df.queryExecution.logical) assert(queryExecution.simpleString.contains("pivotfirst")) } http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 0ea7727..ab7733b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -236,7 +236,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { } test("sampleBy") { - val df = sqlContext.range(0, 100).select((col("id") % 3).as("key")) + val df = spark.range(0, 100).select((col("id") % 3).as("key")) val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L) checkAnswer( sampled.groupBy("key").count().orderBy("key"), @@ -247,7 +247,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { // `CountMinSketch`es that meet required specs. Test cases for `CountMinSketch` can be found in // `CountMinSketchSuite` in project spark-sketch. test("countMinSketch") { - val df = sqlContext.range(1000) + val df = spark.range(1000) val sketch1 = df.stat.countMinSketch("id", depth = 10, width = 20, seed = 42) assert(sketch1.totalCount() === 1000) @@ -279,7 +279,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { // This test only verifies some basic requirements, more correctness tests can be found in // `BloomFilterSuite` in project spark-sketch. test("Bloom filter") { - val df = sqlContext.range(1000) + val df = spark.range(1000) val filter1 = df.stat.bloomFilter("id", 1000, 0.03) assert(filter1.expectedFpp() - 0.03 < 1e-3) @@ -304,7 +304,7 @@ class DataFrameStatPerfSuite extends QueryTest with SharedSQLContext with Loggin // Turn on this test if you want to test the performance of approximate quantiles. ignore("computing quantiles should not take much longer than describe()") { - val df = sqlContext.range(5000000L).toDF("col1").cache() + val df = spark.range(5000000L).toDF("col1").cache() def seconds(f: => Any): Double = { // Do some warmup logDebug("warmup...") http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 80a93ee..f77403c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -99,8 +99,8 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val rowRDD2 = sparkContext.parallelize(Seq(Row(2, new ExamplePoint(3.0, 4.0)))) val schema2 = StructType(Array(StructField("label", IntegerType, false), StructField("point", new ExamplePointUDT(), false))) - val df1 = sqlContext.createDataFrame(rowRDD1, schema1) - val df2 = sqlContext.createDataFrame(rowRDD2, schema2) + val df1 = spark.createDataFrame(rowRDD1, schema1) + val df2 = spark.createDataFrame(rowRDD2, schema2) checkAnswer( df1.union(df2).orderBy("label"), @@ -109,8 +109,8 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("empty data frame") { - assert(sqlContext.emptyDataFrame.columns.toSeq === Seq.empty[String]) - assert(sqlContext.emptyDataFrame.count() === 0) + assert(spark.emptyDataFrame.columns.toSeq === Seq.empty[String]) + assert(spark.emptyDataFrame.count() === 0) } test("head and take") { @@ -369,7 +369,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { // SPARK-12340: overstep the bounds of Int in SparkPlan.executeTake checkAnswer( - sqlContext.range(2).toDF().limit(2147483638), + spark.range(2).toDF().limit(2147483638), Row(0) :: Row(1) :: Nil ) } @@ -672,12 +672,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val parquetDir = new File(dir, "parquet").getCanonicalPath df.write.parquet(parquetDir) - val parquetDF = sqlContext.read.parquet(parquetDir) + val parquetDF = spark.read.parquet(parquetDir) assert(parquetDF.inputFiles.nonEmpty) val jsonDir = new File(dir, "json").getCanonicalPath df.write.json(jsonDir) - val jsonDF = sqlContext.read.json(jsonDir) + val jsonDF = spark.read.json(jsonDir) assert(parquetDF.inputFiles.nonEmpty) val unioned = jsonDF.union(parquetDF).inputFiles.sorted @@ -801,7 +801,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") { val rowRDD = sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false))) - val df = sqlContext.createDataFrame(rowRDD, schema) + val df = spark.createDataFrame(rowRDD, schema) df.rdd.collect() } @@ -818,14 +818,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-7551: support backticks for DataFrame attribute resolution") { - val df = sqlContext.read.json(sparkContext.makeRDD( + val df = spark.read.json(sparkContext.makeRDD( """{"a.b": {"c": {"d..e": {"f": 1}}}}""" :: Nil)) checkAnswer( df.select(df("`a.b`.c.`d..e`.`f`")), Row(1) ) - val df2 = sqlContext.read.json(sparkContext.makeRDD( + val df2 = spark.read.json(sparkContext.makeRDD( """{"a b": {"c": {"d e": {"f": 1}}}}""" :: Nil)) checkAnswer( df2.select(df2("`a b`.c.d e.f")), @@ -881,53 +881,53 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("SPARK-7150 range api") { // numSlice is greater than length - val res1 = sqlContext.range(0, 10, 1, 15).select("id") + val res1 = spark.range(0, 10, 1, 15).select("id") assert(res1.count == 10) assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) - val res2 = sqlContext.range(3, 15, 3, 2).select("id") + val res2 = spark.range(3, 15, 3, 2).select("id") assert(res2.count == 4) assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30))) - val res3 = sqlContext.range(1, -2).select("id") + val res3 = spark.range(1, -2).select("id") assert(res3.count == 0) // start is positive, end is negative, step is negative - val res4 = sqlContext.range(1, -2, -2, 6).select("id") + val res4 = spark.range(1, -2, -2, 6).select("id") assert(res4.count == 2) assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0))) // start, end, step are negative - val res5 = sqlContext.range(-3, -8, -2, 1).select("id") + val res5 = spark.range(-3, -8, -2, 1).select("id") assert(res5.count == 3) assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15))) // start, end are negative, step is positive - val res6 = sqlContext.range(-8, -4, 2, 1).select("id") + val res6 = spark.range(-8, -4, 2, 1).select("id") assert(res6.count == 2) assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14))) - val res7 = sqlContext.range(-10, -9, -20, 1).select("id") + val res7 = spark.range(-10, -9, -20, 1).select("id") assert(res7.count == 0) - val res8 = sqlContext.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id") + val res8 = spark.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id") assert(res8.count == 3) assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3))) - val res9 = sqlContext.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id") + val res9 = spark.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id") assert(res9.count == 2) assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1))) // only end provided as argument - val res10 = sqlContext.range(10).select("id") + val res10 = spark.range(10).select("id") assert(res10.count == 10) assert(res10.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) - val res11 = sqlContext.range(-1).select("id") + val res11 = spark.range(-1).select("id") assert(res11.count == 0) // using the default slice number - val res12 = sqlContext.range(3, 15, 3).select("id") + val res12 = spark.range(3, 15, 3).select("id") assert(res12.count == 4) assert(res12.agg(sum("id")).as("sumid").collect() === Seq(Row(30))) } @@ -993,13 +993,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { // pass case: parquet table (HadoopFsRelation) df.write.mode(SaveMode.Overwrite).parquet(tempParquetFile.getCanonicalPath) - val pdf = sqlContext.read.parquet(tempParquetFile.getCanonicalPath) + val pdf = spark.read.parquet(tempParquetFile.getCanonicalPath) pdf.registerTempTable("parquet_base") insertion.write.insertInto("parquet_base") // pass case: json table (InsertableRelation) df.write.mode(SaveMode.Overwrite).json(tempJsonFile.getCanonicalPath) - val jdf = sqlContext.read.json(tempJsonFile.getCanonicalPath) + val jdf = spark.read.json(tempJsonFile.getCanonicalPath) jdf.registerTempTable("json_base") insertion.write.mode(SaveMode.Overwrite).insertInto("json_base") @@ -1019,7 +1019,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(e2.getMessage.contains("Inserting into an RDD-based table is not allowed.")) // error case: insert into an OneRowRelation - Dataset.ofRows(sqlContext.sparkSession, OneRowRelation).registerTempTable("one_row") + Dataset.ofRows(spark, OneRowRelation).registerTempTable("one_row") val e3 = intercept[AnalysisException] { insertion.write.insertInto("one_row") } @@ -1062,7 +1062,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-9323: DataFrame.orderBy should support nested column name") { - val df = sqlContext.read.json(sparkContext.makeRDD( + val df = spark.read.json(sparkContext.makeRDD( """{"a": {"b": 1}}""" :: Nil)) checkAnswer(df.orderBy("a.b"), Row(Row(1))) } @@ -1091,10 +1091,10 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val dir2 = new File(dir, "dir2").getCanonicalPath df2.write.format("json").save(dir2) - checkAnswer(sqlContext.read.format("json").load(dir1, dir2), + checkAnswer(spark.read.format("json").load(dir1, dir2), Row(1, 22) :: Row(2, 23) :: Nil) - checkAnswer(sqlContext.read.format("json").load(dir1), + checkAnswer(spark.read.format("json").load(dir1), Row(1, 22) :: Nil) } } @@ -1116,7 +1116,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-10316: respect non-deterministic expressions in PhysicalOperation") { - val input = sqlContext.read.json(sqlContext.sparkContext.makeRDD( + val input = spark.read.json(spark.sparkContext.makeRDD( (1 to 10).map(i => s"""{"id": $i}"""))) val df = input.select($"id", rand(0).as('r)) @@ -1185,7 +1185,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { withTempPath { path => Seq(2012 -> "a").toDF("year", "val").write.partitionBy("year").parquet(path.getAbsolutePath) - val df = sqlContext.read.parquet(path.getAbsolutePath) + val df = spark.read.parquet(path.getAbsolutePath) checkAnswer(df.filter($"yEAr" > 2000).select($"val"), Row("a")) } } @@ -1244,7 +1244,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { verifyExchangingAgg(testData.repartition($"key", $"value") .groupBy("key").count()) - val data = sqlContext.sparkContext.parallelize( + val data = spark.sparkContext.parallelize( (1 to 100).map(i => TestData2(i % 10, i))).toDF() // Distribute and order by. @@ -1308,7 +1308,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { withTempPath { path => val p = path.getAbsolutePath Seq(2012 -> "a").toDF("year", "val").write.partitionBy("yEAr").parquet(p) - checkAnswer(sqlContext.read.parquet(p).select("YeaR"), Row(2012)) + checkAnswer(spark.read.parquet(p).select("YeaR"), Row(2012)) } } } @@ -1317,7 +1317,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("SPARK-11633: LogicalRDD throws TreeNode Exception: Failed to Copy Node") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { val rdd = sparkContext.makeRDD(Seq(Row(1, 3), Row(2, 1))) - val df = sqlContext.createDataFrame( + val df = spark.createDataFrame( rdd, new StructType().add("f1", IntegerType).add("f2", IntegerType), needsConversion = false).select($"F1", $"f2".as("f2")) @@ -1344,7 +1344,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } checkAnswer(df.select(boxedUDF($"age")), Row(null) :: Row(-10) :: Nil) - sqlContext.udf.register("boxedUDF", + spark.udf.register("boxedUDF", (i: java.lang.Integer) => (if (i == null) -10 else null): java.lang.Integer) checkAnswer(sql("select boxedUDF(null), boxedUDF(-1)"), Row(-10, null) :: Nil) @@ -1393,7 +1393,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("reuse exchange") { withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "2") { - val df = sqlContext.range(100).toDF() + val df = spark.range(100).toDF() val join = df.join(df, "id") val plan = join.queryExecution.executedPlan checkAnswer(join, df) @@ -1415,14 +1415,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("sameResult() on aggregate") { - val df = sqlContext.range(100) + val df = spark.range(100) val agg1 = df.groupBy().count() val agg2 = df.groupBy().count() // two aggregates with different ExprId within them should have same result assert(agg1.queryExecution.executedPlan.sameResult(agg2.queryExecution.executedPlan)) val agg3 = df.groupBy().sum() assert(!agg1.queryExecution.executedPlan.sameResult(agg3.queryExecution.executedPlan)) - val df2 = sqlContext.range(101) + val df2 = spark.range(101) val agg4 = df2.groupBy().count() assert(!agg1.queryExecution.executedPlan.sameResult(agg4.queryExecution.executedPlan)) } @@ -1454,24 +1454,24 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("assertAnalyzed shouldn't replace original stack trace") { val e = intercept[AnalysisException] { - sqlContext.range(1).select('id as 'a, 'id as 'b).groupBy('a).agg('b) + spark.range(1).select('id as 'a, 'id as 'b).groupBy('a).agg('b) } assert(e.getStackTrace.head.getClassName != classOf[QueryExecution].getName) } test("SPARK-13774: Check error message for non existent path without globbed paths") { - val e = intercept[AnalysisException] (sqlContext.read.format("csv"). + val e = intercept[AnalysisException] (spark.read.format("csv"). load("/xyz/file2", "/xyz/file21", "/abc/files555", "a")).getMessage() assert(e.startsWith("Path does not exist")) } test("SPARK-13774: Check error message for not existent globbed paths") { - val e = intercept[AnalysisException] (sqlContext.read.format("text"). + val e = intercept[AnalysisException] (spark.read.format("text"). load( "/xyz/*")).getMessage() assert(e.startsWith("Path does not exist")) - val e1 = intercept[AnalysisException] (sqlContext.read.json("/mnt/*/*-xyz.json").rdd). + val e1 = intercept[AnalysisException] (spark.read.json("/mnt/*/*-xyz.json").rdd). getMessage() assert(e1.startsWith("Path does not exist")) } http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala index 06584ec..a957d5b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala @@ -249,14 +249,14 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B try { f(tableName) } finally { - sqlContext.dropTempTable(tableName) + spark.catalog.dropTempTable(tableName) } } test("time window in SQL with single string expression") { withTempTable { table => checkAnswer( - sqlContext.sql(s"""select window(time, "10 seconds"), value from $table""") + spark.sql(s"""select window(time, "10 seconds"), value from $table""") .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"value"), Seq( Row("2016-03-27 19:39:20", "2016-03-27 19:39:30", 4), @@ -270,7 +270,7 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B test("time window in SQL with with two expressions") { withTempTable { table => checkAnswer( - sqlContext.sql( + spark.sql( s"""select window(time, "10 seconds", 10000000), value from $table""") .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"value"), Seq( @@ -285,7 +285,7 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B test("time window in SQL with with three expressions") { withTempTable { table => checkAnswer( - sqlContext.sql( + spark.sql( s"""select window(time, "10 seconds", 10000000, "5 seconds"), value from $table""") .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"value"), Seq( http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala index 68e99d6..fe6ba83 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala @@ -48,7 +48,7 @@ class DataFrameTungstenSuite extends QueryTest with SharedSQLContext { .add("b3", FloatType) .add("b4", DoubleType)) - val df = sqlContext.createDataFrame(data, schema) + val df = spark.createDataFrame(data, schema) assert(df.select("b").first() === Row(struct)) } @@ -70,7 +70,7 @@ class DataFrameTungstenSuite extends QueryTest with SharedSQLContext { .add("b5b", StringType)) .add("b6", StringType)) - val df = sqlContext.createDataFrame(data, schema) + val df = spark.createDataFrame(data, schema) assert(df.select("b").first() === Row(outerStruct)) } } http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala index ae9fb80..d8e241c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.expressions.scala.typed import org.apache.spark.sql.functions._ @@ -31,14 +31,14 @@ object DatasetBenchmark { case class Data(l: Long, s: String) - def backToBackMap(sqlContext: SQLContext, numRows: Long, numChains: Int): Benchmark = { - import sqlContext.implicits._ + def backToBackMap(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = { + import spark.implicits._ - val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) + val df = spark.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) val benchmark = new Benchmark("back-to-back map", numRows) val func = (d: Data) => Data(d.l + 1, d.s) - val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) + val rdd = spark.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) benchmark.addCase("RDD") { iter => var res = rdd var i = 0 @@ -72,17 +72,17 @@ object DatasetBenchmark { benchmark } - def backToBackFilter(sqlContext: SQLContext, numRows: Long, numChains: Int): Benchmark = { - import sqlContext.implicits._ + def backToBackFilter(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = { + import spark.implicits._ - val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) + val df = spark.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) val benchmark = new Benchmark("back-to-back filter", numRows) val func = (d: Data, i: Int) => d.l % (100L + i) == 0L val funcs = 0.until(numChains).map { i => (d: Data) => func(d, i) } - val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) + val rdd = spark.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) benchmark.addCase("RDD") { iter => var res = rdd var i = 0 @@ -130,13 +130,13 @@ object DatasetBenchmark { override def outputEncoder: Encoder[Long] = Encoders.scalaLong } - def aggregate(sqlContext: SQLContext, numRows: Long): Benchmark = { - import sqlContext.implicits._ + def aggregate(spark: SparkSession, numRows: Long): Benchmark = { + import spark.implicits._ - val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) + val df = spark.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) val benchmark = new Benchmark("aggregate", numRows) - val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) + val rdd = spark.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) benchmark.addCase("RDD sum") { iter => rdd.aggregate(0L)(_ + _.l, _ + _) } @@ -157,15 +157,17 @@ object DatasetBenchmark { } def main(args: Array[String]): Unit = { - val sparkContext = new SparkContext("local[*]", "Dataset benchmark") - val sqlContext = new SQLContext(sparkContext) + val spark = SparkSession.builder + .master("local[*]") + .appName("Dataset benchmark") + .getOrCreate() val numRows = 100000000 val numChains = 10 - val benchmark = backToBackMap(sqlContext, numRows, numChains) - val benchmark2 = backToBackFilter(sqlContext, numRows, numChains) - val benchmark3 = aggregate(sqlContext, numRows) + val benchmark = backToBackMap(spark, numRows, numChains) + val benchmark2 = backToBackFilter(spark, numRows, numChains) + val benchmark3 = aggregate(spark, numRows) /* Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4 http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index 942cc09..8c0906b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -39,7 +39,7 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { 2, 3, 4) // Drop the cache. cached.unpersist() - assert(!sqlContext.isCached(cached), "The Dataset should not be cached.") + assert(spark.cacheManager.lookupCachedData(cached).isEmpty, "The Dataset should not be cached.") } test("persist and then rebind right encoder when join 2 datasets") { @@ -56,9 +56,11 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { assertCached(joined, 2) ds1.unpersist() - assert(!sqlContext.isCached(ds1), "The Dataset ds1 should not be cached.") + assert(spark.cacheManager.lookupCachedData(ds1).isEmpty, + "The Dataset ds1 should not be cached.") ds2.unpersist() - assert(!sqlContext.isCached(ds2), "The Dataset ds2 should not be cached.") + assert(spark.cacheManager.lookupCachedData(ds2).isEmpty, + "The Dataset ds2 should not be cached.") } test("persist and then groupBy columns asKey, map") { @@ -73,8 +75,9 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { assertCached(agged.filter(_._1 == "b")) ds.unpersist() - assert(!sqlContext.isCached(ds), "The Dataset ds should not be cached.") + assert(spark.cacheManager.lookupCachedData(ds).isEmpty, "The Dataset ds should not be cached.") agged.unpersist() - assert(!sqlContext.isCached(agged), "The Dataset agged should not be cached.") + assert(spark.cacheManager.lookupCachedData(agged).isEmpty, + "The Dataset agged should not be cached.") } } http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 3cb4e52..3c8c862 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -46,12 +46,12 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } test("range") { - assert(sqlContext.range(10).map(_ + 1).reduce(_ + _) == 55) - assert(sqlContext.range(10).map{ case i: java.lang.Long => i + 1 }.reduce(_ + _) == 55) - assert(sqlContext.range(0, 10).map(_ + 1).reduce(_ + _) == 55) - assert(sqlContext.range(0, 10).map{ case i: java.lang.Long => i + 1 }.reduce(_ + _) == 55) - assert(sqlContext.range(0, 10, 1, 2).map(_ + 1).reduce(_ + _) == 55) - assert(sqlContext.range(0, 10, 1, 2).map{ case i: java.lang.Long => i + 1 }.reduce(_ + _) == 55) + assert(spark.range(10).map(_ + 1).reduce(_ + _) == 55) + assert(spark.range(10).map{ case i: java.lang.Long => i + 1 }.reduce(_ + _) == 55) + assert(spark.range(0, 10).map(_ + 1).reduce(_ + _) == 55) + assert(spark.range(0, 10).map{ case i: java.lang.Long => i + 1 }.reduce(_ + _) == 55) + assert(spark.range(0, 10, 1, 2).map(_ + 1).reduce(_ + _) == 55) + assert(spark.range(0, 10, 1, 2).map{ case i: java.lang.Long => i + 1 }.reduce(_ + _) == 55) } test("SPARK-12404: Datatype Helper Serializability") { @@ -472,7 +472,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } test("SPARK-14696: implicit encoders for boxed types") { - assert(sqlContext.range(1).map { i => i : java.lang.Long }.head == 0L) + assert(spark.range(1).map { i => i : java.lang.Long }.head == 0L) } test("SPARK-11894: Incorrect results are returned when using null") { @@ -510,8 +510,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext { )) def buildDataset(rows: Row*): Dataset[NestedStruct] = { - val rowRDD = sqlContext.sparkContext.parallelize(rows) - sqlContext.createDataFrame(rowRDD, schema).as[NestedStruct] + val rowRDD = spark.sparkContext.parallelize(rows) + spark.createDataFrame(rowRDD, schema).as[NestedStruct] } checkDataset( @@ -626,7 +626,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } test("SPARK-14554: Dataset.map may generate wrong java code for wide table") { - val wideDF = sqlContext.range(10).select(Seq.tabulate(1000) {i => ('id + i).as(s"c$i")} : _*) + val wideDF = spark.range(10).select(Seq.tabulate(1000) {i => ('id + i).as(s"c$i")} : _*) // Make sure the generated code for this plan can compile and execute. checkDataset(wideDF.map(_.getLong(0)), 0L until 10 : _*) } @@ -654,7 +654,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { dataset.join(actual, dataset("user") === actual("id")).collect() } - test("SPARK-15097: implicits on dataset's sqlContext can be imported") { + test("SPARK-15097: implicits on dataset's spark can be imported") { val dataset = Seq(1, 2, 3).toDS() checkDataset(DatasetTransform.addOne(dataset), 2, 3, 4) } @@ -735,10 +735,10 @@ object JavaData { def apply(a: Int): JavaData = new JavaData(a) } -/** Used to test importing dataset.sqlContext.implicits._ */ +/** Used to test importing dataset.spark.implicits._ */ object DatasetTransform { def addOne(ds: Dataset[Int]): Dataset[Int] = { - import ds.sqlContext.implicits._ + import ds.sparkSession.implicits._ ds.map(_ + 1) } } http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala index b1987c6..a41b465 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala @@ -51,7 +51,7 @@ class ExtraStrategiesSuite extends QueryTest with SharedSQLContext { test("insert an extraStrategy") { try { - sqlContext.experimental.extraStrategies = TestStrategy :: Nil + spark.experimental.extraStrategies = TestStrategy :: Nil val df = sparkContext.parallelize(Seq(("so slow", 1))).toDF("a", "b") checkAnswer( @@ -62,7 +62,7 @@ class ExtraStrategiesSuite extends QueryTest with SharedSQLContext { df.select("a", "b"), Row("so slow", 1)) } finally { - sqlContext.experimental.extraStrategies = Nil + spark.experimental.extraStrategies = Nil } } } http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 8cbad04..da567db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -37,7 +37,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.optimizedPlan - val planned = sqlContext.sessionState.planner.JoinSelection(join) + val planned = spark.sessionState.planner.JoinSelection(join) assert(planned.size === 1) } @@ -60,7 +60,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { } test("join operator selection") { - sqlContext.cacheManager.clearCache() + spark.cacheManager.clearCache() withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "0") { Seq( @@ -112,7 +112,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { // } test("broadcasted hash join operator selection") { - sqlContext.cacheManager.clearCache() + spark.cacheManager.clearCache() sql("CACHE TABLE testData") Seq( ("SELECT * FROM testData join testData2 ON key = a", @@ -126,7 +126,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { } test("broadcasted hash outer join operator selection") { - sqlContext.cacheManager.clearCache() + spark.cacheManager.clearCache() sql("CACHE TABLE testData") sql("CACHE TABLE testData2") Seq( @@ -144,7 +144,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.optimizedPlan - val planned = sqlContext.sessionState.planner.JoinSelection(join) + val planned = spark.sessionState.planner.JoinSelection(join) assert(planned.size === 1) } @@ -435,7 +435,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { } test("broadcasted existence join operator selection") { - sqlContext.cacheManager.clearCache() + spark.cacheManager.clearCache() sql("CACHE TABLE testData") withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000000000") { @@ -461,17 +461,17 @@ class JoinSuite extends QueryTest with SharedSQLContext { test("cross join with broadcast") { sql("CACHE TABLE testData") - val sizeInByteOfTestData = statisticSizeInByte(sqlContext.table("testData")) + val sizeInByteOfTestData = statisticSizeInByte(spark.table("testData")) // we set the threshold is greater than statistic of the cached table testData withSQLConf( SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> (sizeInByteOfTestData + 1).toString()) { - assert(statisticSizeInByte(sqlContext.table("testData2")) > - sqlContext.conf.autoBroadcastJoinThreshold) + assert(statisticSizeInByte(spark.table("testData2")) > + spark.conf.get(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) - assert(statisticSizeInByte(sqlContext.table("testData")) < - sqlContext.conf.autoBroadcastJoinThreshold) + assert(statisticSizeInByte(spark.table("testData")) < + spark.conf.get(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala index 9f6c86a..c88dfe5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala @@ -33,36 +33,36 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex } after { - sqlContext.sessionState.catalog.dropTable( + spark.sessionState.catalog.dropTable( TableIdentifier("listtablessuitetable"), ignoreIfNotExists = true) } test("get all tables") { checkAnswer( - sqlContext.tables().filter("tableName = 'listtablessuitetable'"), + spark.wrapped.tables().filter("tableName = 'listtablessuitetable'"), Row("listtablessuitetable", true)) checkAnswer( sql("SHOW tables").filter("tableName = 'listtablessuitetable'"), Row("listtablessuitetable", true)) - sqlContext.sessionState.catalog.dropTable( + spark.sessionState.catalog.dropTable( TableIdentifier("listtablessuitetable"), ignoreIfNotExists = true) - assert(sqlContext.tables().filter("tableName = 'listtablessuitetable'").count() === 0) + assert(spark.wrapped.tables().filter("tableName = 'listtablessuitetable'").count() === 0) } test("getting all tables with a database name has no impact on returned table names") { checkAnswer( - sqlContext.tables("default").filter("tableName = 'listtablessuitetable'"), + spark.wrapped.tables("default").filter("tableName = 'listtablessuitetable'"), Row("listtablessuitetable", true)) checkAnswer( sql("show TABLES in default").filter("tableName = 'listtablessuitetable'"), Row("listtablessuitetable", true)) - sqlContext.sessionState.catalog.dropTable( + spark.sessionState.catalog.dropTable( TableIdentifier("listtablessuitetable"), ignoreIfNotExists = true) - assert(sqlContext.tables().filter("tableName = 'listtablessuitetable'").count() === 0) + assert(spark.wrapped.tables().filter("tableName = 'listtablessuitetable'").count() === 0) } test("query the returned DataFrame of tables") { @@ -70,7 +70,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex StructField("tableName", StringType, false) :: StructField("isTemporary", BooleanType, false) :: Nil) - Seq(sqlContext.tables(), sql("SHOW TABLes")).foreach { + Seq(spark.wrapped.tables(), sql("SHOW TABLes")).foreach { case tableDF => assert(expectedSchema === tableDF.schema) @@ -81,9 +81,9 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex Row(true, "listtablessuitetable") ) checkAnswer( - sqlContext.tables().filter("tableName = 'tables'").select("tableName", "isTemporary"), + spark.wrapped.tables().filter("tableName = 'tables'").select("tableName", "isTemporary"), Row("tables", true)) - sqlContext.dropTempTable("tables") + spark.catalog.dropTempTable("tables") } } } http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala new file mode 100644 index 0000000..1732977 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import _root_.io.netty.util.internal.logging.{InternalLoggerFactory, Slf4JLoggerFactory} +import org.scalatest.BeforeAndAfterAll +import org.scalatest.BeforeAndAfterEach +import org.scalatest.Suite + +/** Manages a local `spark` {@link SparkSession} variable, correctly stopping it after each test. */ +trait LocalSparkSession extends BeforeAndAfterEach with BeforeAndAfterAll { self: Suite => + + @transient var spark: SparkSession = _ + + override def beforeAll() { + super.beforeAll() + InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory()) + } + + override def afterEach() { + try { + resetSparkContext() + } finally { + super.afterEach() + } + } + + def resetSparkContext(): Unit = { + LocalSparkSession.stop(spark) + spark = null + } + +} + +object LocalSparkSession { + def stop(spark: SparkSession) { + if (spark != null) { + spark.stop() + } + // To avoid RPC rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.driver.port") + } + + /** Runs `f` by passing in `sc` and ensures that `sc` is stopped. */ + def withSparkSession[T](sc: SparkSession)(f: SparkSession => T): T = { + try { + f(sc) + } finally { + stop(sc) + } + } + +} http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index df8b3b7..a1a9b66 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.types.ObjectType abstract class QueryTest extends PlanTest { - protected def sqlContext: SQLContext + protected def spark: SparkSession // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) @@ -81,7 +81,7 @@ abstract class QueryTest extends PlanTest { expectedAnswer: T*): Unit = { checkAnswer( ds.toDF(), - sqlContext.createDataset(expectedAnswer)(ds.unresolvedTEncoder).toDF().collect().toSeq) + spark.createDataset(expectedAnswer)(ds.unresolvedTEncoder).toDF().collect().toSeq) checkDecoding(ds, expectedAnswer: _*) } @@ -267,7 +267,7 @@ abstract class QueryTest extends PlanTest { val jsonBackPlan = try { - TreeNode.fromJSON[LogicalPlan](jsonString, sqlContext.sparkContext) + TreeNode.fromJSON[LogicalPlan](jsonString, spark.sparkContext) } catch { case NonFatal(e) => fail( @@ -282,7 +282,7 @@ abstract class QueryTest extends PlanTest { def renormalize: PartialFunction[LogicalPlan, LogicalPlan] = { case l: LogicalRDD => val origin = logicalRDDs.pop() - LogicalRDD(l.output, origin.rdd)(sqlContext.sparkSession) + LogicalRDD(l.output, origin.rdd)(spark) case l: LocalRelation => val origin = localRelations.pop() l.copy(data = origin.data) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org