Repository: spark Updated Branches: refs/heads/master d681742b2 -> cd47e2337
[SPARK-15814][SQL] Aggregator can return null result ## What changes were proposed in this pull request? It's similar to the bug fixed in https://github.com/apache/spark/pull/13425, we should consider null object and wrap the `CreateStruct` with `If` to do null check. This PR also improves the test framework to test the objects of `Dataset[T]` directly, instead of calling `toDF` and compare the rows. ## How was this patch tested? new test in `DatasetAggregatorSuite` Author: Wenchen Fan <wenc...@databricks.com> Closes #13553 from cloud-fan/agg-null. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/cd47e233 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/cd47e233 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/cd47e233 Branch: refs/heads/master Commit: cd47e233749f42b016264569a214cbf67f45f436 Parents: d681742 Author: Wenchen Fan <wenc...@databricks.com> Authored: Mon Jun 13 09:58:48 2016 -0700 Committer: Herman van Hovell <hvanhov...@databricks.com> Committed: Mon Jun 13 09:58:48 2016 -0700 ---------------------------------------------------------------------- .../aggregate/TypedAggregateExpression.scala | 7 +- .../spark/sql/DatasetAggregatorSuite.scala | 23 ++++- .../spark/sql/DatasetPrimitiveSuite.scala | 6 +- .../org/apache/spark/sql/DatasetSuite.scala | 38 ++++---- .../scala/org/apache/spark/sql/QueryTest.scala | 95 +++++++++++++------- .../execution/datasources/text/TextSuite.scala | 4 +- .../sql/streaming/FileStreamSinkSuite.scala | 4 +- .../spark/sql/streaming/MemorySinkSuite.scala | 4 +- 8 files changed, 117 insertions(+), 64 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/cd47e233/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index ecb56e2..8bdfa48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -127,7 +127,12 @@ case class TypedAggregateExpression( dataType match { case s: StructType => - ReferenceToExpressions(CreateStruct(outputSerializer), resultObj :: Nil) + val objRef = outputSerializer.head.find(_.isInstanceOf[BoundReference]).get + val struct = If( + IsNull(objRef), + Literal.create(null, dataType), + CreateStruct(outputSerializer)) + ReferenceToExpressions(struct, resultObj :: Nil) case _ => assert(outputSerializer.length == 1) outputSerializer.head transform { http://git-wip-us.apache.org/repos/asf/spark/blob/cd47e233/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index f9b4cd8..f955120 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -115,11 +115,23 @@ object RowAgg extends Aggregator[Row, Int, Int] { override def outputEncoder: Encoder[Int] = Encoders.scalaInt } +object NullResultAgg extends Aggregator[AggData, AggData, AggData] { + override def zero: AggData = AggData(0, "") + override def reduce(b: AggData, a: AggData): AggData = AggData(b.a + a.a, b.b + a.b) + override def finish(reduction: AggData): AggData = { + if (reduction.a % 2 == 0) null else reduction + } + override def merge(b1: AggData, b2: AggData): AggData = AggData(b1.a + b2.a, b1.b + b2.b) + override def bufferEncoder: Encoder[AggData] = Encoders.product[AggData] + override def outputEncoder: Encoder[AggData] = Encoders.product[AggData] +} -class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { +class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { import testImplicits._ + private implicit val ordering = Ordering.by((c: AggData) => c.a -> c.b) + test("typed aggregation: TypedAggregator") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() @@ -204,7 +216,7 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { ds.select(expr("avg(a)").as[Double], ComplexBufferAgg.toColumn), (1.5, 2)) - checkDataset( + checkDatasetUnorderly( ds.groupByKey(_.b).agg(ComplexBufferAgg.toColumn), ("one", 1), ("two", 1)) } @@ -271,4 +283,11 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { "RowAgg(org.apache.spark.sql.Row)") assert(df.groupBy($"j").agg(RowAgg.toColumn as "agg1").columns.last == "agg1") } + + test("SPARK-15814 Aggregator can return null result") { + val ds = Seq(AggData(1, "one"), AggData(2, "two")).toDS() + checkDatasetUnorderly( + ds.groupByKey(_.a).agg(NullResultAgg.toColumn), + 1 -> AggData(1, "one"), 2 -> null) + } } http://git-wip-us.apache.org/repos/asf/spark/blob/cd47e233/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index a634502..6aa3d3f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -82,7 +82,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { test("groupBy function, keys") { val ds = Seq(1, 2, 3, 4, 5).toDS() val grouped = ds.groupByKey(_ % 2) - checkDataset( + checkDatasetUnorderly( grouped.keys, 0, 1) } @@ -95,7 +95,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { (name, iter.size) } - checkDataset( + checkDatasetUnorderly( agged, ("even", 5), ("odd", 6)) } @@ -105,7 +105,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { val grouped = ds.groupByKey(_.length) val agged = grouped.flatMapGroups { case (g, iter) => Iterator(g.toString, iter.mkString) } - checkDataset( + checkDatasetUnorderly( agged, "1", "abc", "3", "xyz", "5", "hello") } http://git-wip-us.apache.org/repos/asf/spark/blob/cd47e233/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 4536a73..96d85f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -32,6 +32,8 @@ import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructT class DatasetSuite extends QueryTest with SharedSQLContext { import testImplicits._ + private implicit val ordering = Ordering.by((c: ClassData) => c.a -> c.b) + test("toDS") { val data = Seq(("a", 1), ("b", 2), ("c", 3)) checkDataset( @@ -95,12 +97,12 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } assert(ds.repartition(10).rdd.partitions.length == 10) - checkDataset( + checkDatasetUnorderly( ds.repartition(10), data: _*) assert(ds.coalesce(1).rdd.partitions.length == 1) - checkDataset( + checkDatasetUnorderly( ds.coalesce(1), data: _*) } @@ -163,7 +165,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { .map(c => ClassData(c.a, c.b + 1)) .groupByKey(p => p).count() - checkDataset( + checkDatasetUnorderly( ds, (ClassData("one", 2), 1L), (ClassData("two", 3), 1L)) } @@ -204,7 +206,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("select 2, primitive and class, fields reordered") { val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() - checkDecoding( + checkDataset( ds.select( expr("_1").as[String], expr("named_struct('b', _2, 'a', _1)").as[ClassData]), @@ -291,7 +293,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy function, keys") { val ds = Seq(("a", 1), ("b", 1)).toDS() val grouped = ds.groupByKey(v => (1, v._2)) - checkDataset( + checkDatasetUnorderly( grouped.keys, (1, 1)) } @@ -301,7 +303,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val grouped = ds.groupByKey(v => (v._1, "word")) val agged = grouped.mapGroups { case (g, iter) => (g._1, iter.map(_._2).sum) } - checkDataset( + checkDatasetUnorderly( agged, ("a", 30), ("b", 3), ("c", 1)) } @@ -313,7 +315,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { Iterator(g._1, iter.map(_._2).sum.toString) } - checkDataset( + checkDatasetUnorderly( agged, "a", "30", "b", "3", "c", "1") } @@ -322,7 +324,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds = Seq("abc", "xyz", "hello").toDS() val agged = ds.groupByKey(_.length).reduceGroups(_ + _) - checkDataset( + checkDatasetUnorderly( agged, 3 -> "abcxyz", 5 -> "hello") } @@ -340,7 +342,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("typed aggregation: expr") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - checkDataset( + checkDatasetUnorderly( ds.groupByKey(_._1).agg(sum("_2").as[Long]), ("a", 30L), ("b", 3L), ("c", 1L)) } @@ -348,7 +350,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("typed aggregation: expr, expr") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - checkDataset( + checkDatasetUnorderly( ds.groupByKey(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long]), ("a", 30L, 32L), ("b", 3L, 5L), ("c", 1L, 2L)) } @@ -356,7 +358,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("typed aggregation: expr, expr, expr") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - checkDataset( + checkDatasetUnorderly( ds.groupByKey(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long], count("*")), ("a", 30L, 32L, 2L), ("b", 3L, 5L, 2L), ("c", 1L, 2L, 1L)) } @@ -364,7 +366,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("typed aggregation: expr, expr, expr, expr") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - checkDataset( + checkDatasetUnorderly( ds.groupByKey(_._1).agg( sum("_2").as[Long], sum($"_2" + 1).as[Long], @@ -380,7 +382,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { Iterator(key -> (data1.map(_._2).mkString + "#" + data2.map(_._2).mkString)) } - checkDataset( + checkDatasetUnorderly( cogrouped, 1 -> "a#", 2 -> "#q", 3 -> "abcfoo#w", 5 -> "hello#er") } @@ -392,7 +394,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { Iterator(key -> (data1.map(_._2.a).mkString + data2.map(_._2.a).mkString)) } - checkDataset( + checkDatasetUnorderly( cogrouped, 1 -> "a", 2 -> "bc", 3 -> "d") } @@ -482,8 +484,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkDataset( ds1.joinWith(ds2, lit(true)), ((nullInt, "1"), (nullInt, "1")), - ((new java.lang.Integer(22), "2"), (nullInt, "1")), ((nullInt, "1"), (new java.lang.Integer(22), "2")), + ((new java.lang.Integer(22), "2"), (nullInt, "1")), ((new java.lang.Integer(22), "2"), (new java.lang.Integer(22), "2"))) } @@ -776,9 +778,9 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds1 = ds.as("d1") val ds2 = ds.as("d2") - checkDataset(ds1.joinWith(ds2, $"d1.value" === $"d2.value"), (2, 2), (3, 3), (4, 4)) - checkDataset(ds1.intersect(ds2), 2, 3, 4) - checkDataset(ds1.except(ds1)) + checkDatasetUnorderly(ds1.joinWith(ds2, $"d1.value" === $"d2.value"), (2, 2), (3, 3), (4, 4)) + checkDatasetUnorderly(ds1.intersect(ds2), 2, 3, 4) + checkDatasetUnorderly(ds1.except(ds1)) } test("SPARK-15441: Dataset outer join") { http://git-wip-us.apache.org/repos/asf/spark/blob/cd47e233/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index acb59d4..742f036 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -68,28 +68,62 @@ abstract class QueryTest extends PlanTest { /** * Evaluates a dataset to make sure that the result of calling collect matches the given * expected answer. - * - Special handling is done based on whether the query plan should be expected to return - * the results in sorted order. - * - This function also checks to make sure that the schema for serializing the expected answer - * matches that produced by the dataset (i.e. does manual construction of object match - * the constructed encoder for cases like joins, etc). Note that this means that it will fail - * for cases where reordering is done on fields. For such tests, user `checkDecoding` instead - * which performs a subset of the checks done by this function. */ protected def checkDataset[T]( - ds: Dataset[T], + ds: => Dataset[T], expectedAnswer: T*): Unit = { - checkAnswer( - ds.toDF(), - spark.createDataset(expectedAnswer)(ds.exprEnc).toDF().collect().toSeq) + val result = getResult(ds) - checkDecoding(ds, expectedAnswer: _*) + if (!compare(result.toSeq, expectedAnswer)) { + fail( + s""" + |Decoded objects do not match expected objects: + |expected: $expectedAnswer + |actual: ${result.toSeq} + |${ds.exprEnc.deserializer.treeString} + """.stripMargin) + } } - protected def checkDecoding[T]( + /** + * Evaluates a dataset to make sure that the result of calling collect matches the given + * expected answer, after sort. + */ + protected def checkDatasetUnorderly[T : Ordering]( ds: => Dataset[T], expectedAnswer: T*): Unit = { - val decoded = try ds.collect().toSet catch { + val result = getResult(ds) + + if (!compare(result.toSeq.sorted, expectedAnswer.sorted)) { + fail( + s""" + |Decoded objects do not match expected objects: + |expected: $expectedAnswer + |actual: ${result.toSeq} + |${ds.exprEnc.deserializer.treeString} + """.stripMargin) + } + } + + private def getResult[T](ds: => Dataset[T]): Array[T] = { + val analyzedDS = try ds catch { + case ae: AnalysisException => + if (ae.plan.isDefined) { + fail( + s""" + |Failed to analyze query: $ae + |${ae.plan.get} + | + |${stackTraceToString(ae)} + """.stripMargin) + } else { + throw ae + } + } + checkJsonFormat(analyzedDS) + assertEmptyMissingInput(analyzedDS) + + try ds.collect() catch { case e: Exception => fail( s""" @@ -99,24 +133,17 @@ abstract class QueryTest extends PlanTest { |${ds.queryExecution} """.stripMargin, e) } + } - // Handle the case where the return type is an array - val isArray = decoded.headOption.map(_.getClass.isArray).getOrElse(false) - def normalEquality = decoded == expectedAnswer.toSet - def expectedAsSeq = expectedAnswer.map(_.asInstanceOf[Array[_]].toSeq).toSet - def decodedAsSeq = decoded.map(_.asInstanceOf[Array[_]].toSeq) - - if (!((isArray && expectedAsSeq == decodedAsSeq) || normalEquality)) { - val expected = expectedAnswer.toSet.toSeq.map((a: Any) => a.toString).sorted - val actual = decoded.toSet.toSeq.map((a: Any) => a.toString).sorted - - val comparison = sideBySide("expected" +: expected, "spark" +: actual).mkString("\n") - fail( - s"""Decoded objects do not match expected objects: - |$comparison - |${ds.exprEnc.deserializer.treeString} - """.stripMargin) - } + private def compare(obj1: Any, obj2: Any): Boolean = (obj1, obj2) match { + case (null, null) => true + case (null, _) => false + case (_, null) => false + case (a: Array[_], b: Array[_]) => + a.length == b.length && a.zip(b).forall { case (l, r) => compare(l, r)} + case (a: Iterable[_], b: Iterable[_]) => + a.size == b.size && a.zip(b).forall { case (l, r) => compare(l, r)} + case (a, b) => a == b } /** @@ -143,7 +170,7 @@ abstract class QueryTest extends PlanTest { checkJsonFormat(analyzedDF) - assertEmptyMissingInput(df) + assertEmptyMissingInput(analyzedDF) QueryTest.checkAnswer(analyzedDF, expectedAnswer) match { case Some(errorMessage) => fail(errorMessage) @@ -201,10 +228,10 @@ abstract class QueryTest extends PlanTest { planWithCaching) } - private def checkJsonFormat(df: DataFrame): Unit = { + private def checkJsonFormat(ds: Dataset[_]): Unit = { // Get the analyzed plan and rewrite the PredicateSubqueries in order to make sure that // RDD and Data resolution does not break. - val logicalPlan = df.queryExecution.analyzed + val logicalPlan = ds.queryExecution.analyzed // bypass some cases that we can't handle currently. logicalPlan.transform { http://git-wip-us.apache.org/repos/asf/spark/blob/cd47e233/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala index 4ed517c..71d3da9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala @@ -132,9 +132,9 @@ class TextSuite extends QueryTest with SharedSQLContext { ds1.write.text(s"$path/part=a") ds1.write.text(s"$path/part=b") - checkDataset( + checkAnswer( spark.read.format("text").load(path).select($"part"), - Row("a"), Row("b")) + Row("a") :: Row("b") :: Nil) } } http://git-wip-us.apache.org/repos/asf/spark/blob/cd47e233/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index 1c73208..bb3063d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -140,7 +140,7 @@ class FileStreamSinkSuite extends StreamTest { } val outputDf = spark.read.parquet(outputDir).as[Int] - checkDataset(outputDf, 1, 2, 3) + checkDatasetUnorderly(outputDf, 1, 2, 3) } finally { if (query != null) { @@ -191,7 +191,7 @@ class FileStreamSinkSuite extends StreamTest { assert(hadoopdFsRelations.head.dataSchema.exists(_.name == "value")) // Verify the data is correctly read - checkDataset( + checkDatasetUnorderly( outputDf.as[(Int, Int)], (1000, 1), (2000, 2), (3000, 3)) http://git-wip-us.apache.org/repos/asf/spark/blob/cd47e233/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala index df76499..9aada0b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala @@ -174,13 +174,13 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { input.addData(1, 2, 3) query.processAllAvailable() - checkDataset( + checkDatasetUnorderly( spark.table("memStream").as[(Int, Long)], (1, 1L), (2, 1L), (3, 1L)) input.addData(4, 5, 6) query.processAllAvailable() - checkDataset( + checkDatasetUnorderly( spark.table("memStream").as[(Int, Long)], (1, 1L), (2, 1L), (3, 1L), (4, 1L), (5, 1L), (6, 1L)) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org