This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new e0c8f14ce53 [SPARK-44531][CONNECT][SQL] Move encoder inference to sql/api e0c8f14ce53 is described below commit e0c8f14ce53080e2863c076b7912239bee35003e Author: Herman van Hovell <her...@databricks.com> AuthorDate: Wed Jul 26 07:15:27 2023 -0400 [SPARK-44531][CONNECT][SQL] Move encoder inference to sql/api ### What changes were proposed in this pull request? This PR move encoder inference (ScalaReflection/RowEncoder/JavaTypeInference) into sql/api. ### Why are the changes needed? We want to use encoder inference in the spark connect scala client. The client's dependency to catalyst is going away, so we need to move this. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. Closes #42134 from hvanhovell/SPARK-44531. Authored-by: Herman van Hovell <her...@databricks.com> Signed-off-by: Herman van Hovell <her...@databricks.com> (cherry picked from commit 071feabbd4325504332679dfa620bc5ee4359370) Signed-off-by: Herman van Hovell <her...@databricks.com> --- .../spark/ml/source/image/ImageFileFormat.scala | 4 +- .../spark/ml/source/libsvm/LibSVMRelation.scala | 4 +- project/MimaExcludes.scala | 3 + sql/api/pom.xml | 4 ++ .../java/org/apache/spark/sql/types/DataTypes.java | 0 .../apache/spark/sql/types/SQLUserDefinedType.java | 0 .../scala/org/apache/spark/sql/SqlApiConf.scala | 2 + .../spark/sql/catalyst/JavaTypeInference.scala | 9 ++- .../spark/sql/catalyst/ScalaReflection.scala | 13 ++-- .../apache/spark/sql/catalyst/WalkedTypePath.scala | 0 .../spark/sql/catalyst/encoders/RowEncoder.scala | 19 ++---- .../spark/sql/errors/DataTypeErrorsBase.scala | 8 ++- .../apache/spark/sql/errors/EncoderErrors.scala | 74 ++++++++++++++++++++++ sql/catalyst/pom.xml | 5 -- .../sql/catalyst/encoders/ExpressionEncoder.scala | 8 ++- .../spark/sql/catalyst/plans/logical/object.scala | 8 +-- .../spark/sql/errors/QueryExecutionErrors.scala | 60 +----------------- .../org/apache/spark/sql/internal/SQLConf.scala | 2 +- .../spark/sql/CalendarIntervalBenchmark.scala | 4 +- .../scala/org/apache/spark/sql/HashBenchmark.scala | 4 +- .../spark/sql/UnsafeProjectionBenchmark.scala | 4 +- .../sql/catalyst/encoders/RowEncoderSuite.scala | 50 +++++++-------- .../expressions/HashExpressionsSuite.scala | 4 +- .../expressions/ObjectExpressionsSuite.scala | 2 +- .../optimizer/ObjectSerializerPruningSuite.scala | 4 +- .../catalyst/util/ArrayDataIndexedSeqSuite.scala | 4 +- .../spark/sql/catalyst/util/UnsafeArraySuite.scala | 6 +- .../main/scala/org/apache/spark/sql/Dataset.scala | 6 +- .../scala/org/apache/spark/sql/SparkSession.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 4 +- .../execution/datasources/DataSourceStrategy.scala | 4 +- .../sql/execution/datasources/jdbc/JdbcUtils.scala | 4 +- .../execution/datasources/v2/V2CommandExec.scala | 4 +- .../FlatMapGroupsInPandasWithStateExec.scala | 4 +- .../execution/streaming/MicroBatchExecution.scala | 4 +- .../sql/execution/streaming/sources/memory.scala | 4 +- .../spark/sql/DataFrameSessionWindowingSuite.scala | 4 +- .../org/apache/spark/sql/DataFrameSuite.scala | 8 +-- .../spark/sql/DataFrameTimeWindowingSuite.scala | 4 +- .../spark/sql/DatasetOptimizationSuite.scala | 4 +- .../scala/org/apache/spark/sql/DatasetSuite.scala | 8 +-- .../spark/sql/execution/GroupedIteratorSuite.scala | 8 +-- .../binaryfile/BinaryFileFormatSuite.scala | 4 +- .../streaming/sources/ForeachBatchSinkSuite.scala | 4 +- .../apache/spark/sql/streaming/StreamTest.scala | 4 +- 45 files changed, 205 insertions(+), 184 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/image/ImageFileFormat.scala b/mllib/src/main/scala/org/apache/spark/ml/source/image/ImageFileFormat.scala index 206ce6f0675..bf6e6b8eec0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/image/ImageFileFormat.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/image/ImageFileFormat.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.mapreduce.Job import org.apache.spark.ml.image.ImageSchema import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriterFactory, PartitionedFile} import org.apache.spark.sql.sources.{DataSourceRegister, Filter} @@ -90,7 +90,7 @@ private[image] class ImageFileFormat extends FileFormat with DataSourceRegister if (requiredSchema.isEmpty) { filteredResult.map(_ => emptyUnsafeRow) } else { - val toRow = RowEncoder(requiredSchema).createSerializer() + val toRow = ExpressionEncoder(requiredSchema).createSerializer() filteredResult.map(row => toRow(row)) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index f4c5e3eece2..3581693c05b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -31,7 +31,7 @@ import org.apache.spark.ml.linalg.{Vectors, VectorUDT} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.types.DataTypeUtils @@ -167,7 +167,7 @@ private[libsvm] class LibSVMFileFormat LabeledPoint(label, Vectors.sparse(numFeatures, indices, values)) } - val toRow = RowEncoder(dataSchema).createSerializer() + val toRow = ExpressionEncoder(dataSchema).createSerializer() val fullOutput = dataSchema.map { f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)() } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 6d527610231..0fdf6ad534d 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,9 @@ object MimaExcludes { // Exclude rules for 3.5.x from 3.4.0 lazy val v35excludes = defaultExcludes ++ Seq( + // [SPARK-44531][CONNECT][SQL] Move encoder inference to sql/api + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.DataTypes"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.SQLUserDefinedType"), // [SPARK-43165][SQL] Move canWrite to DataTypeUtils ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.types.DataType.canWrite"), // [SPARK-43195][CORE] Remove unnecessary serializable wrapper in HadoopFSUtils diff --git a/sql/api/pom.xml b/sql/api/pom.xml index 4119ee11f54..a00a7024bf3 100644 --- a/sql/api/pom.xml +++ b/sql/api/pom.xml @@ -35,6 +35,10 @@ </properties> <dependencies> + <dependency> + <groupId>org.scala-lang</groupId> + <artifactId>scala-reflect</artifactId> + </dependency> <dependency> <groupId>org.scala-lang.modules</groupId> <artifactId>scala-parser-combinators_${scala.binary.version}</artifactId> diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java b/sql/api/src/main/java/org/apache/spark/sql/types/DataTypes.java similarity index 100% rename from sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java rename to sql/api/src/main/java/org/apache/spark/sql/types/DataTypes.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java b/sql/api/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java similarity index 100% rename from sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java rename to sql/api/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java diff --git a/sql/api/src/main/scala/org/apache/spark/sql/SqlApiConf.scala b/sql/api/src/main/scala/org/apache/spark/sql/SqlApiConf.scala index 48efa510666..66398495d90 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/SqlApiConf.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/SqlApiConf.scala @@ -41,6 +41,7 @@ private[sql] trait SqlApiConf { def timestampType: AtomicType def allowNegativeScaleOfDecimalEnabled: Boolean def charVarcharAsString: Boolean + def datetimeJava8ApiEnabled: Boolean } private[sql] object SqlApiConf { @@ -76,4 +77,5 @@ private[sql] object DefaultSqlApiConf extends SqlApiConf { override def timestampType: AtomicType = TimestampType override def allowNegativeScaleOfDecimalEnabled: Boolean = false override def charVarcharAsString: Boolean = false + override def datetimeJava8ApiEnabled: Boolean = false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala similarity index 96% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala rename to sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 36b98737a20..ec9e704c66d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -24,10 +24,9 @@ import javax.annotation.Nonnull import scala.annotation.tailrec import scala.reflect.ClassTag -import org.apache.spark.SPARK_DOC_ROOT import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, DayTimeIntervalEncoder, DEFAULT_JAVA_DECIMAL_ENCODER, EncoderField, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaEnumEncoder, LocalDateTimeEncoder, MapEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, P [...] -import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.errors.EncoderErrors import org.apache.spark.sql.types._ /** @@ -116,7 +115,7 @@ object JavaTypeInference { case c: Class[_] => if (seenTypeSet.contains(c)) { - throw QueryExecutionErrors.cannotHaveCircularReferencesInBeanClassError(c) + throw EncoderErrors.cannotHaveCircularReferencesInBeanClassError(c) } // TODO: we should only collect properties that have getter and setter. However, some tests @@ -139,7 +138,7 @@ object JavaTypeInference { JavaBeanEncoder(ClassTag(c), fields) case _ => - throw QueryExecutionErrors.cannotFindEncoderForTypeError(t.toString, SPARK_DOC_ROOT) + throw EncoderErrors.cannotFindEncoderForTypeError(t.toString) } def getJavaBeanReadableProperties(beanClass: Class[_]): Array[PropertyDescriptor] = { @@ -197,7 +196,7 @@ object JavaTypeInference { } } } - throw QueryExecutionErrors.unreachableError() + throw EncoderErrors.unreachableError() } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala similarity index 97% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala rename to sql/api/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 9f2548c3789..1f366393d25 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -26,12 +26,11 @@ import scala.util.{Failure, Success} import org.apache.commons.lang3.reflect.ConstructorUtils -import org.apache.spark.SPARK_DOC_ROOT import org.apache.spark.internal.Logging import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ -import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.errors.EncoderErrors import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -378,13 +377,13 @@ object ScalaReflection extends ScalaReflection { case t if definedByConstructorParams(t) => if (seenTypeSet.contains(t)) { - throw QueryExecutionErrors.cannotHaveCircularReferencesInClassError(t.toString) + throw EncoderErrors.cannotHaveCircularReferencesInClassError(t.toString) } val params = getConstructorParameters(t).map { case (fieldName, fieldType) => if (SourceVersion.isKeyword(fieldName) || !SourceVersion.isIdentifier(encodeFieldNameToIdentifier(fieldName))) { - throw QueryExecutionErrors.cannotUseInvalidJavaIdentifierAsFieldNameError( + throw EncoderErrors.cannotUseInvalidJavaIdentifierAsFieldNameError( fieldName, path) } @@ -397,7 +396,7 @@ object ScalaReflection extends ScalaReflection { } ProductEncoder(ClassTag(getClassFromType(t)), params) case _ => - throw QueryExecutionErrors.cannotFindEncoderForTypeError(tpe.toString, SPARK_DOC_ROOT) + throw EncoderErrors.cannotFindEncoderForTypeError(tpe.toString) } } } @@ -478,7 +477,7 @@ trait ScalaReflection extends Logging { */ private def getCompanionConstructor(tpe: Type): Symbol = { def throwUnsupportedOperation = { - throw QueryExecutionErrors.cannotFindConstructorForTypeError(tpe.toString) + throw EncoderErrors.cannotFindConstructorForTypeError(tpe.toString) } tpe.typeSymbol.asClass.companion match { case NoSymbol => throwUnsupportedOperation @@ -501,7 +500,7 @@ trait ScalaReflection extends Logging { val primaryConstructorSymbol: Option[Symbol] = constructorSymbol.asTerm.alternatives.find( s => s.isMethod && s.asMethod.isPrimaryConstructor) if (primaryConstructorSymbol.isEmpty) { - throw QueryExecutionErrors.primaryConstructorNotFoundError(tpe.getClass) + throw EncoderErrors.primaryConstructorNotFoundError(tpe.getClass) } else { primaryConstructorSymbol.get.asMethod.paramLists } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala rename to sql/api/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala similarity index 88% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala rename to sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 78243894544..5ab9c3bfc48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -20,10 +20,9 @@ package org.apache.spark.sql.catalyst.encoders import scala.collection.mutable import scala.reflect.classTag -import org.apache.spark.sql.Row +import org.apache.spark.sql.{Row, SqlApiConf} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder, DateEncoder, DayTimeIntervalEncoder, EncoderField, InstantEncoder, IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, NullEncoder, RowEncoder => AgnosticRowEncoder, StringEncoder, TimestampEncoder, UDTEncoder, YearMont [...] -import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.errors.EncoderErrors import org.apache.spark.sql.types._ /** @@ -59,14 +58,6 @@ import org.apache.spark.sql.types._ * }}} */ object RowEncoder { - def apply(schema: StructType, lenient: Boolean): ExpressionEncoder[Row] = { - ExpressionEncoder(encoderFor(schema, lenient)) - } - - def apply(schema: StructType): ExpressionEncoder[Row] = { - apply(schema, lenient = false) - } - def encoderFor(schema: StructType): AgnosticEncoder[Row] = { encoderFor(schema, lenient = false) } @@ -89,10 +80,10 @@ object RowEncoder { case dt: DecimalType => JavaDecimalEncoder(dt, lenientSerialization = true) case BinaryType => BinaryEncoder case StringType => StringEncoder - case TimestampType if SQLConf.get.datetimeJava8ApiEnabled => InstantEncoder(lenient) + case TimestampType if SqlApiConf.get.datetimeJava8ApiEnabled => InstantEncoder(lenient) case TimestampType => TimestampEncoder(lenient) case TimestampNTZType => LocalDateTimeEncoder - case DateType if SQLConf.get.datetimeJava8ApiEnabled => LocalDateEncoder(lenient) + case DateType if SqlApiConf.get.datetimeJava8ApiEnabled => LocalDateEncoder(lenient) case DateType => DateEncoder(lenient) case CalendarIntervalType => CalendarIntervalEncoder case _: DayTimeIntervalType => DayTimeIntervalEncoder @@ -106,7 +97,7 @@ object RowEncoder { annotation.udt() } else { UDTRegistration.getUDTFor(udt.userClass.getName).getOrElse { - throw QueryExecutionErrors.userDefinedTypeNotAnnotatedAndRegisteredError(udt) + throw EncoderErrors.userDefinedTypeNotAnnotatedAndRegisteredError(udt) } } UDTEncoder(udt, udtClass.asInstanceOf[Class[_ <: UserDefinedType[_]]]) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrorsBase.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrorsBase.scala index 4a8847959c2..0f4cb88dcf1 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrorsBase.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrorsBase.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.errors import java.util.Locale -import org.apache.spark.QueryContext +import org.apache.spark.{QueryContext, SparkRuntimeException} import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.catalyst.util.{AttributeNameParser, QuotingUtils, SparkStringUtils} import org.apache.spark.sql.types.{AbstractDataType, DataType, TypeCollection} @@ -73,4 +73,10 @@ private[sql] trait DataTypeErrorsBase { def getQueryContext(sqlContext: SQLQueryContext): Array[QueryContext] = { if (sqlContext == null) Array.empty else Array(sqlContext.asInstanceOf[QueryContext]) } + + def unreachableError(err: String = ""): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2028", + messageParameters = Map("err" -> err)) + } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/EncoderErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/EncoderErrors.scala new file mode 100644 index 00000000000..e70a5124ce7 --- /dev/null +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/EncoderErrors.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.errors + +import org.apache.spark.{SparkBuildInfo, SparkException, SparkRuntimeException, SparkUnsupportedOperationException} +import org.apache.spark.sql.catalyst.WalkedTypePath +import org.apache.spark.sql.types.UserDefinedType + +object EncoderErrors extends DataTypeErrorsBase { + def userDefinedTypeNotAnnotatedAndRegisteredError(udt: UserDefinedType[_]): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2155", + messageParameters = Map( + "userClass" -> udt.userClass.getName), + cause = null) + } + + def cannotFindEncoderForTypeError(typeName: String): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "ENCODER_NOT_FOUND", + messageParameters = Map( + "typeName" -> typeName, + "docroot" -> SparkBuildInfo.spark_doc_root)) + } + + def cannotHaveCircularReferencesInBeanClassError( + clazz: Class[_]): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2138", + messageParameters = Map("clazz" -> clazz.toString())) + } + + def cannotFindConstructorForTypeError(tpe: String): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2144", + messageParameters = Map( + "tpe" -> tpe)) + } + + def cannotHaveCircularReferencesInClassError(t: String): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2139", + messageParameters = Map("t" -> t)) + } + + def cannotUseInvalidJavaIdentifierAsFieldNameError( + fieldName: String, walkedTypePath: WalkedTypePath): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2140", + messageParameters = Map( + "fieldName" -> fieldName, + "walkedTypePath" -> walkedTypePath.toString())) + } + + def primaryConstructorNotFoundError(cls: Class[_]): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2021", + messageParameters = Map("cls" -> cls.toString())) + } +} diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index e19b03bfbe8..242bf304708 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -35,11 +35,6 @@ </properties> <dependencies> - <dependency> - <groupId>org.scala-lang</groupId> - <artifactId>scala-reflect</artifactId> - </dependency> - <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-core_${scala.binary.version}</artifactId> diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 83a018bafe7..ff72b5a0d96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.encoders import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.sql.Encoder +import org.apache.spark.sql.{Encoder, Row} import org.apache.spark.sql.catalyst.{DeserializerBuildHelper, InternalRow, JavaTypeInference, ScalaReflection, SerializerBuildHelper} import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.{Deserializer, Serializer} @@ -58,6 +58,12 @@ object ExpressionEncoder { enc.clsTag) } + def apply(schema: StructType): ExpressionEncoder[Row] = apply(schema, lenient = false) + + def apply(schema: StructType, lenient: Boolean): ExpressionEncoder[Row] = { + apply(RowEncoder.encoderFor(schema, lenient)) + } + // TODO: improve error message for java bean encoder. def javaBean[T](beanClass: Class[T]): ExpressionEncoder[T] = { apply(JavaTypeInference.encoderFor(beanClass)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 35b0bd4363b..0abbbae93c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -149,8 +149,8 @@ object MapPartitionsInR { broadcastVars, encoder.schema, schema, - CatalystSerde.generateObjAttr(RowEncoder(schema)), - deserialized))(RowEncoder(schema)) + CatalystSerde.generateObjAttr(ExpressionEncoder(schema)), + deserialized))(ExpressionEncoder(schema)) } } } @@ -606,8 +606,8 @@ object FlatMapGroupsInR { UnresolvedDeserializer(valueDeserializer, dataAttributes), groupingAttributes, dataAttributes, - CatalystSerde.generateObjAttr(RowEncoder(schema)), - child))(RowEncoder(schema)) + CatalystSerde.generateObjAttr(ExpressionEncoder(schema)), + child))(ExpressionEncoder(schema)) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 2d0e29b1032..5b1eaa20d22 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -33,8 +33,8 @@ import org.apache.spark._ import org.apache.spark.launcher.SparkLauncher import org.apache.spark.memory.SparkOutOfMemoryError import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{TableIdentifier, WalkedTypePath} import org.apache.spark.sql.catalyst.ScalaReflection.Schema +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedGenerator import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} @@ -479,12 +479,6 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { messageParameters = Map("cls" -> cls)) } - def primaryConstructorNotFoundError(cls: Class[_]): SparkRuntimeException = { - new SparkRuntimeException( - errorClass = "_LEGACY_ERROR_TEMP_2021", - messageParameters = Map("cls" -> cls.toString())) - } - def unsupportedNaturalJoinTypeError(joinType: JoinType): SparkException = { SparkException.internalError( s"Unsupported natural join type ${joinType.toString}") @@ -526,12 +520,6 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { messageParameters = Map("op" -> op.toString(), "pos" -> pos)) } - def unreachableError(err: String = ""): SparkRuntimeException = { - new SparkRuntimeException( - errorClass = "_LEGACY_ERROR_TEMP_2028", - messageParameters = Map("err" -> err)) - } - def unsupportedRoundingMode(roundMode: BigDecimal.RoundingMode.Value): SparkException = { DataTypeErrors.unsupportedRoundingMode(roundMode) } @@ -1446,37 +1434,6 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { messageParameters = Map.empty) } - def cannotHaveCircularReferencesInBeanClassError( - clazz: Class[_]): SparkUnsupportedOperationException = { - new SparkUnsupportedOperationException( - errorClass = "_LEGACY_ERROR_TEMP_2138", - messageParameters = Map("clazz" -> clazz.toString())) - } - - def cannotHaveCircularReferencesInClassError(t: String): SparkUnsupportedOperationException = { - new SparkUnsupportedOperationException( - errorClass = "_LEGACY_ERROR_TEMP_2139", - messageParameters = Map("t" -> t)) - } - - def cannotUseInvalidJavaIdentifierAsFieldNameError( - fieldName: String, walkedTypePath: WalkedTypePath): SparkUnsupportedOperationException = { - new SparkUnsupportedOperationException( - errorClass = "_LEGACY_ERROR_TEMP_2140", - messageParameters = Map( - "fieldName" -> fieldName, - "walkedTypePath" -> walkedTypePath.toString())) - } - - def cannotFindEncoderForTypeError( - typeName: String, docroot: String): SparkUnsupportedOperationException = { - new SparkUnsupportedOperationException( - errorClass = "ENCODER_NOT_FOUND", - messageParameters = Map( - "typeName" -> typeName, - "docroot" -> docroot)) - } - def attributesForTypeUnsupportedError(schema: Schema): SparkUnsupportedOperationException = { new SparkUnsupportedOperationException( errorClass = "_LEGACY_ERROR_TEMP_2142", @@ -1484,13 +1441,6 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { "schema" -> schema.toString())) } - def cannotFindConstructorForTypeError(tpe: String): SparkUnsupportedOperationException = { - new SparkUnsupportedOperationException( - errorClass = "_LEGACY_ERROR_TEMP_2144", - messageParameters = Map( - "tpe" -> tpe)) - } - def paramExceedOneCharError(paramName: String): SparkRuntimeException = { new SparkRuntimeException( errorClass = "_LEGACY_ERROR_TEMP_2145", @@ -1569,14 +1519,6 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { "innerCls" -> innerCls.getName)) } - def userDefinedTypeNotAnnotatedAndRegisteredError(udt: UserDefinedType[_]): Throwable = { - new SparkException( - errorClass = "_LEGACY_ERROR_TEMP_2155", - messageParameters = Map( - "userClass" -> udt.userClass.getName), - cause = null) - } - def unsupportedOperandTypeForSizeFunctionError( dataType: DataType): SparkUnsupportedOperationException = { new SparkUnsupportedOperationException( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 00bb6f77ef3..d4987e3443f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -4727,7 +4727,7 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def streamingSessionWindowMergeSessionInLocalPartition: Boolean = getConf(STREAMING_SESSION_WINDOW_MERGE_SESSIONS_IN_LOCAL_PARTITION) - def datetimeJava8ApiEnabled: Boolean = getConf(DATETIME_JAVA8API_ENABLED) + override def datetimeJava8ApiEnabled: Boolean = getConf(DATETIME_JAVA8API_ENABLED) def uiExplainMode: String = getConf(UI_EXPLAIN_MODE) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/CalendarIntervalBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/CalendarIntervalBenchmark.scala index f9ab7455778..043e2b01378 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/CalendarIntervalBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/CalendarIntervalBenchmark.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.types.{CalendarIntervalType, DataType, StructType} import org.apache.spark.unsafe.types.CalendarInterval @@ -44,7 +44,7 @@ object CalendarIntervalBenchmark extends BenchmarkBase { assert(schema.head.dataType.isInstanceOf[CalendarIntervalType]) runBenchmark(name) { val generator = RandomDataGenerator.forType(schema, nullable = false).get - val toRow = RowEncoder(schema).createSerializer() + val toRow = ExpressionEncoder(schema).createSerializer() val intervals = (1 to numRows).map(_ => toRow(generator().asInstanceOf[Row]).copy().getInterval(0)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala index 13ab7e2a705..e515b771c96 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} -import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection import org.apache.spark.sql.catalyst.types.DataTypeUtils @@ -42,7 +42,7 @@ object HashBenchmark extends BenchmarkBase { def test(name: String, schema: StructType, numRows: Int, iters: Int): Unit = { runBenchmark(name) { val generator = RandomDataGenerator.forType(schema, nullable = false).get - val toRow = RowEncoder(schema).createSerializer() + val toRow = ExpressionEncoder(schema).createSerializer() val attrs = DataTypeUtils.toAttributes(schema) val safeProjection = GenerateSafeProjection.generate(attrs, attrs) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala index b7704eb211f..186ae44a108 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.types._ @@ -39,7 +39,7 @@ object UnsafeProjectionBenchmark extends BenchmarkBase { def generateRows(schema: StructType, numRows: Int): Array[InternalRow] = { val generator = RandomDataGenerator.forType(schema, nullable = false).get - val toRow = RowEncoder(schema).createSerializer() + val toRow = ExpressionEncoder(schema).createSerializer() (1 to numRows).map(_ => toRow(generator().asInstanceOf[Row]).copy()).toArray } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index b133b38a559..b82760d8eb1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -151,7 +151,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { .add("scala_decimal", DecimalType.SYSTEM_DEFAULT) .add("catalyst_decimal", DecimalType.SYSTEM_DEFAULT) - val encoder = RowEncoder(schema).resolveAndBind() + val encoder = ExpressionEncoder(schema).resolveAndBind() val javaDecimal = new java.math.BigDecimal("1234.5678") val scalaDecimal = BigDecimal("1234.5678") @@ -167,7 +167,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { test("RowEncoder should preserve decimal precision and scale") { val schema = new StructType().add("decimal", DecimalType(10, 5), false) - val encoder = RowEncoder(schema).resolveAndBind() + val encoder = ExpressionEncoder(schema).resolveAndBind() val decimal = Decimal("67123.45") val input = Row(decimal) val row = toRow(encoder, input) @@ -183,7 +183,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { private def testDecimalOverflow(schema: StructType, row: Row): Unit = { withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { - val encoder = RowEncoder(schema).resolveAndBind() + val encoder = ExpressionEncoder(schema).resolveAndBind() intercept[Exception] { toRow(encoder, row) } match { @@ -196,14 +196,14 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { } withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { - val encoder = RowEncoder(schema).resolveAndBind() + val encoder = ExpressionEncoder(schema).resolveAndBind() assert(roundTrip(encoder, row).get(0) == null) } } test("RowEncoder should preserve schema nullability") { val schema = new StructType().add("int", IntegerType, nullable = false) - val encoder = RowEncoder(schema).resolveAndBind() + val encoder = ExpressionEncoder(schema).resolveAndBind() assert(encoder.serializer.length == 1) assert(encoder.serializer.head.dataType == IntegerType) assert(encoder.serializer.head.nullable == false) @@ -219,7 +219,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { new StructType().add("int", IntegerType, nullable = false), nullable = false), nullable = false) - val encoder = RowEncoder(schema).resolveAndBind() + val encoder = ExpressionEncoder(schema).resolveAndBind() assert(encoder.serializer.length == 1) assert(encoder.serializer.head.dataType == new StructType() @@ -240,7 +240,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { .add("longPrimitiveArray", ArrayType(LongType, false)) .add("floatPrimitiveArray", ArrayType(FloatType, false)) .add("doublePrimitiveArray", ArrayType(DoubleType, false)) - val encoder = RowEncoder(schema).resolveAndBind() + val encoder = ExpressionEncoder(schema).resolveAndBind() val input = Seq( Array(true, false), Array(1.toByte, 64.toByte, Byte.MaxValue), @@ -261,7 +261,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { .add("array", ArrayType(IntegerType)) .add("nestedArray", ArrayType(ArrayType(StringType))) .add("deepNestedArray", ArrayType(ArrayType(ArrayType(LongType)))) - val encoder = RowEncoder(schema).resolveAndBind() + val encoder = ExpressionEncoder(schema).resolveAndBind() val input = Row( Array(1, 2, null), Array(Array("abc", null), null), @@ -274,7 +274,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { test("RowEncoder should throw RuntimeException if input row object is null") { val schema = new StructType().add("int", IntegerType) - val encoder = RowEncoder(schema) + val encoder = ExpressionEncoder(schema) val e = intercept[RuntimeException](toRow(encoder, null)) assert(e.getMessage.contains("Null value appeared in non-nullable field")) assert(e.getMessage.contains("top level Product or row object")) @@ -283,14 +283,14 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { test("RowEncoder should validate external type") { val e1 = intercept[RuntimeException] { val schema = new StructType().add("a", IntegerType) - val encoder = RowEncoder(schema) + val encoder = ExpressionEncoder(schema) toRow(encoder, Row(1.toShort)) } assert(e1.getMessage.contains("java.lang.Short is not a valid external type")) val e2 = intercept[RuntimeException] { val schema = new StructType().add("a", StringType) - val encoder = RowEncoder(schema) + val encoder = ExpressionEncoder(schema) toRow(encoder, Row(1)) } assert(e2.getMessage.contains("java.lang.Integer is not a valid external type")) @@ -298,14 +298,14 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { val e3 = intercept[RuntimeException] { val schema = new StructType().add("a", new StructType().add("b", IntegerType).add("c", StringType)) - val encoder = RowEncoder(schema) + val encoder = ExpressionEncoder(schema) toRow(encoder, Row(1 -> "a")) } assert(e3.getMessage.contains("scala.Tuple2 is not a valid external type")) val e4 = intercept[RuntimeException] { val schema = new StructType().add("a", ArrayType(TimestampType)) - val encoder = RowEncoder(schema) + val encoder = ExpressionEncoder(schema) toRow(encoder, Row(Array("a"))) } assert(e4.getMessage.contains("java.lang.String is not a valid external type")) @@ -314,7 +314,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { private def roundTripArray[T](dt: DataType, nullable: Boolean, data: Array[T]): Unit = { val schema = new StructType().add("a", ArrayType(dt, nullable)) test(s"RowEncoder should return WrappedArray with properly typed array for $schema") { - val encoder = RowEncoder(schema).resolveAndBind() + val encoder = ExpressionEncoder(schema).resolveAndBind() val result = fromRow(encoder, toRow(encoder, Row(data))).getAs[mutable.WrappedArray[_]](0) assert(result.array.getClass === data.getClass) assert(result === data) @@ -328,14 +328,14 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { val udtSQLType = new StructType().add("a", IntegerType) val pythonUDT = new PythonUserDefinedType(udtSQLType, "pyUDT", "serializedPyClass") val schema = new StructType().add("pythonUDT", pythonUDT, true) - val encoder = RowEncoder(schema) + val encoder = ExpressionEncoder(schema) assert(encoder.serializer(0).dataType == pythonUDT.sqlType) } test("encoding/decoding TimestampType to/from java.time.Instant") { withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> "true") { val schema = new StructType().add("t", TimestampType) - val encoder = RowEncoder(schema).resolveAndBind() + val encoder = ExpressionEncoder(schema).resolveAndBind() val instant = java.time.Instant.parse("2019-02-26T16:56:00Z") val row = toRow(encoder, Row(instant)) assert(row.getLong(0) === DateTimeUtils.instantToMicros(instant)) @@ -346,7 +346,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { test("SPARK-35664: encoding/decoding TimestampNTZType to/from java.time.LocalDateTime") { val schema = new StructType().add("t", TimestampNTZType) - val encoder = RowEncoder(schema).resolveAndBind() + val encoder = ExpressionEncoder(schema).resolveAndBind() val localDateTime = java.time.LocalDateTime.parse("2019-02-26T16:56:00") val row = toRow(encoder, Row(localDateTime)) assert(row.getLong(0) === DateTimeUtils.localDateTimeToMicros(localDateTime)) @@ -357,7 +357,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { test("encoding/decoding DateType to/from java.time.LocalDate") { withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> "true") { val schema = new StructType().add("d", DateType) - val encoder = RowEncoder(schema).resolveAndBind() + val encoder = ExpressionEncoder(schema).resolveAndBind() val localDate = java.time.LocalDate.parse("2019-02-27") val row = toRow(encoder, Row(localDate)) assert(row.getInt(0) === DateTimeUtils.localDateToDays(localDate)) @@ -369,7 +369,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { test("SPARK-34605: encoding/decoding DayTimeIntervalType to/from java.time.Duration") { dayTimeIntervalTypes.foreach { dayTimeIntervalType => val schema = new StructType().add("d", dayTimeIntervalType) - val encoder = RowEncoder(schema).resolveAndBind() + val encoder = ExpressionEncoder(schema).resolveAndBind() val duration = java.time.Duration.ofDays(1) val row = toRow(encoder, Row(duration)) assert(row.getLong(0) === IntervalUtils.durationToMicros(duration)) @@ -381,7 +381,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { test("SPARK-34615: encoding/decoding YearMonthIntervalType to/from java.time.Period") { yearMonthIntervalTypes.foreach { yearMonthIntervalType => val schema = new StructType().add("p", yearMonthIntervalType) - val encoder = RowEncoder(schema).resolveAndBind() + val encoder = ExpressionEncoder(schema).resolveAndBind() val period = java.time.Period.ofMonths(1) val row = toRow(encoder, Row(period)) assert(row.getInt(0) === IntervalUtils.periodToMonths(period)) @@ -398,7 +398,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { test("RowEncoder should preserve array nullability: " + s"ArrayType($elementType, containsNull = $containsNull), nullable = $nullable") { val schema = new StructType().add("array", ArrayType(elementType, containsNull), nullable) - val encoder = RowEncoder(schema).resolveAndBind() + val encoder = ExpressionEncoder(schema).resolveAndBind() assert(encoder.serializer.length == 1) assert(encoder.serializer.head.dataType == ArrayType(elementType, containsNull)) assert(encoder.serializer.head.nullable == nullable) @@ -416,7 +416,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { s"nullable = $nullable") { val schema = new StructType().add( "map", MapType(keyType, valueType, valueContainsNull), nullable) - val encoder = RowEncoder(schema).resolveAndBind() + val encoder = ExpressionEncoder(schema).resolveAndBind() assert(encoder.serializer.length == 1) assert(encoder.serializer.head.dataType == MapType(keyType, valueType, valueContainsNull)) assert(encoder.serializer.head.nullable == nullable) @@ -427,7 +427,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { test(s"encode/decode: ${schema.simpleString}") { Seq(false, true).foreach { java8Api => withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> java8Api.toString) { - val encoder = RowEncoder(schema).resolveAndBind() + val encoder = ExpressionEncoder(schema).resolveAndBind() val inputGenerator = RandomDataGenerator.forType(schema, nullable = false).get var input: Row = null @@ -458,7 +458,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { .add("t1", TimestampType) .add("d0", DateType) .add("d1", DateType) - val encoder = RowEncoder(schema, lenient = true).resolveAndBind() + val encoder = ExpressionEncoder(schema, lenient = true).resolveAndBind() val instant = java.time.Instant.parse("2019-02-26T16:56:00Z") val ld = java.time.LocalDate.parse("2022-03-08") val row = encoder.createSerializer().apply( @@ -478,7 +478,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { .add("headers", ArrayType(new StructType() .add("key", StringType) .add("value", BinaryType))) - val encoder = RowEncoder(schema, lenient = true).resolveAndBind() + val encoder = ExpressionEncoder(schema, lenient = true).resolveAndBind() val data = Row(mutable.WrappedArray.make(Array(Row("key", "value".getBytes)))) val row = encoder.createSerializer()(data) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala index 5897dee5a4d..6f5f22a84ba 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala @@ -29,7 +29,7 @@ import org.scalatest.exceptions.TestFailedException import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder} +import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData, IntervalUtils} import org.apache.spark.sql.types.{ArrayType, StructType, _} @@ -731,7 +731,7 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { private def testHash(inputSchema: StructType): Unit = { val inputGenerator = RandomDataGenerator.forType(inputSchema, nullable = false).get - val toRow = RowEncoder(inputSchema).createSerializer() + val toRow = ExpressionEncoder(inputSchema).createSerializer() val seed = scala.util.Random.nextInt() test(s"murmur3/xxHash64/hive hash: ${inputSchema.simpleString}") { for (_ <- 1 to 10) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 73da5f4d3af..3a662e68d58 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -483,7 +483,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { testTypes.foreach { dt => genSchema(dt).map { schema => val row = RandomDataGenerator.randomRow(random, schema) - val toRow = RowEncoder(schema).createSerializer() + val toRow = ExpressionEncoder(schema).createSerializer() val internalRow = toRow(row) val lambda = LambdaVariable("dummy", schema(0).dataType, schema(0).nullable, id = 0) checkEvaluationWithoutCodegen(lambda, internalRow.get(0, schema(0).dataType), internalRow) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ObjectSerializerPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ObjectSerializerPruningSuite.scala index 3dd58dc9fc1..a1039b051ce 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ObjectSerializerPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ObjectSerializerPruningSuite.scala @@ -22,7 +22,7 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.catalyst.plans.PlanTest @@ -115,7 +115,7 @@ class ObjectSerializerPruningSuite extends PlanTest { test("SPARK-32652: Prune nested serializers: RowEncoder") { withSQLConf(SQLConf.SERIALIZER_NESTED_SCHEMA_PRUNING_ENABLED.key -> "true") { val testRelation = LocalRelation($"i".struct(StructType.fromDDL("a int, b string")), $"j".int) - val rowEncoder = RowEncoder(new StructType() + val rowEncoder = ExpressionEncoder(new StructType() .add("i", new StructType().add("a", "int").add("b", "string")) .add("j", "int")) val serializerObject = CatalystSerde.serialize( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala index b015829e672..50667c5df8c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala @@ -21,7 +21,7 @@ import scala.util.Random import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.sql.RandomDataGenerator -import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder} +import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.{SafeProjection, UnsafeProjection} import org.apache.spark.sql.types._ @@ -75,7 +75,7 @@ class ArrayDataIndexedSeqSuite extends SparkFunSuite { arrayTypes.foreach { dt => val schema = StructType(StructField("col_1", dt, nullable = false) :: Nil) val row = RandomDataGenerator.randomRow(random, schema) - val toRow = RowEncoder(schema).createSerializer() + val toRow = ExpressionEncoder(schema).createSerializer() val internalRow = toRow(row) val unsafeRowConverter = UnsafeProjection.create(schema) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala index 34e133095d6..1801094f6d6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala @@ -24,7 +24,7 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform @@ -128,7 +128,7 @@ class UnsafeArraySuite extends SparkFunSuite { val decimal = decimalArray(0) val schema = new StructType().add( "array", ArrayType(DecimalType(decimal.precision, decimal.scale))) - val encoder = RowEncoder(schema).resolveAndBind() + val encoder = ExpressionEncoder(schema).resolveAndBind() val externalRow = Row(decimalArray) val ir = encoder.createSerializer().apply(externalRow) @@ -141,7 +141,7 @@ class UnsafeArraySuite extends SparkFunSuite { } val schema = new StructType().add("array", ArrayType(CalendarIntervalType)) - val encoder = RowEncoder(schema).resolveAndBind() + val encoder = ExpressionEncoder(schema).resolveAndBind() val externalRow = Row(calendarintervalArray) val ir = encoder.createSerializer().apply(externalRow) val unsafeCalendar = ir.getArray(0) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 7eef2e9bbac..7b2259a6d99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -90,7 +90,7 @@ private[sql] object Dataset { sparkSession.withActive { val qe = sparkSession.sessionState.executePlan(logicalPlan) qe.assertAnalyzed() - new Dataset[Row](qe, RowEncoder(qe.analyzed.schema)) + new Dataset[Row](qe, ExpressionEncoder(qe.analyzed.schema)) } /** A variant of ofRows that allows passing in a tracker so we can track query parsing time. */ @@ -98,7 +98,7 @@ private[sql] object Dataset { : DataFrame = sparkSession.withActive { val qe = new QueryExecution(sparkSession, logicalPlan, tracker) qe.assertAnalyzed() - new Dataset[Row](qe, RowEncoder(qe.analyzed.schema)) + new Dataset[Row](qe, ExpressionEncoder(qe.analyzed.schema)) } } @@ -464,7 +464,7 @@ class Dataset[T] private[sql]( */ // This is declared with parentheses to prevent the Scala compiler from treating // `ds.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. - def toDF(): DataFrame = new Dataset[Row](queryExecution, RowEncoder(schema)) + def toDF(): DataFrame = new Dataset[Row](queryExecution, ExpressionEncoder(schema)) /** * Returns a new Dataset where each record has been mapped on to the specified type. The diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 1cf122eaa68..27ae10b3d59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -360,7 +360,7 @@ class SparkSession private( val replaced = CharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType] // TODO: use MutableProjection when rowRDD is another DataFrame and the applied // schema differs from the existing schema on any field data type. - val encoder = RowEncoder(replaced) + val encoder = ExpressionEncoder(replaced) val toRow = encoder.createSerializer() val catalystRows = rowRDD.map(toRow) internalCreateDataFrame(catalystRows.setName(rowRDD.name), schema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5c93c72e36d..903565a6d59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -22,7 +22,7 @@ import java.util.Locale import org.apache.spark.rdd.RDD import org.apache.spark.sql.{execution, AnalysisException, Strategy} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide, JoinSelectionHelper, NormalizeFloatingNumbers} @@ -776,7 +776,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case r: RunnableCommand => ExecutedCommandExec(r) :: Nil case MemoryPlan(sink, output) => - val encoder = RowEncoder(DataTypeUtils.fromAttributes(output)) + val encoder = ExpressionEncoder(DataTypeUtils.fromAttributes(output)) val toRow = encoder.createSerializer() LocalTableScanExec(output, sink.allData.map(r => toRow(r).copy())) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 3ddb897f708..5e6e0ad0392 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, Quali import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression @@ -742,7 +742,7 @@ object DataSourceStrategy rdd: RDD[Row]): RDD[InternalRow] = { if (relation.needConversion) { val toRow = - RowEncoder(DataTypeUtils.fromAttributes(output), lenient = true).createSerializer() + ExpressionEncoder(DataTypeUtils.fromAttributes(output), lenient = true).createSerializer() rdd.mapPartitions { iterator => iterator.map(toRow) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 3335f21a0d3..daeecaeb90a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -34,7 +34,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper} import org.apache.spark.sql.catalyst.analysis.Resolver -import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils, DateTimeUtils, GenericArrayData} @@ -327,7 +327,7 @@ object JdbcUtils extends Logging with SQLConfHelper { dialect: JdbcDialect): Iterator[Row] = { val inputMetrics = Option(TaskContext.get()).map(_.taskMetrics().inputMetrics).getOrElse(new InputMetrics) - val fromRow = RowEncoder(schema).resolveAndBind().createDeserializer() + val fromRow = ExpressionEncoder(schema).resolveAndBind().createDeserializer() val internalRows = resultSetToSparkInternalRows(resultSet, dialect, schema, inputMetrics) internalRows.map(fromRow) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2CommandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2CommandExec.scala index 42ad83c9821..7a6db856692 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2CommandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2CommandExec.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{AttributeSet, GenericRowWithSchema} import org.apache.spark.sql.catalyst.trees.LeafLike import org.apache.spark.sql.catalyst.types.DataTypeUtils @@ -65,7 +65,7 @@ abstract class V2CommandExec extends SparkPlan { } private lazy val rowSerializer = { - RowEncoder(DataTypeUtils.fromAttributes(output)).resolveAndBind().createSerializer() + ExpressionEncoder(DataTypeUtils.fromAttributes(output)).resolveAndBind().createSerializer() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala index b05a2d130d3..32d010b00d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala @@ -20,7 +20,7 @@ import org.apache.spark.{JobArtifactSet, TaskContext} import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, ProcessingTimeTimeout} import org.apache.spark.sql.catalyst.plans.physical.Distribution @@ -75,7 +75,7 @@ case class FlatMapGroupsInPandasWithStateExec( private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) override protected val stateEncoder: ExpressionEncoder[Any] = - RowEncoder(stateType).resolveAndBind().asInstanceOf[ExpressionEncoder[Any]] + ExpressionEncoder(stateType).resolveAndBind().asInstanceOf[ExpressionEncoder[Any]] override def output: Seq[Attribute] = outAttributes diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 691cea9edde..010ac75a73d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -21,7 +21,7 @@ import scala.collection.mutable.{Map => MutableMap} import scala.collection.mutable import org.apache.spark.sql.{Dataset, SparkSession} -import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp, FileSourceMetadataAttribute, LocalTimestamp} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.streaming.{StreamingRelationV2, WriteToStream} @@ -723,7 +723,7 @@ class MicroBatchExecution( markMicroBatchExecutionStart() val nextBatch = - new Dataset(lastExecution, RowEncoder(lastExecution.analyzed.schema)) + new Dataset(lastExecution, ExpressionEncoder(lastExecution.analyzed.schema)) val batchSinkProgress: Option[StreamWriterCommitProgress] = reportTimeTaken("addBatch") { SQLExecution.withNewExecutionId(lastExecution) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala index 27a39bccfdd..93ee615291f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala @@ -27,7 +27,7 @@ import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils @@ -179,7 +179,7 @@ class MemoryDataWriter(partition: Int, schema: StructType) private val data = mutable.Buffer[Row]() - private val fromRow = RowEncoder(schema).resolveAndBind().createDeserializer() + private val fromRow = ExpressionEncoder(schema).resolveAndBind().createDeserializer() override def write(row: InternalRow): Unit = { data.append(fromRow(row)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala index c9880682270..1ac1dda374f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala @@ -21,7 +21,7 @@ import java.time.LocalDateTime import org.scalatest.BeforeAndAfterEach -import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GreaterThan} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, Filter, LogicalPlan, Project} import org.apache.spark.sql.functions._ @@ -480,7 +480,7 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession df <- Seq(df1, df2) nullable <- Seq(true, false) } { - val dfWithDesiredNullability = new DataFrame(df.queryExecution, RowEncoder( + val dfWithDesiredNullability = new DataFrame(df.queryExecution, ExpressionEncoder( StructType(df.schema.fields.map(_.copy(nullable = nullable))))) // session window without dynamic gap val windowedProject = dfWithDesiredNullability diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 39db165148c..e1c355dc019 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -34,7 +34,7 @@ import org.apache.spark.api.python.PythonEvalType import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd} import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, EqualTo, ExpressionSet, GreaterThan, Literal, PythonUDF, Uuid} import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation import org.apache.spark.sql.catalyst.parser.ParseException @@ -2888,7 +2888,7 @@ class DataFrameSuite extends QueryTest val df2 = Seq((1, 2, 4), (2, 3, 5)).toDF("a", "b", "c") .repartition($"a", $"b").sortWithinPartitions("a", "b") - implicit val valueEncoder = RowEncoder(df1.schema) + implicit val valueEncoder = ExpressionEncoder(df1.schema) val df3 = df1.groupBy("a", "b").as[GroupByKey, Row] .cogroup(df2.groupBy("a", "b").as[GroupByKey, Row]) { case (_, data1, data2) => @@ -2912,7 +2912,7 @@ class DataFrameSuite extends QueryTest val df2 = Seq((1, 2, 4), (2, 3, 5)).toDF("a1", "b", "c") .repartition($"a1", $"b").sortWithinPartitions("a1", "b") - implicit val valueEncoder = RowEncoder(df1.schema) + implicit val valueEncoder = ExpressionEncoder(df1.schema) val groupedDataset1 = df1.groupBy(($"a1" + 1).as("a"), $"b").as[GroupByKey, Row] val groupedDataset2 = df2.groupBy(($"a1" + 1).as("a"), $"b").as[GroupByKey, Row] @@ -2930,7 +2930,7 @@ class DataFrameSuite extends QueryTest test("groupBy.as: throw AnalysisException for unresolved grouping expr") { val df = Seq((1, 2, 3), (2, 3, 4)).toDF("a", "b", "c") - implicit val valueEncoder = RowEncoder(df.schema) + implicit val valueEncoder = ExpressionEncoder(df.schema) checkError( exception = intercept[AnalysisException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala index 367cdbe8447..6ee173bc6af 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import java.time.LocalDateTime -import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, Filter} import org.apache.spark.sql.functions._ @@ -595,7 +595,7 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSparkSession { df <- Seq(df1, df2) nullable <- Seq(true, false) } { - val dfWithDesiredNullability = new DataFrame(df.queryExecution, RowEncoder( + val dfWithDesiredNullability = new DataFrame(df.queryExecution, ExpressionEncoder( StructType(df.schema.fields.map(_.copy(nullable = nullable))))) // tumbling windows val windowedProject = dfWithDesiredNullability diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetOptimizationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetOptimizationSuite.scala index 5b8c80b471b..81d7de856f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetOptimizationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetOptimizationSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql import org.apache.spark.metrics.source.CodegenMetrics -import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, Expression} import org.apache.spark.sql.catalyst.plans.logical.SerializeFromObject import org.apache.spark.sql.functions._ @@ -199,7 +199,7 @@ class DatasetOptimizationSuite extends QueryTest with SharedSparkSession { test("SPARK-32652: Pruned nested serializers: RowEncoder") { val df = Seq(("a", 1), ("b", 2), ("c", 3)).toDF("i", "j") - val encoder = RowEncoder(new StructType().add("s", df.schema)) + val encoder = ExpressionEncoder(new StructType().add("s", df.schema)) val query = df.map(row => Row(row))(encoder).select("s.i") testSerializer(query, Seq(Seq("i"))) checkAnswer(query, Seq(Row("a"), Row("b"), Row("c"))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index c09c6d18b66..a021b049cf0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.TestUtils.withListener import org.apache.spark.internal.config.MAX_RESULT_SIZE import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.{FooClassWithEnum, FooEnum, ScroogeLikeExample} -import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder} +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, OuterScopes} import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} import org.apache.spark.sql.catalyst.util.sideBySide @@ -1322,7 +1322,7 @@ class DatasetSuite extends QueryTest } else { Row(l) } - })(RowEncoder(schema)) + })(ExpressionEncoder(schema)) val message = intercept[Exception] { df.collect() @@ -1375,7 +1375,7 @@ class DatasetSuite extends QueryTest test("SPARK-15381: physical object operator should define `reference` correctly") { val df = Seq(1 -> 2).toDF("a", "b") - checkAnswer(df.map(row => row)(RowEncoder(df.schema)).select("b", "a"), Row(2, 1)) + checkAnswer(df.map(row => row)(ExpressionEncoder(df.schema)).select("b", "a"), Row(2, 1)) } private def checkShowString[T](ds: Dataset[T], expected: String): Unit = { @@ -2157,7 +2157,7 @@ class DatasetSuite extends QueryTest test("SPARK-26233: serializer should enforce decimal precision and scale") { val s = StructType(Seq(StructField("a", StringType), StructField("b", DecimalType(38, 8)))) - val encoder = RowEncoder(s) + val encoder = ExpressionEncoder(s) implicit val uEnc = encoder val df = spark.range(2).map(l => Row(l.toString, BigDecimal.valueOf(l + 0.1111))) checkAnswer(df.groupBy(col("a")).agg(first(col("b"))), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala index 8b27a98e2b9..1cb35d303a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} @@ -28,7 +28,7 @@ class GroupedIteratorSuite extends SparkFunSuite { test("basic") { val schema = new StructType().add("i", IntegerType).add("s", StringType) - val encoder = RowEncoder(schema).resolveAndBind() + val encoder = ExpressionEncoder(schema).resolveAndBind() val toRow = encoder.createSerializer() val fromRow = encoder.createDeserializer() val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c")) @@ -48,7 +48,7 @@ class GroupedIteratorSuite extends SparkFunSuite { test("group by 2 columns") { val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType) - val encoder = RowEncoder(schema).resolveAndBind() + val encoder = ExpressionEncoder(schema).resolveAndBind() val toRow = encoder.createSerializer() val fromRow = encoder.createDeserializer() @@ -77,7 +77,7 @@ class GroupedIteratorSuite extends SparkFunSuite { test("do nothing to the value iterator") { val schema = new StructType().add("i", IntegerType).add("s", StringType) - val encoder = RowEncoder(schema).resolveAndBind() + val encoder = ExpressionEncoder(schema).resolveAndBind() val toRow = encoder.createSerializer() val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c")) val grouped = GroupedIterator(input.iterator.map(toRow), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala index 1d2e467c94c..0b6fdef4f74 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala @@ -29,7 +29,7 @@ import org.mockito.Mockito.{mock, when} import org.apache.spark.SparkException import org.apache.spark.sql.{DataFrame, QueryTest, Row} -import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.datasources.PartitionedFile import org.apache.spark.sql.functions.col import org.apache.spark.sql.internal.SQLConf.SOURCES_BINARY_FILE_MAX_LENGTH @@ -306,7 +306,7 @@ class BinaryFileFormatSuite extends QueryTest with SharedSparkSession { ) val partitionedFile = mock(classOf[PartitionedFile]) when(partitionedFile.toPath).thenReturn(new Path(file.toURI)) - val encoder = RowEncoder(requiredSchema).resolveAndBind() + val encoder = ExpressionEncoder(requiredSchema).resolveAndBind() encoder.createDeserializer().apply(reader(partitionedFile).next()) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala index 9a0003e9e5c..b92fa4cf3a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import scala.language.implicitConversions import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.SerializeFromObjectExec import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions._ @@ -49,7 +49,7 @@ class ForeachBatchSinkSuite extends StreamTest { val mem = MemoryStream[Int] val ds = mem.toDF.selectExpr("value + 1 as value") - val tester = new ForeachBatchTester[Row](mem)(RowEncoder.apply(ds.schema)) + val tester = new ForeachBatchTester[Row](mem)(ExpressionEncoder(ds.schema)) val writer = (df: DataFrame, batchId: Long) => tester.record(batchId, df.selectExpr("value + 1")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index a17788afc5a..cb7995abcd0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -32,7 +32,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkEnv import org.apache.spark.sql.{Dataset, Encoder, QueryTest, Row} -import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical.AllTuples import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 @@ -147,7 +147,7 @@ trait StreamTest extends QueryTest with SharedSparkSession with TimeLimits with private def createToExternalRowConverter[A : Encoder](): A => Row = { val encoder = encoderFor[A] val toInternalRow = encoder.createSerializer() - val toExternalRow = RowEncoder(encoder.schema).resolveAndBind().createDeserializer() + val toExternalRow = ExpressionEncoder(encoder.schema).resolveAndBind().createDeserializer() toExternalRow.compose(toInternalRow) } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org