This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 3fd38d4c07f6 [SPARK-47803][FOLLOWUP] Check nulls when casting nested type to variant 3fd38d4c07f6 is described below commit 3fd38d4c07f6c998ec8bb234796f83a6aecfc0d2 Author: Chenhao Li <chenhao...@databricks.com> AuthorDate: Thu May 9 22:45:10 2024 +0800 [SPARK-47803][FOLLOWUP] Check nulls when casting nested type to variant ### What changes were proposed in this pull request? It adds null checks when accessing a nested element when casting a nested type to variant. It is necessary because the `get` API doesn't guarantee to return null when the slot is null. For example, `ColumnarArray.get` may return the default value of a primitive type if the slot is null. ### Why are the changes needed? It is a bug fix is necessary for the cast-to-variant expression to work correctly. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Two new unit tests. One directly uses `ColumnarArray` as the input of the cast. The other creates a real-world situation where `ColumnarArray` is the input of the cast (scan). Both of them would fail without the code change in this PR. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46486 from chenhao-db/fix_cast_nested_to_variant. Authored-by: Chenhao Li <chenhao...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../variant/VariantExpressionEvalUtils.scala | 9 ++++-- .../apache/spark/sql/VariantEndToEndSuite.scala | 33 ++++++++++++++++++++-- 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala index eb235eb854e0..f7f7097173bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala @@ -103,7 +103,8 @@ object VariantExpressionEvalUtils { val offsets = new java.util.ArrayList[java.lang.Integer](data.numElements()) for (i <- 0 until data.numElements()) { offsets.add(builder.getWritePos - start) - buildVariant(builder, data.get(i, elementType), elementType) + val element = if (data.isNullAt(i)) null else data.get(i, elementType) + buildVariant(builder, element, elementType) } builder.finishWritingArray(start, offsets) case MapType(StringType, valueType, _) => @@ -116,7 +117,8 @@ object VariantExpressionEvalUtils { val key = keys.getUTF8String(i).toString val id = builder.addKey(key) fields.add(new VariantBuilder.FieldEntry(key, id, builder.getWritePos - start)) - buildVariant(builder, values.get(i, valueType), valueType) + val value = if (values.isNullAt(i)) null else values.get(i, valueType) + buildVariant(builder, value, valueType) } builder.finishWritingObject(start, fields) case StructType(structFields) => @@ -127,7 +129,8 @@ object VariantExpressionEvalUtils { val key = structFields(i).name val id = builder.addKey(key) fields.add(new VariantBuilder.FieldEntry(key, id, builder.getWritePos - start)) - buildVariant(builder, data.get(i, structFields(i).dataType), structFields(i).dataType) + val value = if (data.isNullAt(i)) null else data.get(i, structFields(i).dataType) + buildVariant(builder, value, structFields(i).dataType) } builder.finishWritingObject(start, fields) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala index 3964bf3aedec..53be9d50d351 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala @@ -16,11 +16,13 @@ */ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.expressions.{CreateArray, CreateNamedStruct, JsonToStructs, Literal, StructsToJson} +import org.apache.spark.sql.catalyst.expressions.{Cast, CreateArray, CreateNamedStruct, JsonToStructs, Literal, StructsToJson} import org.apache.spark.sql.catalyst.expressions.variant.ParseJson import org.apache.spark.sql.execution.WholeStageCodegenExec +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.VariantType +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnarArray import org.apache.spark.types.variant.VariantBuilder import org.apache.spark.unsafe.types.VariantVal @@ -250,4 +252,31 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession { Seq.fill(3)(Row("STRUCT<a: ARRAY<STRING>>")) ++ Seq(Row("STRUCT<a: ARRAY<BIGINT>>"))) } } + + test("cast to variant with ColumnarArray input") { + val dataVector = new OnHeapColumnVector(4, LongType) + dataVector.appendNull() + dataVector.appendLong(123) + dataVector.appendNull() + dataVector.appendLong(456) + val array = new ColumnarArray(dataVector, 0, 4) + val variant = Cast(Literal(array, ArrayType(LongType)), VariantType).eval() + assert(variant.toString == "[null,123,null,456]") + dataVector.close() + } + + test("cast to variant with scan input") { + withTempPath { dir => + val path = dir.getAbsolutePath + val input = Seq(Row(Array(1, null), Map("k1" -> null, "k2" -> false), Row(null, "str"))) + val schema = StructType.fromDDL( + "a array<int>, m map<string, boolean>, s struct<f1 string, f2 string>") + spark.createDataFrame(spark.sparkContext.parallelize(input), schema).write.parquet(path) + val df = spark.read.parquet(path).selectExpr( + s"cast(cast(a as variant) as ${schema(0).dataType.sql})", + s"cast(cast(m as variant) as ${schema(1).dataType.sql})", + s"cast(cast(s as variant) as ${schema(2).dataType.sql})") + checkAnswer(df, input) + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org