This is an automated email from the ASF dual-hosted git repository. maxgekk pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new f853afd [SPARK-36931][SQL] Support reading and writing ANSI intervals from/to ORC datasources f853afd is described below commit f853afdc035273f772dc47f5476be6cf205d0941 Author: Kousuke Saruta <saru...@oss.nttdata.com> AuthorDate: Fri Oct 8 10:49:11 2021 +0300 [SPARK-36931][SQL] Support reading and writing ANSI intervals from/to ORC datasources ### What changes were proposed in this pull request? This PR aims to support reading and writing ANSI intervals from/to ORC datasources. year-month and day-time intervals are mapped to ORC's `int` and `bigint` respectively, To preserve the Catalyst's types, this change adds `spark.sql.catalyst.type` attribute for each ORC's type information. The value of the attribute is the value returned by `YearMonthIntervalType.typeName` or `DayTimeIntervalType.typeName`. ### Why are the changes needed? For better usability. There should be no reason to prohibit from reading/writing ANSI intervals from/to ORC datasources. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New tests. Closes #34184 from sarutak/ansi-interval-orc-source. Authored-by: Kousuke Saruta <saru...@oss.nttdata.com> Signed-off-by: Max Gekk <max.g...@gmail.com> --- .../sql/execution/datasources/DataSource.scala | 3 +- .../datasources/orc/OrcDeserializer.scala | 10 ++-- .../execution/datasources/orc/OrcFileFormat.scala | 8 ++- .../datasources/orc/OrcOutputWriter.scala | 1 + .../execution/datasources/orc/OrcSerializer.scala | 4 +- .../sql/execution/datasources/orc/OrcUtils.scala | 62 +++++++++++++++++++++- .../execution/datasources/v2/orc/OrcTable.scala | 2 - .../datasources/CommonFileDataSourceSuite.scala | 2 +- .../execution/datasources/orc/OrcSourceSuite.scala | 49 ++++++++++++++++- 9 files changed, 123 insertions(+), 18 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 32913c6..9936126 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -581,7 +581,8 @@ case class DataSource( // TODO: Remove the Set below once all the built-in datasources support ANSI interval types private val writeAllowedSources: Set[Class[_]] = - Set(classOf[ParquetFileFormat], classOf[CSVFileFormat], classOf[JsonFileFormat]) + Set(classOf[ParquetFileFormat], classOf[CSVFileFormat], + classOf[JsonFileFormat], classOf[OrcFileFormat]) private def disallowWritingIntervals( dataTypes: Seq[DataType], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala index fa8977f..1476083 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala @@ -86,10 +86,10 @@ class OrcDeserializer( case ShortType => (ordinal, value) => updater.setShort(ordinal, value.asInstanceOf[ShortWritable].get) - case IntegerType => (ordinal, value) => + case IntegerType | _: YearMonthIntervalType => (ordinal, value) => updater.setInt(ordinal, value.asInstanceOf[IntWritable].get) - case LongType => (ordinal, value) => + case LongType | _: DayTimeIntervalType => (ordinal, value) => updater.setLong(ordinal, value.asInstanceOf[LongWritable].get) case FloatType => (ordinal, value) => @@ -197,8 +197,10 @@ class OrcDeserializer( case BooleanType => UnsafeArrayData.fromPrimitiveArray(new Array[Boolean](length)) case ByteType => UnsafeArrayData.fromPrimitiveArray(new Array[Byte](length)) case ShortType => UnsafeArrayData.fromPrimitiveArray(new Array[Short](length)) - case IntegerType => UnsafeArrayData.fromPrimitiveArray(new Array[Int](length)) - case LongType => UnsafeArrayData.fromPrimitiveArray(new Array[Long](length)) + case IntegerType | _: YearMonthIntervalType => + UnsafeArrayData.fromPrimitiveArray(new Array[Int](length)) + case LongType | _: DayTimeIntervalType => + UnsafeArrayData.fromPrimitiveArray(new Array[Long](length)) case FloatType => UnsafeArrayData.fromPrimitiveArray(new Array[Float](length)) case DoubleType => UnsafeArrayData.fromPrimitiveArray(new Array[Double](length)) case _ => new GenericArrayData(new Array[Any](length)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index c4ffdb4..26af2c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -27,7 +27,7 @@ import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.FileSplit import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.orc.{OrcUtils => _, _} -import org.apache.orc.OrcConf.{COMPRESS, MAPRED_OUTPUT_SCHEMA} +import org.apache.orc.OrcConf.COMPRESS import org.apache.orc.mapred.OrcStruct import org.apache.orc.mapreduce._ @@ -45,6 +45,8 @@ import org.apache.spark.util.{SerializableConfiguration, Utils} private[sql] object OrcFileFormat { def getQuotedSchemaString(dataType: DataType): String = dataType match { + case _: DayTimeIntervalType => LongType.catalogString + case _: YearMonthIntervalType => IntegerType.catalogString case _: AtomicType => dataType.catalogString case StructType(fields) => fields.map(f => s"`${f.name}`:${getQuotedSchemaString(f.dataType)}") @@ -90,8 +92,6 @@ class OrcFileFormat val conf = job.getConfiguration - conf.set(MAPRED_OUTPUT_SCHEMA.getAttribute, OrcFileFormat.getQuotedSchemaString(dataSchema)) - conf.set(COMPRESS.getAttribute, orcOptions.compressionCodec) conf.asInstanceOf[JobConf] @@ -233,8 +233,6 @@ class OrcFileFormat } override def supportDataType(dataType: DataType): Boolean = dataType match { - case _: AnsiIntervalType => false - case _: AtomicType => true case st: StructType => st.forall { f => supportDataType(f.dataType) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala index 6f21573..fe057e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala @@ -44,6 +44,7 @@ private[sql] class OrcOutputWriter( } val filename = orcOutputFormat.getDefaultWorkFile(context, ".orc") val options = OrcMapRedOutputFormat.buildOptions(context.getConfiguration) + options.setSchema(OrcUtils.orcTypeDescription(dataSchema)) val writer = OrcFile.createWriter(filename, options) val recordWriter = new OrcMapreduceRecordWriter[OrcStruct](writer) OrcUtils.addSparkVersionMetadata(writer) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala index ac32be2..9a1eb8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala @@ -88,7 +88,7 @@ class OrcSerializer(dataSchema: StructType) { (getter, ordinal) => new ShortWritable(getter.getShort(ordinal)) } - case IntegerType => + case IntegerType | _: YearMonthIntervalType => if (reuseObj) { val result = new IntWritable() (getter, ordinal) => @@ -99,7 +99,7 @@ class OrcSerializer(dataSchema: StructType) { } - case LongType => + case LongType | _: DayTimeIntervalType => if (reuseObj) { val result = new LongWritable() (getter, ordinal) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index ec57375..475448a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -50,6 +50,8 @@ object OrcUtils extends Logging { "LZ4" -> ".lz4", "LZO" -> ".lzo") + val CATALYST_TYPE_ATTRIBUTE_NAME = "spark.sql.catalyst.type" + def listOrcFiles(pathStr: String, conf: Configuration): Seq[Path] = { val origPath = new Path(pathStr) val fs = origPath.getFileSystem(conf) @@ -93,7 +95,13 @@ object OrcUtils extends Logging { case Category.STRUCT => toStructType(orcType) case Category.LIST => toArrayType(orcType) case Category.MAP => toMapType(orcType) - case _ => CatalystSqlParser.parseDataType(orcType.toString) + case _ => + val catalystTypeAttrValue = orcType.getAttributeValue(CATALYST_TYPE_ATTRIBUTE_NAME) + if (catalystTypeAttrValue != null) { + CatalystSqlParser.parseDataType(catalystTypeAttrValue) + } else { + CatalystSqlParser.parseDataType(orcType.toString) + } } } @@ -265,9 +273,61 @@ object OrcUtils extends Logging { s"array<${orcTypeDescriptionString(a.elementType)}>" case m: MapType => s"map<${orcTypeDescriptionString(m.keyType)},${orcTypeDescriptionString(m.valueType)}>" + case _: DayTimeIntervalType => LongType.catalogString + case _: YearMonthIntervalType => IntegerType.catalogString case _ => dt.catalogString } + def orcTypeDescription(dt: DataType): TypeDescription = { + def getInnerTypeDecription(dt: DataType): Option[TypeDescription] = { + dt match { + case y: YearMonthIntervalType => + val typeDesc = orcTypeDescription(IntegerType) + typeDesc.setAttribute( + CATALYST_TYPE_ATTRIBUTE_NAME, y.typeName) + Some(typeDesc) + case d: DayTimeIntervalType => + val typeDesc = orcTypeDescription(LongType) + typeDesc.setAttribute( + CATALYST_TYPE_ATTRIBUTE_NAME, d.typeName) + Some(typeDesc) + case _ => None + } + } + + dt match { + case s: StructType => + val result = new TypeDescription(TypeDescription.Category.STRUCT) + s.fields.foreach { f => + getInnerTypeDecription(f.dataType) match { + case Some(t) => result.addField(f.name, t) + case None => result.addField(f.name, orcTypeDescription(f.dataType)) + } + } + result + case a: ArrayType => + val result = new TypeDescription(TypeDescription.Category.LIST) + getInnerTypeDecription(a.elementType) match { + case Some(t) => result.addChild(t) + case None => result.addChild(orcTypeDescription(a.elementType)) + } + result + case m: MapType => + val result = new TypeDescription(TypeDescription.Category.MAP) + getInnerTypeDecription(m.keyType) match { + case Some(t) => result.addChild(t) + case None => result.addChild(orcTypeDescription(m.keyType)) + } + getInnerTypeDecription(m.valueType) match { + case Some(t) => result.addChild(t) + case None => result.addChild(orcTypeDescription(m.valueType)) + } + result + case other => + TypeDescription.fromString(other.catalogString) + } + } + /** * Returns the result schema to read from ORC file. In addition, It sets * the schema string to 'orc.mapred.input.schema' so ORC reader can use later. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala index 628b0a1..9cc4525 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala @@ -49,8 +49,6 @@ case class OrcTable( } override def supportsDataType(dataType: DataType): Boolean = dataType match { - case _: AnsiIntervalType => false - case _: AtomicType => true case st: StructType => st.forall { f => supportsDataType(f.dataType) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CommonFileDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CommonFileDataSourceSuite.scala index 28d0967..854463d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CommonFileDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CommonFileDataSourceSuite.scala @@ -36,7 +36,7 @@ trait CommonFileDataSourceSuite extends SQLHelper { self: AnyFunSuite => protected def inputDataset: Dataset[_] = spark.createDataset(Seq("abc"))(Encoders.STRING) test(s"SPARK-36349: disallow saving of ANSI intervals to $dataSourceFormat") { - if (!Set("parquet", "csv", "json").contains(dataSourceFormat.toLowerCase(Locale.ROOT))) { + if (!Set("parquet", "csv", "json", "orc").contains(dataSourceFormat.toLowerCase(Locale.ROOT))) { Seq("INTERVAL '1' DAY", "INTERVAL '1' YEAR").foreach { i => withTempPath { dir => val errMsg = intercept[AnalysisException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala index d077814..8ffccd9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.orc import java.io.File import java.nio.charset.StandardCharsets.UTF_8 import java.sql.{Date, Timestamp} +import java.time.{Duration, Period} import java.util.Locale import org.apache.hadoop.conf.Configuration @@ -35,13 +36,14 @@ import org.apache.spark.{SPARK_VERSION_SHORT, SparkConf, SparkException} import org.apache.spark.sql.{Row, SPARK_VERSION_METADATA_KEY} import org.apache.spark.sql.execution.datasources.{CommonFileDataSourceSuite, SchemaMergeUtils} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtilsBase} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils case class OrcData(intField: Int, stringField: String) -abstract class OrcSuite extends OrcTest with BeforeAndAfterAll with CommonFileDataSourceSuite { +abstract class OrcSuite + extends OrcTest with BeforeAndAfterAll with CommonFileDataSourceSuite with SQLTestUtilsBase { import testImplicits._ override protected def dataSourceFormat = "orc" @@ -806,6 +808,49 @@ abstract class OrcSourceSuite extends OrcSuite with SharedSparkSession { StructField("456", StringType) :: Nil)))))) } } + + Seq(true, false).foreach { vecReaderEnabled => + Seq(true, false).foreach { vecReaderNestedColEnabled => + test("SPARK-36931: Support reading and writing ANSI intervals (" + + s"${SQLConf.ORC_VECTORIZED_READER_ENABLED.key}=$vecReaderEnabled, " + + s"${SQLConf.ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key}=$vecReaderNestedColEnabled)") { + + withSQLConf( + SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> + vecReaderEnabled.toString, + SQLConf.ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> + vecReaderNestedColEnabled.toString) { + Seq( + YearMonthIntervalType() -> ((i: Int) => Period.of(i, i, 0)), + DayTimeIntervalType() -> ((i: Int) => Duration.ofDays(i).plusSeconds(i)) + ).foreach { case (it, f) => + val data = (1 to 10).map(i => Row(i, f(i))) + val schema = StructType(Array(StructField("d", IntegerType, false), + StructField("i", it, false))) + withTempPath { file => + val df = spark.createDataFrame(sparkContext.parallelize(data), schema) + df.write.orc(file.getCanonicalPath) + val df2 = spark.read.orc(file.getCanonicalPath) + checkAnswer(df2, df.collect().toSeq) + } + } + + // Tests for ANSI intervals in complex types. + withTempPath { file => + val df = spark.sql( + """SELECT + | named_struct('interval', interval '1-2' year to month) a, + | array(interval '1 2:3' day to minute) b, + | map('key', interval '10' year) c, + | map(interval '20' second, 'value') d""".stripMargin) + df.write.orc(file.getCanonicalPath) + val df2 = spark.read.orc(file.getCanonicalPath) + checkAnswer(df2, df.collect().toSeq) + } + } + } + } + } } class OrcSourceV1Suite extends OrcSourceSuite { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org