This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3e9821e  [SPARK-33443][SQL] LEAD/LAG should support [ IGNORE NULLS | 
RESPECT NULLS ]
3e9821e is described below

commit 3e9821edfd636d2bc8be8f9cc5fc87be48bebc79
Author: gengjiaan <gengji...@360.cn>
AuthorDate: Thu Dec 24 08:13:48 2020 +0000

    [SPARK-33443][SQL] LEAD/LAG should support [ IGNORE NULLS | RESPECT NULLS ]
    
    ### What changes were proposed in this pull request?
    The mainstream database support `[ IGNORE NULLS | RESPECT NULLS ]` for 
`LEAD`/`LAG`/`NTH_VALUE`/`FIRST_VALUE`/`LAST_VALUE`.
    But the current implement of `LEAD`/`LAG` don't support this syntax.
    
    **Oracle**
    
https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/LEAD.html#GUID-0A0481F1-E98F-4535-A739-FCCA8D1B5B77
    
    **Presto**
    https://prestodb.io/docs/current/functions/window.html
    
    **Redshift**
    https://docs.aws.amazon.com/redshift/latest/dg/r_WF_LEAD.html
    
    **DB2**
    
https://www.ibm.com/support/knowledgecenter/SSGU8G_14.1.0/com.ibm.sqls.doc/ids_sqs_1513.htm
    
    **Teradata**
    https://docs.teradata.com/r/756LNiPSFdY~4JcCCcR5Cw/GjCT6l7trjkIEjt~7Dhx4w
    
    **Snowflake**
    https://docs.snowflake.com/en/sql-reference/functions/lead.html
    https://docs.snowflake.com/en/sql-reference/functions/lag.html
    
    ### Why are the changes needed?
    Support `[ IGNORE NULLS | RESPECT NULLS ]` for `LEAD`/`LAG` is very useful.
    
    ### Does this PR introduce _any_ user-facing change?
    'Yes'.
    
    ### How was this patch tested?
    Jenkins test.
    
    Closes #30387 from beliefer/SPARK-33443.
    
    Lead-authored-by: gengjiaan <gengji...@360.cn>
    Co-authored-by: beliefer <belie...@163.com>
    Co-authored-by: Jiaan Geng <belie...@163.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../catalyst/expressions/windowExpressions.scala   |  16 ++-
 .../sql/execution/window/WindowExecBase.scala      |  29 +++--
 .../sql/execution/window/WindowFunctionFrame.scala | 132 ++++++++++++++++++---
 .../scala/org/apache/spark/sql/functions.scala     |  40 ++++++-
 .../spark/sql/DataFrameWindowFunctionsSuite.scala  |  55 +++++++++
 5 files changed, 238 insertions(+), 34 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
index 43ecbd6..b167499 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
@@ -387,8 +387,6 @@ abstract class FrameLessOffsetWindowFunction
 
   override def nullable: Boolean = default == null || default.nullable || 
input.nullable
 
-  override val ignoreNulls = false
-
   override lazy val frame: WindowFrame = fakeFrame
 
   override def checkInputDataTypes(): TypeCheckResult = {
@@ -443,9 +441,13 @@ abstract class FrameLessOffsetWindowFunction
   since = "2.0.0",
   group = "window_funcs")
 // scalastyle:on line.size.limit line.contains.tab
-case class Lead(input: Expression, offset: Expression, default: Expression)
+case class Lead(
+    input: Expression, offset: Expression, default: Expression, ignoreNulls: 
Boolean)
     extends FrameLessOffsetWindowFunction {
 
+  def this(input: Expression, offset: Expression, default: Expression) =
+    this(input, offset, default, false)
+
   def this(input: Expression, offset: Expression) = this(input, offset, 
Literal(null))
 
   def this(input: Expression) = this(input, Literal(1))
@@ -485,10 +487,14 @@ case class Lead(input: Expression, offset: Expression, 
default: Expression)
   since = "2.0.0",
   group = "window_funcs")
 // scalastyle:on line.size.limit line.contains.tab
-case class Lag(input: Expression, inputOffset: Expression, default: Expression)
+case class Lag(
+    input: Expression, inputOffset: Expression, default: Expression, 
ignoreNulls: Boolean)
     extends FrameLessOffsetWindowFunction {
 
-  def this(input: Expression, offset: Expression) = this(input, offset, 
Literal(null))
+  def this(input: Expression, inputOffset: Expression, default: Expression) =
+    this(input, inputOffset, default, false)
+
+  def this(input: Expression, inputOffset: Expression) = this(input, 
inputOffset, Literal(null))
 
   def this(input: Expression) = this(input, Literal(1))
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala
index 9832e5c..5d999cb 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala
@@ -119,13 +119,21 @@ trait WindowExecBase extends UnaryExecNode {
    * [[WindowExpression]]s and factory function for the 
[[WindowFrameFunction]].
    */
   protected lazy val windowFrameExpressionFactoryPairs = {
-    type FrameKey = (String, FrameType, Expression, Expression)
+    type FrameKey = (String, FrameType, Expression, Expression, 
Seq[Expression])
     type ExpressionBuffer = mutable.Buffer[Expression]
     val framedFunctions = mutable.Map.empty[FrameKey, (ExpressionBuffer, 
ExpressionBuffer)]
 
     // Add a function and its function to the map for a given frame.
     def collect(tpe: String, fr: SpecifiedWindowFrame, e: Expression, fn: 
Expression): Unit = {
-      val key = (tpe, fr.frameType, fr.lower, fr.upper)
+      val key = fn match {
+        // This branch is used for Lead/Lag to support ignoring null.
+        // All window frames move in rows. If there are multiple Leads or Lags 
acting on a row
+        // and operating on different input expressions, they should not be 
moved uniformly
+        // by row. Therefore, we put these functions in different window 
frames.
+        case f: FrameLessOffsetWindowFunction if f.ignoreNulls =>
+          (tpe, fr.frameType, fr.lower, fr.upper, 
f.children.map(_.canonicalized))
+        case _ => (tpe, fr.frameType, fr.lower, fr.upper, Nil)
+      }
       val (es, fns) = framedFunctions.getOrElseUpdate(
         key, (ArrayBuffer.empty[Expression], ArrayBuffer.empty[Expression]))
       es += e
@@ -183,7 +191,7 @@ trait WindowExecBase extends UnaryExecNode {
         // Create the factory to produce WindowFunctionFrame.
         val factory = key match {
           // Frameless offset Frame
-          case ("FRAME_LESS_OFFSET", _, IntegerLiteral(offset), _) =>
+          case ("FRAME_LESS_OFFSET", _, IntegerLiteral(offset), _, expr) =>
             target: InternalRow =>
               new FrameLessOffsetWindowFunctionFrame(
                 target,
@@ -193,8 +201,9 @@ trait WindowExecBase extends UnaryExecNode {
                 child.output,
                 (expressions, schema) =>
                   MutableProjection.create(expressions, schema),
-                offset)
-          case ("UNBOUNDED_OFFSET", _, IntegerLiteral(offset), _) =>
+                offset,
+                expr.nonEmpty)
+          case ("UNBOUNDED_OFFSET", _, IntegerLiteral(offset), _, _) =>
             target: InternalRow => {
               new UnboundedOffsetWindowFunctionFrame(
                 target,
@@ -206,7 +215,7 @@ trait WindowExecBase extends UnaryExecNode {
                   MutableProjection.create(expressions, schema),
                 offset)
             }
-          case ("UNBOUNDED_PRECEDING_OFFSET", _, IntegerLiteral(offset), _) =>
+          case ("UNBOUNDED_PRECEDING_OFFSET", _, IntegerLiteral(offset), _, _) 
=>
             target: InternalRow => {
               new UnboundedPrecedingOffsetWindowFunctionFrame(
                 target,
@@ -220,13 +229,13 @@ trait WindowExecBase extends UnaryExecNode {
             }
 
           // Entire Partition Frame.
-          case ("AGGREGATE", _, UnboundedPreceding, UnboundedFollowing) =>
+          case ("AGGREGATE", _, UnboundedPreceding, UnboundedFollowing, _) =>
             target: InternalRow => {
               new UnboundedWindowFunctionFrame(target, processor)
             }
 
           // Growing Frame.
-          case ("AGGREGATE", frameType, UnboundedPreceding, upper) =>
+          case ("AGGREGATE", frameType, UnboundedPreceding, upper, _) =>
             target: InternalRow => {
               new UnboundedPrecedingWindowFunctionFrame(
                 target,
@@ -235,7 +244,7 @@ trait WindowExecBase extends UnaryExecNode {
             }
 
           // Shrinking Frame.
-          case ("AGGREGATE", frameType, lower, UnboundedFollowing) =>
+          case ("AGGREGATE", frameType, lower, UnboundedFollowing, _) =>
             target: InternalRow => {
               new UnboundedFollowingWindowFunctionFrame(
                 target,
@@ -244,7 +253,7 @@ trait WindowExecBase extends UnaryExecNode {
             }
 
           // Moving Frame.
-          case ("AGGREGATE", frameType, lower, upper) =>
+          case ("AGGREGATE", frameType, lower, upper, _) =>
             target: InternalRow => {
               new SlidingWindowFunctionFrame(
                 target,
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
index 2a4b957..0408deb 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
@@ -97,13 +97,15 @@ abstract class OffsetWindowFunctionFrameBase(
   /** Index of the input row currently used for output. */
   protected var inputIndex = 0
 
+  /** Attributes of the input row currently used for output. */
+  protected val inputAttrs = inputSchema.map(_.withNullability(true))
+
   /**
    * Create the projection used when the offset row exists.
    * Please note that this project always respect null input values (like 
PostgreSQL).
    */
   protected val projection = {
     // Collect the expressions and bind them.
-    val inputAttrs = inputSchema.map(_.withNullability(true))
     val boundExpressions = Seq.fill(ordinal)(NoOp) ++ bindReferences(
       expressions.toSeq.map(_.input), inputAttrs)
 
@@ -114,7 +116,6 @@ abstract class OffsetWindowFunctionFrameBase(
   /** Create the projection used when the offset row DOES NOT exists. */
   protected val fillDefaultValue = {
     // Collect the expressions and bind them.
-    val inputAttrs: AttributeSeq = inputSchema.map(_.withNullability(true))
     val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map { 
e =>
       if (e.default == null || e.default.foldable && e.default.eval() == null) 
{
         // The default value is null.
@@ -147,31 +148,132 @@ class FrameLessOffsetWindowFunctionFrame(
     expressions: Array[OffsetWindowFunction],
     inputSchema: Seq[Attribute],
     newMutableProjection: (Seq[Expression], Seq[Attribute]) => 
MutableProjection,
-    offset: Int)
+    offset: Int,
+    ignoreNulls: Boolean = false)
   extends OffsetWindowFunctionFrameBase(
     target, ordinal, expressions, inputSchema, newMutableProjection, offset) {
 
+  /** Holder the UnsafeRow where the input operator by function is not null. */
+  private var nextSelectedRow = EmptyRow
+
+  // The number of rows skipped to get the next UnsafeRow where the input 
operator by function
+  // is not null.
+  private var skippedNonNullCount = 0
+
+  /** Create the projection to determine whether input is null. */
+  private val project = 
UnsafeProjection.create(Seq(IsNull(expressions.head.input)), inputSchema)
+
+  /** Check if the output value of the first index is null. */
+  private def nullCheck(row: InternalRow): Boolean = project(row).getBoolean(0)
+
+  /** find the offset row whose input is not null */
+  private def findNextRowWithNonNullInput(): Unit = {
+    while (skippedNonNullCount < offset && inputIndex < input.length) {
+      val r = WindowFunctionFrame.getNextOrNull(inputIterator)
+      if (!nullCheck(r)) {
+        nextSelectedRow = r
+        skippedNonNullCount += 1
+      }
+      inputIndex += 1
+    }
+  }
+
   override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
     input = rows
     inputIterator = input.generateIterator()
     // drain the first few rows if offset is larger than zero
     inputIndex = 0
-    while (inputIndex < offset) {
-      if (inputIterator.hasNext) inputIterator.next()
-      inputIndex += 1
+    if (ignoreNulls) {
+      findNextRowWithNonNullInput()
+    } else {
+      while (inputIndex < offset) {
+        if (inputIterator.hasNext) inputIterator.next()
+        inputIndex += 1
+      }
+      inputIndex = offset
     }
-    inputIndex = offset
+  }
+
+  private val doWrite = if (ignoreNulls && offset > 0) {
+    // For illustration, here is one example: the input data contains nine 
rows,
+    // and the input values of each row are: null, x, null, null, y, null, z, 
v, null.
+    // We use lead(input, 2) with IGNORE NULLS and the process is as follows:
+    // 1. current row -> null, next selected row -> y, output: y;
+    // 2. current row -> x, next selected row -> z, output: z;
+    // 3. current row -> null, next selected row -> z, output: z;
+    // 4. current row -> null, next selected row -> z, output: z;
+    // 5. current row -> y, next selected row -> v, output: v;
+    // 6. current row -> null, next selected row -> v, output: v;
+    // 7. current row -> z, next selected row -> empty, output: null;
+    // ... next selected row is empty, all following return null.
+    (current: InternalRow) =>
+      if (nextSelectedRow == EmptyRow) {
+        // Use default values since the offset row whose input value is not 
null does not exist.
+        fillDefaultValue(current)
+      } else {
+        if (nullCheck(current)) {
+          projection(nextSelectedRow)
+        } else {
+          skippedNonNullCount -= 1
+          findNextRowWithNonNullInput()
+          if (skippedNonNullCount == offset) {
+            projection(nextSelectedRow)
+          } else {
+            // Use default values since the offset row whose input value is 
not null does not exist.
+            fillDefaultValue(current)
+            nextSelectedRow = EmptyRow
+          }
+        }
+      }
+  } else if (ignoreNulls && offset < 0) {
+    // For illustration, here is one example: the input data contains nine 
rows,
+    // and the input values of each row are: null, x, null, null, y, null, z, 
v, null.
+    // We use lag(input, 1) with IGNORE NULLS and the process is as follows:
+    // 1. current row -> null, next selected row -> empty, output: null;
+    // 2. current row -> x, next selected row -> empty, output: null;
+    // 3. current row -> null, next selected row -> x, output: x;
+    // 4. current row -> null, next selected row -> x, output: x;
+    // 5. current row -> y, next selected row -> x, output: x;
+    // 6. current row -> null, next selected row -> y, output: y;
+    // 7. current row -> z, next selected row -> y, output: y;
+    // 8. current row -> v, next selected row -> z, output: z;
+    // 9. current row -> null, next selected row -> v, output: v;
+    val absOffset = Math.abs(offset)
+    (current: InternalRow) =>
+      if (skippedNonNullCount == absOffset) {
+        nextSelectedRow = EmptyRow
+        skippedNonNullCount -= 1
+        while (nextSelectedRow == EmptyRow && inputIndex < input.length) {
+          val r = WindowFunctionFrame.getNextOrNull(inputIterator)
+          if (!nullCheck(r)) {
+            nextSelectedRow = r
+          }
+          inputIndex += 1
+        }
+      }
+      if (nextSelectedRow == EmptyRow) {
+        // Use default values since the offset row whose input value is not 
null does not exist.
+        fillDefaultValue(current)
+      } else {
+        projection(nextSelectedRow)
+      }
+      if (!nullCheck(current)) {
+        skippedNonNullCount += 1
+      }
+  } else {
+    (current: InternalRow) =>
+      if (inputIndex >= 0 && inputIndex < input.length) {
+        val r = WindowFunctionFrame.getNextOrNull(inputIterator)
+        projection(r)
+      } else {
+        // Use default values since the offset row does not exist.
+        fillDefaultValue(current)
+      }
+      inputIndex += 1
   }
 
   override def write(index: Int, current: InternalRow): Unit = {
-    if (inputIndex >= 0 && inputIndex < input.length) {
-      val r = WindowFunctionFrame.getNextOrNull(inputIterator)
-      projection(r)
-    } else {
-      // Use default values since the offset row does not exist.
-      fillDefaultValue(current)
-    }
-    inputIndex += 1
+    doWrite(current)
   }
 }
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 4defcb8..764e088 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -937,8 +937,24 @@ object functions {
    * @group window_funcs
    * @since 1.4.0
    */
-  def lag(e: Column, offset: Int, defaultValue: Any): Column = withExpr {
-    Lag(e.expr, Literal(offset), Literal(defaultValue))
+  def lag(e: Column, offset: Int, defaultValue: Any): Column = {
+    lag(e, offset, defaultValue, false)
+  }
+
+  /**
+   * Window function: returns the value that is `offset` rows before the 
current row, and
+   * `defaultValue` if there is less than `offset` rows before the current 
row. `ignoreNulls`
+   * determines whether null values of row are included in or eliminated from 
the calculation.
+   * For example, an `offset` of one will return the previous row at any given 
point in the
+   * window partition.
+   *
+   * This is equivalent to the LAG function in SQL.
+   *
+   * @group window_funcs
+   * @since 3.2.0
+   */
+  def lag(e: Column, offset: Int, defaultValue: Any, ignoreNulls: Boolean): 
Column = withExpr {
+    Lag(e.expr, Literal(offset), Literal(defaultValue), ignoreNulls)
   }
 
   /**
@@ -989,8 +1005,24 @@ object functions {
    * @group window_funcs
    * @since 1.4.0
    */
-  def lead(e: Column, offset: Int, defaultValue: Any): Column = withExpr {
-    Lead(e.expr, Literal(offset), Literal(defaultValue))
+  def lead(e: Column, offset: Int, defaultValue: Any): Column = {
+    lead(e, offset, defaultValue, false)
+  }
+
+  /**
+   * Window function: returns the value that is `offset` rows after the 
current row, and
+   * `defaultValue` if there is less than `offset` rows after the current row. 
`ignoreNulls`
+   * determines whether null values of row are included in or eliminated from 
the calculation.
+   * The default value of `ignoreNulls` is false. For example, an `offset` of 
one will return
+   * the next row at any given point in the window partition.
+   *
+   * This is equivalent to the LEAD function in SQL.
+   *
+   * @group window_funcs
+   * @since 3.2.0
+   */
+  def lead(e: Column, offset: Int, defaultValue: Any, ignoreNulls: Boolean): 
Column = withExpr {
+    Lead(e.expr, Literal(offset), Literal(defaultValue), ignoreNulls)
   }
 
   /**
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
index 207b296..3568ad3 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
@@ -700,6 +700,61 @@ class DataFrameWindowFunctionsSuite extends QueryTest
         Row("b", 3, null, null, null)))
   }
 
+  test("lead/lag with ignoreNulls") {
+    val nullStr: String = null
+    val df = Seq(
+      ("a", 0, nullStr),
+      ("a", 1, "x"),
+      ("b", 2, nullStr),
+      ("c", 3, nullStr),
+      ("a", 4, "y"),
+      ("b", 5, nullStr),
+      ("a", 6, "z"),
+      ("a", 7, "v"),
+      ("a", 8, nullStr)).
+      toDF("key", "order", "value")
+    val window = Window.orderBy($"order")
+    checkAnswer(
+      df.select(
+        $"key",
+        $"order",
+        $"value",
+        lead($"value", 1).over(window),
+        lead($"value", 2).over(window),
+        lead($"value", 0, null, true).over(window),
+        lead($"value", 1, null, true).over(window),
+        lead($"value", 2, null, true).over(window),
+        lead($"value", 3, null, true).over(window),
+        lead(concat($"value", $"key"), 1, null, true).over(window),
+        lag($"value", 1).over(window),
+        lag($"value", 2).over(window),
+        lag($"value", 0, null, true).over(window),
+        lag($"value", 1, null, true).over(window),
+        lag($"value", 2, null, true).over(window),
+        lag($"value", 3, null, true).over(window),
+        lag(concat($"value", $"key"), 1, null, true).over(window))
+        .orderBy($"order"),
+      Seq(
+        Row("a", 0, null, "x", null, null, "x", "y", "z", "xa",
+          null, null, null, null, null, null, null),
+        Row("a", 1, "x", null, null, "x", "y", "z", "v", "ya",
+          null, null, "x", null, null, null, null),
+        Row("b", 2, null, null, "y", null, "y", "z", "v", "ya",
+          "x", null, null, "x", null, null, "xa"),
+        Row("c", 3, null, "y", null, null, "y", "z", "v", "ya",
+          null, "x", null, "x", null, null, "xa"),
+        Row("a", 4, "y", null, "z", "y", "z", "v", null, "za",
+          null, null, "y", "x", null, null, "xa"),
+        Row("b", 5, null, "z", "v", null, "z", "v", null, "za",
+          "y", null, null, "y", "x", null, "ya"),
+        Row("a", 6, "z", "v", null, "z", "v", null, null, "va",
+          null, "y", "z", "y", "x", null, "ya"),
+        Row("a", 7, "v", null, null, "v", null, null, null, null,
+          "z", null, "v", "z", "y", "x", "za"),
+        Row("a", 8, null, null, null, null, null, null, null, null,
+          "v", "z", null, "v", "z", "y", "va")))
+  }
+
   test("SPARK-12989 ExtractWindowExpressions treats alias as regular 
attribute") {
     val src = Seq((0, 3, 5)).toDF("a", "b", "c")
       .withColumn("Data", struct("a", "b"))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to