This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 3e9821e [SPARK-33443][SQL] LEAD/LAG should support [ IGNORE NULLS | RESPECT NULLS ] 3e9821e is described below commit 3e9821edfd636d2bc8be8f9cc5fc87be48bebc79 Author: gengjiaan <gengji...@360.cn> AuthorDate: Thu Dec 24 08:13:48 2020 +0000 [SPARK-33443][SQL] LEAD/LAG should support [ IGNORE NULLS | RESPECT NULLS ] ### What changes were proposed in this pull request? The mainstream database support `[ IGNORE NULLS | RESPECT NULLS ]` for `LEAD`/`LAG`/`NTH_VALUE`/`FIRST_VALUE`/`LAST_VALUE`. But the current implement of `LEAD`/`LAG` don't support this syntax. **Oracle** https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/LEAD.html#GUID-0A0481F1-E98F-4535-A739-FCCA8D1B5B77 **Presto** https://prestodb.io/docs/current/functions/window.html **Redshift** https://docs.aws.amazon.com/redshift/latest/dg/r_WF_LEAD.html **DB2** https://www.ibm.com/support/knowledgecenter/SSGU8G_14.1.0/com.ibm.sqls.doc/ids_sqs_1513.htm **Teradata** https://docs.teradata.com/r/756LNiPSFdY~4JcCCcR5Cw/GjCT6l7trjkIEjt~7Dhx4w **Snowflake** https://docs.snowflake.com/en/sql-reference/functions/lead.html https://docs.snowflake.com/en/sql-reference/functions/lag.html ### Why are the changes needed? Support `[ IGNORE NULLS | RESPECT NULLS ]` for `LEAD`/`LAG` is very useful. ### Does this PR introduce _any_ user-facing change? 'Yes'. ### How was this patch tested? Jenkins test. Closes #30387 from beliefer/SPARK-33443. Lead-authored-by: gengjiaan <gengji...@360.cn> Co-authored-by: beliefer <belie...@163.com> Co-authored-by: Jiaan Geng <belie...@163.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../catalyst/expressions/windowExpressions.scala | 16 ++- .../sql/execution/window/WindowExecBase.scala | 29 +++-- .../sql/execution/window/WindowFunctionFrame.scala | 132 ++++++++++++++++++--- .../scala/org/apache/spark/sql/functions.scala | 40 ++++++- .../spark/sql/DataFrameWindowFunctionsSuite.scala | 55 +++++++++ 5 files changed, 238 insertions(+), 34 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 43ecbd6..b167499 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -387,8 +387,6 @@ abstract class FrameLessOffsetWindowFunction override def nullable: Boolean = default == null || default.nullable || input.nullable - override val ignoreNulls = false - override lazy val frame: WindowFrame = fakeFrame override def checkInputDataTypes(): TypeCheckResult = { @@ -443,9 +441,13 @@ abstract class FrameLessOffsetWindowFunction since = "2.0.0", group = "window_funcs") // scalastyle:on line.size.limit line.contains.tab -case class Lead(input: Expression, offset: Expression, default: Expression) +case class Lead( + input: Expression, offset: Expression, default: Expression, ignoreNulls: Boolean) extends FrameLessOffsetWindowFunction { + def this(input: Expression, offset: Expression, default: Expression) = + this(input, offset, default, false) + def this(input: Expression, offset: Expression) = this(input, offset, Literal(null)) def this(input: Expression) = this(input, Literal(1)) @@ -485,10 +487,14 @@ case class Lead(input: Expression, offset: Expression, default: Expression) since = "2.0.0", group = "window_funcs") // scalastyle:on line.size.limit line.contains.tab -case class Lag(input: Expression, inputOffset: Expression, default: Expression) +case class Lag( + input: Expression, inputOffset: Expression, default: Expression, ignoreNulls: Boolean) extends FrameLessOffsetWindowFunction { - def this(input: Expression, offset: Expression) = this(input, offset, Literal(null)) + def this(input: Expression, inputOffset: Expression, default: Expression) = + this(input, inputOffset, default, false) + + def this(input: Expression, inputOffset: Expression) = this(input, inputOffset, Literal(null)) def this(input: Expression) = this(input, Literal(1)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala index 9832e5c..5d999cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala @@ -119,13 +119,21 @@ trait WindowExecBase extends UnaryExecNode { * [[WindowExpression]]s and factory function for the [[WindowFrameFunction]]. */ protected lazy val windowFrameExpressionFactoryPairs = { - type FrameKey = (String, FrameType, Expression, Expression) + type FrameKey = (String, FrameType, Expression, Expression, Seq[Expression]) type ExpressionBuffer = mutable.Buffer[Expression] val framedFunctions = mutable.Map.empty[FrameKey, (ExpressionBuffer, ExpressionBuffer)] // Add a function and its function to the map for a given frame. def collect(tpe: String, fr: SpecifiedWindowFrame, e: Expression, fn: Expression): Unit = { - val key = (tpe, fr.frameType, fr.lower, fr.upper) + val key = fn match { + // This branch is used for Lead/Lag to support ignoring null. + // All window frames move in rows. If there are multiple Leads or Lags acting on a row + // and operating on different input expressions, they should not be moved uniformly + // by row. Therefore, we put these functions in different window frames. + case f: FrameLessOffsetWindowFunction if f.ignoreNulls => + (tpe, fr.frameType, fr.lower, fr.upper, f.children.map(_.canonicalized)) + case _ => (tpe, fr.frameType, fr.lower, fr.upper, Nil) + } val (es, fns) = framedFunctions.getOrElseUpdate( key, (ArrayBuffer.empty[Expression], ArrayBuffer.empty[Expression])) es += e @@ -183,7 +191,7 @@ trait WindowExecBase extends UnaryExecNode { // Create the factory to produce WindowFunctionFrame. val factory = key match { // Frameless offset Frame - case ("FRAME_LESS_OFFSET", _, IntegerLiteral(offset), _) => + case ("FRAME_LESS_OFFSET", _, IntegerLiteral(offset), _, expr) => target: InternalRow => new FrameLessOffsetWindowFunctionFrame( target, @@ -193,8 +201,9 @@ trait WindowExecBase extends UnaryExecNode { child.output, (expressions, schema) => MutableProjection.create(expressions, schema), - offset) - case ("UNBOUNDED_OFFSET", _, IntegerLiteral(offset), _) => + offset, + expr.nonEmpty) + case ("UNBOUNDED_OFFSET", _, IntegerLiteral(offset), _, _) => target: InternalRow => { new UnboundedOffsetWindowFunctionFrame( target, @@ -206,7 +215,7 @@ trait WindowExecBase extends UnaryExecNode { MutableProjection.create(expressions, schema), offset) } - case ("UNBOUNDED_PRECEDING_OFFSET", _, IntegerLiteral(offset), _) => + case ("UNBOUNDED_PRECEDING_OFFSET", _, IntegerLiteral(offset), _, _) => target: InternalRow => { new UnboundedPrecedingOffsetWindowFunctionFrame( target, @@ -220,13 +229,13 @@ trait WindowExecBase extends UnaryExecNode { } // Entire Partition Frame. - case ("AGGREGATE", _, UnboundedPreceding, UnboundedFollowing) => + case ("AGGREGATE", _, UnboundedPreceding, UnboundedFollowing, _) => target: InternalRow => { new UnboundedWindowFunctionFrame(target, processor) } // Growing Frame. - case ("AGGREGATE", frameType, UnboundedPreceding, upper) => + case ("AGGREGATE", frameType, UnboundedPreceding, upper, _) => target: InternalRow => { new UnboundedPrecedingWindowFunctionFrame( target, @@ -235,7 +244,7 @@ trait WindowExecBase extends UnaryExecNode { } // Shrinking Frame. - case ("AGGREGATE", frameType, lower, UnboundedFollowing) => + case ("AGGREGATE", frameType, lower, UnboundedFollowing, _) => target: InternalRow => { new UnboundedFollowingWindowFunctionFrame( target, @@ -244,7 +253,7 @@ trait WindowExecBase extends UnaryExecNode { } // Moving Frame. - case ("AGGREGATE", frameType, lower, upper) => + case ("AGGREGATE", frameType, lower, upper, _) => target: InternalRow => { new SlidingWindowFunctionFrame( target, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala index 2a4b957..0408deb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala @@ -97,13 +97,15 @@ abstract class OffsetWindowFunctionFrameBase( /** Index of the input row currently used for output. */ protected var inputIndex = 0 + /** Attributes of the input row currently used for output. */ + protected val inputAttrs = inputSchema.map(_.withNullability(true)) + /** * Create the projection used when the offset row exists. * Please note that this project always respect null input values (like PostgreSQL). */ protected val projection = { // Collect the expressions and bind them. - val inputAttrs = inputSchema.map(_.withNullability(true)) val boundExpressions = Seq.fill(ordinal)(NoOp) ++ bindReferences( expressions.toSeq.map(_.input), inputAttrs) @@ -114,7 +116,6 @@ abstract class OffsetWindowFunctionFrameBase( /** Create the projection used when the offset row DOES NOT exists. */ protected val fillDefaultValue = { // Collect the expressions and bind them. - val inputAttrs: AttributeSeq = inputSchema.map(_.withNullability(true)) val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map { e => if (e.default == null || e.default.foldable && e.default.eval() == null) { // The default value is null. @@ -147,31 +148,132 @@ class FrameLessOffsetWindowFunctionFrame( expressions: Array[OffsetWindowFunction], inputSchema: Seq[Attribute], newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection, - offset: Int) + offset: Int, + ignoreNulls: Boolean = false) extends OffsetWindowFunctionFrameBase( target, ordinal, expressions, inputSchema, newMutableProjection, offset) { + /** Holder the UnsafeRow where the input operator by function is not null. */ + private var nextSelectedRow = EmptyRow + + // The number of rows skipped to get the next UnsafeRow where the input operator by function + // is not null. + private var skippedNonNullCount = 0 + + /** Create the projection to determine whether input is null. */ + private val project = UnsafeProjection.create(Seq(IsNull(expressions.head.input)), inputSchema) + + /** Check if the output value of the first index is null. */ + private def nullCheck(row: InternalRow): Boolean = project(row).getBoolean(0) + + /** find the offset row whose input is not null */ + private def findNextRowWithNonNullInput(): Unit = { + while (skippedNonNullCount < offset && inputIndex < input.length) { + val r = WindowFunctionFrame.getNextOrNull(inputIterator) + if (!nullCheck(r)) { + nextSelectedRow = r + skippedNonNullCount += 1 + } + inputIndex += 1 + } + } + override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = { input = rows inputIterator = input.generateIterator() // drain the first few rows if offset is larger than zero inputIndex = 0 - while (inputIndex < offset) { - if (inputIterator.hasNext) inputIterator.next() - inputIndex += 1 + if (ignoreNulls) { + findNextRowWithNonNullInput() + } else { + while (inputIndex < offset) { + if (inputIterator.hasNext) inputIterator.next() + inputIndex += 1 + } + inputIndex = offset } - inputIndex = offset + } + + private val doWrite = if (ignoreNulls && offset > 0) { + // For illustration, here is one example: the input data contains nine rows, + // and the input values of each row are: null, x, null, null, y, null, z, v, null. + // We use lead(input, 2) with IGNORE NULLS and the process is as follows: + // 1. current row -> null, next selected row -> y, output: y; + // 2. current row -> x, next selected row -> z, output: z; + // 3. current row -> null, next selected row -> z, output: z; + // 4. current row -> null, next selected row -> z, output: z; + // 5. current row -> y, next selected row -> v, output: v; + // 6. current row -> null, next selected row -> v, output: v; + // 7. current row -> z, next selected row -> empty, output: null; + // ... next selected row is empty, all following return null. + (current: InternalRow) => + if (nextSelectedRow == EmptyRow) { + // Use default values since the offset row whose input value is not null does not exist. + fillDefaultValue(current) + } else { + if (nullCheck(current)) { + projection(nextSelectedRow) + } else { + skippedNonNullCount -= 1 + findNextRowWithNonNullInput() + if (skippedNonNullCount == offset) { + projection(nextSelectedRow) + } else { + // Use default values since the offset row whose input value is not null does not exist. + fillDefaultValue(current) + nextSelectedRow = EmptyRow + } + } + } + } else if (ignoreNulls && offset < 0) { + // For illustration, here is one example: the input data contains nine rows, + // and the input values of each row are: null, x, null, null, y, null, z, v, null. + // We use lag(input, 1) with IGNORE NULLS and the process is as follows: + // 1. current row -> null, next selected row -> empty, output: null; + // 2. current row -> x, next selected row -> empty, output: null; + // 3. current row -> null, next selected row -> x, output: x; + // 4. current row -> null, next selected row -> x, output: x; + // 5. current row -> y, next selected row -> x, output: x; + // 6. current row -> null, next selected row -> y, output: y; + // 7. current row -> z, next selected row -> y, output: y; + // 8. current row -> v, next selected row -> z, output: z; + // 9. current row -> null, next selected row -> v, output: v; + val absOffset = Math.abs(offset) + (current: InternalRow) => + if (skippedNonNullCount == absOffset) { + nextSelectedRow = EmptyRow + skippedNonNullCount -= 1 + while (nextSelectedRow == EmptyRow && inputIndex < input.length) { + val r = WindowFunctionFrame.getNextOrNull(inputIterator) + if (!nullCheck(r)) { + nextSelectedRow = r + } + inputIndex += 1 + } + } + if (nextSelectedRow == EmptyRow) { + // Use default values since the offset row whose input value is not null does not exist. + fillDefaultValue(current) + } else { + projection(nextSelectedRow) + } + if (!nullCheck(current)) { + skippedNonNullCount += 1 + } + } else { + (current: InternalRow) => + if (inputIndex >= 0 && inputIndex < input.length) { + val r = WindowFunctionFrame.getNextOrNull(inputIterator) + projection(r) + } else { + // Use default values since the offset row does not exist. + fillDefaultValue(current) + } + inputIndex += 1 } override def write(index: Int, current: InternalRow): Unit = { - if (inputIndex >= 0 && inputIndex < input.length) { - val r = WindowFunctionFrame.getNextOrNull(inputIterator) - projection(r) - } else { - // Use default values since the offset row does not exist. - fillDefaultValue(current) - } - inputIndex += 1 + doWrite(current) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 4defcb8..764e088 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -937,8 +937,24 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def lag(e: Column, offset: Int, defaultValue: Any): Column = withExpr { - Lag(e.expr, Literal(offset), Literal(defaultValue)) + def lag(e: Column, offset: Int, defaultValue: Any): Column = { + lag(e, offset, defaultValue, false) + } + + /** + * Window function: returns the value that is `offset` rows before the current row, and + * `defaultValue` if there is less than `offset` rows before the current row. `ignoreNulls` + * determines whether null values of row are included in or eliminated from the calculation. + * For example, an `offset` of one will return the previous row at any given point in the + * window partition. + * + * This is equivalent to the LAG function in SQL. + * + * @group window_funcs + * @since 3.2.0 + */ + def lag(e: Column, offset: Int, defaultValue: Any, ignoreNulls: Boolean): Column = withExpr { + Lag(e.expr, Literal(offset), Literal(defaultValue), ignoreNulls) } /** @@ -989,8 +1005,24 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def lead(e: Column, offset: Int, defaultValue: Any): Column = withExpr { - Lead(e.expr, Literal(offset), Literal(defaultValue)) + def lead(e: Column, offset: Int, defaultValue: Any): Column = { + lead(e, offset, defaultValue, false) + } + + /** + * Window function: returns the value that is `offset` rows after the current row, and + * `defaultValue` if there is less than `offset` rows after the current row. `ignoreNulls` + * determines whether null values of row are included in or eliminated from the calculation. + * The default value of `ignoreNulls` is false. For example, an `offset` of one will return + * the next row at any given point in the window partition. + * + * This is equivalent to the LEAD function in SQL. + * + * @group window_funcs + * @since 3.2.0 + */ + def lead(e: Column, offset: Int, defaultValue: Any, ignoreNulls: Boolean): Column = withExpr { + Lead(e.expr, Literal(offset), Literal(defaultValue), ignoreNulls) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 207b296..3568ad3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -700,6 +700,61 @@ class DataFrameWindowFunctionsSuite extends QueryTest Row("b", 3, null, null, null))) } + test("lead/lag with ignoreNulls") { + val nullStr: String = null + val df = Seq( + ("a", 0, nullStr), + ("a", 1, "x"), + ("b", 2, nullStr), + ("c", 3, nullStr), + ("a", 4, "y"), + ("b", 5, nullStr), + ("a", 6, "z"), + ("a", 7, "v"), + ("a", 8, nullStr)). + toDF("key", "order", "value") + val window = Window.orderBy($"order") + checkAnswer( + df.select( + $"key", + $"order", + $"value", + lead($"value", 1).over(window), + lead($"value", 2).over(window), + lead($"value", 0, null, true).over(window), + lead($"value", 1, null, true).over(window), + lead($"value", 2, null, true).over(window), + lead($"value", 3, null, true).over(window), + lead(concat($"value", $"key"), 1, null, true).over(window), + lag($"value", 1).over(window), + lag($"value", 2).over(window), + lag($"value", 0, null, true).over(window), + lag($"value", 1, null, true).over(window), + lag($"value", 2, null, true).over(window), + lag($"value", 3, null, true).over(window), + lag(concat($"value", $"key"), 1, null, true).over(window)) + .orderBy($"order"), + Seq( + Row("a", 0, null, "x", null, null, "x", "y", "z", "xa", + null, null, null, null, null, null, null), + Row("a", 1, "x", null, null, "x", "y", "z", "v", "ya", + null, null, "x", null, null, null, null), + Row("b", 2, null, null, "y", null, "y", "z", "v", "ya", + "x", null, null, "x", null, null, "xa"), + Row("c", 3, null, "y", null, null, "y", "z", "v", "ya", + null, "x", null, "x", null, null, "xa"), + Row("a", 4, "y", null, "z", "y", "z", "v", null, "za", + null, null, "y", "x", null, null, "xa"), + Row("b", 5, null, "z", "v", null, "z", "v", null, "za", + "y", null, null, "y", "x", null, "ya"), + Row("a", 6, "z", "v", null, "z", "v", null, null, "va", + null, "y", "z", "y", "x", null, "ya"), + Row("a", 7, "v", null, null, "v", null, null, null, null, + "z", null, "v", "z", "y", "x", "za"), + Row("a", 8, null, null, null, null, null, null, null, null, + "v", "z", null, "v", "z", "y", "va"))) + } + test("SPARK-12989 ExtractWindowExpressions treats alias as regular attribute") { val src = Seq((0, 3, 5)).toDF("a", "b", "c") .withColumn("Data", struct("a", "b")) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org