This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 98ae33b7acb [SPARK-43142] Fix DSL expressions on attributes with special characters 98ae33b7acb is described below commit 98ae33b7acbd932714301c83d71d42ef318dda9b Author: Willi Raschkowski <wraschkow...@palantir.com> AuthorDate: Tue Apr 25 18:23:36 2023 +0800 [SPARK-43142] Fix DSL expressions on attributes with special characters Re-attempting #40794. #40794 tried to more safely create `AttributeReference` objects from multi-part attributes in `ImplicitAttribute`. But that broke things and we had to revert. This PR is limiting the fix to the `UnresolvedAttribute` object returned by `DslAttr.attr`, which is enough to fix the issue here. ### What changes were proposed in this pull request? This PR fixes DSL expressions on attributes with special characters by making `DslAttr.attr` and `DslAttr.expr` return the implicitly wrapped attribute instead of creating a new one. ### Why are the changes needed? SPARK-43142: DSL expressions on attributes with special characters don't work even if the attribute names are quoted: ```scala scala> "`slashed/col`".attr res0: org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute = 'slashed/col scala> "`slashed/col`".attr.asc org.apache.spark.sql.catalyst.parser.ParseException: mismatched input '/' expecting {<EOF>, '.', '-'}(line 1, pos 7) == SQL == slashed/col -------^^^ ``` DSL expressions rely on a call to `expr` to get child of the new expression [(e.g.)](https://github.com/apache/spark/blob/87a5442f7ed96b11051d8a9333476d080054e5a0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala#L149). `expr` here is a call on implicit class `DslAttr` that's wrapping the `UnresolvedAttribute` returned by `"...".attr` is wrapped by the implicit class `DslAttr`. `DslAttr` and its super class implement `DslAttr.expr` such that a new `UnresolvedAttribute` is created from `UnresolvedAttribute.name` of the wrapped attribute [(here)](https://github.com/apache/spark/blob/87a5442f7ed96b11051d8a9333476d080054e5a0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala#L273-L280). But `UnresolvedAttribute.name` drops the quotes and thus the newly created `UnresolvedAttribute` parses an identifier that should be quoted but isn't: ```scala scala> "`col/slash`".attr.name res5: String = col/slash ``` ### Does this PR introduce _any_ user-facing change? DSL expressions on attributes with special characters no longer fail. ### How was this patch tested? I couldn't find a suite testing the implicit classes in the DSL package, but the DSL package seems used widely enough that I'm confident this doesn't break existing behavior. Locally, I was able to reproduce with this test; it was failing before and passes now: ```scala test("chained DSL expressions on attributes with special characters") { $"`slashed/col`".asc } ``` Closes #40902 from rshkv/wr/spark-43142-v2. Authored-by: Willi Raschkowski <wraschkow...@palantir.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../apache/spark/sql/catalyst/dsl/package.scala | 73 ++++++++++------------ .../expressions/ExpressionSQLBuilderSuite.scala | 2 +- .../datasources/DataSourceStrategySuite.scala | 12 ++-- .../datasources/v2/DataSourceV2StrategySuite.scala | 12 ++-- 4 files changed, 46 insertions(+), 53 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index ac439203cb7..27d05f3bac7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -271,8 +271,8 @@ package object dsl { override def expr: Expression = Literal(s) def attr: UnresolvedAttribute = analysis.UnresolvedAttribute(s) } - implicit class DslAttr(attr: UnresolvedAttribute) extends ImplicitAttribute { - def s: String = attr.name + implicit class DslAttr(override val attr: UnresolvedAttribute) extends ImplicitAttribute { + def s: String = attr.sql } abstract class ImplicitAttribute extends ImplicitOperators { @@ -280,90 +280,83 @@ package object dsl { def expr: UnresolvedAttribute = attr def attr: UnresolvedAttribute = analysis.UnresolvedAttribute(s) + private def attrRef(dataType: DataType): AttributeReference = + AttributeReference(attr.nameParts.last, dataType)(qualifier = attr.nameParts.init) + /** Creates a new AttributeReference of type boolean */ - def boolean: AttributeReference = AttributeReference(s, BooleanType, nullable = true)() + def boolean: AttributeReference = attrRef(BooleanType) /** Creates a new AttributeReference of type byte */ - def byte: AttributeReference = AttributeReference(s, ByteType, nullable = true)() + def byte: AttributeReference = attrRef(ByteType) /** Creates a new AttributeReference of type short */ - def short: AttributeReference = AttributeReference(s, ShortType, nullable = true)() + def short: AttributeReference = attrRef(ShortType) /** Creates a new AttributeReference of type int */ - def int: AttributeReference = AttributeReference(s, IntegerType, nullable = true)() + def int: AttributeReference = attrRef(IntegerType) /** Creates a new AttributeReference of type long */ - def long: AttributeReference = AttributeReference(s, LongType, nullable = true)() + def long: AttributeReference = attrRef(LongType) /** Creates a new AttributeReference of type float */ - def float: AttributeReference = AttributeReference(s, FloatType, nullable = true)() + def float: AttributeReference = attrRef(FloatType) /** Creates a new AttributeReference of type double */ - def double: AttributeReference = AttributeReference(s, DoubleType, nullable = true)() + def double: AttributeReference = attrRef(DoubleType) /** Creates a new AttributeReference of type string */ - def string: AttributeReference = AttributeReference(s, StringType, nullable = true)() + def string: AttributeReference = attrRef(StringType) /** Creates a new AttributeReference of type date */ - def date: AttributeReference = AttributeReference(s, DateType, nullable = true)() + def date: AttributeReference = attrRef(DateType) /** Creates a new AttributeReference of type decimal */ - def decimal: AttributeReference = - AttributeReference(s, DecimalType.SYSTEM_DEFAULT, nullable = true)() + def decimal: AttributeReference = attrRef(DecimalType.SYSTEM_DEFAULT) /** Creates a new AttributeReference of type decimal */ def decimal(precision: Int, scale: Int): AttributeReference = - AttributeReference(s, DecimalType(precision, scale), nullable = true)() + attrRef(DecimalType(precision, scale)) /** Creates a new AttributeReference of type timestamp */ - def timestamp: AttributeReference = AttributeReference(s, TimestampType, nullable = true)() + def timestamp: AttributeReference = attrRef(TimestampType) /** Creates a new AttributeReference of type timestamp without time zone */ - def timestampNTZ: AttributeReference = - AttributeReference(s, TimestampNTZType, nullable = true)() + def timestampNTZ: AttributeReference = attrRef(TimestampNTZType) /** Creates a new AttributeReference of the day-time interval type */ - def dayTimeInterval(startField: Byte, endField: Byte): AttributeReference = { - AttributeReference(s, DayTimeIntervalType(startField, endField), nullable = true)() - } - def dayTimeInterval(): AttributeReference = { - AttributeReference(s, DayTimeIntervalType(), nullable = true)() - } + def dayTimeInterval(startField: Byte, endField: Byte): AttributeReference = + attrRef(DayTimeIntervalType(startField, endField)) + + def dayTimeInterval(): AttributeReference = attrRef(DayTimeIntervalType()) /** Creates a new AttributeReference of the year-month interval type */ - def yearMonthInterval(startField: Byte, endField: Byte): AttributeReference = { - AttributeReference(s, YearMonthIntervalType(startField, endField), nullable = true)() - } - def yearMonthInterval(): AttributeReference = { - AttributeReference(s, YearMonthIntervalType(), nullable = true)() - } + def yearMonthInterval(startField: Byte, endField: Byte): AttributeReference = + attrRef(YearMonthIntervalType(startField, endField)) + + def yearMonthInterval(): AttributeReference = attrRef(YearMonthIntervalType()) /** Creates a new AttributeReference of type binary */ - def binary: AttributeReference = AttributeReference(s, BinaryType, nullable = true)() + def binary: AttributeReference = attrRef(BinaryType) /** Creates a new AttributeReference of type array */ - def array(dataType: DataType): AttributeReference = - AttributeReference(s, ArrayType(dataType), nullable = true)() + def array(dataType: DataType): AttributeReference = attrRef(ArrayType(dataType)) - def array(arrayType: ArrayType): AttributeReference = - AttributeReference(s, arrayType)() + def array(arrayType: ArrayType): AttributeReference = attrRef(arrayType) /** Creates a new AttributeReference of type map */ def map(keyType: DataType, valueType: DataType): AttributeReference = map(MapType(keyType, valueType)) - def map(mapType: MapType): AttributeReference = - AttributeReference(s, mapType, nullable = true)() + def map(mapType: MapType): AttributeReference = attrRef(mapType) /** Creates a new AttributeReference of type struct */ - def struct(structType: StructType): AttributeReference = - AttributeReference(s, structType, nullable = true)() + def struct(structType: StructType): AttributeReference = attrRef(structType) + def struct(attrs: AttributeReference*): AttributeReference = struct(StructType.fromAttributes(attrs)) /** Creates a new AttributeReference of object type */ - def obj(cls: Class[_]): AttributeReference = - AttributeReference(s, ObjectType(cls), nullable = true)() + def obj(cls: Class[_]): AttributeReference = attrRef(ObjectType(cls)) /** Create a function. */ def function(exprs: Expression*): UnresolvedFunction = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSQLBuilderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSQLBuilderSuite.scala index d450aecb732..e88b0e32e90 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSQLBuilderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSQLBuilderSuite.scala @@ -95,7 +95,7 @@ class ExpressionSQLBuilderSuite extends SparkFunSuite { test("attributes") { checkSQL($"a".int, "a") - checkSQL(Symbol("foo bar").int, "`foo bar`") + checkSQL(Symbol("`foo bar`").int, "`foo bar`") // Keyword checkSQL($"int".int, "int") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala index cf8aea45583..a35fb5f6271 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala @@ -27,18 +27,18 @@ import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructT class DataSourceStrategySuite extends PlanTest with SharedSparkSession { val attrInts = Seq( $"cint".int, - $"c.int".int, + $"`c.int`".int, GetStructField($"a".struct(StructType( StructField("cstr", StringType, nullable = true) :: StructField("cint", IntegerType, nullable = true) :: Nil)), 1, None), GetStructField($"a".struct(StructType( StructField("c.int", IntegerType, nullable = true) :: StructField("cstr", StringType, nullable = true) :: Nil)), 0, None), - GetStructField($"a.b".struct(StructType( + GetStructField($"`a.b`".struct(StructType( StructField("cstr1", StringType, nullable = true) :: StructField("cstr2", StringType, nullable = true) :: StructField("cint", IntegerType, nullable = true) :: Nil)), 2, None), - GetStructField($"a.b".struct(StructType( + GetStructField($"`a.b`".struct(StructType( StructField("c.int", IntegerType, nullable = true) :: Nil)), 0, None), GetStructField(GetStructField($"a".struct(StructType( StructField("cstr1", StringType, nullable = true) :: @@ -56,18 +56,18 @@ class DataSourceStrategySuite extends PlanTest with SharedSparkSession { val attrStrs = Seq( $"cstr".string, - $"c.str".string, + $"`c.str`".string, GetStructField($"a".struct(StructType( StructField("cint", IntegerType, nullable = true) :: StructField("cstr", StringType, nullable = true) :: Nil)), 1, None), GetStructField($"a".struct(StructType( StructField("c.str", StringType, nullable = true) :: StructField("cint", IntegerType, nullable = true) :: Nil)), 0, None), - GetStructField($"a.b".struct(StructType( + GetStructField($"`a.b`".struct(StructType( StructField("cint1", IntegerType, nullable = true) :: StructField("cint2", IntegerType, nullable = true) :: StructField("cstr", StringType, nullable = true) :: Nil)), 2, None), - GetStructField($"a.b".struct(StructType( + GetStructField($"`a.b`".struct(StructType( StructField("c.str", StringType, nullable = true) :: Nil)), 0, None), GetStructField(GetStructField($"a".struct(StructType( StructField("cint1", IntegerType, nullable = true) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala index 8d6ffa30a72..3c4f5814375 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala @@ -29,18 +29,18 @@ import org.apache.spark.unsafe.types.UTF8String class DataSourceV2StrategySuite extends PlanTest with SharedSparkSession { val attrInts = Seq( $"cint".int, - $"c.int".int, + $"`c.int`".int, GetStructField($"a".struct(StructType( StructField("cstr", StringType, nullable = true) :: StructField("cint", IntegerType, nullable = true) :: Nil)), 1, None), GetStructField($"a".struct(StructType( StructField("c.int", IntegerType, nullable = true) :: StructField("cstr", StringType, nullable = true) :: Nil)), 0, None), - GetStructField($"a.b".struct(StructType( + GetStructField($"`a.b`".struct(StructType( StructField("cstr1", StringType, nullable = true) :: StructField("cstr2", StringType, nullable = true) :: StructField("cint", IntegerType, nullable = true) :: Nil)), 2, None), - GetStructField($"a.b".struct(StructType( + GetStructField($"`a.b`".struct(StructType( StructField("c.int", IntegerType, nullable = true) :: Nil)), 0, None), GetStructField(GetStructField($"a".struct(StructType( StructField("cstr1", StringType, nullable = true) :: @@ -58,18 +58,18 @@ class DataSourceV2StrategySuite extends PlanTest with SharedSparkSession { val attrStrs = Seq( $"cstr".string, - $"c.str".string, + $"`c.str`".string, GetStructField($"a".struct(StructType( StructField("cint", IntegerType, nullable = true) :: StructField("cstr", StringType, nullable = true) :: Nil)), 1, None), GetStructField($"a".struct(StructType( StructField("c.str", StringType, nullable = true) :: StructField("cint", IntegerType, nullable = true) :: Nil)), 0, None), - GetStructField($"a.b".struct(StructType( + GetStructField($"`a.b`".struct(StructType( StructField("cint1", IntegerType, nullable = true) :: StructField("cint2", IntegerType, nullable = true) :: StructField("cstr", StringType, nullable = true) :: Nil)), 2, None), - GetStructField($"a.b".struct(StructType( + GetStructField($"`a.b`".struct(StructType( StructField("c.str", StringType, nullable = true) :: Nil)), 0, None), GetStructField(GetStructField($"a".struct(StructType( StructField("cint1", IntegerType, nullable = true) :: --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org