Repository: spark
Updated Branches:
  refs/heads/master 221b418b1 -> 5fd54b994


[SPARK-17758][SQL] Last returns wrong result in case of empty partition

## What changes were proposed in this pull request?
The result of the `Last` function can be wrong when the last partition 
processed is empty. It can return `null` instead of the expected value. For 
example, this can happen when we process partitions in the following order:
```
- Partition 1 [Row1, Row2]
- Partition 2 [Row3]
- Partition 3 []
```
In this case the `Last` function will currently return a null, instead of the 
value of `Row3`.

This PR fixes this by adding a `valueSet` flag to the `Last` function.

## How was this patch tested?
We only used end to end tests for `DeclarativeAggregateFunction`s. I have added 
an evaluator for these functions so we can tests them in catalyst. I have added 
a `LastTestSuite` to test the `Last` aggregate function.

Author: Herman van Hovell <hvanhov...@databricks.com>

Closes #15348 from hvanhovell/SPARK-17758.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5fd54b99
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5fd54b99
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5fd54b99

Branch: refs/heads/master
Commit: 5fd54b994e2078dbf0794932b4e0ffa9a9eda0c3
Parents: 221b418
Author: Herman van Hovell <hvanhov...@databricks.com>
Authored: Wed Oct 5 16:05:30 2016 -0700
Committer: Yin Huai <yh...@databricks.com>
Committed: Wed Oct 5 16:05:30 2016 -0700

----------------------------------------------------------------------
 .../catalyst/expressions/aggregate/Last.scala   |  27 ++---
 .../DeclarativeAggregateEvaluator.scala         |  61 +++++++++++
 .../expressions/aggregate/LastTestSuite.scala   | 109 +++++++++++++++++++
 3 files changed, 184 insertions(+), 13 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5fd54b99/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala
index af88403..8579f72 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala
@@ -55,34 +55,35 @@ case class Last(child: Expression, ignoreNullsExpr: 
Expression) extends Declarat
 
   private lazy val last = AttributeReference("last", child.dataType)()
 
-  override lazy val aggBufferAttributes: Seq[AttributeReference] = last :: Nil
+  private lazy val valueSet = AttributeReference("valueSet", BooleanType)()
+
+  override lazy val aggBufferAttributes: Seq[AttributeReference] = last :: 
valueSet :: Nil
 
   override lazy val initialValues: Seq[Literal] = Seq(
-    /* last = */ Literal.create(null, child.dataType)
+    /* last = */ Literal.create(null, child.dataType),
+    /* valueSet = */ Literal.create(false, BooleanType)
   )
 
   override lazy val updateExpressions: Seq[Expression] = {
     if (ignoreNulls) {
       Seq(
-        /* last = */ If(IsNull(child), last, child)
+        /* last = */ If(IsNull(child), last, child),
+        /* valueSet = */ Or(valueSet, IsNotNull(child))
       )
     } else {
       Seq(
-        /* last = */ child
+        /* last = */ child,
+        /* valueSet = */ Literal.create(true, BooleanType)
       )
     }
   }
 
   override lazy val mergeExpressions: Seq[Expression] = {
-    if (ignoreNulls) {
-      Seq(
-        /* last = */ If(IsNull(last.right), last.left, last.right)
-      )
-    } else {
-      Seq(
-        /* last = */ last.right
-      )
-    }
+    // Prefer the right hand expression if it has been set.
+    Seq(
+      /* last = */ If(valueSet.right, last.right, last.left),
+      /* valueSet = */ Or(valueSet.right, valueSet.left)
+    )
   }
 
   override lazy val evaluateExpression: AttributeReference = last

http://git-wip-us.apache.org/repos/asf/spark/blob/5fd54b99/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DeclarativeAggregateEvaluator.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DeclarativeAggregateEvaluator.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DeclarativeAggregateEvaluator.scala
new file mode 100644
index 0000000..614f24d
--- /dev/null
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DeclarativeAggregateEvaluator.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, JoinedRow}
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection
+
+/**
+ * Evaluator for a [[DeclarativeAggregate]].
+ */
+case class DeclarativeAggregateEvaluator(function: DeclarativeAggregate, 
input: Seq[Attribute]) {
+
+  lazy val initializer = 
GenerateSafeProjection.generate(function.initialValues)
+
+  lazy val updater = GenerateSafeProjection.generate(
+    function.updateExpressions,
+    function.aggBufferAttributes ++ input)
+
+  lazy val merger = GenerateSafeProjection.generate(
+    function.mergeExpressions,
+    function.aggBufferAttributes ++ function.inputAggBufferAttributes)
+
+  lazy val evaluator = GenerateSafeProjection.generate(
+    function.evaluateExpression :: Nil,
+    function.aggBufferAttributes)
+
+  def initialize(): InternalRow = initializer.apply(InternalRow.empty).copy()
+
+  def update(values: InternalRow*): InternalRow = {
+    val joiner = new JoinedRow
+    val buffer = values.foldLeft(initialize()) { (buffer, input) =>
+      updater(joiner(buffer, input))
+    }
+    buffer.copy()
+  }
+
+  def merge(buffers: InternalRow*): InternalRow = {
+    val joiner = new JoinedRow
+    val buffer = buffers.foldLeft(initialize()) { (left, right) =>
+      merger(joiner(left, right))
+    }
+    buffer.copy()
+  }
+
+  def eval(buffer: InternalRow): InternalRow = evaluator(buffer).copy()
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/5fd54b99/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/LastTestSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/LastTestSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/LastTestSuite.scala
new file mode 100644
index 0000000..ba36bc0
--- /dev/null
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/LastTestSuite.scala
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Literal}
+import org.apache.spark.sql.types.IntegerType
+
+class LastTestSuite extends SparkFunSuite {
+  val input = AttributeReference("input", IntegerType, nullable = true)()
+  val evaluator = DeclarativeAggregateEvaluator(Last(input, Literal(false)), 
Seq(input))
+  val evaluatorIgnoreNulls = DeclarativeAggregateEvaluator(Last(input, 
Literal(true)), Seq(input))
+
+  test("empty buffer") {
+    assert(evaluator.initialize() === InternalRow(null, false))
+  }
+
+  test("update") {
+    val result = evaluator.update(
+      InternalRow(1),
+      InternalRow(9),
+      InternalRow(-1))
+    assert(result === InternalRow(-1, true))
+  }
+
+  test("update - ignore nulls") {
+    val result1 = evaluatorIgnoreNulls.update(
+      InternalRow(null),
+      InternalRow(9),
+      InternalRow(null))
+    assert(result1 === InternalRow(9, true))
+
+    val result2 = evaluatorIgnoreNulls.update(
+      InternalRow(null),
+      InternalRow(null))
+    assert(result2 === InternalRow(null, false))
+  }
+
+  test("merge") {
+    // Empty merge
+    val p0 = evaluator.initialize()
+    assert(evaluator.merge(p0) === InternalRow(null, false))
+
+    // Single merge
+    val p1 = evaluator.update(InternalRow(1), InternalRow(-99))
+    assert(evaluator.merge(p1) === p1)
+
+    // Multiple merges.
+    val p2 = evaluator.update(InternalRow(2), InternalRow(10))
+    assert(evaluator.merge(p1, p2) === p2)
+
+    // Empty partitions (p0 is empty)
+    assert(evaluator.merge(p1, p0, p2) === p2)
+    assert(evaluator.merge(p2, p1, p0) === p1)
+  }
+
+  test("merge - ignore nulls") {
+    // Multi merges
+    val p1 = evaluatorIgnoreNulls.update(InternalRow(1), InternalRow(null))
+    val p2 = evaluatorIgnoreNulls.update(InternalRow(null), InternalRow(null))
+    assert(evaluatorIgnoreNulls.merge(p1, p2) === p1)
+  }
+
+  test("eval") {
+    // Null Eval
+    assert(evaluator.eval(InternalRow(null, true)) === InternalRow(null))
+    assert(evaluator.eval(InternalRow(null, false)) === InternalRow(null))
+
+    // Empty Eval
+    val p0 = evaluator.initialize()
+    assert(evaluator.eval(p0) === InternalRow(null))
+
+    // Update - Eval
+    val p1 = evaluator.update(InternalRow(1), InternalRow(-99))
+    assert(evaluator.eval(p1) === InternalRow(-99))
+
+    // Update - Merge - Eval
+    val p2 = evaluator.update(InternalRow(2), InternalRow(10))
+    val m1 = evaluator.merge(p1, p0, p2)
+    assert(evaluator.eval(m1) === InternalRow(10))
+
+    // Update - Merge - Eval (empty partition at the end)
+    val m2 = evaluator.merge(p2, p1, p0)
+    assert(evaluator.eval(m2) === InternalRow(-99))
+  }
+
+  test("eval - ignore nulls") {
+    // Update - Merge - Eval
+    val p1 = evaluatorIgnoreNulls.update(InternalRow(1), InternalRow(null))
+    val p2 = evaluatorIgnoreNulls.update(InternalRow(null), InternalRow(null))
+    val m1 = evaluatorIgnoreNulls.merge(p1, p2)
+    assert(evaluatorIgnoreNulls.eval(m1) === InternalRow(1))
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to