Repository: spark Updated Branches: refs/heads/master 27fe6bacc -> 1d542785b
http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index ee85626..47cc74d 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -64,13 +64,13 @@ public class JavaDataFrameSuite { @Test public void testExecution() { - DataFrame df = context.table("testData").filter("key = 1"); - Assert.assertEquals(1, df.select("key").collect()[0].get(0)); + Dataset<Row> df = context.table("testData").filter("key = 1"); + Assert.assertEquals(1, df.select("key").collectRows()[0].get(0)); } @Test public void testCollectAndTake() { - DataFrame df = context.table("testData").filter("key = 1 or key = 2 or key = 3"); + Dataset<Row> df = context.table("testData").filter("key = 1 or key = 2 or key = 3"); Assert.assertEquals(3, df.select("key").collectAsList().size()); Assert.assertEquals(2, df.select("key").takeAsList(2).size()); } @@ -80,7 +80,7 @@ public class JavaDataFrameSuite { */ @Test public void testVarargMethods() { - DataFrame df = context.table("testData"); + Dataset<Row> df = context.table("testData"); df.toDF("key1", "value1"); @@ -109,7 +109,7 @@ public class JavaDataFrameSuite { df.select(coalesce(col("key"))); // Varargs with mathfunctions - DataFrame df2 = context.table("testData2"); + Dataset<Row> df2 = context.table("testData2"); df2.select(exp("a"), exp("b")); df2.select(exp(log("a"))); df2.select(pow("a", "a"), pow("b", 2.0)); @@ -123,7 +123,7 @@ public class JavaDataFrameSuite { @Ignore public void testShow() { // This test case is intended ignored, but to make sure it compiles correctly - DataFrame df = context.table("testData"); + Dataset<Row> df = context.table("testData"); df.show(); df.show(1000); } @@ -151,7 +151,7 @@ public class JavaDataFrameSuite { } } - void validateDataFrameWithBeans(Bean bean, DataFrame df) { + void validateDataFrameWithBeans(Bean bean, Dataset<Row> df) { StructType schema = df.schema(); Assert.assertEquals(new StructField("a", DoubleType$.MODULE$, false, Metadata.empty()), schema.apply("a")); @@ -191,7 +191,7 @@ public class JavaDataFrameSuite { public void testCreateDataFrameFromLocalJavaBeans() { Bean bean = new Bean(); List<Bean> data = Arrays.asList(bean); - DataFrame df = context.createDataFrame(data, Bean.class); + Dataset<Row> df = context.createDataFrame(data, Bean.class); validateDataFrameWithBeans(bean, df); } @@ -199,7 +199,7 @@ public class JavaDataFrameSuite { public void testCreateDataFrameFromJavaBeans() { Bean bean = new Bean(); JavaRDD<Bean> rdd = jsc.parallelize(Arrays.asList(bean)); - DataFrame df = context.createDataFrame(rdd, Bean.class); + Dataset<Row> df = context.createDataFrame(rdd, Bean.class); validateDataFrameWithBeans(bean, df); } @@ -207,8 +207,8 @@ public class JavaDataFrameSuite { public void testCreateDataFromFromList() { StructType schema = createStructType(Arrays.asList(createStructField("i", IntegerType, true))); List<Row> rows = Arrays.asList(RowFactory.create(0)); - DataFrame df = context.createDataFrame(rows, schema); - Row[] result = df.collect(); + Dataset<Row> df = context.createDataFrame(rows, schema); + Row[] result = df.collectRows(); Assert.assertEquals(1, result.length); } @@ -235,13 +235,13 @@ public class JavaDataFrameSuite { @Test public void testCrosstab() { - DataFrame df = context.table("testData2"); - DataFrame crosstab = df.stat().crosstab("a", "b"); + Dataset<Row> df = context.table("testData2"); + Dataset<Row> crosstab = df.stat().crosstab("a", "b"); String[] columnNames = crosstab.schema().fieldNames(); Assert.assertEquals("a_b", columnNames[0]); Assert.assertEquals("2", columnNames[1]); Assert.assertEquals("1", columnNames[2]); - Row[] rows = crosstab.collect(); + Row[] rows = crosstab.collectRows(); Arrays.sort(rows, crosstabRowComparator); Integer count = 1; for (Row row : rows) { @@ -254,31 +254,31 @@ public class JavaDataFrameSuite { @Test public void testFrequentItems() { - DataFrame df = context.table("testData2"); + Dataset<Row> df = context.table("testData2"); String[] cols = {"a"}; - DataFrame results = df.stat().freqItems(cols, 0.2); - Assert.assertTrue(results.collect()[0].getSeq(0).contains(1)); + Dataset<Row> results = df.stat().freqItems(cols, 0.2); + Assert.assertTrue(results.collectRows()[0].getSeq(0).contains(1)); } @Test public void testCorrelation() { - DataFrame df = context.table("testData2"); + Dataset<Row> df = context.table("testData2"); Double pearsonCorr = df.stat().corr("a", "b", "pearson"); Assert.assertTrue(Math.abs(pearsonCorr) < 1.0e-6); } @Test public void testCovariance() { - DataFrame df = context.table("testData2"); + Dataset<Row> df = context.table("testData2"); Double result = df.stat().cov("a", "b"); Assert.assertTrue(Math.abs(result) < 1.0e-6); } @Test public void testSampleBy() { - DataFrame df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); - DataFrame sampled = df.stat().<Integer>sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); - Row[] actual = sampled.groupBy("key").count().orderBy("key").collect(); + Dataset<Row> df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); + Dataset<Row> sampled = df.stat().<Integer>sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); + Row[] actual = sampled.groupBy("key").count().orderBy("key").collectRows(); Assert.assertEquals(0, actual[0].getLong(0)); Assert.assertTrue(0 <= actual[0].getLong(1) && actual[0].getLong(1) <= 8); Assert.assertEquals(1, actual[1].getLong(0)); @@ -287,10 +287,10 @@ public class JavaDataFrameSuite { @Test public void pivot() { - DataFrame df = context.table("courseSales"); + Dataset<Row> df = context.table("courseSales"); Row[] actual = df.groupBy("year") .pivot("course", Arrays.<Object>asList("dotNET", "Java")) - .agg(sum("earnings")).orderBy("year").collect(); + .agg(sum("earnings")).orderBy("year").collectRows(); Assert.assertEquals(2012, actual[0].getInt(0)); Assert.assertEquals(15000.0, actual[0].getDouble(1), 0.01); @@ -303,11 +303,11 @@ public class JavaDataFrameSuite { @Test public void testGenericLoad() { - DataFrame df1 = context.read().format("text").load( + Dataset<Row> df1 = context.read().format("text").load( Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString()); Assert.assertEquals(4L, df1.count()); - DataFrame df2 = context.read().format("text").load( + Dataset<Row> df2 = context.read().format("text").load( Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString(), Thread.currentThread().getContextClassLoader().getResource("text-suite2.txt").toString()); Assert.assertEquals(5L, df2.count()); @@ -315,11 +315,11 @@ public class JavaDataFrameSuite { @Test public void testTextLoad() { - DataFrame df1 = context.read().text( + Dataset<Row> df1 = context.read().text( Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString()); Assert.assertEquals(4L, df1.count()); - DataFrame df2 = context.read().text( + Dataset<Row> df2 = context.read().text( Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString(), Thread.currentThread().getContextClassLoader().getResource("text-suite2.txt").toString()); Assert.assertEquals(5L, df2.count()); @@ -327,7 +327,7 @@ public class JavaDataFrameSuite { @Test public void testCountMinSketch() { - DataFrame df = context.range(1000); + Dataset<Row> df = context.range(1000); CountMinSketch sketch1 = df.stat().countMinSketch("id", 10, 20, 42); Assert.assertEquals(sketch1.totalCount(), 1000); @@ -352,7 +352,7 @@ public class JavaDataFrameSuite { @Test public void testBloomFilter() { - DataFrame df = context.range(1000); + Dataset<Row> df = context.range(1000); BloomFilter filter1 = df.stat().bloomFilter("id", 1000, 0.03); Assert.assertTrue(filter1.expectedFpp() - 0.03 < 1e-3); http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index b054b10..79b6e61 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -169,7 +169,7 @@ public class JavaDatasetSuite implements Serializable { public void testGroupBy() { List<String> data = Arrays.asList("a", "foo", "bar"); Dataset<String> ds = context.createDataset(data, Encoders.STRING()); - GroupedDataset<Integer, String> grouped = ds.groupBy(new MapFunction<String, Integer>() { + GroupedDataset<Integer, String> grouped = ds.groupByKey(new MapFunction<String, Integer>() { @Override public Integer call(String v) throws Exception { return v.length(); @@ -217,7 +217,7 @@ public class JavaDatasetSuite implements Serializable { List<Integer> data2 = Arrays.asList(2, 6, 10); Dataset<Integer> ds2 = context.createDataset(data2, Encoders.INT()); - GroupedDataset<Integer, Integer> grouped2 = ds2.groupBy(new MapFunction<Integer, Integer>() { + GroupedDataset<Integer, Integer> grouped2 = ds2.groupByKey(new MapFunction<Integer, Integer>() { @Override public Integer call(Integer v) throws Exception { return v / 2; @@ -250,7 +250,7 @@ public class JavaDatasetSuite implements Serializable { List<String> data = Arrays.asList("a", "foo", "bar"); Dataset<String> ds = context.createDataset(data, Encoders.STRING()); GroupedDataset<Integer, String> grouped = - ds.groupBy(length(col("value"))).keyAs(Encoders.INT()); + ds.groupByKey(length(col("value"))).keyAs(Encoders.INT()); Dataset<String> mapped = grouped.mapGroups( new MapGroupsFunction<Integer, String, String>() { @@ -410,7 +410,7 @@ public class JavaDatasetSuite implements Serializable { Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3)); Dataset<Tuple2<String, Integer>> ds = context.createDataset(data, encoder); - GroupedDataset<String, Tuple2<String, Integer>> grouped = ds.groupBy( + GroupedDataset<String, Tuple2<String, Integer>> grouped = ds.groupByKey( new MapFunction<Tuple2<String, Integer>, String>() { @Override public String call(Tuple2<String, Integer> value) throws Exception { @@ -828,7 +828,7 @@ public class JavaDatasetSuite implements Serializable { }) }); - DataFrame df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset<Row> df = context.createDataFrame(Collections.singletonList(row), schema); Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class)); SmallBean smallBean = new SmallBean(); @@ -845,7 +845,7 @@ public class JavaDatasetSuite implements Serializable { { Row row = new GenericRow(new Object[] { null }); - DataFrame df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset<Row> df = context.createDataFrame(Collections.singletonList(row), schema); Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class)); NestedSmallBean nestedSmallBean = new NestedSmallBean(); @@ -862,7 +862,7 @@ public class JavaDatasetSuite implements Serializable { }) }); - DataFrame df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset<Row> df = context.createDataFrame(Collections.singletonList(row), schema); Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class)); ds.collect(); http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java index 9e241f2..0f9e453 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java @@ -42,9 +42,9 @@ public class JavaSaveLoadSuite { String originalDefaultSource; File path; - DataFrame df; + Dataset<Row> df; - private static void checkAnswer(DataFrame actual, List<Row> expected) { + private static void checkAnswer(Dataset<Row> actual, List<Row> expected) { String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); if (errorMessage != null) { Assert.fail(errorMessage); @@ -85,7 +85,7 @@ public class JavaSaveLoadSuite { Map<String, String> options = new HashMap<>(); options.put("path", path.toString()); df.write().mode(SaveMode.ErrorIfExists).format("json").options(options).save(); - DataFrame loadedDF = sqlContext.read().format("json").options(options).load(); + Dataset<Row> loadedDF = sqlContext.read().format("json").options(options).load(); checkAnswer(loadedDF, df.collectAsList()); } @@ -98,7 +98,7 @@ public class JavaSaveLoadSuite { List<StructField> fields = new ArrayList<>(); fields.add(DataTypes.createStructField("b", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields); - DataFrame loadedDF = sqlContext.read().format("json").schema(schema).options(options).load(); + Dataset<Row> loadedDF = sqlContext.read().format("json").schema(schema).options(options).load(); checkAnswer(loadedDF, sqlContext.sql("SELECT b FROM jsonTable").collectAsList()); } http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 26775c3..f4a5107 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -38,23 +38,15 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("analysis error should be eagerly reported") { - // Eager analysis. - withSQLConf(SQLConf.DATAFRAME_EAGER_ANALYSIS.key -> "true") { - intercept[Exception] { testData.select('nonExistentName) } - intercept[Exception] { - testData.groupBy('key).agg(Map("nonExistentName" -> "sum")) - } - intercept[Exception] { - testData.groupBy("nonExistentName").agg(Map("key" -> "sum")) - } - intercept[Exception] { - testData.groupBy($"abcd").agg(Map("key" -> "sum")) - } + intercept[Exception] { testData.select('nonExistentName) } + intercept[Exception] { + testData.groupBy('key).agg(Map("nonExistentName" -> "sum")) } - - // No more eager analysis once the flag is turned off - withSQLConf(SQLConf.DATAFRAME_EAGER_ANALYSIS.key -> "false") { - testData.select('nonExistentName) + intercept[Exception] { + testData.groupBy("nonExistentName").agg(Map("key" -> "sum")) + } + intercept[Exception] { + testData.groupBy($"abcd").agg(Map("key" -> "sum")) } } @@ -72,7 +64,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Row(1, 1) :: Nil) } - test("invalid plan toString, debug mode") { + ignore("invalid plan toString, debug mode") { // Turn on debug mode so we can see invalid query plans. import org.apache.spark.sql.execution.debug._ @@ -941,7 +933,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(e2.getMessage.contains("Inserting into an RDD-based table is not allowed.")) // error case: insert into an OneRowRelation - new DataFrame(sqlContext, OneRowRelation).registerTempTable("one_row") + DataFrame(sqlContext, OneRowRelation).registerTempTable("one_row") val e3 = intercept[AnalysisException] { insertion.write.insertInto("one_row") } http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 3258f37..8477016 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -119,16 +119,16 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { test("typed aggregation: TypedAggregator") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - checkAnswer( - ds.groupBy(_._1).agg(sum(_._2)), + checkDataset( + ds.groupByKey(_._1).agg(sum(_._2)), ("a", 30), ("b", 3), ("c", 1)) } test("typed aggregation: TypedAggregator, expr, expr") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - checkAnswer( - ds.groupBy(_._1).agg( + checkDataset( + ds.groupByKey(_._1).agg( sum(_._2), expr("sum(_2)").as[Long], count("*")), @@ -138,8 +138,8 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { test("typed aggregation: complex case") { val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS() - checkAnswer( - ds.groupBy(_._1).agg( + checkDataset( + ds.groupByKey(_._1).agg( expr("avg(_2)").as[Double], TypedAverage.toColumn), ("a", 2.0, 2.0), ("b", 3.0, 3.0)) @@ -148,8 +148,8 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { test("typed aggregation: complex result type") { val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS() - checkAnswer( - ds.groupBy(_._1).agg( + checkDataset( + ds.groupByKey(_._1).agg( expr("avg(_2)").as[Double], ComplexResultAgg.toColumn), ("a", 2.0, (2L, 4L)), ("b", 3.0, (1L, 3L))) @@ -158,10 +158,10 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { test("typed aggregation: in project list") { val ds = Seq(1, 3, 2, 5).toDS() - checkAnswer( + checkDataset( ds.select(sum((i: Int) => i)), 11) - checkAnswer( + checkDataset( ds.select(sum((i: Int) => i), sum((i: Int) => i * 2)), 11 -> 22) } @@ -169,7 +169,7 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { test("typed aggregation: class input") { val ds = Seq(AggData(1, "one"), AggData(2, "two")).toDS() - checkAnswer( + checkDataset( ds.select(ClassInputAgg.toColumn), 3) } @@ -177,33 +177,33 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { test("typed aggregation: class input with reordering") { val ds = sql("SELECT 'one' AS b, 1 as a").as[AggData] - checkAnswer( + checkDataset( ds.select(ClassInputAgg.toColumn), 1) - checkAnswer( + checkDataset( ds.select(expr("avg(a)").as[Double], ClassInputAgg.toColumn), (1.0, 1)) - checkAnswer( - ds.groupBy(_.b).agg(ClassInputAgg.toColumn), + checkDataset( + ds.groupByKey(_.b).agg(ClassInputAgg.toColumn), ("one", 1)) } test("typed aggregation: complex input") { val ds = Seq(AggData(1, "one"), AggData(2, "two")).toDS() - checkAnswer( + checkDataset( ds.select(ComplexBufferAgg.toColumn), 2 ) - checkAnswer( + checkDataset( ds.select(expr("avg(a)").as[Double], ComplexBufferAgg.toColumn), (1.5, 2)) - checkAnswer( - ds.groupBy(_.b).agg(ComplexBufferAgg.toColumn), + checkDataset( + ds.groupByKey(_.b).agg(ComplexBufferAgg.toColumn), ("one", 1), ("two", 1)) } } http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index 848f1af..2e5179a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -34,7 +34,7 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { // Make sure, the Dataset is indeed cached. assertCached(cached) // Check result. - checkAnswer( + checkDataset( cached, 2, 3, 4) // Drop the cache. @@ -52,7 +52,7 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { assertCached(ds2) val joined = ds1.joinWith(ds2, $"a.value" === $"b.value") - checkAnswer(joined, ("2", 2)) + checkDataset(joined, ("2", 2)) assertCached(joined, 2) ds1.unpersist() @@ -63,11 +63,11 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { test("persist and then groupBy columns asKey, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupBy($"_1").keyAs[String] + val grouped = ds.groupByKey($"_1").keyAs[String] val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } agged.persist() - checkAnswer( + checkDataset( agged.filter(_._1 == "b"), ("b", 3)) assertCached(agged.filter(_._1 == "b")) http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 243d13b..6e9840e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -28,14 +28,14 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { test("toDS") { val data = Seq(1, 2, 3, 4, 5, 6) - checkAnswer( + checkDataset( data.toDS(), data: _*) } test("as case class / collect") { val ds = Seq(1, 2, 3).toDS().as[IntClass] - checkAnswer( + checkDataset( ds, IntClass(1), IntClass(2), IntClass(3)) @@ -44,14 +44,14 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { test("map") { val ds = Seq(1, 2, 3).toDS() - checkAnswer( + checkDataset( ds.map(_ + 1), 2, 3, 4) } test("filter") { val ds = Seq(1, 2, 3, 4).toDS() - checkAnswer( + checkDataset( ds.filter(_ % 2 == 0), 2, 4) } @@ -77,54 +77,54 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { test("groupBy function, keys") { val ds = Seq(1, 2, 3, 4, 5).toDS() - val grouped = ds.groupBy(_ % 2) - checkAnswer( + val grouped = ds.groupByKey(_ % 2) + checkDataset( grouped.keys, 0, 1) } test("groupBy function, map") { val ds = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11).toDS() - val grouped = ds.groupBy(_ % 2) + val grouped = ds.groupByKey(_ % 2) val agged = grouped.mapGroups { case (g, iter) => val name = if (g == 0) "even" else "odd" (name, iter.size) } - checkAnswer( + checkDataset( agged, ("even", 5), ("odd", 6)) } test("groupBy function, flatMap") { val ds = Seq("a", "b", "c", "xyz", "hello").toDS() - val grouped = ds.groupBy(_.length) + val grouped = ds.groupByKey(_.length) val agged = grouped.flatMapGroups { case (g, iter) => Iterator(g.toString, iter.mkString) } - checkAnswer( + checkDataset( agged, "1", "abc", "3", "xyz", "5", "hello") } test("Arrays and Lists") { - checkAnswer(Seq(Seq(1)).toDS(), Seq(1)) - checkAnswer(Seq(Seq(1.toLong)).toDS(), Seq(1.toLong)) - checkAnswer(Seq(Seq(1.toDouble)).toDS(), Seq(1.toDouble)) - checkAnswer(Seq(Seq(1.toFloat)).toDS(), Seq(1.toFloat)) - checkAnswer(Seq(Seq(1.toByte)).toDS(), Seq(1.toByte)) - checkAnswer(Seq(Seq(1.toShort)).toDS(), Seq(1.toShort)) - checkAnswer(Seq(Seq(true)).toDS(), Seq(true)) - checkAnswer(Seq(Seq("test")).toDS(), Seq("test")) - checkAnswer(Seq(Seq(Tuple1(1))).toDS(), Seq(Tuple1(1))) - - checkAnswer(Seq(Array(1)).toDS(), Array(1)) - checkAnswer(Seq(Array(1.toLong)).toDS(), Array(1.toLong)) - checkAnswer(Seq(Array(1.toDouble)).toDS(), Array(1.toDouble)) - checkAnswer(Seq(Array(1.toFloat)).toDS(), Array(1.toFloat)) - checkAnswer(Seq(Array(1.toByte)).toDS(), Array(1.toByte)) - checkAnswer(Seq(Array(1.toShort)).toDS(), Array(1.toShort)) - checkAnswer(Seq(Array(true)).toDS(), Array(true)) - checkAnswer(Seq(Array("test")).toDS(), Array("test")) - checkAnswer(Seq(Array(Tuple1(1))).toDS(), Array(Tuple1(1))) + checkDataset(Seq(Seq(1)).toDS(), Seq(1)) + checkDataset(Seq(Seq(1.toLong)).toDS(), Seq(1.toLong)) + checkDataset(Seq(Seq(1.toDouble)).toDS(), Seq(1.toDouble)) + checkDataset(Seq(Seq(1.toFloat)).toDS(), Seq(1.toFloat)) + checkDataset(Seq(Seq(1.toByte)).toDS(), Seq(1.toByte)) + checkDataset(Seq(Seq(1.toShort)).toDS(), Seq(1.toShort)) + checkDataset(Seq(Seq(true)).toDS(), Seq(true)) + checkDataset(Seq(Seq("test")).toDS(), Seq("test")) + checkDataset(Seq(Seq(Tuple1(1))).toDS(), Seq(Tuple1(1))) + + checkDataset(Seq(Array(1)).toDS(), Array(1)) + checkDataset(Seq(Array(1.toLong)).toDS(), Array(1.toLong)) + checkDataset(Seq(Array(1.toDouble)).toDS(), Array(1.toDouble)) + checkDataset(Seq(Array(1.toFloat)).toDS(), Array(1.toFloat)) + checkDataset(Seq(Array(1.toByte)).toDS(), Array(1.toByte)) + checkDataset(Seq(Array(1.toShort)).toDS(), Array(1.toShort)) + checkDataset(Seq(Array(true)).toDS(), Array(true)) + checkDataset(Seq(Array("test")).toDS(), Array("test")) + checkDataset(Seq(Array(Tuple1(1))).toDS(), Array(Tuple1(1))) } } http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 79e1021..9f32c8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -34,14 +34,14 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("toDS") { val data = Seq(("a", 1), ("b", 2), ("c", 3)) - checkAnswer( + checkDataset( data.toDS(), data: _*) } test("toDS with RDD") { val ds = sparkContext.makeRDD(Seq("a", "b", "c"), 3).toDS() - checkAnswer( + checkDataset( ds.mapPartitions(_ => Iterator(1)), 1, 1, 1) } @@ -71,26 +71,26 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds = data.toDS() assert(ds.repartition(10).rdd.partitions.length == 10) - checkAnswer( + checkDataset( ds.repartition(10), data: _*) assert(ds.coalesce(1).rdd.partitions.length == 1) - checkAnswer( + checkDataset( ds.coalesce(1), data: _*) } test("as tuple") { val data = Seq(("a", 1), ("b", 2)).toDF("a", "b") - checkAnswer( + checkDataset( data.as[(String, Int)], ("a", 1), ("b", 2)) } test("as case class / collect") { val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDF("a", "b").as[ClassData] - checkAnswer( + checkDataset( ds, ClassData("a", 1), ClassData("b", 2), ClassData("c", 3)) assert(ds.collect().head == ClassData("a", 1)) @@ -108,7 +108,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("map") { val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() - checkAnswer( + checkDataset( ds.map(v => (v._1, v._2 + 1)), ("a", 2), ("b", 3), ("c", 4)) } @@ -116,7 +116,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("map with type change with the exact matched number of attributes") { val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() - checkAnswer( + checkDataset( ds.map(identity[(String, Int)]) .as[OtherTuple] .map(identity[OtherTuple]), @@ -126,7 +126,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("map with type change with less attributes") { val ds = Seq(("a", 1, 3), ("b", 2, 4), ("c", 3, 5)).toDS() - checkAnswer( + checkDataset( ds.as[OtherTuple] .map(identity[OtherTuple]), OtherTuple("a", 1), OtherTuple("b", 2), OtherTuple("c", 3)) @@ -137,23 +137,23 @@ class DatasetSuite extends QueryTest with SharedSQLContext { // when we implement better pipelining and local execution mode. val ds: Dataset[(ClassData, Long)] = Seq(ClassData("one", 1), ClassData("two", 2)).toDS() .map(c => ClassData(c.a, c.b + 1)) - .groupBy(p => p).count() + .groupByKey(p => p).count() - checkAnswer( + checkDataset( ds, (ClassData("one", 2), 1L), (ClassData("two", 3), 1L)) } test("select") { val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() - checkAnswer( + checkDataset( ds.select(expr("_2 + 1").as[Int]), 2, 3, 4) } test("select 2") { val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() - checkAnswer( + checkDataset( ds.select( expr("_1").as[String], expr("_2").as[Int]) : Dataset[(String, Int)], @@ -162,7 +162,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("select 2, primitive and tuple") { val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() - checkAnswer( + checkDataset( ds.select( expr("_1").as[String], expr("struct(_2, _2)").as[(Int, Int)]), @@ -171,7 +171,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("select 2, primitive and class") { val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() - checkAnswer( + checkDataset( ds.select( expr("_1").as[String], expr("named_struct('a', _1, 'b', _2)").as[ClassData]), @@ -189,7 +189,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("filter") { val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() - checkAnswer( + checkDataset( ds.filter(_._1 == "b"), ("b", 2)) } @@ -217,7 +217,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds1 = Seq(1, 2, 3).toDS().as("a") val ds2 = Seq(1, 2).toDS().as("b") - checkAnswer( + checkDataset( ds1.joinWith(ds2, $"a.value" === $"b.value", "inner"), (1, 1), (2, 2)) } @@ -230,7 +230,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds2 = Seq(("a", new Integer(1)), ("b", new Integer(2))).toDS() - checkAnswer( + checkDataset( ds1.joinWith(ds2, $"_1" === $"a", "outer"), (ClassNullableData("a", 1), ("a", new Integer(1))), (ClassNullableData("c", 3), (nullString, nullInteger)), @@ -241,7 +241,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds1 = Seq(1, 1, 2).toDS() val ds2 = Seq(("a", 1), ("b", 2)).toDS() - checkAnswer( + checkDataset( ds1.joinWith(ds2, $"value" === $"_2"), (1, ("a", 1)), (1, ("a", 1)), (2, ("b", 2))) } @@ -260,7 +260,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds2 = Seq(("a", 1), ("b", 2)).toDS().as("b") val ds3 = Seq(("a", 1), ("b", 2)).toDS().as("c") - checkAnswer( + checkDataset( ds1.joinWith(ds2, $"a._2" === $"b._2").as("ab").joinWith(ds3, $"ab._1._2" === $"c._2"), ((("a", 1), ("a", 1)), ("a", 1)), ((("b", 2), ("b", 2)), ("b", 2))) @@ -268,48 +268,48 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy function, keys") { val ds = Seq(("a", 1), ("b", 1)).toDS() - val grouped = ds.groupBy(v => (1, v._2)) - checkAnswer( + val grouped = ds.groupByKey(v => (1, v._2)) + checkDataset( grouped.keys, (1, 1)) } test("groupBy function, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupBy(v => (v._1, "word")) + val grouped = ds.groupByKey(v => (v._1, "word")) val agged = grouped.mapGroups { case (g, iter) => (g._1, iter.map(_._2).sum) } - checkAnswer( + checkDataset( agged, ("a", 30), ("b", 3), ("c", 1)) } test("groupBy function, flatMap") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupBy(v => (v._1, "word")) + val grouped = ds.groupByKey(v => (v._1, "word")) val agged = grouped.flatMapGroups { case (g, iter) => Iterator(g._1, iter.map(_._2).sum.toString) } - checkAnswer( + checkDataset( agged, "a", "30", "b", "3", "c", "1") } test("groupBy function, reduce") { val ds = Seq("abc", "xyz", "hello").toDS() - val agged = ds.groupBy(_.length).reduce(_ + _) + val agged = ds.groupByKey(_.length).reduce(_ + _) - checkAnswer( + checkDataset( agged, 3 -> "abcxyz", 5 -> "hello") } test("groupBy single field class, count") { val ds = Seq("abc", "xyz", "hello").toDS() - val count = ds.groupBy(s => Tuple1(s.length)).count() + val count = ds.groupByKey(s => Tuple1(s.length)).count() - checkAnswer( + checkDataset( count, (Tuple1(3), 2L), (Tuple1(5), 1L) ) @@ -317,49 +317,49 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy columns, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupBy($"_1") + val grouped = ds.groupByKey($"_1") val agged = grouped.mapGroups { case (g, iter) => (g.getString(0), iter.map(_._2).sum) } - checkAnswer( + checkDataset( agged, ("a", 30), ("b", 3), ("c", 1)) } test("groupBy columns, count") { val ds = Seq("a" -> 1, "b" -> 1, "a" -> 2).toDS() - val count = ds.groupBy($"_1").count() + val count = ds.groupByKey($"_1").count() - checkAnswer( + checkDataset( count, (Row("a"), 2L), (Row("b"), 1L)) } test("groupBy columns asKey, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupBy($"_1").keyAs[String] + val grouped = ds.groupByKey($"_1").keyAs[String] val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } - checkAnswer( + checkDataset( agged, ("a", 30), ("b", 3), ("c", 1)) } test("groupBy columns asKey tuple, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupBy($"_1", lit(1)).keyAs[(String, Int)] + val grouped = ds.groupByKey($"_1", lit(1)).keyAs[(String, Int)] val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } - checkAnswer( + checkDataset( agged, (("a", 1), 30), (("b", 1), 3), (("c", 1), 1)) } test("groupBy columns asKey class, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupBy($"_1".as("a"), lit(1).as("b")).keyAs[ClassData] + val grouped = ds.groupByKey($"_1".as("a"), lit(1).as("b")).keyAs[ClassData] val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } - checkAnswer( + checkDataset( agged, (ClassData("a", 1), 30), (ClassData("b", 1), 3), (ClassData("c", 1), 1)) } @@ -367,32 +367,32 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("typed aggregation: expr") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - checkAnswer( - ds.groupBy(_._1).agg(sum("_2").as[Long]), + checkDataset( + ds.groupByKey(_._1).agg(sum("_2").as[Long]), ("a", 30L), ("b", 3L), ("c", 1L)) } test("typed aggregation: expr, expr") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - checkAnswer( - ds.groupBy(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long]), + checkDataset( + ds.groupByKey(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long]), ("a", 30L, 32L), ("b", 3L, 5L), ("c", 1L, 2L)) } test("typed aggregation: expr, expr, expr") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - checkAnswer( - ds.groupBy(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long], count("*")), + checkDataset( + ds.groupByKey(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long], count("*")), ("a", 30L, 32L, 2L), ("b", 3L, 5L, 2L), ("c", 1L, 2L, 1L)) } test("typed aggregation: expr, expr, expr, expr") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - checkAnswer( - ds.groupBy(_._1).agg( + checkDataset( + ds.groupByKey(_._1).agg( sum("_2").as[Long], sum($"_2" + 1).as[Long], count("*").as[Long], @@ -403,11 +403,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("cogroup") { val ds1 = Seq(1 -> "a", 3 -> "abc", 5 -> "hello", 3 -> "foo").toDS() val ds2 = Seq(2 -> "q", 3 -> "w", 5 -> "e", 5 -> "r").toDS() - val cogrouped = ds1.groupBy(_._1).cogroup(ds2.groupBy(_._1)) { case (key, data1, data2) => + val cogrouped = ds1.groupByKey(_._1).cogroup(ds2.groupByKey(_._1)) { case (key, data1, data2) => Iterator(key -> (data1.map(_._2).mkString + "#" + data2.map(_._2).mkString)) } - checkAnswer( + checkDataset( cogrouped, 1 -> "a#", 2 -> "#q", 3 -> "abcfoo#w", 5 -> "hello#er") } @@ -415,11 +415,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("cogroup with complex data") { val ds1 = Seq(1 -> ClassData("a", 1), 2 -> ClassData("b", 2)).toDS() val ds2 = Seq(2 -> ClassData("c", 3), 3 -> ClassData("d", 4)).toDS() - val cogrouped = ds1.groupBy(_._1).cogroup(ds2.groupBy(_._1)) { case (key, data1, data2) => + val cogrouped = ds1.groupByKey(_._1).cogroup(ds2.groupByKey(_._1)) { case (key, data1, data2) => Iterator(key -> (data1.map(_._2.a).mkString + data2.map(_._2.a).mkString)) } - checkAnswer( + checkDataset( cogrouped, 1 -> "a", 2 -> "bc", 3 -> "d") } @@ -427,7 +427,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("sample with replacement") { val n = 100 val data = sparkContext.parallelize(1 to n, 2).toDS() - checkAnswer( + checkDataset( data.sample(withReplacement = true, 0.05, seed = 13), 5, 10, 52, 73) } @@ -435,7 +435,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("sample without replacement") { val n = 100 val data = sparkContext.parallelize(1 to n, 2).toDS() - checkAnswer( + checkDataset( data.sample(withReplacement = false, 0.05, seed = 13), 3, 17, 27, 58, 62) } @@ -445,13 +445,13 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds2 = Seq(2, 3).toDS().as("b") val joined = ds1.joinWith(ds2, $"a.value" === $"b.value") - checkAnswer(joined, ("2", 2)) + checkDataset(joined, ("2", 2)) } test("self join") { val ds = Seq("1", "2").toDS().as("a") val joined = ds.joinWith(ds, lit(true)) - checkAnswer(joined, ("1", "1"), ("1", "2"), ("2", "1"), ("2", "2")) + checkDataset(joined, ("1", "1"), ("1", "2"), ("2", "1"), ("2", "2")) } test("toString") { @@ -477,7 +477,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { implicit val kryoEncoder = Encoders.kryo[KryoData] val ds = Seq(KryoData(1), KryoData(2)).toDS() - assert(ds.groupBy(p => p).count().collect().toSet == + assert(ds.groupByKey(p => p).count().collect().toSet == Set((KryoData(1), 1L), (KryoData(2), 1L))) } @@ -496,7 +496,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { implicit val kryoEncoder = Encoders.javaSerialization[JavaData] val ds = Seq(JavaData(1), JavaData(2)).toDS() - assert(ds.groupBy(p => p).count().collect().toSeq == + assert(ds.groupByKey(p => p).count().collect().toSeq == Seq((JavaData(1), 1L), (JavaData(2), 1L))) } @@ -516,7 +516,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds1 = Seq((nullInt, "1"), (new java.lang.Integer(22), "2")).toDS() val ds2 = Seq((nullInt, "1"), (new java.lang.Integer(22), "2")).toDS() - checkAnswer( + checkDataset( ds1.joinWith(ds2, lit(true)), ((nullInt, "1"), (nullInt, "1")), ((new java.lang.Integer(22), "2"), (nullInt, "1")), @@ -550,7 +550,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { sqlContext.createDataFrame(rowRDD, schema).as[NestedStruct] } - checkAnswer( + checkDataset( buildDataset(Row(Row("hello", 1))), NestedStruct(ClassData("hello", 1)) ) @@ -567,11 +567,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("SPARK-12478: top level null field") { val ds0 = Seq(NestedStruct(null)).toDS() - checkAnswer(ds0, NestedStruct(null)) + checkDataset(ds0, NestedStruct(null)) checkAnswer(ds0.toDF(), Row(null)) val ds1 = Seq(DeepNestedStruct(NestedStruct(null))).toDS() - checkAnswer(ds1, DeepNestedStruct(NestedStruct(null))) + checkDataset(ds1, DeepNestedStruct(NestedStruct(null))) checkAnswer(ds1.toDF(), Row(Row(null))) } @@ -579,26 +579,26 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val outer = new OuterClass OuterScopes.addOuterScope(outer) val ds = Seq(outer.InnerClass("1"), outer.InnerClass("2")).toDS() - checkAnswer(ds.map(_.a), "1", "2") + checkDataset(ds.map(_.a), "1", "2") } test("grouping key and grouped value has field with same name") { val ds = Seq(ClassData("a", 1), ClassData("a", 2)).toDS() - val agged = ds.groupBy(d => ClassNullableData(d.a, null)).mapGroups { + val agged = ds.groupByKey(d => ClassNullableData(d.a, null)).mapGroups { case (key, values) => key.a + values.map(_.b).sum } - checkAnswer(agged, "a3") + checkDataset(agged, "a3") } test("cogroup's left and right side has field with same name") { val left = Seq(ClassData("a", 1), ClassData("b", 2)).toDS() val right = Seq(ClassNullableData("a", 3), ClassNullableData("b", 4)).toDS() - val cogrouped = left.groupBy(_.a).cogroup(right.groupBy(_.a)) { + val cogrouped = left.groupByKey(_.a).cogroup(right.groupByKey(_.a)) { case (key, lData, rData) => Iterator(key + lData.map(_.b).sum + rData.map(_.b.toInt).sum) } - checkAnswer(cogrouped, "a13", "b24") + checkDataset(cogrouped, "a13", "b24") } test("give nice error message when the real number of fields doesn't match encoder schema") { @@ -626,13 +626,13 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("SPARK-13440: Resolving option fields") { val df = Seq(1, 2, 3).toDS() val ds = df.as[Option[Int]] - checkAnswer( + checkDataset( ds.filter(_ => true), Some(1), Some(2), Some(3)) } test("SPARK-13540 Dataset of nested class defined in Scala object") { - checkAnswer( + checkDataset( Seq(OuterObject.InnerClass("foo")).toDS(), OuterObject.InnerClass("foo")) } http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index c05aa54..855295d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -72,7 +72,7 @@ abstract class QueryTest extends PlanTest { * for cases where reordering is done on fields. For such tests, user `checkDecoding` instead * which performs a subset of the checks done by this function. */ - protected def checkAnswer[T]( + protected def checkDataset[T]( ds: Dataset[T], expectedAnswer: T*): Unit = { checkAnswer( @@ -123,17 +123,17 @@ abstract class QueryTest extends PlanTest { protected def checkAnswer(df: => DataFrame, expectedAnswer: Seq[Row]): Unit = { val analyzedDF = try df catch { case ae: AnalysisException => - val currentValue = sqlContext.conf.dataFrameEagerAnalysis - sqlContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, false) - val partiallyAnalzyedPlan = df.queryExecution.analyzed - sqlContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, currentValue) - fail( - s""" - |Failed to analyze query: $ae - |$partiallyAnalzyedPlan - | - |${stackTraceToString(ae)} - |""".stripMargin) + if (ae.plan.isDefined) { + fail( + s""" + |Failed to analyze query: $ae + |${ae.plan.get} + | + |${stackTraceToString(ae)} + |""".stripMargin) + } else { + throw ae + } } checkJsonFormat(analyzedDF) http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala index bb51358..493a5a6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala @@ -65,9 +65,9 @@ import org.apache.spark.sql.execution.streaming._ trait StreamTest extends QueryTest with Timeouts { implicit class RichSource(s: Source) { - def toDF(): DataFrame = new DataFrame(sqlContext, StreamingRelation(s)) + def toDF(): DataFrame = DataFrame(sqlContext, StreamingRelation(s)) - def toDS[A: Encoder](): Dataset[A] = new Dataset(sqlContext, StreamingRelation(s)) + def toDS[A: Encoder](): Dataset[A] = Dataset(sqlContext, StreamingRelation(s)) } /** How long to wait for an active stream to catch up when checking a result. */ @@ -168,10 +168,6 @@ trait StreamTest extends QueryTest with Timeouts { } } - /** A helper for running actions on a Streaming Dataset. See `checkAnswer(DataFrame)`. */ - def testStream(stream: Dataset[_])(actions: StreamAction*): Unit = - testStream(stream.toDF())(actions: _*) - /** * Executes the specified actions on the given streaming DataFrame and provides helpful * error messages in the case of failures or incorrect answers. @@ -179,7 +175,8 @@ trait StreamTest extends QueryTest with Timeouts { * Note that if the stream is not explicitly started before an action that requires it to be * running then it will be automatically started before performing any other actions. */ - def testStream(stream: DataFrame)(actions: StreamAction*): Unit = { + def testStream(_stream: Dataset[_])(actions: StreamAction*): Unit = { + val stream = _stream.toDF() var pos = 0 var currentPlan: LogicalPlan = stream.logicalPlan var currentStream: StreamExecution = null http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java index b4bf9ee..63fb4b7 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java @@ -38,9 +38,9 @@ public class JavaDataFrameSuite { private transient JavaSparkContext sc; private transient HiveContext hc; - DataFrame df; + Dataset<Row> df; - private static void checkAnswer(DataFrame actual, List<Row> expected) { + private static void checkAnswer(Dataset<Row> actual, List<Row> expected) { String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); if (errorMessage != null) { Assert.fail(errorMessage); @@ -82,12 +82,12 @@ public class JavaDataFrameSuite { @Test public void testUDAF() { - DataFrame df = hc.range(0, 100).unionAll(hc.range(0, 100)).select(col("id").as("value")); + Dataset<Row> df = hc.range(0, 100).unionAll(hc.range(0, 100)).select(col("id").as("value")); UserDefinedAggregateFunction udaf = new MyDoubleSum(); UserDefinedAggregateFunction registeredUDAF = hc.udf().register("mydoublesum", udaf); // Create Columns for the UDAF. For now, callUDF does not take an argument to specific if // we want to use distinct aggregation. - DataFrame aggregatedDF = + Dataset<Row> aggregatedDF = df.groupBy() .agg( udaf.distinct(col("value")), http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java index 8c4af1b..5a539ea 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java @@ -33,7 +33,7 @@ import org.junit.Test; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.QueryTest$; import org.apache.spark.sql.Row; import org.apache.spark.sql.hive.test.TestHive$; @@ -52,9 +52,9 @@ public class JavaMetastoreDataSourcesSuite { File path; Path hiveManagedPath; FileSystem fs; - DataFrame df; + Dataset<Row> df; - private static void checkAnswer(DataFrame actual, List<Row> expected) { + private static void checkAnswer(Dataset<Row> actual, List<Row> expected) { String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); if (errorMessage != null) { Assert.fail(errorMessage); @@ -111,7 +111,7 @@ public class JavaMetastoreDataSourcesSuite { sqlContext.sql("SELECT * FROM javaSavedTable"), df.collectAsList()); - DataFrame loadedDF = + Dataset<Row> loadedDF = sqlContext.createExternalTable("externalTable", "org.apache.spark.sql.json", options); checkAnswer(loadedDF, df.collectAsList()); @@ -137,7 +137,7 @@ public class JavaMetastoreDataSourcesSuite { List<StructField> fields = new ArrayList<>(); fields.add(DataTypes.createStructField("b", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields); - DataFrame loadedDF = + Dataset<Row> loadedDF = sqlContext.createExternalTable("externalTable", "org.apache.spark.sql.json", schema, options); checkAnswer( http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala index 4adc5c1..a0a0d13 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala @@ -63,7 +63,7 @@ abstract class SQLBuilderTest extends QueryTest with TestHiveSingleton { """.stripMargin) } - checkAnswer(sqlContext.sql(generatedSQL), new DataFrame(sqlContext, plan)) + checkAnswer(sqlContext.sql(generatedSQL), DataFrame(sqlContext, plan)) } protected def checkSQL(df: DataFrame, expectedSQL: String): Unit = { http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 45634a4..d5a4295 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -128,6 +128,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te import testImplicits._ override def beforeAll(): Unit = { + super.beforeAll() val data1 = Seq[(Integer, Integer)]( (1, 10), (null, -60), --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org