This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.3 by this push:
     new a7480e647fe [SPARK-38823][SQL] Make `NewInstance` non-foldable to fix 
aggregation buffer corruption issue
a7480e647fe is described below

commit a7480e647fe1ed930c0cd2ad1679b3685a675d02
Author: Bruce Robbins <bersprock...@gmail.com>
AuthorDate: Fri Apr 15 08:37:43 2022 +0900

    [SPARK-38823][SQL] Make `NewInstance` non-foldable to fix aggregation 
buffer corruption issue
    
    ### What changes were proposed in this pull request?
    
    Make `NewInstance` non-foldable.
    
    ### Why are the changes needed?
    
    When handling Java beans as input, Spark creates `NewInstance` with no 
arguments. On master and 3.3, `NewInstance` with no arguments is considered 
foldable. As a result, the `ConstantFolding` rule converts `NewInstance` into a 
`Literal` holding an instance of the user's specified Java bean. The instance 
becomes a singleton that gets reused for each input record (although its fields 
get updated by `InitializeJavaBean`).
    
    Because the instance gets reused, sometimes multiple buffers in 
`AggregationIterator` are actually referring to the same Java bean instance.
    
    Take, for example, the test I added in this PR, or the `spark-shell` 
example I added to SPARK-38823 as a comment.
    
    The input is:
    ```
        new Item("a", 1),
        new Item("b", 3),
        new Item("c", 2),
        new Item("a", 7)
    ```
    As `ObjectAggregationIterator` reads the input, the buffers get set up as 
follows (note that the first field of Item should be the same as the key):
    ```
    - Read Item("a", 1)
    
    - Buffers are now:
      Key "a" --> Item("a", 1)
    
    - Read Item("b", 3)
    
    - Buffers are now:
      Key "a" -> Item("b", 3)
      Key "b" -> Item("b", 3)
    ```
    
    The buffer for key "a" now contains `Item("b", 3)`. That's because both 
buffers contain a reference to the same Item instance, and that Item instance's 
fields were updated when `Item("b", 3)` was read.
    
    This PR makes `NewInstance` non-foldable, so it will not get optimized 
away, thus ensuring a new instance of the Java bean for each input record.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    New unit test.
    
    Closes #36183 from bersprockets/newinstance_issue.
    
    Authored-by: Bruce Robbins <bersprock...@gmail.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
    (cherry picked from commit cc7cb7a803d5de03c526480c8968bbb2c3e82484)
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../sql/catalyst/expressions/objects/objects.scala |  3 +
 .../catalyst/optimizer/ConstantFoldingSuite.scala  | 15 +---
 .../spark/sql/JavaBeanDeserializationSuite.java    | 93 ++++++++++++++++++++++
 3 files changed, 99 insertions(+), 12 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 2c879beeed6..fe982b23829 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -516,6 +516,9 @@ case class NewInstance(
 
   override def nullable: Boolean = needNullCheck
 
+  // Non-foldable to prevent the optimizer from replacing NewInstance with a 
singleton instance
+  // of the specified class.
+  override def foldable: Boolean = false
   override def children: Seq[Expression] = arguments
 
   final override val nodePatterns: Seq[TreePattern] = Seq(NEW_INSTANCE)
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
index a2ee2a2fb68..b06e001e412 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
@@ -21,11 +21,10 @@ import 
org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, Unresol
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, NewInstance, 
StaticInvoke}
+import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke}
 import org.apache.spark.sql.catalyst.plans.PlanTest
 import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
 import org.apache.spark.sql.catalyst.rules.RuleExecutor
-import org.apache.spark.sql.catalyst.util.GenericArrayData
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.ByteArray
 
@@ -318,14 +317,7 @@ class ConstantFoldingSuite extends PlanTest {
             Literal.create("a", StringType),
             "substring",
             StringType,
-            Seq(Literal(0), Literal(1))).as("c2"),
-          NewInstance(
-            cls = classOf[GenericArrayData],
-            arguments = Literal.fromObject(List(1, 2, 3)) :: Nil,
-            inputTypes = Nil,
-            propagateNull = false,
-            dataType = ArrayType(IntegerType),
-            outerPointer = None).as("c3"))
+            Seq(Literal(0), Literal(1))).as("c2"))
 
     val optimized = Optimize.execute(originalQuery.analyze)
 
@@ -333,8 +325,7 @@ class ConstantFoldingSuite extends PlanTest {
       testRelation
         .select(
           Literal("WWSpark".getBytes()).as("c1"),
-          Literal.create("a", StringType).as("c2"),
-          Literal.create(new GenericArrayData(List(1, 2, 3)), 
ArrayType(IntegerType)).as("c3"))
+          Literal.create("a", StringType).as("c2"))
         .analyze
 
     comparePlans(optimized, correctAnswer)
diff --git 
a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java
 
b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java
index af0a22b0360..06a5c50b30c 100644
--- 
a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java
+++ 
b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java
@@ -26,6 +26,8 @@ import java.util.*;
 
 import org.apache.commons.lang3.builder.ToStringBuilder;
 import org.apache.commons.lang3.builder.ToStringStyle;
+import org.apache.spark.api.java.function.MapFunction;
+import org.apache.spark.api.java.function.ReduceFunction;
 import org.junit.*;
 
 import org.apache.spark.sql.*;
@@ -37,6 +39,7 @@ import org.apache.spark.sql.types.DataTypes;
 import org.apache.spark.sql.types.StructType;
 
 import org.apache.spark.sql.test.TestSparkSession;
+import scala.Tuple2;
 
 public class JavaBeanDeserializationSuite implements Serializable {
 
@@ -562,6 +565,96 @@ public class JavaBeanDeserializationSuite implements 
Serializable {
     }
   }
 
+  @Test
+  public void testSPARK38823NoBeanReuse() {
+    List<Item> items = Arrays.asList(
+            new Item("a", 1),
+            new Item("b", 3),
+            new Item("c", 2),
+            new Item("a", 7));
+
+    Encoder<Item> encoder = Encoders.bean(Item.class);
+
+    Dataset<Item> ds = spark.createDataFrame(items, Item.class)
+            .as(encoder)
+            .coalesce(1);
+
+    MapFunction<Item, String> mf = new MapFunction<Item, String>() {
+      @Override
+      public String call(Item item) throws Exception {
+        return item.getK();
+      }
+    };
+
+    ReduceFunction<Item> rf = new ReduceFunction<Item>() {
+      @Override
+      public Item call(Item item1, Item item2) throws Exception {
+        Assert.assertNotSame(item1, item2);
+        return item1.addValue(item2.getV());
+      }
+    };
+
+    Dataset<Tuple2<String, Item>> finalDs = ds
+            .groupByKey(mf, Encoders.STRING())
+            .reduceGroups(rf);
+
+    List<Tuple2<String, Item>> expectedRecords = Arrays.asList(
+            new Tuple2("a", new Item("a", 8)),
+            new Tuple2("b", new Item("b", 3)),
+            new Tuple2("c", new Item("c", 2)));
+
+    List<Tuple2<String, Item>> result = finalDs.collectAsList();
+
+    Assert.assertEquals(expectedRecords, result);
+  }
+
+  public static class Item implements Serializable {
+    private String k;
+    private int v;
+
+    public String getK() {
+      return k;
+    }
+
+    public int getV() {
+      return v;
+    }
+
+    public void setK(String k) {
+      this.k = k;
+    }
+
+    public void setV(int v) {
+      this.v = v;
+    }
+
+    public Item() { }
+
+    public Item(String k, int v) {
+      this.k = k;
+      this.v = v;
+    }
+
+    public Item addValue(int inc) {
+      return new Item(k, v + inc);
+    }
+
+    public String toString() {
+      return "Item(" + k + "," + v + ")";
+    }
+
+    public boolean equals(Object o) {
+      if (!(o instanceof Item)) {
+        return false;
+      }
+      Item other = (Item) o;
+      if (other.getK().equals(k) && other.getV() == v) {
+        return true;
+      }
+      return false;
+    }
+  }
+
   public static final class LocalDateInstantRecord {
     private String localDateField;
     private String instantField;


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to