This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-3.1 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.1 by this push: new 128d0cd [SPARK-33641][SQL] Invalidate new char/varchar types in public APIs that produce incorrect results 128d0cd is described below commit 128d0cdbc579a101d6db6105dac168cc8b3623ff Author: Kent Yao <yaooq...@hotmail.com> AuthorDate: Mon Dec 7 13:40:15 2020 +0000 [SPARK-33641][SQL] Invalidate new char/varchar types in public APIs that produce incorrect results ### What changes were proposed in this pull request? In this PR, we suppose to narrow the use cases of the char/varchar data types, of which are invalid now or later ### Why are the changes needed? 1. udf ```scala scala> spark.udf.register("abcd", () => "12345", org.apache.spark.sql.types.VarcharType(2)) scala> spark.sql("select abcd()").show scala.MatchError: CharType(2) (of class org.apache.spark.sql.types.VarcharType) at org.apache.spark.sql.catalyst.encoders.RowEncoder$.externalDataTypeFor(RowEncoder.scala:215) at org.apache.spark.sql.catalyst.encoders.RowEncoder$.externalDataTypeForInput(RowEncoder.scala:212) at org.apache.spark.sql.catalyst.expressions.objects.ValidateExternalType.<init>(objects.scala:1741) at org.apache.spark.sql.catalyst.encoders.RowEncoder$.$anonfun$serializerFor$3(RowEncoder.scala:175) at scala.collection.TraversableLike.$anonfun$flatMap$1(TraversableLike.scala:245) at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:36) at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33) at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:198) at scala.collection.TraversableLike.flatMap(TraversableLike.scala:245) at scala.collection.TraversableLike.flatMap$(TraversableLike.scala:242) at scala.collection.mutable.ArrayOps$ofRef.flatMap(ArrayOps.scala:198) at org.apache.spark.sql.catalyst.encoders.RowEncoder$.serializerFor(RowEncoder.scala:171) at org.apache.spark.sql.catalyst.encoders.RowEncoder$.apply(RowEncoder.scala:66) at org.apache.spark.sql.Dataset$.$anonfun$ofRows$2(Dataset.scala:99) at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:768) at org.apache.spark.sql.Dataset$.ofRows(Dataset.scala:96) at org.apache.spark.sql.SparkSession.$anonfun$sql$1(SparkSession.scala:611) at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:768) at org.apache.spark.sql.SparkSession.sql(SparkSession.scala:606) ... 47 elided ``` 2. spark.createDataframe ``` scala> spark.createDataFrame(spark.read.text("README.md").rdd, new org.apache.spark.sql.types.StructType().add("c", "char(1)")).show +--------------------+ | c| +--------------------+ | # Apache Spark| | | |Spark is a unifie...| |high-level APIs i...| |supports general ...| |rich set of highe...| |MLlib for machine...| |and Structured St...| | | |<https://spark.ap...| | | |[![Jenkins Build]...| |[![AppVeyor Build...| |[![PySpark Covera...| | | | | ``` 3. reader.schema ``` scala> spark.read.schema("a varchar(2)").text("./README.md").show(100) +--------------------+ | a| +--------------------+ | # Apache Spark| | | |Spark is a unifie...| |high-level APIs i...| |supports general ...| ``` 4. etc ### Does this PR introduce _any_ user-facing change? NO, we intend to avoid protentical breaking change ### How was this patch tested? new tests Closes #30586 from yaooqinn/SPARK-33641. Authored-by: Kent Yao <yaooq...@hotmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> (cherry picked from commit da72b87374a7be5416b99ed016dc2fc9da0ed88a) Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../spark/sql/catalyst/expressions/ExprUtils.scala | 6 +- .../spark/sql/catalyst/parser/AstBuilder.scala | 19 +--- .../spark/sql/catalyst/parser/ParseDriver.scala | 5 - .../sql/catalyst/parser/ParserInterface.scala | 6 -- .../spark/sql/catalyst/util/CharVarcharUtils.scala | 38 ++++++- .../org/apache/spark/sql/internal/SQLConf.scala | 13 +++ .../org/apache/spark/sql/types/VarcharType.scala | 2 +- .../sql/catalyst/parser/DataTypeParserSuite.scala | 14 +-- .../catalyst/parser/TableSchemaParserSuite.scala | 4 +- .../org/apache/spark/sql/types/DataTypeSuite.scala | 10 ++ .../main/scala/org/apache/spark/sql/Column.scala | 2 +- .../org/apache/spark/sql/DataFrameReader.scala | 7 +- .../scala/org/apache/spark/sql/SparkSession.scala | 10 +- .../org/apache/spark/sql/UDFRegistration.scala | 73 ++++++++----- .../sql/execution/datasources/jdbc/JdbcUtils.scala | 7 +- .../scala/org/apache/spark/sql/functions.scala | 12 ++- .../apache/spark/sql/CharVarcharTestSuite.scala | 114 ++++++++++++++------- .../spark/sql/SparkSessionExtensionSuite.scala | 3 - .../org/apache/spark/sql/jdbc/JDBCWriteSuite.scala | 5 +- .../spark/sql/hive/client/HiveClientImpl.scala | 2 +- 20 files changed, 226 insertions(+), 126 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala index 56bd3d7..b45bbe4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala @@ -21,7 +21,7 @@ import java.text.{DecimalFormat, DecimalFormatSymbols, ParsePosition} import java.util.Locale import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.util.ArrayBasedMapData +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, CharVarcharUtils} import org.apache.spark.sql.types.{DataType, MapType, StringType, StructType} import org.apache.spark.unsafe.types.UTF8String @@ -30,7 +30,9 @@ object ExprUtils { def evalTypeExpr(exp: Expression): DataType = { if (exp.foldable) { exp.eval() match { - case s: UTF8String if s != null => DataType.fromDDL(s.toString) + case s: UTF8String if s != null => + val dataType = DataType.fromDDL(s.toString) + CharVarcharUtils.failIfHasCharVarchar(dataType) case _ => throw new AnalysisException( s"The expression '${exp.sql}' is not a valid schema string.") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 12c5e0d..a22383c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -95,19 +95,14 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg } override def visitSingleDataType(ctx: SingleDataTypeContext): DataType = withOrigin(ctx) { - visitSparkDataType(ctx.dataType) + typedVisit[DataType](ctx.dataType) } override def visitSingleTableSchema(ctx: SingleTableSchemaContext): StructType = { - val schema = CharVarcharUtils.replaceCharVarcharWithStringInSchema( - StructType(visitColTypeList(ctx.colTypeList))) + val schema = StructType(visitColTypeList(ctx.colTypeList)) withOrigin(ctx)(schema) } - def parseRawDataType(ctx: SingleDataTypeContext): DataType = withOrigin(ctx) { - typedVisit[DataType](ctx.dataType()) - } - /* ******************************************************************************************** * Plan parsing * ******************************************************************************************** */ @@ -1550,7 +1545,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg * Create a [[Cast]] expression. */ override def visitCast(ctx: CastContext): Expression = withOrigin(ctx) { - Cast(expression(ctx.expression), visitSparkDataType(ctx.dataType)) + val rawDataType = typedVisit[DataType](ctx.dataType()) + val dataType = CharVarcharUtils.replaceCharVarcharWithStringForCast(rawDataType) + Cast(expression(ctx.expression), dataType) } /** @@ -2229,12 +2226,6 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg /* ******************************************************************************************** * DataType parsing * ******************************************************************************************** */ - /** - * Create a Spark DataType. - */ - private def visitSparkDataType(ctx: DataTypeContext): DataType = { - CharVarcharUtils.replaceCharVarcharWithString(typedVisit(ctx)) - } /** * Resolve/create a primitive type. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala index ac3fbbf..d08be46 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -39,11 +39,6 @@ abstract class AbstractSqlParser extends ParserInterface with SQLConfHelper with astBuilder.visitSingleDataType(parser.singleDataType()) } - /** Similar to `parseDataType`, but without CHAR/VARCHAR replacement. */ - override def parseRawDataType(sqlText: String): DataType = parse(sqlText) { parser => - astBuilder.parseRawDataType(parser.singleDataType()) - } - /** Creates Expression for a given SQL string. */ override def parseExpression(sqlText: String): Expression = parse(sqlText) { parser => astBuilder.visitSingleExpression(parser.singleExpression()) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala index d724933..77e357a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala @@ -70,10 +70,4 @@ trait ParserInterface { */ @throws[ParseException]("Text cannot be parsed to a DataType") def parseDataType(sqlText: String): DataType - - /** - * Parse a string to a raw [[DataType]] without CHAR/VARCHAR replacement. - */ - @throws[ParseException]("Text cannot be parsed to a DataType") - def parseRawDataType(sqlText: String): DataType } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala index 0cbe5ab..b551d96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala @@ -19,11 +19,14 @@ package org.apache.spark.sql.catalyst.util import scala.collection.mutable +import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -object CharVarcharUtils { +object CharVarcharUtils extends Logging { private val CHAR_VARCHAR_TYPE_STRING_METADATA_KEY = "__CHAR_VARCHAR_TYPE_STRING" @@ -53,6 +56,19 @@ object CharVarcharUtils { } /** + * Validate the given [[DataType]] to fail if it is char or varchar types or contains nested ones + */ + def failIfHasCharVarchar(dt: DataType): DataType = { + if (!SQLConf.get.charVarcharAsString && hasCharVarchar(dt)) { + throw new AnalysisException("char/varchar type can only be used in the table schema. " + + s"You can set ${SQLConf.LEGACY_CHAR_VARCHAR_AS_STRING.key} to true, so that Spark" + + s" treat them as string type as same as Spark 3.0 and earlier") + } else { + replaceCharVarcharWithString(dt) + } + } + + /** * Replaces CharType/VarcharType with StringType recursively in the given data type. */ def replaceCharVarcharWithString(dt: DataType): DataType = dt match { @@ -70,6 +86,24 @@ object CharVarcharUtils { } /** + * Replaces CharType/VarcharType with StringType recursively in the given data type, with a + * warning message if it has char or varchar types + */ + def replaceCharVarcharWithStringForCast(dt: DataType): DataType = { + if (SQLConf.get.charVarcharAsString) { + replaceCharVarcharWithString(dt) + } else if (hasCharVarchar(dt)) { + logWarning("The Spark cast operator does not support char/varchar type and simply treats" + + " them as string type. Please use string type directly to avoid confusion. Otherwise," + + s" you can set ${SQLConf.LEGACY_CHAR_VARCHAR_AS_STRING.key} to true, so that Spark treat" + + s" them as string type as same as Spark 3.0 and earlier") + replaceCharVarcharWithString(dt) + } else { + dt + } + } + + /** * Removes the metadata entry that contains the original type string of CharType/VarcharType from * the given attribute's metadata. */ @@ -85,7 +119,7 @@ object CharVarcharUtils { */ def getRawType(metadata: Metadata): Option[DataType] = { if (metadata.contains(CHAR_VARCHAR_TYPE_STRING_METADATA_KEY)) { - Some(CatalystSqlParser.parseRawDataType( + Some(CatalystSqlParser.parseDataType( metadata.getString(CHAR_VARCHAR_TYPE_STRING_METADATA_KEY))) } else { None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 0254782..ea301ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2954,6 +2954,17 @@ object SQLConf { .booleanConf .createWithDefault(true) + val LEGACY_CHAR_VARCHAR_AS_STRING = + buildConf("spark.sql.legacy.charVarcharAsString") + .internal() + .doc("When true, Spark will not fail if user uses char and varchar type directly in those" + + " APIs that accept or parse data types as parameters, e.g." + + " `SparkSession.read.schema(...)`, `SparkSession.udf.register(...)` but treat them as" + + " string type as Spark 3.0 and earlier.") + .version("3.1.0") + .booleanConf + .createWithDefault(false) + /** * Holds information about keys that have been deprecated. * @@ -3602,6 +3613,8 @@ class SQLConf extends Serializable with Logging { def disabledJdbcConnectionProviders: String = getConf(SQLConf.DISABLED_JDBC_CONN_PROVIDER_LIST) + def charVarcharAsString: Boolean = getConf(SQLConf.LEGACY_CHAR_VARCHAR_AS_STRING) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/VarcharType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/VarcharType.scala index 8d78640..2e30820 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/VarcharType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/VarcharType.scala @@ -32,6 +32,6 @@ case class VarcharType(length: Int) extends AtomicType { override def defaultSize: Int = length override def typeName: String = s"varchar($length)" - override def toString: String = s"CharType($length)" + override def toString: String = s"VarcharType($length)" private[spark] override def asNullable: VarcharType = this } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala index 655b1d2..b9f9840 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala @@ -56,10 +56,10 @@ class DataTypeParserSuite extends SparkFunSuite { checkDataType("DATE", DateType) checkDataType("timestamp", TimestampType) checkDataType("string", StringType) - checkDataType("ChaR(5)", StringType) - checkDataType("ChaRacter(5)", StringType) - checkDataType("varchAr(20)", StringType) - checkDataType("cHaR(27)", StringType) + checkDataType("ChaR(5)", CharType(5)) + checkDataType("ChaRacter(5)", CharType(5)) + checkDataType("varchAr(20)", VarcharType(20)) + checkDataType("cHaR(27)", CharType(27)) checkDataType("BINARY", BinaryType) checkDataType("void", NullType) checkDataType("interval", CalendarIntervalType) @@ -103,9 +103,9 @@ class DataTypeParserSuite extends SparkFunSuite { StructType( StructField("deciMal", DecimalType.USER_DEFAULT, true) :: StructField("anotherDecimal", DecimalType(5, 2), true) :: Nil), true) :: - StructField("MAP", MapType(TimestampType, StringType), true) :: + StructField("MAP", MapType(TimestampType, VarcharType(10)), true) :: StructField("arrAy", ArrayType(DoubleType, true), true) :: - StructField("anotherArray", ArrayType(StringType, true), true) :: Nil) + StructField("anotherArray", ArrayType(CharType(9), true), true) :: Nil) ) // Use backticks to quote column names having special characters. checkDataType( @@ -113,7 +113,7 @@ class DataTypeParserSuite extends SparkFunSuite { StructType( StructField("x+y", IntegerType, true) :: StructField("!@#$%^&*()", StringType, true) :: - StructField("1_2.345<>:\"", StringType, true) :: Nil) + StructField("1_2.345<>:\"", VarcharType(20), true) :: Nil) ) // Empty struct. checkDataType("strUCt<>", StructType(Nil)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala index 95851d4..5519f01 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.types._ class TableSchemaParserSuite extends SparkFunSuite { @@ -69,8 +68,7 @@ class TableSchemaParserSuite extends SparkFunSuite { StructField("arrAy", ArrayType(DoubleType)) :: StructField("anotherArray", ArrayType(CharType(9))) :: Nil)) :: Nil) - assert(parse(tableSchemaString) === - CharVarcharUtils.replaceCharVarcharWithStringInSchema(expectedDataType)) + assert(parse(tableSchemaString) === expectedDataType) } // Negative cases diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 9442a3e..8c2e5db 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -249,6 +249,12 @@ class DataTypeSuite extends SparkFunSuite { checkDataTypeFromJson(MapType(IntegerType, ArrayType(DoubleType), false)) checkDataTypeFromDDL(MapType(IntegerType, ArrayType(DoubleType), false)) + checkDataTypeFromJson(CharType(1)) + checkDataTypeFromDDL(CharType(1)) + + checkDataTypeFromJson(VarcharType(10)) + checkDataTypeFromDDL(VarcharType(11)) + val metadata = new MetadataBuilder() .putString("name", "age") .build() @@ -310,6 +316,10 @@ class DataTypeSuite extends SparkFunSuite { checkDefaultSize(MapType(IntegerType, StringType, true), 24) checkDefaultSize(MapType(IntegerType, ArrayType(DoubleType), false), 12) checkDefaultSize(structType, 20) + checkDefaultSize(CharType(5), 5) + checkDefaultSize(CharType(100), 100) + checkDefaultSize(VarcharType(5), 5) + checkDefaultSize(VarcharType(10), 10) def checkEqualsIgnoreCompatibleNullability( from: DataType, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 86ba813..4ef23d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -1185,7 +1185,7 @@ class Column(val expr: Expression) extends Logging { * @since 1.3.0 */ def cast(to: DataType): Column = withExpr { - Cast(expr, CharVarcharUtils.replaceCharVarcharWithString(to)) + Cast(expr, CharVarcharUtils.replaceCharVarcharWithStringForCast(to)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 007df18..b94c42a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -73,7 +73,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * @since 1.4.0 */ def schema(schema: StructType): DataFrameReader = { - this.userSpecifiedSchema = Option(CharVarcharUtils.replaceCharVarcharWithStringInSchema(schema)) + val replaced = CharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType] + this.userSpecifiedSchema = Option(replaced) this } @@ -89,7 +90,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * @since 2.3.0 */ def schema(schemaString: String): DataFrameReader = { - this.userSpecifiedSchema = Option(StructType.fromDDL(schemaString)) + val rawSchema = StructType.fromDDL(schemaString) + val schema = CharVarcharUtils.failIfHasCharVarchar(rawSchema).asInstanceOf[StructType] + this.userSpecifiedSchema = Option(schema) this } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index db5ad52..83fb744 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range} +import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.ExternalCommandRunner import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command.ExternalCommandExecutor @@ -347,9 +348,10 @@ class SparkSession private( */ @DeveloperApi def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = withActive { + val replaced = CharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType] // TODO: use MutableProjection when rowRDD is another DataFrame and the applied // schema differs from the existing schema on any field data type. - val encoder = RowEncoder(schema) + val encoder = RowEncoder(replaced) val toRow = encoder.createSerializer() val catalystRows = rowRDD.map(toRow) internalCreateDataFrame(catalystRows.setName(rowRDD.name), schema) @@ -365,7 +367,8 @@ class SparkSession private( */ @DeveloperApi def createDataFrame(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = { - createDataFrame(rowRDD.rdd, schema) + val replaced = CharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType] + createDataFrame(rowRDD.rdd, replaced) } /** @@ -378,7 +381,8 @@ class SparkSession private( */ @DeveloperApi def createDataFrame(rows: java.util.List[Row], schema: StructType): DataFrame = withActive { - Dataset.ofRows(self, LocalRelation.fromExternalRows(schema.toAttributes, rows.asScala.toSeq)) + val replaced = CharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType] + Dataset.ofRows(self, LocalRelation.fromExternalRows(replaced.toAttributes, rows.asScala.toSeq)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index cceb385..237cfe1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} +import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.execution.aggregate.ScalaUDAF import org.apache.spark.sql.execution.python.UserDefinedPythonFunction import org.apache.spark.sql.expressions.{SparkUserDefinedFunction, UserDefinedAggregateFunction, UserDefinedAggregator, UserDefinedFunction} @@ -162,9 +163,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends | * @since $version | */ |def register(name: String, f: UDF$i[$extTypeArgs], returnType: DataType): Unit = { + | val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) | val func = $funcCall | def builder(e: Seq[Expression]) = if (e.length == $i) { - | ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) + | ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) | } else { | throw new AnalysisException("Invalid number of arguments for function " + name + | ". Expected: $i; Found: " + e.length) @@ -753,9 +755,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 2.3.0 */ def register(name: String, f: UDF0[_], returnType: DataType): Unit = { + val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) val func = () => f.asInstanceOf[UDF0[Any]].call() def builder(e: Seq[Expression]) = if (e.length == 0) { - ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) + ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 0; Found: " + e.length) @@ -768,9 +771,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF1[_, _], returnType: DataType): Unit = { + val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) val func = f.asInstanceOf[UDF1[Any, Any]].call(_: Any) def builder(e: Seq[Expression]) = if (e.length == 1) { - ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) + ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 1; Found: " + e.length) @@ -783,9 +787,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF2[_, _, _], returnType: DataType): Unit = { + val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 2) { - ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) + ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 2; Found: " + e.length) @@ -798,9 +803,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF3[_, _, _, _], returnType: DataType): Unit = { + val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) val func = f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 3) { - ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) + ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 3; Found: " + e.length) @@ -813,9 +819,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF4[_, _, _, _, _], returnType: DataType): Unit = { + val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) val func = f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 4) { - ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) + ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 4; Found: " + e.length) @@ -828,9 +835,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF5[_, _, _, _, _, _], returnType: DataType): Unit = { + val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) val func = f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 5) { - ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) + ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 5; Found: " + e.length) @@ -843,9 +851,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF6[_, _, _, _, _, _, _], returnType: DataType): Unit = { + val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) val func = f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 6) { - ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) + ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 6; Found: " + e.length) @@ -858,9 +867,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): Unit = { + val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) val func = f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 7) { - ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) + ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 7; Found: " + e.length) @@ -873,9 +883,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { + val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) val func = f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 8) { - ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) + ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 8; Found: " + e.length) @@ -888,9 +899,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { + val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) val func = f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 9) { - ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) + ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 9; Found: " + e.length) @@ -903,9 +915,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { + val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) val func = f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 10) { - ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) + ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 10; Found: " + e.length) @@ -918,9 +931,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { + val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) val func = f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 11) { - ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) + ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 11; Found: " + e.length) @@ -933,9 +947,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { + val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) val func = f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 12) { - ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) + ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 12; Found: " + e.length) @@ -948,9 +963,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { + val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) val func = f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 13) { - ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) + ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 13; Found: " + e.length) @@ -963,9 +979,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { + val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) val func = f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 14) { - ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) + ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 14; Found: " + e.length) @@ -978,9 +995,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { + val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) val func = f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 15) { - ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) + ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 15; Found: " + e.length) @@ -993,9 +1011,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { + val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) val func = f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 16) { - ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) + ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 16; Found: " + e.length) @@ -1008,9 +1027,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { + val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) val func = f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 17) { - ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) + ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 17; Found: " + e.length) @@ -1023,9 +1043,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { + val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) val func = f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 18) { - ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) + ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 18; Found: " + e.length) @@ -1038,9 +1059,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { + val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) val func = f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 19) { - ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) + ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 19; Found: " + e.length) @@ -1053,9 +1075,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { + val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) val func = f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 20) { - ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) + ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 20; Found: " + e.length) @@ -1068,9 +1091,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { + val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) val func = f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 21) { - ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) + ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 21; Found: " + e.length) @@ -1083,9 +1107,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { + val replaced = CharVarcharUtils.failIfHasCharVarchar(returnType) val func = f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 22) { - ScalaUDF(func, returnType, e, Nil, udfName = Some(name)) + ScalaUDF(func, replaced, e, Nil, udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 22; Found: " + e.length) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 216fb02..f997e57b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.connector.catalog.TableChange import org.apache.spark.sql.execution.datasources.jdbc.connection.ConnectionProvider import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType} @@ -761,10 +761,7 @@ object JdbcUtils extends Logging { schema: StructType, caseSensitive: Boolean, createTableColumnTypes: String): Map[String, String] = { - val parsedSchema = CatalystSqlParser.parseTableSchema(createTableColumnTypes) - val userSchema = StructType(parsedSchema.map { field => - field.copy(dataType = CharVarcharUtils.getRawType(field.metadata).getOrElse(field.dataType)) - }) + val userSchema = CatalystSqlParser.parseTableSchema(createTableColumnTypes) val nameEquality = if (caseSensitive) { org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 9861d21..5b1ee2d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, HintInfo, ResolvedHint} -import org.apache.spark.sql.catalyst.util.TimestampFormatter +import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, TimestampFormatter} import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.expressions.{Aggregator, SparkUserDefinedFunction, UserDefinedAggregator, UserDefinedFunction} import org.apache.spark.sql.internal.SQLConf @@ -4009,7 +4009,7 @@ object functions { * @since 2.2.0 */ def from_json(e: Column, schema: DataType, options: Map[String, String]): Column = withExpr { - JsonToStructs(schema, options, e.expr) + JsonToStructs(CharVarcharUtils.failIfHasCharVarchar(schema), options, e.expr) } /** @@ -4040,8 +4040,9 @@ object functions { * @group collection_funcs * @since 2.2.0 */ - def from_json(e: Column, schema: DataType, options: java.util.Map[String, String]): Column = - from_json(e, schema, options.asScala.toMap) + def from_json(e: Column, schema: DataType, options: java.util.Map[String, String]): Column = { + from_json(e, CharVarcharUtils.failIfHasCharVarchar(schema), options.asScala.toMap) + } /** * Parses a column containing a JSON string into a `StructType` with the specified schema. @@ -4393,7 +4394,8 @@ object functions { * @since 3.0.0 */ def from_csv(e: Column, schema: StructType, options: Map[String, String]): Column = withExpr { - CsvToStructs(schema, options, e.expr) + val replaced = CharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType] + CsvToStructs(replaced, options, e.expr) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala index abb1327..fcd334b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.SimpleInsertSource import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} -import org.apache.spark.sql.types.{ArrayType, CharType, DataType, MapType, StringType, StructField, StructType} +import org.apache.spark.sql.types._ // The base trait for char/varchar tests that need to be run with different table implementations. trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { @@ -435,55 +435,91 @@ class BasicCharVarcharTestSuite extends QueryTest with SharedSparkSession { assert(df.schema.map(_.dataType) == Seq(StringType)) } - assertNoCharType(spark.range(1).select($"id".cast("char(5)"))) - assertNoCharType(spark.range(1).select($"id".cast(CharType(5)))) - assertNoCharType(spark.range(1).selectExpr("CAST(id AS CHAR(5))")) - assertNoCharType(sql("SELECT CAST(id AS CHAR(5)) FROM range(1)")) + val logAppender = new LogAppender("The Spark cast operator does not support char/varchar" + + " type and simply treats them as string type. Please use string type directly to avoid" + + " confusion.") + withLogAppender(logAppender) { + assertNoCharType(spark.range(1).select($"id".cast("char(5)"))) + assertNoCharType(spark.range(1).select($"id".cast(CharType(5)))) + assertNoCharType(spark.range(1).selectExpr("CAST(id AS CHAR(5))")) + assertNoCharType(sql("SELECT CAST(id AS CHAR(5)) FROM range(1)")) + } } - test("user-specified schema in functions") { - val df = sql("""SELECT from_json('{"a": "str"}', 'a CHAR(5)')""") - checkAnswer(df, Row(Row("str"))) - val schema = df.schema.head.dataType.asInstanceOf[StructType] - assert(schema.map(_.dataType) == Seq(StringType)) + def failWithInvalidCharUsage[T](fn: => T): Unit = { + val e = intercept[AnalysisException](fn) + assert(e.getMessage contains "char/varchar type can only be used in the table schema") } - test("user-specified schema in DataFrameReader: file source from Dataset") { - val ds = spark.range(10).map(_.toString) - val df1 = spark.read.schema(new StructType().add("id", CharType(5))).csv(ds) - assert(df1.schema.map(_.dataType) == Seq(StringType)) - val df2 = spark.read.schema("id char(5)").csv(ds) - assert(df2.schema.map(_.dataType) == Seq(StringType)) + test("invalidate char/varchar in functions") { + failWithInvalidCharUsage(sql("""SELECT from_json('{"a": "str"}', 'a CHAR(5)')""")) + withSQLConf((SQLConf.LEGACY_CHAR_VARCHAR_AS_STRING.key, "true")) { + val df = sql("""SELECT from_json('{"a": "str"}', 'a CHAR(5)')""") + checkAnswer(df, Row(Row("str"))) + val schema = df.schema.head.dataType.asInstanceOf[StructType] + assert(schema.map(_.dataType) == Seq(StringType)) + } } - test("user-specified schema in DataFrameReader: DSV1") { - def checkSchema(df: DataFrame): Unit = { - val relations = df.queryExecution.analyzed.collect { - case l: LogicalRelation => l.relation - } - assert(relations.length == 1) - assert(relations.head.schema.map(_.dataType) == Seq(StringType)) + test("invalidate char/varchar in SparkSession createDataframe") { + val df = spark.range(10).map(_.toString).toDF() + val schema = new StructType().add("id", CharType(5)) + failWithInvalidCharUsage(spark.createDataFrame(df.collectAsList(), schema)) + failWithInvalidCharUsage(spark.createDataFrame(df.rdd, schema)) + failWithInvalidCharUsage(spark.createDataFrame(df.toJavaRDD, schema)) + withSQLConf((SQLConf.LEGACY_CHAR_VARCHAR_AS_STRING.key, "true")) { + val df1 = spark.createDataFrame(df.collectAsList(), schema) + checkAnswer(df1, df) + assert(df1.schema.head.dataType === StringType) } - - checkSchema(spark.read.schema(new StructType().add("id", CharType(5))) - .format(classOf[SimpleInsertSource].getName).load()) - checkSchema(spark.read.schema("id char(5)") - .format(classOf[SimpleInsertSource].getName).load()) } - test("user-specified schema in DataFrameReader: DSV2") { - def checkSchema(df: DataFrame): Unit = { - val tables = df.queryExecution.analyzed.collect { - case d: DataSourceV2Relation => d.table + test("invalidate char/varchar in spark.read.schema") { + failWithInvalidCharUsage(spark.read.schema(new StructType().add("id", CharType(5)))) + failWithInvalidCharUsage(spark.read.schema("id char(5)")) + withSQLConf((SQLConf.LEGACY_CHAR_VARCHAR_AS_STRING.key, "true")) { + val ds = spark.range(10).map(_.toString) + val df1 = spark.read.schema(new StructType().add("id", CharType(5))).csv(ds) + assert(df1.schema.map(_.dataType) == Seq(StringType)) + val df2 = spark.read.schema("id char(5)").csv(ds) + assert(df2.schema.map(_.dataType) == Seq(StringType)) + + def checkSchema(df: DataFrame): Unit = { + val schemas = df.queryExecution.analyzed.collect { + case l: LogicalRelation => l.relation.schema + case d: DataSourceV2Relation => d.table.schema() + } + assert(schemas.length == 1) + assert(schemas.head.map(_.dataType) == Seq(StringType)) } - assert(tables.length == 1) - assert(tables.head.schema.map(_.dataType) == Seq(StringType)) - } - checkSchema(spark.read.schema(new StructType().add("id", CharType(5))) - .format(classOf[SchemaRequiredDataSource].getName).load()) - checkSchema(spark.read.schema("id char(5)") - .format(classOf[SchemaRequiredDataSource].getName).load()) + // user-specified schema in DataFrameReader: DSV1 + checkSchema(spark.read.schema(new StructType().add("id", CharType(5))) + .format(classOf[SimpleInsertSource].getName).load()) + checkSchema(spark.read.schema("id char(5)") + .format(classOf[SimpleInsertSource].getName).load()) + + // user-specified schema in DataFrameReader: DSV2 + checkSchema(spark.read.schema(new StructType().add("id", CharType(5))) + .format(classOf[SchemaRequiredDataSource].getName).load()) + checkSchema(spark.read.schema("id char(5)") + .format(classOf[SchemaRequiredDataSource].getName).load()) + } + } + + test("invalidate char/varchar in udf's result type") { + failWithInvalidCharUsage(spark.udf.register("testchar", () => "B", VarcharType(1))) + failWithInvalidCharUsage(spark.udf.register("testchar2", (x: String) => x, VarcharType(1))) + withSQLConf((SQLConf.LEGACY_CHAR_VARCHAR_AS_STRING.key, "true")) { + spark.udf.register("testchar", () => "B", VarcharType(1)) + spark.udf.register("testchar2", (x: String) => x, VarcharType(1)) + val df1 = spark.sql("select testchar()") + checkAnswer(df1, Row("B")) + assert(df1.schema.head.dataType === StringType) + val df2 = spark.sql("select testchar2('abc')") + checkAnswer(df2, Row("abc")) + assert(df2.schema.head.dataType === StringType) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index f02d204..ea276bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -384,9 +384,6 @@ case class MyParser(spark: SparkSession, delegate: ParserInterface) extends Pars override def parseDataType(sqlText: String): DataType = delegate.parseDataType(sqlText) - - override def parseRawDataType(sqlText: String): DataType = - delegate.parseRawDataType(sqlText) } object MyExtensions { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index fb46c2f..1a28523 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -390,14 +390,13 @@ class JDBCWriteSuite extends SharedSparkSession with BeforeAndAfter { .foldLeft(new StructType())((schema, colType) => schema.add(colType._1, colType._2)) val createTableColTypes = colTypes.map { case (col, dataType) => s"$col $dataType" }.mkString(", ") - val df = spark.createDataFrame(sparkContext.parallelize(Seq(Row.empty)), schema) val expectedSchemaStr = colTypes.map { case (col, dataType) => s""""$col" $dataType """ }.mkString(", ") assert(JdbcUtils.schemaString( - df.schema, - df.sqlContext.conf.caseSensitiveAnalysis, + schema, + spark.sqlContext.conf.caseSensitiveAnalysis, url1, Option(createTableColTypes)) == expectedSchemaStr) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index bada131..34befb8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -985,7 +985,7 @@ private[hive] object HiveClientImpl extends Logging { /** Get the Spark SQL native DataType from Hive's FieldSchema. */ private def getSparkSQLDataType(hc: FieldSchema): DataType = { try { - CatalystSqlParser.parseRawDataType(hc.getType) + CatalystSqlParser.parseDataType(hc.getType) } catch { case e: ParseException => throw new SparkException( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org