This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 4cb364e6f615 [SPARK-47680][SQL] Add variant_explode expression 4cb364e6f615 is described below commit 4cb364e6f615512811b3001597d0cf98a7a30b00 Author: Chenhao Li <chenhao...@databricks.com> AuthorDate: Wed Apr 10 22:47:43 2024 +0800 [SPARK-47680][SQL] Add variant_explode expression ### What changes were proposed in this pull request? This PR adds a new `VariantExplode` expression. It separates a variant object/array into multiple rows containing its fields/elements. Its result schema is `struct<pos int, key string, value variant>`. `pos` is the position of the field/element in its parent object/array, and `value` is the field/element value. `key` is the field name when exploding a variant object, or is NULL when exploding a variant array. It ignores any input that is not a variant array/object, including SQL NULL, [...] It is exposed as two SQL expressions, `variant_explode` and `variant_explode_outer`. The only difference is that whenever `variant_explode` produces zero output row for an input row, `variant_explode_outer` will produce one output row containing `{NULL, NULL, NULL}`. Usage examples: ``` > SELECT variant_explode(parse_json('["hello", "world"]')); 0 NULL "hello" 1 NULL "world" > SELECT variant_explode(parse_json('{"a": true, "b": 3.14}')); 0 a true 1 b 3.14 ``` ### Why are the changes needed? This expression allows the user to process variant array and object more conveniently. ### Does this PR introduce _any_ user-facing change? Yes. A new SQL expression is added. ### How was this patch tested? Unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #45805 from chenhao-db/variant_explode. Authored-by: Chenhao Li <chenhao...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../sql/catalyst/analysis/FunctionRegistry.scala | 4 +- .../expressions/variant/variantExpressions.scala | 83 ++++++++++++++++++++++ .../scala/org/apache/spark/sql/VariantSuite.scala | 26 +++++++ 3 files changed, 112 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 99ae3adde44f..9447ea63b51f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -1096,7 +1096,9 @@ object TableFunctionRegistry { generator[PosExplode]("posexplode"), generator[PosExplode]("posexplode_outer", outer = true), generator[Stack]("stack"), - generator[SQLKeywords]("sql_keywords") + generator[SQLKeywords]("sql_keywords"), + generator[VariantExplode]("variant_explode"), + generator[VariantExplode]("variant_explode_outer", outer = true) ) val builtin: SimpleTableFunctionRegistry = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala index 7d1a3cf00d2b..c5e316dc6c8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.variant import scala.util.parsing.combinator.RegexParsers import org.apache.spark.SparkRuntimeException +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.ExpressionBuilder import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch @@ -419,6 +420,88 @@ object VariantGetExpressionBuilder extends VariantGetExpressionBuilderBase(true) // scalastyle:on line.size.limit object TryVariantGetExpressionBuilder extends VariantGetExpressionBuilderBase(false) +// scalastyle:off line.size.limit line.contains.tab +@ExpressionDescription( + usage = "_FUNC_(expr) - It separates a variant object/array into multiple rows containing its fields/elements. Its result schema is `struct<pos int, key string, value variant>`. `pos` is the position of the field/element in its parent object/array, and `value` is the field/element value. `key` is the field name when exploding a variant object, or is NULL when exploding a variant array. It ignores any input that is not a variant array/object, including SQL NULL, variant null, and any ot [...] + examples = """ + Examples: + > SELECT * from _FUNC_(parse_json('["hello", "world"]')); + 0 NULL "hello" + 1 NULL "world" + > SELECT * from _FUNC_(parse_json('{"a": true, "b": 3.14}')); + 0 a true + 1 b 3.14 + """, + since = "4.0.0", + group = "variant_funcs") +// scalastyle:on line.size.limit line.contains.tab +case class VariantExplode(child: Expression) extends UnaryExpression with Generator + with ExpectsInputTypes { + override def inputTypes: Seq[AbstractDataType] = Seq(VariantType) + + override def prettyName: String = "variant_explode" + + override protected def withNewChildInternal(newChild: Expression): VariantExplode = + copy(child = newChild) + + override def eval(input: InternalRow): IterableOnce[InternalRow] = { + val inputVariant = child.eval(input).asInstanceOf[VariantVal] + VariantExplode.variantExplode(inputVariant, inputVariant == null) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val childCode = child.genCode(ctx) + val cls = classOf[VariantExplode].getName + val code = code""" + ${childCode.code} + scala.collection.Seq<InternalRow> ${ev.value} = $cls.variantExplode( + ${childCode.value}, ${childCode.isNull}); + """ + ev.copy(code = code, isNull = FalseLiteral) + } + + override def elementSchema: StructType = { + new StructType() + .add("pos", IntegerType, nullable = false) + .add("key", StringType, nullable = true) + .add("value", VariantType, nullable = false) + } +} + +object VariantExplode { + /** + * The actual implementation of the `VariantExplode` expression. We check `isNull` separately + * rather than `input == null` because the documentation of `ExprCode` says that the value is not + * valid if `isNull` is set to `true`. + */ + def variantExplode(input: VariantVal, isNull: Boolean): scala.collection.Seq[InternalRow] = { + if (isNull) { + return Nil + } + val v = new Variant(input.getValue, input.getMetadata) + v.getType match { + case Type.OBJECT => + val size = v.objectSize() + val result = new Array[InternalRow](size) + for (i <- 0 until size) { + val field = v.getFieldAtIndex(i) + result(i) = InternalRow(i, UTF8String.fromString(field.key), + new VariantVal(field.value.getValue, field.value.getMetadata)) + } + result + case Type.ARRAY => + val size = v.arraySize() + val result = new Array[InternalRow](size) + for (i <- 0 until size) { + val elem = v.getElementAtIndex(i) + result(i) = InternalRow(i, null, new VariantVal(elem.getValue, elem.getMetadata)) + } + result + case _ => Nil + } + } +} + @ExpressionDescription( usage = "_FUNC_(v) - Returns schema in the SQL format of a variant.", examples = """ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala index 4f82dbc90dc5..d276ec4428b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala @@ -32,6 +32,8 @@ import org.apache.spark.unsafe.types.VariantVal import org.apache.spark.util.ArrayImplicits._ class VariantSuite extends QueryTest with SharedSparkSession { + import testImplicits._ + test("basic tests") { def verifyResult(df: DataFrame): Unit = { val result = df.collect() @@ -298,4 +300,28 @@ class VariantSuite extends QueryTest with SharedSparkSession { } assert(ex.getErrorClass == "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE") } + + test("variant_explode") { + def check(input: String, expected: Seq[Row]): Unit = { + withView("v") { + Seq(input).toDF("json").createOrReplaceTempView("v") + checkAnswer(sql("select pos, key, to_json(value) from v, " + + "lateral variant_explode(parse_json(json))"), expected) + val expectedOuter = if (expected.isEmpty) Seq(Row(null, null, null)) else expected + checkAnswer(sql("select pos, key, to_json(value) from v, " + + "lateral variant_explode_outer(parse_json(json))"), expectedOuter) + } + } + + Seq("true", "false").foreach { codegenEnabled => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegenEnabled) { + check(null, Nil) + check("1", Nil) + check("null", Nil) + check("""{"a": [1, 2, 3], "b": true}""", Seq(Row(0, "a", "[1,2,3]"), Row(1, "b", "true"))) + check("""[null, "hello", {}]""", + Seq(Row(0, null, "null"), Row(1, null, "\"hello\""), Row(2, null, "{}"))) + } + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org