This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch branch-3.3 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.3 by this push: new d39e7bac73d [SPARK-41395][SQL] `InterpretedMutableProjection` should use `setDecimal` to set null values for decimals in an unsafe row d39e7bac73d is described below commit d39e7bac73d6b4f12492cef5e8a31e406b2a5d3a Author: Bruce Robbins <bersprock...@gmail.com> AuthorDate: Fri Dec 9 21:44:45 2022 +0900 [SPARK-41395][SQL] `InterpretedMutableProjection` should use `setDecimal` to set null values for decimals in an unsafe row Change `InterpretedMutableProjection` to use `setDecimal` rather than `setNullAt` to set null values for decimals in unsafe rows. The following returns the wrong answer: ``` set spark.sql.codegen.wholeStage=false; set spark.sql.codegen.factoryMode=NO_CODEGEN; select max(col1), max(col2) from values (cast(null as decimal(27,2)), cast(null as decimal(27,2))), (cast(77.77 as decimal(27,2)), cast(245.00 as decimal(27,2))) as data(col1, col2); +---------+---------+ |max(col1)|max(col2)| +---------+---------+ |null |239.88 | +---------+---------+ ``` This is because `InterpretedMutableProjection` inappropriately uses `InternalRow#setNullAt` on unsafe rows to set null for decimal types with precision > `Decimal.MAX_LONG_DIGITS`. When `setNullAt` is used, the pointer to the decimal's storage area in the variable length region gets zeroed out. Later, when `InterpretedMutableProjection` calls `setDecimal` on that field, `UnsafeRow#setDecimal` picks up the zero pointer and stores decimal data on top of the null-tracking bit set. Later updates to the null-tracking bit set (e.g., calls to `setNotNullAt`) further corrupt the decimal data (turning 245.00 into 239.88, for example). The stomping of the null-tracking bi [...] This bug can manifest for end-users after codegen fallback (say, if an expression's generated code fails to compile). [Codegen for mutable projection](https://github.com/apache/spark/blob/89b2ee27d258dec8fe265fa862846e800a374d8e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala#L1729) uses `mutableRow.setDecimal` for null decimal values regardless of precision or the type for `mutableRow`, so this PR does the same. No. New unit tests. Closes #38923 from bersprockets/unsafe_decimal_issue. Authored-by: Bruce Robbins <bersprock...@gmail.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> (cherry picked from commit fec210b36be22f187b51b67970960692f75ac31f) Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../expressions/InterpretedMutableProjection.scala | 3 +- .../expressions/MutableProjectionSuite.scala | 62 ++++++++++++++++++++++ 2 files changed, 64 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala index 91c9457af7d..4e129e96d1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.DecimalType /** @@ -72,7 +73,7 @@ class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mutable private[this] val fieldWriters: Array[Any => Unit] = validExprs.map { case (e, i) => val writer = InternalRow.getWriter(i, e.dataType) - if (!e.nullable) { + if (!e.nullable || e.dataType.isInstanceOf[DecimalType]) { (v: Any) => writer(mutableRow, v) } else { (v: Any) => { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala index 0f01bfbb894..e3f11283816 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala @@ -65,6 +65,68 @@ class MutableProjectionSuite extends SparkFunSuite with ExpressionEvalHelper { assert(SafeProjection.create(fixedLengthTypes)(projUnsafeRow) === inputRow) } + def testRows( + bufferSchema: StructType, + buffer: InternalRow, + scalaRows: Seq[Seq[Any]]): Unit = { + val bufferTypes = bufferSchema.map(_.dataType).toArray + val proj = createMutableProjection(bufferTypes) + + scalaRows.foreach { scalaRow => + val inputRow = InternalRow.fromSeq(scalaRow.zip(bufferTypes).map { + case (v, dataType) => CatalystTypeConverters.createToCatalystConverter(dataType)(v) + }) + val projRow = proj.target(buffer)(inputRow) + assert(SafeProjection.create(bufferTypes)(projRow) === inputRow) + } + } + + testBothCodegenAndInterpreted("SPARK-41395: unsafe buffer with null decimal (high precision)") { + val bufferSchema = StructType(Array( + StructField("dec1", DecimalType(27, 2), nullable = true), + StructField("dec2", DecimalType(27, 2), nullable = true))) + val buffer = UnsafeProjection.create(bufferSchema) + .apply(new GenericInternalRow(bufferSchema.length)) + val scalaRows = Seq( + Seq(null, null), + Seq(BigDecimal(77.77), BigDecimal(245.00))) + testRows(bufferSchema, buffer, scalaRows) + } + + testBothCodegenAndInterpreted("SPARK-41395: unsafe buffer with null decimal (low precision)") { + val bufferSchema = StructType(Array( + StructField("dec1", DecimalType(10, 2), nullable = true), + StructField("dec2", DecimalType(10, 2), nullable = true))) + val buffer = UnsafeProjection.create(bufferSchema) + .apply(new GenericInternalRow(bufferSchema.length)) + val scalaRows = Seq( + Seq(null, null), + Seq(BigDecimal(77.77), BigDecimal(245.00))) + testRows(bufferSchema, buffer, scalaRows) + } + + testBothCodegenAndInterpreted("SPARK-41395: generic buffer with null decimal (high precision)") { + val bufferSchema = StructType(Array( + StructField("dec1", DecimalType(27, 2), nullable = true), + StructField("dec2", DecimalType(27, 2), nullable = true))) + val buffer = new GenericInternalRow(bufferSchema.length) + val scalaRows = Seq( + Seq(null, null), + Seq(BigDecimal(77.77), BigDecimal(245.00))) + testRows(bufferSchema, buffer, scalaRows) + } + + testBothCodegenAndInterpreted("SPARK-41395: generic buffer with null decimal (low precision)") { + val bufferSchema = StructType(Array( + StructField("dec1", DecimalType(10, 2), nullable = true), + StructField("dec2", DecimalType(10, 2), nullable = true))) + val buffer = new GenericInternalRow(bufferSchema.length) + val scalaRows = Seq( + Seq(null, null), + Seq(BigDecimal(77.77), BigDecimal(245.00))) + testRows(bufferSchema, buffer, scalaRows) + } + testBothCodegenAndInterpreted("variable-length types") { val proj = createMutableProjection(variableLengthTypes) val scalaValues = Seq("abc", BigDecimal(10), --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org