This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f853afd  [SPARK-36931][SQL] Support reading and writing ANSI intervals 
from/to ORC datasources
f853afd is described below

commit f853afdc035273f772dc47f5476be6cf205d0941
Author: Kousuke Saruta <saru...@oss.nttdata.com>
AuthorDate: Fri Oct 8 10:49:11 2021 +0300

    [SPARK-36931][SQL] Support reading and writing ANSI intervals from/to ORC 
datasources
    
    ### What changes were proposed in this pull request?
    
    This PR aims to support reading and writing ANSI intervals from/to ORC 
datasources.
    
    year-month and day-time intervals are mapped to ORC's `int` and `bigint` 
respectively,
    To preserve the Catalyst's types, this change adds 
`spark.sql.catalyst.type` attribute for each ORC's type information.
    The value of the attribute is the value returned by 
`YearMonthIntervalType.typeName` or `DayTimeIntervalType.typeName`.
    
    ### Why are the changes needed?
    
    For better usability. There should be no reason to prohibit from 
reading/writing ANSI intervals from/to ORC datasources.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    New tests.
    
    Closes #34184 from sarutak/ansi-interval-orc-source.
    
    Authored-by: Kousuke Saruta <saru...@oss.nttdata.com>
    Signed-off-by: Max Gekk <max.g...@gmail.com>
---
 .../sql/execution/datasources/DataSource.scala     |  3 +-
 .../datasources/orc/OrcDeserializer.scala          | 10 ++--
 .../execution/datasources/orc/OrcFileFormat.scala  |  8 ++-
 .../datasources/orc/OrcOutputWriter.scala          |  1 +
 .../execution/datasources/orc/OrcSerializer.scala  |  4 +-
 .../sql/execution/datasources/orc/OrcUtils.scala   | 62 +++++++++++++++++++++-
 .../execution/datasources/v2/orc/OrcTable.scala    |  2 -
 .../datasources/CommonFileDataSourceSuite.scala    |  2 +-
 .../execution/datasources/orc/OrcSourceSuite.scala | 49 ++++++++++++++++-
 9 files changed, 123 insertions(+), 18 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index 32913c6..9936126 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -581,7 +581,8 @@ case class DataSource(
 
   // TODO: Remove the Set below once all the built-in datasources support ANSI 
interval types
   private val writeAllowedSources: Set[Class[_]] =
-    Set(classOf[ParquetFileFormat], classOf[CSVFileFormat], 
classOf[JsonFileFormat])
+    Set(classOf[ParquetFileFormat], classOf[CSVFileFormat],
+      classOf[JsonFileFormat], classOf[OrcFileFormat])
 
   private def disallowWritingIntervals(
       dataTypes: Seq[DataType],
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala
index fa8977f..1476083 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala
@@ -86,10 +86,10 @@ class OrcDeserializer(
       case ShortType => (ordinal, value) =>
         updater.setShort(ordinal, value.asInstanceOf[ShortWritable].get)
 
-      case IntegerType => (ordinal, value) =>
+      case IntegerType | _: YearMonthIntervalType => (ordinal, value) =>
         updater.setInt(ordinal, value.asInstanceOf[IntWritable].get)
 
-      case LongType => (ordinal, value) =>
+      case LongType | _: DayTimeIntervalType => (ordinal, value) =>
         updater.setLong(ordinal, value.asInstanceOf[LongWritable].get)
 
       case FloatType => (ordinal, value) =>
@@ -197,8 +197,10 @@ class OrcDeserializer(
     case BooleanType => UnsafeArrayData.fromPrimitiveArray(new 
Array[Boolean](length))
     case ByteType => UnsafeArrayData.fromPrimitiveArray(new 
Array[Byte](length))
     case ShortType => UnsafeArrayData.fromPrimitiveArray(new 
Array[Short](length))
-    case IntegerType => UnsafeArrayData.fromPrimitiveArray(new 
Array[Int](length))
-    case LongType => UnsafeArrayData.fromPrimitiveArray(new 
Array[Long](length))
+    case IntegerType | _: YearMonthIntervalType =>
+      UnsafeArrayData.fromPrimitiveArray(new Array[Int](length))
+    case LongType | _: DayTimeIntervalType =>
+      UnsafeArrayData.fromPrimitiveArray(new Array[Long](length))
     case FloatType => UnsafeArrayData.fromPrimitiveArray(new 
Array[Float](length))
     case DoubleType => UnsafeArrayData.fromPrimitiveArray(new 
Array[Double](length))
     case _ => new GenericArrayData(new Array[Any](length))
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
index c4ffdb4..26af2c3 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
@@ -27,7 +27,7 @@ import org.apache.hadoop.mapreduce._
 import org.apache.hadoop.mapreduce.lib.input.FileSplit
 import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
 import org.apache.orc.{OrcUtils => _, _}
-import org.apache.orc.OrcConf.{COMPRESS, MAPRED_OUTPUT_SCHEMA}
+import org.apache.orc.OrcConf.COMPRESS
 import org.apache.orc.mapred.OrcStruct
 import org.apache.orc.mapreduce._
 
@@ -45,6 +45,8 @@ import org.apache.spark.util.{SerializableConfiguration, 
Utils}
 private[sql] object OrcFileFormat {
 
   def getQuotedSchemaString(dataType: DataType): String = dataType match {
+    case _: DayTimeIntervalType => LongType.catalogString
+    case _: YearMonthIntervalType => IntegerType.catalogString
     case _: AtomicType => dataType.catalogString
     case StructType(fields) =>
       fields.map(f => s"`${f.name}`:${getQuotedSchemaString(f.dataType)}")
@@ -90,8 +92,6 @@ class OrcFileFormat
 
     val conf = job.getConfiguration
 
-    conf.set(MAPRED_OUTPUT_SCHEMA.getAttribute, 
OrcFileFormat.getQuotedSchemaString(dataSchema))
-
     conf.set(COMPRESS.getAttribute, orcOptions.compressionCodec)
 
     conf.asInstanceOf[JobConf]
@@ -233,8 +233,6 @@ class OrcFileFormat
   }
 
   override def supportDataType(dataType: DataType): Boolean = dataType match {
-    case _: AnsiIntervalType => false
-
     case _: AtomicType => true
 
     case st: StructType => st.forall { f => supportDataType(f.dataType) }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala
index 6f21573..fe057e0 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala
@@ -44,6 +44,7 @@ private[sql] class OrcOutputWriter(
     }
     val filename = orcOutputFormat.getDefaultWorkFile(context, ".orc")
     val options = OrcMapRedOutputFormat.buildOptions(context.getConfiguration)
+    options.setSchema(OrcUtils.orcTypeDescription(dataSchema))
     val writer = OrcFile.createWriter(filename, options)
     val recordWriter = new OrcMapreduceRecordWriter[OrcStruct](writer)
     OrcUtils.addSparkVersionMetadata(writer)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala
index ac32be2..9a1eb8a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala
@@ -88,7 +88,7 @@ class OrcSerializer(dataSchema: StructType) {
         (getter, ordinal) => new ShortWritable(getter.getShort(ordinal))
       }
 
-    case IntegerType =>
+    case IntegerType | _: YearMonthIntervalType =>
       if (reuseObj) {
         val result = new IntWritable()
         (getter, ordinal) =>
@@ -99,7 +99,7 @@ class OrcSerializer(dataSchema: StructType) {
       }
 
 
-    case LongType =>
+    case LongType | _: DayTimeIntervalType =>
       if (reuseObj) {
         val result = new LongWritable()
         (getter, ordinal) =>
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
index ec57375..475448a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
@@ -50,6 +50,8 @@ object OrcUtils extends Logging {
     "LZ4" -> ".lz4",
     "LZO" -> ".lzo")
 
+  val CATALYST_TYPE_ATTRIBUTE_NAME = "spark.sql.catalyst.type"
+
   def listOrcFiles(pathStr: String, conf: Configuration): Seq[Path] = {
     val origPath = new Path(pathStr)
     val fs = origPath.getFileSystem(conf)
@@ -93,7 +95,13 @@ object OrcUtils extends Logging {
         case Category.STRUCT => toStructType(orcType)
         case Category.LIST => toArrayType(orcType)
         case Category.MAP => toMapType(orcType)
-        case _ => CatalystSqlParser.parseDataType(orcType.toString)
+        case _ =>
+          val catalystTypeAttrValue = 
orcType.getAttributeValue(CATALYST_TYPE_ATTRIBUTE_NAME)
+          if (catalystTypeAttrValue != null) {
+            CatalystSqlParser.parseDataType(catalystTypeAttrValue)
+          } else {
+            CatalystSqlParser.parseDataType(orcType.toString)
+          }
       }
     }
 
@@ -265,9 +273,61 @@ object OrcUtils extends Logging {
       s"array<${orcTypeDescriptionString(a.elementType)}>"
     case m: MapType =>
       
s"map<${orcTypeDescriptionString(m.keyType)},${orcTypeDescriptionString(m.valueType)}>"
+    case _: DayTimeIntervalType => LongType.catalogString
+    case _: YearMonthIntervalType => IntegerType.catalogString
     case _ => dt.catalogString
   }
 
+  def orcTypeDescription(dt: DataType): TypeDescription = {
+    def getInnerTypeDecription(dt: DataType): Option[TypeDescription] = {
+      dt match {
+        case y: YearMonthIntervalType =>
+          val typeDesc = orcTypeDescription(IntegerType)
+          typeDesc.setAttribute(
+            CATALYST_TYPE_ATTRIBUTE_NAME, y.typeName)
+          Some(typeDesc)
+        case d: DayTimeIntervalType =>
+          val typeDesc = orcTypeDescription(LongType)
+          typeDesc.setAttribute(
+            CATALYST_TYPE_ATTRIBUTE_NAME, d.typeName)
+          Some(typeDesc)
+        case _ => None
+      }
+    }
+
+    dt match {
+      case s: StructType =>
+        val result = new TypeDescription(TypeDescription.Category.STRUCT)
+        s.fields.foreach { f =>
+          getInnerTypeDecription(f.dataType) match {
+            case Some(t) => result.addField(f.name, t)
+            case None => result.addField(f.name, 
orcTypeDescription(f.dataType))
+          }
+        }
+        result
+      case a: ArrayType =>
+        val result = new TypeDescription(TypeDescription.Category.LIST)
+        getInnerTypeDecription(a.elementType) match {
+          case Some(t) => result.addChild(t)
+          case None => result.addChild(orcTypeDescription(a.elementType))
+        }
+        result
+      case m: MapType =>
+        val result = new TypeDescription(TypeDescription.Category.MAP)
+        getInnerTypeDecription(m.keyType) match {
+          case Some(t) => result.addChild(t)
+          case None => result.addChild(orcTypeDescription(m.keyType))
+        }
+        getInnerTypeDecription(m.valueType) match {
+          case Some(t) => result.addChild(t)
+          case None => result.addChild(orcTypeDescription(m.valueType))
+        }
+        result
+      case other =>
+        TypeDescription.fromString(other.catalogString)
+    }
+  }
+
   /**
    * Returns the result schema to read from ORC file. In addition, It sets
    * the schema string to 'orc.mapred.input.schema' so ORC reader can use 
later.
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala
index 628b0a1..9cc4525 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala
@@ -49,8 +49,6 @@ case class OrcTable(
     }
 
   override def supportsDataType(dataType: DataType): Boolean = dataType match {
-    case _: AnsiIntervalType => false
-
     case _: AtomicType => true
 
     case st: StructType => st.forall { f => supportsDataType(f.dataType) }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CommonFileDataSourceSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CommonFileDataSourceSuite.scala
index 28d0967..854463d3 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CommonFileDataSourceSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CommonFileDataSourceSuite.scala
@@ -36,7 +36,7 @@ trait CommonFileDataSourceSuite extends SQLHelper { self: 
AnyFunSuite =>
   protected def inputDataset: Dataset[_] = 
spark.createDataset(Seq("abc"))(Encoders.STRING)
 
   test(s"SPARK-36349: disallow saving of ANSI intervals to $dataSourceFormat") 
{
-    if (!Set("parquet", "csv", 
"json").contains(dataSourceFormat.toLowerCase(Locale.ROOT))) {
+    if (!Set("parquet", "csv", "json", 
"orc").contains(dataSourceFormat.toLowerCase(Locale.ROOT))) {
       Seq("INTERVAL '1' DAY", "INTERVAL '1' YEAR").foreach { i =>
         withTempPath { dir =>
           val errMsg = intercept[AnalysisException] {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala
index d077814..8ffccd9 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.orc
 import java.io.File
 import java.nio.charset.StandardCharsets.UTF_8
 import java.sql.{Date, Timestamp}
+import java.time.{Duration, Period}
 import java.util.Locale
 
 import org.apache.hadoop.conf.Configuration
@@ -35,13 +36,14 @@ import org.apache.spark.{SPARK_VERSION_SHORT, SparkConf, 
SparkException}
 import org.apache.spark.sql.{Row, SPARK_VERSION_METADATA_KEY}
 import org.apache.spark.sql.execution.datasources.{CommonFileDataSourceSuite, 
SchemaMergeUtils}
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtilsBase}
 import org.apache.spark.sql.types._
 import org.apache.spark.util.Utils
 
 case class OrcData(intField: Int, stringField: String)
 
-abstract class OrcSuite extends OrcTest with BeforeAndAfterAll with 
CommonFileDataSourceSuite {
+abstract class OrcSuite
+  extends OrcTest with BeforeAndAfterAll with CommonFileDataSourceSuite with 
SQLTestUtilsBase {
   import testImplicits._
 
   override protected def dataSourceFormat = "orc"
@@ -806,6 +808,49 @@ abstract class OrcSourceSuite extends OrcSuite with 
SharedSparkSession {
                 StructField("456", StringType) :: Nil))))))
     }
   }
+
+  Seq(true, false).foreach { vecReaderEnabled =>
+    Seq(true, false).foreach { vecReaderNestedColEnabled =>
+      test("SPARK-36931: Support reading and writing ANSI intervals (" +
+      s"${SQLConf.ORC_VECTORIZED_READER_ENABLED.key}=$vecReaderEnabled, " +
+      
s"${SQLConf.ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key}=$vecReaderNestedColEnabled)")
 {
+
+        withSQLConf(
+          SQLConf.ORC_VECTORIZED_READER_ENABLED.key ->
+            vecReaderEnabled.toString,
+          SQLConf.ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key ->
+            vecReaderNestedColEnabled.toString) {
+          Seq(
+            YearMonthIntervalType() -> ((i: Int) => Period.of(i, i, 0)),
+            DayTimeIntervalType() -> ((i: Int) => 
Duration.ofDays(i).plusSeconds(i))
+          ).foreach { case (it, f) =>
+            val data = (1 to 10).map(i => Row(i, f(i)))
+            val schema = StructType(Array(StructField("d", IntegerType, false),
+              StructField("i", it, false)))
+            withTempPath { file =>
+              val df = spark.createDataFrame(sparkContext.parallelize(data), 
schema)
+              df.write.orc(file.getCanonicalPath)
+              val df2 = spark.read.orc(file.getCanonicalPath)
+              checkAnswer(df2, df.collect().toSeq)
+            }
+          }
+
+          // Tests for ANSI intervals in complex types.
+          withTempPath { file =>
+            val df = spark.sql(
+              """SELECT
+                |  named_struct('interval', interval '1-2' year to month) a,
+                |  array(interval '1 2:3' day to minute) b,
+                |  map('key', interval '10' year) c,
+                |  map(interval '20' second, 'value') d""".stripMargin)
+            df.write.orc(file.getCanonicalPath)
+            val df2 = spark.read.orc(file.getCanonicalPath)
+            checkAnswer(df2, df.collect().toSeq)
+          }
+        }
+      }
+    }
+  }
 }
 
 class OrcSourceV1Suite extends OrcSourceSuite {

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to