This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new cb869328ea7 [SPARK-41804][SQL] Choose correct element size in `InterpretedUnsafeProjection` for array of UDTs cb869328ea7 is described below commit cb869328ea7fcf95a4178a0db19a6fa821ce3f15 Author: Bruce Robbins <bersprock...@gmail.com> AuthorDate: Tue Jan 3 10:22:35 2023 +0900 [SPARK-41804][SQL] Choose correct element size in `InterpretedUnsafeProjection` for array of UDTs ### What changes were proposed in this pull request? Change `InterpretedUnsafeProjection#getElementSize` to choose the appropriate element size for the underlying SQL type of a UDT, rather than simply using the the default size of the underlying SQL type. ### Why are the changes needed? Consider this query: ``` // create a file of vector data import org.apache.spark.ml.linalg.{DenseVector, Vector} case class TestRow(varr: Array[Vector]) val values = Array(0.1d, 0.2d, 0.3d) val dv = new DenseVector(values).asInstanceOf[Vector] val ds = Seq(TestRow(Array(dv, dv))).toDS ds.coalesce(1).write.mode("overwrite").format("parquet").save("vector_data") // this works spark.read.format("parquet").load("vector_data").collect sql("set spark.sql.codegen.wholeStage=false") sql("set spark.sql.codegen.factoryMode=NO_CODEGEN") // this will get an error spark.read.format("parquet").load("vector_data").collect ``` The failures vary, incuding * `VectorUDT` attempting to deserialize to a `SparseVector` (rather than a `DenseVector`) * negative array size (for one of the nested arrays) * JVM crash (SIGBUS error). This is because `InterpretedUnsafeProjection` initializes the outer-most array writer with an element size of 17 (the size of the UDT's underlying struct), rather than an element size of 8, which would be appropriate for an array of structs. When the outer-most array is later accessed, `UnsafeArrayData` assumes an element size of 8, so it picks up a garbage offset/size tuple for the second element. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New unit test. Closes #39349 from bersprockets/udt_issue. Authored-by: Bruce Robbins <bersprock...@gmail.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../expressions/InterpretedUnsafeProjection.scala | 2 ++ .../catalyst/expressions/UnsafeRowConverterSuite.scala | 16 ++++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala index d87c0c006cf..9108a045c09 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -294,6 +294,8 @@ object InterpretedUnsafeProjection { private def getElementSize(dataType: DataType): Int = dataType match { case NullType | StringType | BinaryType | CalendarIntervalType | _: DecimalType | _: StructType | _: ArrayType | _: MapType => 8 + case udt: UserDefinedType[_] => + getElementSize(udt.sqlType) case _ => dataType.defaultSize } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 83dc8127828..cbab8894cb5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -687,4 +687,20 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestB val fields5 = Array[DataType](udt) assert(convertBackToInternalRow(udtRow, fields5) === udtRow) } + + testBothCodegenAndInterpreted("SPARK-41804: Array of UDTs") { + val udt = new ExampleBaseTypeUDT + val objs = Seq( + udt.serialize(new ExampleSubClass(1)), + udt.serialize(new ExampleSubClass(2))) + val arr = new GenericArrayData(objs) + val row = new GenericInternalRow(Array[Any](arr)) + val unsafeProj = UnsafeProjection.create(Array[DataType](ArrayType(udt))) + val unsafeRow = unsafeProj.apply(row) + val unsafeOuterArray = unsafeRow.getArray(0) + // get second element from unsafe array + val unsafeStruct = unsafeOuterArray.getStruct(1, 1) + val result = unsafeStruct.getInt(0) + assert(result == 2) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org