This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 134a13928965 [SPARK-47681][SQL] Add schema_of_variant expression 134a13928965 is described below commit 134a13928965e9818393511eadd504b9f1679766 Author: Chenhao Li <chenhao...@databricks.com> AuthorDate: Tue Apr 9 00:04:02 2024 +0800 [SPARK-47681][SQL] Add schema_of_variant expression ### What changes were proposed in this pull request? This PR adds a new `SchemaOfVariant` expression. It returns schema in the SQL format of a variant. Usage examples: ``` > SELECT schema_of_variant(parse_json('null')); VOID > SELECT schema_of_variant(parse_json('[{"b":true,"a":0}]')); ARRAY<STRUCT<a: BIGINT, b: BOOLEAN>> ``` ### Why are the changes needed? This expression can help the user explore the content of variant values. ### Does this PR introduce _any_ user-facing change? Yes. A new SQL expression is added. ### How was this patch tested? Unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #45806 from chenhao-db/variant_schema. Authored-by: Chenhao Li <chenhao...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../sql/catalyst/analysis/FunctionRegistry.scala | 1 + .../spark/sql/catalyst/encoders/EncoderUtils.scala | 7 +- .../expressions/variant/variantExpressions.scala | 86 ++++++++++++++++++++++ .../spark/sql/catalyst/json/JsonInferSchema.scala | 18 +++-- .../sql-functions/sql-expression-schema.md | 1 + .../apache/spark/sql/VariantEndToEndSuite.scala | 31 ++++++++ 6 files changed, 134 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index ecba8b263c41..bbc063c32103 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -822,6 +822,7 @@ object FunctionRegistry { expression[ParseJson]("parse_json"), expressionBuilder("variant_get", VariantGetExpressionBuilder), expressionBuilder("try_variant_get", TryVariantGetExpressionBuilder), + expression[SchemaOfVariant]("schema_of_variant"), // cast expression[Cast]("cast"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala index 45598b6a66f2..20f86a32c1a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala @@ -23,8 +23,8 @@ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, C import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.types.{PhysicalBinaryType, PhysicalIntegerType, PhysicalLongType} import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} -import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, DataType, DateType, DayTimeIntervalType, Decimal, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ObjectType, ShortType, StringType, StructType, TimestampNTZType, TimestampType, UserDefinedType, YearMonthIntervalType} -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, DataType, DateType, DayTimeIntervalType, Decimal, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ObjectType, ShortType, StringType, StructType, TimestampNTZType, TimestampType, UserDefinedType, VariantType, YearMonthIntervalType} +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal} /** * Helper class for Generating [[ExpressionEncoder]]s. @@ -122,7 +122,8 @@ object EncoderUtils { TimestampType -> classOf[PhysicalLongType.InternalType], TimestampNTZType -> classOf[PhysicalLongType.InternalType], BinaryType -> classOf[PhysicalBinaryType.InternalType], - CalendarIntervalType -> classOf[CalendarInterval] + CalendarIntervalType -> classOf[CalendarInterval], + VariantType -> classOf[VariantVal] ) val typeBoxedJavaMapping: Map[DataType, Class[_]] = Map[DataType, Class[_]]( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala index 4681326136c7..2f2b5923fed7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.variant import scala.util.parsing.combinator.RegexParsers +import org.apache.spark.SparkRuntimeException import org.apache.spark.sql.catalyst.analysis.ExpressionBuilder import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch @@ -26,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke +import org.apache.spark.sql.catalyst.json.JsonInferSchema import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, VARIANT_GET} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.catalyst.util.DateTimeConstants._ @@ -403,3 +405,87 @@ object VariantGetExpressionBuilder extends VariantGetExpressionBuilderBase(true) ) // scalastyle:on line.size.limit object TryVariantGetExpressionBuilder extends VariantGetExpressionBuilderBase(false) + +@ExpressionDescription( + usage = "_FUNC_(v) - Returns schema in the SQL format of a variant.", + examples = """ + Examples: + > SELECT _FUNC_(parse_json('null')); + VOID + > SELECT _FUNC_(parse_json('[{"b":true,"a":0}]')); + ARRAY<STRUCT<a: BIGINT, b: BOOLEAN>> + """, + since = "4.0.0", + group = "variant_funcs" +) +case class SchemaOfVariant(child: Expression) + extends UnaryExpression + with RuntimeReplaceable + with ExpectsInputTypes { + override lazy val replacement: Expression = StaticInvoke( + SchemaOfVariant.getClass, + StringType, + "schemaOfVariant", + Seq(child), + inputTypes, + returnNullable = false) + + override def inputTypes: Seq[AbstractDataType] = Seq(VariantType) + + override def dataType: DataType = StringType + + override def prettyName: String = "schema_of_variant" + + override protected def withNewChildInternal(newChild: Expression): SchemaOfVariant = + copy(child = newChild) +} + +object SchemaOfVariant { + /** The actual implementation of the `SchemaOfVariant` expression. */ + def schemaOfVariant(input: VariantVal): UTF8String = { + val v = new Variant(input.getValue, input.getMetadata) + UTF8String.fromString(schemaOf(v).sql) + } + + /** + * Return the schema of a variant. Struct fields are guaranteed to be sorted alphabetically. + */ + def schemaOf(v: Variant): DataType = v.getType match { + case Type.OBJECT => + val size = v.objectSize() + val fields = new Array[StructField](size) + for (i <- 0 until size) { + val field = v.getFieldAtIndex(i) + fields(i) = StructField(field.key, schemaOf(field.value)) + } + // According to the variant spec, object fields must be sorted alphabetically. So we don't + // have to sort, but just need to validate they are sorted. + for (i <- 1 until size) { + if (fields(i - 1).name >= fields(i).name) { + throw new SparkRuntimeException("MALFORMED_VARIANT", Map.empty) + } + } + StructType(fields) + case Type.ARRAY => + var elementType: DataType = NullType + for (i <- 0 until v.arraySize()) { + elementType = mergeSchema(elementType, schemaOf(v.getElementAtIndex(i))) + } + ArrayType(elementType) + case Type.NULL => NullType + case Type.BOOLEAN => BooleanType + case Type.LONG => LongType + case Type.STRING => StringType + case Type.DOUBLE => DoubleType + case Type.DECIMAL => + val d = v.getDecimal + DecimalType(d.precision(), d.scale()) + } + + /** + * Returns the tightest common type for two given data types. Input struct fields are assumed to + * be sorted alphabetically. + */ + def mergeSchema(t1: DataType, t2: DataType): DataType = + JsonInferSchema.compatibleType(t1, t2, VariantType) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala index 12c1be7c0de7..7ee522226e3e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala @@ -360,8 +360,10 @@ object JsonInferSchema { /** * Returns the most general data type for two given data types. + * When the two types are incompatible, return `defaultDataType` as a fallback result. */ - def compatibleType(t1: DataType, t2: DataType): DataType = { + def compatibleType( + t1: DataType, t2: DataType, defaultDataType: DataType = StringType): DataType = { TypeCoercion.findTightestCommonType(t1, t2).getOrElse { // t1 or t2 is a StructType, ArrayType, or an unexpected type. (t1, t2) match { @@ -399,7 +401,8 @@ object JsonInferSchema { val f2Name = fields2(f2Idx).name val comp = f1Name.compareTo(f2Name) if (comp == 0) { - val dataType = compatibleType(fields1(f1Idx).dataType, fields2(f2Idx).dataType) + val dataType = compatibleType( + fields1(f1Idx).dataType, fields2(f2Idx).dataType, defaultDataType) newFields.add(StructField(f1Name, dataType, nullable = true)) f1Idx += 1 f2Idx += 1 @@ -422,21 +425,22 @@ object JsonInferSchema { StructType(newFields.toArray(emptyStructFieldArray)) case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) => - ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2) + ArrayType( + compatibleType(elementType1, elementType2, defaultDataType), + containsNull1 || containsNull2) // The case that given `DecimalType` is capable of given `IntegralType` is handled in // `findTightestCommonType`. Both cases below will be executed only when the given // `DecimalType` is not capable of the given `IntegralType`. case (t1: IntegralType, t2: DecimalType) => - compatibleType(DecimalType.forType(t1), t2) + compatibleType(DecimalType.forType(t1), t2, defaultDataType) case (t1: DecimalType, t2: IntegralType) => - compatibleType(t1, DecimalType.forType(t2)) + compatibleType(t1, DecimalType.forType(t2), defaultDataType) case (TimestampNTZType, TimestampType) | (TimestampType, TimestampNTZType) => TimestampType - // strings and every string is a Json object. - case (_, _) => StringType + case (_, _) => defaultDataType } } } diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index c461e3ec09fb..05491034e6c7 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -437,6 +437,7 @@ | org.apache.spark.sql.catalyst.expressions.aggregate.VarianceSamp | var_samp | SELECT var_samp(col) FROM VALUES (1), (2), (3) AS tab(col) | struct<var_samp(col):double> | | org.apache.spark.sql.catalyst.expressions.aggregate.VarianceSamp | variance | SELECT variance(col) FROM VALUES (1), (2), (3) AS tab(col) | struct<variance(col):double> | | org.apache.spark.sql.catalyst.expressions.variant.ParseJson | parse_json | SELECT parse_json('{"a":1,"b":0.8}') | struct<parse_json({"a":1,"b":0.8}):variant> | +| org.apache.spark.sql.catalyst.expressions.variant.SchemaOfVariant | schema_of_variant | SELECT schema_of_variant(parse_json('null')) | struct<schema_of_variant(parse_json(null)):string> | | org.apache.spark.sql.catalyst.expressions.variant.TryVariantGetExpressionBuilder | try_variant_get | SELECT try_variant_get(parse_json('{"a": 1}'), '$.a', 'int') | struct<try_variant_get(parse_json({"a": 1}), $.a):int> | | org.apache.spark.sql.catalyst.expressions.variant.VariantGetExpressionBuilder | variant_get | SELECT variant_get(parse_json('{"a": 1}'), '$.a', 'int') | struct<variant_get(parse_json({"a": 1}), $.a):int> | | org.apache.spark.sql.catalyst.expressions.xml.XPathBoolean | xpath_boolean | SELECT xpath_boolean('<a><b>1</b></a>','a/b') | struct<xpath_boolean(<a><b>1</b></a>, a/b):boolean> | diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala index cf12001fa71b..d8b1dca21ca6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala @@ -81,4 +81,35 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession { val expected = new VariantVal(v.getValue, v.getMetadata) checkAnswer(variantDF, Seq(Row(expected))) } + + test("schema_of_variant") { + def check(json: String, expected: String): Unit = { + val df = Seq(json).toDF("j").selectExpr("schema_of_variant(parse_json(j))") + checkAnswer(df, Seq(Row(expected))) + } + + check("null", "VOID") + check("1", "BIGINT") + check("1.0", "DECIMAL(1,0)") + check("1E0", "DOUBLE") + check("true", "BOOLEAN") + check("\"2000-01-01\"", "STRING") + check("""{"a":0}""", "STRUCT<a: BIGINT>") + check("""{"b": {"c": "c"}, "a":["a"]}""", "STRUCT<a: ARRAY<STRING>, b: STRUCT<c: STRING>>") + check("[]", "ARRAY<VOID>") + check("[false]", "ARRAY<BOOLEAN>") + check("[null, 1, 1.0]", "ARRAY<DECIMAL(20,0)>") + check("[null, 1, 1.1]", "ARRAY<DECIMAL(21,1)>") + check("[123456.789, 123.456789]", "ARRAY<DECIMAL(12,6)>") + check("[1, 11111111111111111111111111111111111111]", "ARRAY<DECIMAL(38,0)>") + check("[1.1, 11111111111111111111111111111111111111]", "ARRAY<DOUBLE>") + check("[1, \"1\"]", "ARRAY<VARIANT>") + check("[{}, true]", "ARRAY<VARIANT>") + check("""[{"c": ""}, {"a": null}, {"b": 1}]""", "ARRAY<STRUCT<a: VOID, b: BIGINT, c: STRING>>") + check("""[{"a": ""}, {"a": null}, {"b": 1}]""", "ARRAY<STRUCT<a: STRING, b: BIGINT>>") + check( + """[{"a": 1, "b": null}, {"b": true, "a": 1E0}]""", + "ARRAY<STRUCT<a: DOUBLE, b: BOOLEAN>>" + ) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org