This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch branch-3.3 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.3 by this push: new a7480e647fe [SPARK-38823][SQL] Make `NewInstance` non-foldable to fix aggregation buffer corruption issue a7480e647fe is described below commit a7480e647fe1ed930c0cd2ad1679b3685a675d02 Author: Bruce Robbins <bersprock...@gmail.com> AuthorDate: Fri Apr 15 08:37:43 2022 +0900 [SPARK-38823][SQL] Make `NewInstance` non-foldable to fix aggregation buffer corruption issue ### What changes were proposed in this pull request? Make `NewInstance` non-foldable. ### Why are the changes needed? When handling Java beans as input, Spark creates `NewInstance` with no arguments. On master and 3.3, `NewInstance` with no arguments is considered foldable. As a result, the `ConstantFolding` rule converts `NewInstance` into a `Literal` holding an instance of the user's specified Java bean. The instance becomes a singleton that gets reused for each input record (although its fields get updated by `InitializeJavaBean`). Because the instance gets reused, sometimes multiple buffers in `AggregationIterator` are actually referring to the same Java bean instance. Take, for example, the test I added in this PR, or the `spark-shell` example I added to SPARK-38823 as a comment. The input is: ``` new Item("a", 1), new Item("b", 3), new Item("c", 2), new Item("a", 7) ``` As `ObjectAggregationIterator` reads the input, the buffers get set up as follows (note that the first field of Item should be the same as the key): ``` - Read Item("a", 1) - Buffers are now: Key "a" --> Item("a", 1) - Read Item("b", 3) - Buffers are now: Key "a" -> Item("b", 3) Key "b" -> Item("b", 3) ``` The buffer for key "a" now contains `Item("b", 3)`. That's because both buffers contain a reference to the same Item instance, and that Item instance's fields were updated when `Item("b", 3)` was read. This PR makes `NewInstance` non-foldable, so it will not get optimized away, thus ensuring a new instance of the Java bean for each input record. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New unit test. Closes #36183 from bersprockets/newinstance_issue. Authored-by: Bruce Robbins <bersprock...@gmail.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> (cherry picked from commit cc7cb7a803d5de03c526480c8968bbb2c3e82484) Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../sql/catalyst/expressions/objects/objects.scala | 3 + .../catalyst/optimizer/ConstantFoldingSuite.scala | 15 +--- .../spark/sql/JavaBeanDeserializationSuite.java | 93 ++++++++++++++++++++++ 3 files changed, 99 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 2c879beeed6..fe982b23829 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -516,6 +516,9 @@ case class NewInstance( override def nullable: Boolean = needNullCheck + // Non-foldable to prevent the optimizer from replacing NewInstance with a singleton instance + // of the specified class. + override def foldable: Boolean = false override def children: Seq[Expression] = arguments final override val nodePatterns: Seq[TreePattern] = Seq(NEW_INSTANCE) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index a2ee2a2fb68..b06e001e412 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -21,11 +21,10 @@ import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, Unresol import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, NewInstance, StaticInvoke} +import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.ByteArray @@ -318,14 +317,7 @@ class ConstantFoldingSuite extends PlanTest { Literal.create("a", StringType), "substring", StringType, - Seq(Literal(0), Literal(1))).as("c2"), - NewInstance( - cls = classOf[GenericArrayData], - arguments = Literal.fromObject(List(1, 2, 3)) :: Nil, - inputTypes = Nil, - propagateNull = false, - dataType = ArrayType(IntegerType), - outerPointer = None).as("c3")) + Seq(Literal(0), Literal(1))).as("c2")) val optimized = Optimize.execute(originalQuery.analyze) @@ -333,8 +325,7 @@ class ConstantFoldingSuite extends PlanTest { testRelation .select( Literal("WWSpark".getBytes()).as("c1"), - Literal.create("a", StringType).as("c2"), - Literal.create(new GenericArrayData(List(1, 2, 3)), ArrayType(IntegerType)).as("c3")) + Literal.create("a", StringType).as("c2")) .analyze comparePlans(optimized, correctAnswer) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java index af0a22b0360..06a5c50b30c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java @@ -26,6 +26,8 @@ import java.util.*; import org.apache.commons.lang3.builder.ToStringBuilder; import org.apache.commons.lang3.builder.ToStringStyle; +import org.apache.spark.api.java.function.MapFunction; +import org.apache.spark.api.java.function.ReduceFunction; import org.junit.*; import org.apache.spark.sql.*; @@ -37,6 +39,7 @@ import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.test.TestSparkSession; +import scala.Tuple2; public class JavaBeanDeserializationSuite implements Serializable { @@ -562,6 +565,96 @@ public class JavaBeanDeserializationSuite implements Serializable { } } + @Test + public void testSPARK38823NoBeanReuse() { + List<Item> items = Arrays.asList( + new Item("a", 1), + new Item("b", 3), + new Item("c", 2), + new Item("a", 7)); + + Encoder<Item> encoder = Encoders.bean(Item.class); + + Dataset<Item> ds = spark.createDataFrame(items, Item.class) + .as(encoder) + .coalesce(1); + + MapFunction<Item, String> mf = new MapFunction<Item, String>() { + @Override + public String call(Item item) throws Exception { + return item.getK(); + } + }; + + ReduceFunction<Item> rf = new ReduceFunction<Item>() { + @Override + public Item call(Item item1, Item item2) throws Exception { + Assert.assertNotSame(item1, item2); + return item1.addValue(item2.getV()); + } + }; + + Dataset<Tuple2<String, Item>> finalDs = ds + .groupByKey(mf, Encoders.STRING()) + .reduceGroups(rf); + + List<Tuple2<String, Item>> expectedRecords = Arrays.asList( + new Tuple2("a", new Item("a", 8)), + new Tuple2("b", new Item("b", 3)), + new Tuple2("c", new Item("c", 2))); + + List<Tuple2<String, Item>> result = finalDs.collectAsList(); + + Assert.assertEquals(expectedRecords, result); + } + + public static class Item implements Serializable { + private String k; + private int v; + + public String getK() { + return k; + } + + public int getV() { + return v; + } + + public void setK(String k) { + this.k = k; + } + + public void setV(int v) { + this.v = v; + } + + public Item() { } + + public Item(String k, int v) { + this.k = k; + this.v = v; + } + + public Item addValue(int inc) { + return new Item(k, v + inc); + } + + public String toString() { + return "Item(" + k + "," + v + ")"; + } + + public boolean equals(Object o) { + if (!(o instanceof Item)) { + return false; + } + Item other = (Item) o; + if (other.getK().equals(k) && other.getV() == v) { + return true; + } + return false; + } + } + public static final class LocalDateInstantRecord { private String localDateField; private String instantField; --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org