Repository: spark
Updated Branches:
  refs/heads/master f99e34e8e -> 815f3eece


[SPARK-16633][SPARK-16642][SPARK-16721][SQL] Fixes three issues related to lead 
and lag functions

## What changes were proposed in this pull request?
This PR contains three changes.

First, this PR changes the behavior of lead/lag back to Spark 1.6's behavior, 
which is described as below:
1. lead/lag respect null input values, which means that if the offset row 
exists and the input value is null, the result will be null instead of the 
default value.
2. If the offset row does not exist, the default value will be used.
3. OffsetWindowFunction's nullable setting also considers the nullability of 
its input (because of the first change).

Second, this PR fixes the evaluation of lead/lag when the input expression is a 
literal. This fix is a result of the first change. In current master, if a 
literal is used as the input expression of a lead or lag function, the result 
will be this literal even if the offset row does not exist.

Third, this PR makes ResolveWindowFrame not fire if a window function is not 
resolved.

## How was this patch tested?
New tests in SQLWindowFunctionSuite

Author: Yin Huai <yh...@databricks.com>

Closes #14284 from yhuai/lead-lag.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/815f3eec
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/815f3eec
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/815f3eec

Branch: refs/heads/master
Commit: 815f3eece5f095919a329af8cbd762b9ed71c7a8
Parents: f99e34e
Author: Yin Huai <yh...@databricks.com>
Authored: Mon Jul 25 20:58:07 2016 -0700
Committer: Yin Huai <yh...@databricks.com>
Committed: Mon Jul 25 20:58:07 2016 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  |   3 +-
 .../expressions/windowExpressions.scala         |  45 +-
 .../apache/spark/sql/execution/WindowExec.scala |  34 +-
 .../sql/execution/SQLWindowFunctionSuite.scala  | 414 +++++++++++++++++++
 .../hive/execution/SQLWindowFunctionSuite.scala | 370 -----------------
 5 files changed, 467 insertions(+), 399 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/815f3eec/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index d1d2c59..61162cc 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1787,7 +1787,8 @@ class Analyzer(
         s @ WindowSpecDefinition(_, o, UnspecifiedFrame))
           if wf.frame != UnspecifiedFrame =>
           WindowExpression(wf, s.copy(frameSpecification = wf.frame))
-        case we @ WindowExpression(e, s @ WindowSpecDefinition(_, o, 
UnspecifiedFrame)) =>
+        case we @ WindowExpression(e, s @ WindowSpecDefinition(_, o, 
UnspecifiedFrame))
+          if e.resolved =>
           val frame = SpecifiedWindowFrame.defaultWindowFrame(o.nonEmpty, 
acceptWindowFrame = true)
           we.copy(windowSpec = s.copy(frameSpecification = frame))
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/815f3eec/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
index e35192c..6806591 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
@@ -321,8 +321,7 @@ abstract class OffsetWindowFunction
   val input: Expression
 
   /**
-   * Default result value for the function when the input expression returns 
NULL. The default will
-   * evaluated against the current row instead of the offset row.
+   * Default result value for the function when the 'offset'th row does not 
exist.
    */
   val default: Expression
 
@@ -348,7 +347,7 @@ abstract class OffsetWindowFunction
    */
   override def foldable: Boolean = false
 
-  override def nullable: Boolean = default == null || default.nullable
+  override def nullable: Boolean = default == null || default.nullable || 
input.nullable
 
   override lazy val frame = {
     // This will be triggered by the Analyzer.
@@ -373,20 +372,22 @@ abstract class OffsetWindowFunction
 }
 
 /**
- * The Lead function returns the value of 'x' at 'offset' rows after the 
current row in the window.
- * Offsets start at 0, which is the current row. The offset must be constant 
integer value. The
- * default offset is 1. When the value of 'x' is null at the offset, or when 
the offset is larger
- * than the window, the default expression is evaluated.
- *
- * This documentation has been based upon similar documentation for the Hive 
and Presto projects.
+ * The Lead function returns the value of 'x' at the 'offset'th row after the 
current row in
+ * the window. Offsets start at 0, which is the current row. The offset must 
be constant
+ * integer value. The default offset is 1. When the value of 'x' is null at 
the 'offset'th row,
+ * null is returned. If there is no such offset row, the default expression is 
evaluated.
  *
  * @param input expression to evaluate 'offset' rows after the current row.
  * @param offset rows to jump ahead in the partition.
- * @param default to use when the input value is null or when the offset is 
larger than the window.
+ * @param default to use when the offset is larger than the window. The 
default value is null.
  */
 @ExpressionDescription(usage =
-  """_FUNC_(input, offset, default) - LEAD returns the value of 'x' at 
'offset' rows
-     after the current row in the window""")
+  """_FUNC_(input, offset, default) - LEAD returns the value of 'x' at the 
'offset'th row
+     after the current row in the window.
+     The default value of 'offset' is 1 and the default value of 'default' is 
null.
+     If the value of 'x' at the 'offset'th row is null, null is returned.
+     If there is no such offset row (e.g. when the offset is 1, the last row 
of the window
+     does not have any subsequent row), 'default' is returned.""")
 case class Lead(input: Expression, offset: Expression, default: Expression)
     extends OffsetWindowFunction {
 
@@ -400,20 +401,22 @@ case class Lead(input: Expression, offset: Expression, 
default: Expression)
 }
 
 /**
- * The Lag function returns the value of 'x' at 'offset' rows before the 
current row in the window.
- * Offsets start at 0, which is the current row. The offset must be constant 
integer value. The
- * default offset is 1. When the value of 'x' is null at the offset, or when 
the offset is smaller
- * than the window, the default expression is evaluated.
- *
- * This documentation has been based upon similar documentation for the Hive 
and Presto projects.
+ * The Lag function returns the value of 'x' at the 'offset'th row before the 
current row in
+ * the window. Offsets start at 0, which is the current row. The offset must 
be constant
+ * integer value. The default offset is 1. When the value of 'x' is null at 
the 'offset'th row,
+ * null is returned. If there is no such offset row, the default expression is 
evaluated.
  *
  * @param input expression to evaluate 'offset' rows before the current row.
  * @param offset rows to jump back in the partition.
- * @param default to use when the input value is null or when the offset is 
smaller than the window.
+ * @param default to use when the offset row does not exist.
  */
 @ExpressionDescription(usage =
-  """_FUNC_(input, offset, default) - LAG returns the value of 'x' at 'offset' 
rows
-     before the current row in the window""")
+  """_FUNC_(input, offset, default) - LAG returns the value of 'x' at the 
'offset'th row
+     before the current row in the window.
+     The default value of 'offset' is 1 and the default value of 'default' is 
null.
+     If the value of 'x' at the 'offset'th row is null, null is returned.
+     If there is no such offset row (e.g. when the offset is 1, the first row 
of the window
+     does not have any previous row), 'default' is returned.""")
 case class Lag(input: Expression, offset: Expression, default: Expression)
     extends OffsetWindowFunction {
 

http://git-wip-us.apache.org/repos/asf/spark/blob/815f3eec/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala
index 93f007f..7149603 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala
@@ -582,25 +582,43 @@ private[execution] final class OffsetWindowFunctionFrame(
   /** Row used to combine the offset and the current row. */
   private[this] val join = new JoinedRow
 
-  /** Create the projection. */
+  /**
+   * Create the projection used when the offset row exists.
+   * Please note that this project always respect null input values (like 
PostgreSQL).
+   */
   private[this] val projection = {
     // Collect the expressions and bind them.
     val inputAttrs = inputSchema.map(_.withNullability(true))
-    val numInputAttributes = inputAttrs.size
     val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map {
       case e: OffsetWindowFunction =>
         val input = BindReferences.bindReference(e.input, inputAttrs)
+        input
+      case e =>
+        BindReferences.bindReference(e, inputAttrs)
+    }
+
+    // Create the projection.
+    newMutableProjection(boundExpressions, Nil).target(target)
+  }
+
+  /** Create the projection used when the offset row DOES NOT exists. */
+  private[this] val fillDefaultValue = {
+    // Collect the expressions and bind them.
+    val inputAttrs = inputSchema.map(_.withNullability(true))
+    val numInputAttributes = inputAttrs.size
+    val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map {
+      case e: OffsetWindowFunction =>
         if (e.default == null || e.default.foldable && e.default.eval() == 
null) {
-          // Without default value.
-          input
+          // The default value is null.
+          Literal.create(null, e.dataType)
         } else {
-          // With default value.
+          // The default value is an expression.
           val default = BindReferences.bindReference(e.default, 
inputAttrs).transform {
             // Shift the input reference to its default version.
             case BoundReference(o, dataType, nullable) =>
               BoundReference(o + numInputAttributes, dataType, nullable)
           }
-          org.apache.spark.sql.catalyst.expressions.Coalesce(input :: default 
:: Nil)
+          default
         }
       case e =>
         BindReferences.bindReference(e, inputAttrs)
@@ -625,10 +643,12 @@ private[execution] final class OffsetWindowFunctionFrame(
     if (inputIndex >= 0 && inputIndex < input.size) {
       val r = input.next()
       join(r, current)
+      projection(join)
     } else {
       join(emptyRow, current)
+      // Use default values since the offset row does not exist.
+      fillDefaultValue(join)
     }
-    projection(join)
     inputIndex += 1
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/815f3eec/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala
new file mode 100644
index 0000000..d3cfa95
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala
@@ -0,0 +1,414 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
+import org.apache.spark.sql.test.SharedSQLContext
+
+case class WindowData(month: Int, area: String, product: Int)
+
+
+/**
+ * Test suite for SQL window functions.
+ */
+class SQLWindowFunctionSuite extends QueryTest with SharedSQLContext {
+
+  import testImplicits._
+
+  test("window function: udaf with aggregate expression") {
+    val data = Seq(
+      WindowData(1, "a", 5),
+      WindowData(2, "a", 6),
+      WindowData(3, "b", 7),
+      WindowData(4, "b", 8),
+      WindowData(5, "c", 9),
+      WindowData(6, "c", 10)
+    )
+    sparkContext.parallelize(data).toDF().createOrReplaceTempView("windowData")
+
+    checkAnswer(
+      sql(
+        """
+          |select area, sum(product), sum(sum(product)) over (partition by 
area)
+          |from windowData group by month, area
+        """.stripMargin),
+      Seq(
+        ("a", 5, 11),
+        ("a", 6, 11),
+        ("b", 7, 15),
+        ("b", 8, 15),
+        ("c", 9, 19),
+        ("c", 10, 19)
+      ).map(i => Row(i._1, i._2, i._3)))
+
+    checkAnswer(
+      sql(
+        """
+          |select area, sum(product) - 1, sum(sum(product)) over (partition by 
area)
+          |from windowData group by month, area
+        """.stripMargin),
+      Seq(
+        ("a", 4, 11),
+        ("a", 5, 11),
+        ("b", 6, 15),
+        ("b", 7, 15),
+        ("c", 8, 19),
+        ("c", 9, 19)
+      ).map(i => Row(i._1, i._2, i._3)))
+
+    checkAnswer(
+      sql(
+        """
+          |select area, sum(product), sum(product) / sum(sum(product)) over 
(partition by area)
+          |from windowData group by month, area
+        """.stripMargin),
+      Seq(
+        ("a", 5, 5d/11),
+        ("a", 6, 6d/11),
+        ("b", 7, 7d/15),
+        ("b", 8, 8d/15),
+        ("c", 10, 10d/19),
+        ("c", 9, 9d/19)
+      ).map(i => Row(i._1, i._2, i._3)))
+
+    checkAnswer(
+      sql(
+        """
+          |select area, sum(product), sum(product) / sum(sum(product) - 1) 
over (partition by area)
+          |from windowData group by month, area
+        """.stripMargin),
+      Seq(
+        ("a", 5, 5d/9),
+        ("a", 6, 6d/9),
+        ("b", 7, 7d/13),
+        ("b", 8, 8d/13),
+        ("c", 10, 10d/17),
+        ("c", 9, 9d/17)
+      ).map(i => Row(i._1, i._2, i._3)))
+  }
+
+  test("window function: refer column in inner select block") {
+    val data = Seq(
+      WindowData(1, "a", 5),
+      WindowData(2, "a", 6),
+      WindowData(3, "b", 7),
+      WindowData(4, "b", 8),
+      WindowData(5, "c", 9),
+      WindowData(6, "c", 10)
+    )
+    sparkContext.parallelize(data).toDF().createOrReplaceTempView("windowData")
+
+    checkAnswer(
+      sql(
+        """
+          |select area, rank() over (partition by area order by tmp.month) + 
tmp.tmp1 as c1
+          |from (select month, area, product, 1 as tmp1 from windowData) tmp
+        """.stripMargin),
+      Seq(
+        ("a", 2),
+        ("a", 3),
+        ("b", 2),
+        ("b", 3),
+        ("c", 2),
+        ("c", 3)
+      ).map(i => Row(i._1, i._2)))
+  }
+
+  test("window function: partition and order expressions") {
+    val data = Seq(
+      WindowData(1, "a", 5),
+      WindowData(2, "a", 6),
+      WindowData(3, "b", 7),
+      WindowData(4, "b", 8),
+      WindowData(5, "c", 9),
+      WindowData(6, "c", 10)
+    )
+    sparkContext.parallelize(data).toDF().createOrReplaceTempView("windowData")
+
+    checkAnswer(
+      sql(
+        """
+          |select month, area, product, sum(product + 1) over (partition by 1 
order by 2)
+          |from windowData
+        """.stripMargin),
+      Seq(
+        (1, "a", 5, 51),
+        (2, "a", 6, 51),
+        (3, "b", 7, 51),
+        (4, "b", 8, 51),
+        (5, "c", 9, 51),
+        (6, "c", 10, 51)
+      ).map(i => Row(i._1, i._2, i._3, i._4)))
+
+    checkAnswer(
+      sql(
+        """
+          |select month, area, product, sum(product)
+          |over (partition by month % 2 order by 10 - product)
+          |from windowData
+        """.stripMargin),
+      Seq(
+        (1, "a", 5, 21),
+        (2, "a", 6, 24),
+        (3, "b", 7, 16),
+        (4, "b", 8, 18),
+        (5, "c", 9, 9),
+        (6, "c", 10, 10)
+      ).map(i => Row(i._1, i._2, i._3, i._4)))
+  }
+
+  test("window function: distinct should not be silently ignored") {
+    val data = Seq(
+      WindowData(1, "a", 5),
+      WindowData(2, "a", 6),
+      WindowData(3, "b", 7),
+      WindowData(4, "b", 8),
+      WindowData(5, "c", 9),
+      WindowData(6, "c", 10)
+    )
+    sparkContext.parallelize(data).toDF().createOrReplaceTempView("windowData")
+
+    val e = intercept[AnalysisException] {
+      sql(
+        """
+          |select month, area, product, sum(distinct product + 1) over 
(partition by 1 order by 2)
+          |from windowData
+        """.stripMargin)
+    }
+    assert(e.getMessage.contains("Distinct window functions are not 
supported"))
+  }
+
+  test("window function: expressions in arguments of a window functions") {
+    val data = Seq(
+      WindowData(1, "a", 5),
+      WindowData(2, "a", 6),
+      WindowData(3, "b", 7),
+      WindowData(4, "b", 8),
+      WindowData(5, "c", 9),
+      WindowData(6, "c", 10)
+    )
+    sparkContext.parallelize(data).toDF().createOrReplaceTempView("windowData")
+
+    checkAnswer(
+      sql(
+        """
+          |select month, area, month % 2,
+          |lag(product, 1 + 1, product) over (partition by month % 2 order by 
area)
+          |from windowData
+        """.stripMargin),
+      Seq(
+        (1, "a", 1, 5),
+        (2, "a", 0, 6),
+        (3, "b", 1, 7),
+        (4, "b", 0, 8),
+        (5, "c", 1, 5),
+        (6, "c", 0, 6)
+      ).map(i => Row(i._1, i._2, i._3, i._4)))
+  }
+
+
+  test("window function: Sorting columns are not in Project") {
+    val data = Seq(
+      WindowData(1, "d", 10),
+      WindowData(2, "a", 6),
+      WindowData(3, "b", 7),
+      WindowData(4, "b", 8),
+      WindowData(5, "c", 9),
+      WindowData(6, "c", 11)
+    )
+    sparkContext.parallelize(data).toDF().createOrReplaceTempView("windowData")
+
+    checkAnswer(
+      sql("select month, product, sum(product + 1) over() from windowData 
order by area"),
+      Seq(
+        (2, 6, 57),
+        (3, 7, 57),
+        (4, 8, 57),
+        (5, 9, 57),
+        (6, 11, 57),
+        (1, 10, 57)
+      ).map(i => Row(i._1, i._2, i._3)))
+
+    checkAnswer(
+      sql(
+        """
+          |select area, rank() over (partition by area order by tmp.month) + 
tmp.tmp1 as c1
+          |from (select month, area, product as p, 1 as tmp1 from windowData) 
tmp order by p
+        """.stripMargin),
+      Seq(
+        ("a", 2),
+        ("b", 2),
+        ("b", 3),
+        ("c", 2),
+        ("d", 2),
+        ("c", 3)
+      ).map(i => Row(i._1, i._2)))
+
+    checkAnswer(
+      sql(
+        """
+          |select area, rank() over (partition by area order by month) as c1
+          |from windowData group by product, area, month order by product, area
+        """.stripMargin),
+      Seq(
+        ("a", 1),
+        ("b", 1),
+        ("b", 2),
+        ("c", 1),
+        ("d", 1),
+        ("c", 2)
+      ).map(i => Row(i._1, i._2)))
+
+    checkAnswer(
+      sql(
+        """
+          |select area, sum(product) / sum(sum(product)) over (partition by 
area) as c1
+          |from windowData group by area, month order by month, c1
+        """.stripMargin),
+      Seq(
+        ("d", 1.0),
+        ("a", 1.0),
+        ("b", 0.4666666666666667),
+        ("b", 0.5333333333333333),
+        ("c", 0.45),
+        ("c", 0.55)
+      ).map(i => Row(i._1, i._2)))
+  }
+
+  // todo: fix this test case by reimplementing the function 
ResolveAggregateFunctions
+  ignore("window function: Pushing aggregate Expressions in Sort to 
Aggregate") {
+    val data = Seq(
+      WindowData(1, "d", 10),
+      WindowData(2, "a", 6),
+      WindowData(3, "b", 7),
+      WindowData(4, "b", 8),
+      WindowData(5, "c", 9),
+      WindowData(6, "c", 11)
+    )
+    sparkContext.parallelize(data).toDF().createOrReplaceTempView("windowData")
+
+    checkAnswer(
+      sql(
+        """
+          |select area, sum(product) over () as c from windowData
+          |where product > 3 group by area, product
+          |having avg(month) > 0 order by avg(month), product
+        """.stripMargin),
+      Seq(
+        ("a", 51),
+        ("b", 51),
+        ("b", 51),
+        ("c", 51),
+        ("c", 51),
+        ("d", 51)
+      ).map(i => Row(i._1, i._2)))
+  }
+
+  test("window function: multiple window expressions in a single expression") {
+    val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 
2)).toDF("x", "y")
+    nums.createOrReplaceTempView("nums")
+
+    val expected =
+      Row(1, 1, 1, 55, 1, 57) ::
+        Row(0, 2, 3, 55, 2, 60) ::
+        Row(1, 3, 6, 55, 4, 65) ::
+        Row(0, 4, 10, 55, 6, 71) ::
+        Row(1, 5, 15, 55, 9, 79) ::
+        Row(0, 6, 21, 55, 12, 88) ::
+        Row(1, 7, 28, 55, 16, 99) ::
+        Row(0, 8, 36, 55, 20, 111) ::
+        Row(1, 9, 45, 55, 25, 125) ::
+        Row(0, 10, 55, 55, 30, 140) :: Nil
+
+    val actual = sql(
+      """
+        |SELECT
+        |  y,
+        |  x,
+        |  sum(x) OVER w1 AS running_sum,
+        |  sum(x) OVER w2 AS total_sum,
+        |  sum(x) OVER w3 AS running_sum_per_y,
+        |  ((sum(x) OVER w1) + (sum(x) OVER w2) + (sum(x) OVER w3)) as 
combined2
+        |FROM nums
+        |WINDOW w1 AS (ORDER BY x ROWS BETWEEN UnBOUNDED PRECEDiNG AND CuRRENT 
RoW),
+        |       w2 AS (ORDER BY x ROWS BETWEEN UNBOUNDED PRECEDING AND 
UNBOuNDED FoLLOWING),
+        |       w3 AS (PARTITION BY y ORDER BY x ROWS BETWEEN UNBOUNDED 
PRECEDING AND CURRENT ROW)
+      """.stripMargin)
+
+    checkAnswer(actual, expected)
+
+    spark.catalog.dropTempView("nums")
+  }
+
+  test("SPARK-7595: Window will cause resolve failed with self join") {
+    checkAnswer(sql(
+      """
+        |with
+        | v0 as (select 0 as key, 1 as value),
+        | v1 as (select key, count(value) over (partition by key) cnt_val from 
v0),
+        | v2 as (select v1.key, v1_lag.cnt_val from v1, v1 v1_lag where v1.key 
= v1_lag.key)
+        | select key, cnt_val from v2 order by key limit 1
+      """.stripMargin), Row(0, 1))
+  }
+
+  test("SPARK-16633: lead/lag should return the default value if the offset 
row does not exist") {
+    checkAnswer(sql(
+      """
+        |SELECT
+        |  lag(123, 100, 321) OVER (ORDER BY id) as lag,
+        |  lead(123, 100, 321) OVER (ORDER BY id) as lead
+        |FROM (SELECT 1 as id) tmp
+      """.stripMargin),
+      Row(321, 321))
+
+    checkAnswer(sql(
+      """
+        |SELECT
+        |  lag(123, 100, a) OVER (ORDER BY id) as lag,
+        |  lead(123, 100, a) OVER (ORDER BY id) as lead
+        |FROM (SELECT 1 as id, 2 as a) tmp
+      """.stripMargin),
+      Row(2, 2))
+  }
+
+  test("lead/lag should respect null values") {
+    checkAnswer(sql(
+      """
+        |SELECT
+        |  b,
+        |  lag(a, 1, 321) OVER (ORDER BY b) as lag,
+        |  lead(a, 1, 321) OVER (ORDER BY b) as lead
+        |FROM (SELECT cast(null as int) as a, 1 as b
+        |      UNION ALL
+        |      select cast(null as int) as id, 2 as b) tmp
+      """.stripMargin),
+      Row(1, 321, null) :: Row(2, null, 321) :: Nil)
+
+    checkAnswer(sql(
+      """
+        |SELECT
+        |  b,
+        |  lag(a, 1, c) OVER (ORDER BY b) as lag,
+        |  lead(a, 1, c) OVER (ORDER BY b) as lead
+        |FROM (SELECT cast(null as int) as a, 1 as b, 3 as c
+        |      UNION ALL
+        |      select cast(null as int) as id, 2 as b, 4 as c) tmp
+      """.stripMargin),
+      Row(1, 3, null) :: Row(2, null, 4) :: Nil)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/815f3eec/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLWindowFunctionSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLWindowFunctionSuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLWindowFunctionSuite.scala
deleted file mode 100644
index 77e97df..0000000
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLWindowFunctionSuite.scala
+++ /dev/null
@@ -1,370 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.hive.execution
-
-import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
-import org.apache.spark.sql.hive.test.TestHiveSingleton
-import org.apache.spark.sql.test.SQLTestUtils
-
-
-case class WindowData(month: Int, area: String, product: Int)
-
-
-/**
- * Test suite for SQL window functions.
- */
-class SQLWindowFunctionSuite extends QueryTest with SQLTestUtils with 
TestHiveSingleton {
-  import spark.implicits._
-
-  test("window function: udaf with aggregate expression") {
-    val data = Seq(
-      WindowData(1, "a", 5),
-      WindowData(2, "a", 6),
-      WindowData(3, "b", 7),
-      WindowData(4, "b", 8),
-      WindowData(5, "c", 9),
-      WindowData(6, "c", 10)
-    )
-    sparkContext.parallelize(data).toDF().createOrReplaceTempView("windowData")
-
-    checkAnswer(
-      sql(
-        """
-          |select area, sum(product), sum(sum(product)) over (partition by 
area)
-          |from windowData group by month, area
-        """.stripMargin),
-      Seq(
-        ("a", 5, 11),
-        ("a", 6, 11),
-        ("b", 7, 15),
-        ("b", 8, 15),
-        ("c", 9, 19),
-        ("c", 10, 19)
-      ).map(i => Row(i._1, i._2, i._3)))
-
-    checkAnswer(
-      sql(
-        """
-          |select area, sum(product) - 1, sum(sum(product)) over (partition by 
area)
-          |from windowData group by month, area
-        """.stripMargin),
-      Seq(
-        ("a", 4, 11),
-        ("a", 5, 11),
-        ("b", 6, 15),
-        ("b", 7, 15),
-        ("c", 8, 19),
-        ("c", 9, 19)
-      ).map(i => Row(i._1, i._2, i._3)))
-
-    checkAnswer(
-      sql(
-        """
-          |select area, sum(product), sum(product) / sum(sum(product)) over 
(partition by area)
-          |from windowData group by month, area
-        """.stripMargin),
-      Seq(
-        ("a", 5, 5d/11),
-        ("a", 6, 6d/11),
-        ("b", 7, 7d/15),
-        ("b", 8, 8d/15),
-        ("c", 10, 10d/19),
-        ("c", 9, 9d/19)
-      ).map(i => Row(i._1, i._2, i._3)))
-
-    checkAnswer(
-      sql(
-        """
-          |select area, sum(product), sum(product) / sum(sum(product) - 1) 
over (partition by area)
-          |from windowData group by month, area
-        """.stripMargin),
-      Seq(
-        ("a", 5, 5d/9),
-        ("a", 6, 6d/9),
-        ("b", 7, 7d/13),
-        ("b", 8, 8d/13),
-        ("c", 10, 10d/17),
-        ("c", 9, 9d/17)
-      ).map(i => Row(i._1, i._2, i._3)))
-  }
-
-  test("window function: refer column in inner select block") {
-    val data = Seq(
-      WindowData(1, "a", 5),
-      WindowData(2, "a", 6),
-      WindowData(3, "b", 7),
-      WindowData(4, "b", 8),
-      WindowData(5, "c", 9),
-      WindowData(6, "c", 10)
-    )
-    sparkContext.parallelize(data).toDF().createOrReplaceTempView("windowData")
-
-    checkAnswer(
-      sql(
-        """
-          |select area, rank() over (partition by area order by tmp.month) + 
tmp.tmp1 as c1
-          |from (select month, area, product, 1 as tmp1 from windowData) tmp
-        """.stripMargin),
-      Seq(
-        ("a", 2),
-        ("a", 3),
-        ("b", 2),
-        ("b", 3),
-        ("c", 2),
-        ("c", 3)
-      ).map(i => Row(i._1, i._2)))
-  }
-
-  test("window function: partition and order expressions") {
-    val data = Seq(
-      WindowData(1, "a", 5),
-      WindowData(2, "a", 6),
-      WindowData(3, "b", 7),
-      WindowData(4, "b", 8),
-      WindowData(5, "c", 9),
-      WindowData(6, "c", 10)
-    )
-    sparkContext.parallelize(data).toDF().createOrReplaceTempView("windowData")
-
-    checkAnswer(
-      sql(
-        """
-          |select month, area, product, sum(product + 1) over (partition by 1 
order by 2)
-          |from windowData
-        """.stripMargin),
-      Seq(
-        (1, "a", 5, 51),
-        (2, "a", 6, 51),
-        (3, "b", 7, 51),
-        (4, "b", 8, 51),
-        (5, "c", 9, 51),
-        (6, "c", 10, 51)
-      ).map(i => Row(i._1, i._2, i._3, i._4)))
-
-    checkAnswer(
-      sql(
-        """
-          |select month, area, product, sum(product)
-          |over (partition by month % 2 order by 10 - product)
-          |from windowData
-        """.stripMargin),
-      Seq(
-        (1, "a", 5, 21),
-        (2, "a", 6, 24),
-        (3, "b", 7, 16),
-        (4, "b", 8, 18),
-        (5, "c", 9, 9),
-        (6, "c", 10, 10)
-      ).map(i => Row(i._1, i._2, i._3, i._4)))
-  }
-
-  test("window function: distinct should not be silently ignored") {
-    val data = Seq(
-      WindowData(1, "a", 5),
-      WindowData(2, "a", 6),
-      WindowData(3, "b", 7),
-      WindowData(4, "b", 8),
-      WindowData(5, "c", 9),
-      WindowData(6, "c", 10)
-    )
-    sparkContext.parallelize(data).toDF().createOrReplaceTempView("windowData")
-
-    val e = intercept[AnalysisException] {
-      sql(
-        """
-          |select month, area, product, sum(distinct product + 1) over 
(partition by 1 order by 2)
-          |from windowData
-        """.stripMargin)
-    }
-    assert(e.getMessage.contains("Distinct window functions are not 
supported"))
-  }
-
-  test("window function: expressions in arguments of a window functions") {
-    val data = Seq(
-      WindowData(1, "a", 5),
-      WindowData(2, "a", 6),
-      WindowData(3, "b", 7),
-      WindowData(4, "b", 8),
-      WindowData(5, "c", 9),
-      WindowData(6, "c", 10)
-    )
-    sparkContext.parallelize(data).toDF().createOrReplaceTempView("windowData")
-
-    checkAnswer(
-      sql(
-        """
-          |select month, area, month % 2,
-          |lag(product, 1 + 1, product) over (partition by month % 2 order by 
area)
-          |from windowData
-        """.stripMargin),
-      Seq(
-        (1, "a", 1, 5),
-        (2, "a", 0, 6),
-        (3, "b", 1, 7),
-        (4, "b", 0, 8),
-        (5, "c", 1, 5),
-        (6, "c", 0, 6)
-      ).map(i => Row(i._1, i._2, i._3, i._4)))
-  }
-
-
-  test("window function: Sorting columns are not in Project") {
-    val data = Seq(
-      WindowData(1, "d", 10),
-      WindowData(2, "a", 6),
-      WindowData(3, "b", 7),
-      WindowData(4, "b", 8),
-      WindowData(5, "c", 9),
-      WindowData(6, "c", 11)
-    )
-    sparkContext.parallelize(data).toDF().createOrReplaceTempView("windowData")
-
-    checkAnswer(
-      sql("select month, product, sum(product + 1) over() from windowData 
order by area"),
-      Seq(
-        (2, 6, 57),
-        (3, 7, 57),
-        (4, 8, 57),
-        (5, 9, 57),
-        (6, 11, 57),
-        (1, 10, 57)
-      ).map(i => Row(i._1, i._2, i._3)))
-
-    checkAnswer(
-      sql(
-        """
-          |select area, rank() over (partition by area order by tmp.month) + 
tmp.tmp1 as c1
-          |from (select month, area, product as p, 1 as tmp1 from windowData) 
tmp order by p
-        """.stripMargin),
-      Seq(
-        ("a", 2),
-        ("b", 2),
-        ("b", 3),
-        ("c", 2),
-        ("d", 2),
-        ("c", 3)
-      ).map(i => Row(i._1, i._2)))
-
-    checkAnswer(
-      sql(
-        """
-          |select area, rank() over (partition by area order by month) as c1
-          |from windowData group by product, area, month order by product, area
-        """.stripMargin),
-      Seq(
-        ("a", 1),
-        ("b", 1),
-        ("b", 2),
-        ("c", 1),
-        ("d", 1),
-        ("c", 2)
-      ).map(i => Row(i._1, i._2)))
-
-    checkAnswer(
-      sql(
-        """
-          |select area, sum(product) / sum(sum(product)) over (partition by 
area) as c1
-          |from windowData group by area, month order by month, c1
-        """.stripMargin),
-      Seq(
-        ("d", 1.0),
-        ("a", 1.0),
-        ("b", 0.4666666666666667),
-        ("b", 0.5333333333333333),
-        ("c", 0.45),
-        ("c", 0.55)
-      ).map(i => Row(i._1, i._2)))
-  }
-
-  // todo: fix this test case by reimplementing the function 
ResolveAggregateFunctions
-  ignore("window function: Pushing aggregate Expressions in Sort to 
Aggregate") {
-    val data = Seq(
-      WindowData(1, "d", 10),
-      WindowData(2, "a", 6),
-      WindowData(3, "b", 7),
-      WindowData(4, "b", 8),
-      WindowData(5, "c", 9),
-      WindowData(6, "c", 11)
-    )
-    sparkContext.parallelize(data).toDF().createOrReplaceTempView("windowData")
-
-    checkAnswer(
-      sql(
-        """
-          |select area, sum(product) over () as c from windowData
-          |where product > 3 group by area, product
-          |having avg(month) > 0 order by avg(month), product
-        """.stripMargin),
-      Seq(
-        ("a", 51),
-        ("b", 51),
-        ("b", 51),
-        ("c", 51),
-        ("c", 51),
-        ("d", 51)
-      ).map(i => Row(i._1, i._2)))
-  }
-
-  test("window function: multiple window expressions in a single expression") {
-    val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 
2)).toDF("x", "y")
-    nums.createOrReplaceTempView("nums")
-
-    val expected =
-      Row(1, 1, 1, 55, 1, 57) ::
-        Row(0, 2, 3, 55, 2, 60) ::
-        Row(1, 3, 6, 55, 4, 65) ::
-        Row(0, 4, 10, 55, 6, 71) ::
-        Row(1, 5, 15, 55, 9, 79) ::
-        Row(0, 6, 21, 55, 12, 88) ::
-        Row(1, 7, 28, 55, 16, 99) ::
-        Row(0, 8, 36, 55, 20, 111) ::
-        Row(1, 9, 45, 55, 25, 125) ::
-        Row(0, 10, 55, 55, 30, 140) :: Nil
-
-    val actual = sql(
-      """
-        |SELECT
-        |  y,
-        |  x,
-        |  sum(x) OVER w1 AS running_sum,
-        |  sum(x) OVER w2 AS total_sum,
-        |  sum(x) OVER w3 AS running_sum_per_y,
-        |  ((sum(x) OVER w1) + (sum(x) OVER w2) + (sum(x) OVER w3)) as 
combined2
-        |FROM nums
-        |WINDOW w1 AS (ORDER BY x ROWS BETWEEN UnBOUNDED PRECEDiNG AND CuRRENT 
RoW),
-        |       w2 AS (ORDER BY x ROWS BETWEEN UNBOUNDED PRECEDING AND 
UNBOuNDED FoLLOWING),
-        |       w3 AS (PARTITION BY y ORDER BY x ROWS BETWEEN UNBOUNDED 
PRECEDING AND CURRENT ROW)
-      """.stripMargin)
-
-    checkAnswer(actual, expected)
-
-    spark.catalog.dropTempView("nums")
-  }
-
-  test("SPARK-7595: Window will cause resolve failed with self join") {
-    sql("SELECT * FROM src") // Force loading of src table.
-
-    checkAnswer(sql(
-      """
-        |with
-        | v1 as (select key, count(value) over (partition by key) cnt_val from 
src),
-        | v2 as (select v1.key, v1_lag.cnt_val from v1, v1 v1_lag where v1.key 
= v1_lag.key)
-        | select * from v2 order by key limit 1
-      """.stripMargin), Row(0, 3))
-  }
-}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to