This is an automated email from the ASF dual-hosted git repository.

yamamuro pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 8ef4023  [SPARK-34794][SQL] Fix lambda variable name issues in nested 
DataFrame functions
8ef4023 is described below

commit 8ef4023683dee537a40d376d93c329a802a929bd
Author: dsolow <dso...@sayari.com>
AuthorDate: Wed May 5 12:46:13 2021 +0900

    [SPARK-34794][SQL] Fix lambda variable name issues in nested DataFrame 
functions
    
    ### What changes were proposed in this pull request?
    
    To fix lambda variable name issues in nested DataFrame functions, this PR 
modifies code to use a global counter for `LambdaVariables` names created by 
higher order functions.
    
    This is the rework of #31887. Closes #31887.
    
    ### Why are the changes needed?
    
     This moves away from the current hard-coded variable names which break on 
nested function calls. There is currently a bug where nested transforms in 
particular fail (the inner variable shadows the outer variable)
    
    For this query:
    ```
    val df = Seq(
        (Seq(1,2,3), Seq("a", "b", "c"))
    ).toDF("numbers", "letters")
    
    df.select(
        f.flatten(
            f.transform(
                $"numbers",
                (number: Column) => { f.transform(
                    $"letters",
                    (letter: Column) => { f.struct(
                        number.as("number"),
                        letter.as("letter")
                    ) }
                ) }
            )
        ).as("zipped")
    ).show(10, false)
    ```
    This is the current (incorrect) output:
    ```
    +------------------------------------------------------------------------+
    |zipped                                                                  |
    +------------------------------------------------------------------------+
    |[{a, a}, {b, b}, {c, c}, {a, a}, {b, b}, {c, c}, {a, a}, {b, b}, {c, c}]|
    +------------------------------------------------------------------------+
    ```
    And this is the correct output after fix:
    ```
    +------------------------------------------------------------------------+
    |zipped                                                                  |
    +------------------------------------------------------------------------+
    |[{1, a}, {1, b}, {1, c}, {2, a}, {2, b}, {2, c}, {3, a}, {3, b}, {3, c}]|
    +------------------------------------------------------------------------+
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Added the new test in `DataFrameFunctionsSuite`.
    
    Closes #32424 from maropu/pr31887.
    
    Lead-authored-by: dsolow <dso...@sayari.com>
    Co-authored-by: Takeshi Yamamuro <yamam...@apache.org>
    Co-authored-by: dmsolow <dso...@sayarianalytics.com>
    Signed-off-by: Takeshi Yamamuro <yamam...@apache.org>
    (cherry picked from commit f550e03b96638de93381734c4eada2ace02d9a4f)
    Signed-off-by: Takeshi Yamamuro <yamam...@apache.org>
---
 .../expressions/higherOrderFunctions.scala         | 12 ++++++++++-
 .../scala/org/apache/spark/sql/functions.scala     | 12 +++++------
 .../apache/spark/sql/DataFrameFunctionsSuite.scala | 23 ++++++++++++++++++++++
 3 files changed, 40 insertions(+), 7 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
index e5cf8c0..a530ce5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import java.util.Comparator
-import java.util.concurrent.atomic.AtomicReference
+import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}
 
 import scala.collection.mutable
 
@@ -52,6 +52,16 @@ case class UnresolvedNamedLambdaVariable(nameParts: 
Seq[String])
   override def sql: String = name
 }
 
+object UnresolvedNamedLambdaVariable {
+
+  // Counter to ensure lambda variable names are unique
+  private val nextVarNameId = new AtomicInteger(0)
+
+  def freshVarName(name: String): String = {
+    s"${name}_${nextVarNameId.getAndIncrement()}"
+  }
+}
+
 /**
  * A named lambda variable.
  */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index bb77c7e..f6d6200 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -3489,22 +3489,22 @@ object functions {
   }
 
   private def createLambda(f: Column => Column) = {
-    val x = UnresolvedNamedLambdaVariable(Seq("x"))
+    val x = 
UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x")))
     val function = f(Column(x)).expr
     LambdaFunction(function, Seq(x))
   }
 
   private def createLambda(f: (Column, Column) => Column) = {
-    val x = UnresolvedNamedLambdaVariable(Seq("x"))
-    val y = UnresolvedNamedLambdaVariable(Seq("y"))
+    val x = 
UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x")))
+    val y = 
UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("y")))
     val function = f(Column(x), Column(y)).expr
     LambdaFunction(function, Seq(x, y))
   }
 
   private def createLambda(f: (Column, Column, Column) => Column) = {
-    val x = UnresolvedNamedLambdaVariable(Seq("x"))
-    val y = UnresolvedNamedLambdaVariable(Seq("y"))
-    val z = UnresolvedNamedLambdaVariable(Seq("z"))
+    val x = 
UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x")))
+    val y = 
UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("y")))
+    val z = 
UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("z")))
     val function = f(Column(x), Column(y), Column(z)).expr
     LambdaFunction(function, Seq(x, y, z))
   }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index ac98d3f..1a468a6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -3621,6 +3621,29 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSparkSession {
       df.select(map(map_entries($"m"), lit(1))),
       Row(Map(Seq(Row(1, "a")) -> 1)))
   }
+
+  test("SPARK-34794: lambda variable name issues in nested functions") {
+    val df1 = Seq((Seq(1, 2), Seq("a", "b"))).toDF("numbers", "letters")
+
+    checkAnswer(df1.select(flatten(transform($"numbers", (number: Column) =>
+      transform($"letters", (letter: Column) =>
+        struct(number, letter))))),
+      Seq(Row(Seq(Row(1, "a"), Row(1, "b"), Row(2, "a"), Row(2, "b"))))
+    )
+    checkAnswer(df1.select(flatten(transform($"numbers", (number: Column, i: 
Column) =>
+      transform($"letters", (letter: Column, j: Column) =>
+        struct(number + j, concat(letter, i)))))),
+      Seq(Row(Seq(Row(1, "a0"), Row(2, "b0"), Row(2, "a1"), Row(3, "b1"))))
+    )
+
+    val df2 = Seq((Map("a" -> 1, "b" -> 2), Map("a" -> 2, "b" -> 
3))).toDF("m1", "m2")
+
+    checkAnswer(df2.select(map_zip_with($"m1", $"m2", (k1: Column, ov1: 
Column, ov2: Column) =>
+      map_zip_with($"m1", $"m2", (k2: Column, iv1: Column, iv2: Column) =>
+        ov1 + iv1 + ov2 + iv2))),
+      Seq(Row(Map("a" -> Map("a" -> 6, "b" -> 8), "b" -> Map("a" -> 8, "b" -> 
10))))
+    )
+  }
 }
 
 object DataFrameFunctionsSuite {

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to