This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c77c0d4  [SPARK-36825][SQL] Read/write dataframes with ANSI intervals 
from/to parquet files
c77c0d4 is described below

commit c77c0d41e13ba85b9a6cf713cefebbb7170f53c2
Author: Max Gekk <max.g...@gmail.com>
AuthorDate: Fri Sep 24 09:55:11 2021 +0300

    [SPARK-36825][SQL] Read/write dataframes with ANSI intervals from/to 
parquet files
    
    ### What changes were proposed in this pull request?
    Allow saving and loading of ANSI intervals - `YearMonthIntervalType` and 
`DayTimeIntervalType` to/from the Parquet datasource. After the changes, Spark 
saves ANSI intervals as primitive physical Parquet types:
    - year-month intervals as `INT32`
    - day-time intervals as `INT64`
    
    w/o any modifications. To load the values as intervals back, Spark puts the 
info about interval types to the extra key 
`org.apache.spark.sql.parquet.row.metadata`:
    ```
    $ java -jar parquet-tools-1.12.0.jar meta ./part-...-c000.snappy.parquet
    
    creator:     parquet-mr version 1.12.1 (build 
2a5c06c58fa987f85aa22170be14d927d5ff6e7d)
    extra:       org.apache.spark.version = 3.3.0
    extra:       org.apache.spark.sql.parquet.row.metadata = 
{"type":"struct","fields":[...,{"name":"i","type":"interval year to 
month","nullable":false,"metadata":{}}]}
    
    file schema: spark_schema
    
--------------------------------------------------------------------------------
    ...
    i:           REQUIRED INT32 R:0 D:0
    ```
    
    **Note:** The given PR focus on support of ANSI intervals in the Parquet 
datasource via write or read as a column in `Dataset`.
    
    ### Why are the changes needed?
    To improve user experience with Spark SQL. At the moment, users can make 
ANSI intervals "inside" Spark or parallelize Java collections of 
`Period`/`Duration` objects but cannot save the intervals to any built-in 
datasources. After the changes, users can save datasets/dataframes with 
year-month/day-time intervals to load them back later by Apache Spark.
    
    For example:
    ```scala
    scala> sql("select date'today' - date'2021-01-01' as 
diff").write.parquet("/Users/maximgekk/tmp/parquet_interval")
    
    scala> val readback = 
spark.read.parquet("/Users/maximgekk/tmp/parquet_interval")
    readback: org.apache.spark.sql.DataFrame = [diff: interval day]
    
    scala> readback.printSchema
    root
     |-- diff: interval day (nullable = true)
    
    scala> readback.show
    +------------------+
    |              diff|
    +------------------+
    |INTERVAL '264' DAY|
    +------------------+
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    In some sense, yes. Before the changes, users get an error while saving of 
ANSI intervals as dataframe columns to parquet files but the operation should 
complete successfully after the changes.
    
    ### How was this patch tested?
    1. By running the existing test suites:
    ```
    $ build/sbt "test:testOnly *ParquetFileFormatV2Suite"
    $ build/sbt "test:testOnly *FileBasedDataSourceSuite"
    $ build/sbt "sql/test:testOnly *JsonV2Suite"
    ```
    2. Added new tests:
    ```
    $ build/sbt "sql/test:testOnly *ParquetIOSuite"
    $ build/sbt "sql/test:testOnly *ParquetSchemaSuite"
    ```
    
    Closes #34057 from MaxGekk/ansi-interval-save-parquet.
    
    Authored-by: Max Gekk <max.g...@gmail.com>
    Signed-off-by: Max Gekk <max.g...@gmail.com>
---
 .../apache/spark/sql/catalyst/util/TypeUtils.scala |  7 ++---
 .../parquet/ParquetVectorUpdaterFactory.java       |  8 +++---
 .../sql/execution/datasources/DataSource.scala     | 13 ++++++---
 .../datasources/parquet/ParquetFileFormat.scala    |  2 --
 .../datasources/parquet/ParquetRowConverter.scala  |  3 ++-
 .../parquet/ParquetSchemaConverter.scala           |  4 +--
 .../datasources/parquet/ParquetWriteSupport.scala  |  4 +--
 .../spark/sql/FileBasedDataSourceSuite.scala       |  4 ++-
 .../datasources/CommonFileDataSourceSuite.scala    | 31 ++++++++++++----------
 .../datasources/parquet/ParquetIOSuite.scala       | 20 ++++++++++++++
 .../datasources/parquet/ParquetSchemaSuite.scala   | 18 +++++++++++++
 11 files changed, 82 insertions(+), 32 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
index f4c0f3b..1a8de4c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
@@ -110,14 +110,15 @@ object TypeUtils {
   }
 
   def failWithIntervalType(dataType: DataType): Unit = {
-    invokeOnceForInterval(dataType) {
+    invokeOnceForInterval(dataType, forbidAnsiIntervals = true) {
       throw QueryCompilationErrors.cannotUseIntervalTypeInTableSchemaError()
     }
   }
 
-  def invokeOnceForInterval(dataType: DataType)(f: => Unit): Unit = {
+  def invokeOnceForInterval(dataType: DataType, forbidAnsiIntervals: 
Boolean)(f: => Unit): Unit = {
     def isInterval(dataType: DataType): Boolean = dataType match {
-      case CalendarIntervalType | _: AnsiIntervalType => true
+      case _: AnsiIntervalType => forbidAnsiIntervals
+      case CalendarIntervalType => true
       case _ => false
     }
     if (dataType.existsRecursively(isInterval)) f
diff --git 
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java
 
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java
index 39de909..d02045b 100644
--- 
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java
+++ 
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java
@@ -31,9 +31,7 @@ import org.apache.spark.sql.catalyst.util.RebaseDateTime;
 import org.apache.spark.sql.execution.datasources.DataSourceUtils;
 import 
org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupportedException;
 import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
-import org.apache.spark.sql.types.DataType;
-import org.apache.spark.sql.types.DataTypes;
-import org.apache.spark.sql.types.DecimalType;
+import org.apache.spark.sql.types.*;
 
 import java.math.BigInteger;
 import java.time.ZoneId;
@@ -88,6 +86,8 @@ public class ParquetVectorUpdaterFactory {
             boolean failIfRebase = "EXCEPTION".equals(datetimeRebaseMode);
             return new IntegerWithRebaseUpdater(failIfRebase);
           }
+        } else if (sparkType instanceof YearMonthIntervalType) {
+          return new IntegerUpdater();
         }
         break;
       case INT64:
@@ -117,6 +117,8 @@ public class ParquetVectorUpdaterFactory {
             final boolean failIfRebase = 
"EXCEPTION".equals(datetimeRebaseMode);
             return new LongAsMicrosRebaseUpdater(failIfRebase);
           }
+        } else if (sparkType instanceof DayTimeIntervalType) {
+          return new LongUpdater();
         }
         break;
       case FLOAT:
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index ad850cf..0707af4 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -512,12 +512,13 @@ case class DataSource(
       physicalPlan: SparkPlan,
       metrics: Map[String, SQLMetric]): BaseRelation = {
     val outputColumns = DataWritingCommand.logicalPlanOutputWithNames(data, 
outputColumnNames)
-    disallowWritingIntervals(outputColumns.map(_.dataType))
     providingInstance() match {
       case dataSource: CreatableRelationProvider =>
+        disallowWritingIntervals(outputColumns.map(_.dataType), 
forbidAnsiIntervals = true)
         dataSource.createRelation(
           sparkSession.sqlContext, mode, caseInsensitiveOptions, 
Dataset.ofRows(sparkSession, data))
       case format: FileFormat =>
+        disallowWritingIntervals(outputColumns.map(_.dataType), 
forbidAnsiIntervals = false)
         val cmd = planForWritingFileFormat(format, mode, data)
         val resolvedPartCols = cmd.partitionColumns.map { col =>
           // The partition columns created in `planForWritingFileFormat` 
should always be
@@ -547,11 +548,12 @@ case class DataSource(
    * Returns a logical plan to write the given [[LogicalPlan]] out to this 
[[DataSource]].
    */
   def planForWriting(mode: SaveMode, data: LogicalPlan): LogicalPlan = {
-    disallowWritingIntervals(data.schema.map(_.dataType))
     providingInstance() match {
       case dataSource: CreatableRelationProvider =>
+        disallowWritingIntervals(data.schema.map(_.dataType), 
forbidAnsiIntervals = true)
         SaveIntoDataSourceCommand(data, dataSource, caseInsensitiveOptions, 
mode)
       case format: FileFormat =>
+        disallowWritingIntervals(data.schema.map(_.dataType), 
forbidAnsiIntervals = false)
         DataSource.validateSchema(data.schema)
         planForWritingFileFormat(format, mode, data)
       case _ =>
@@ -577,8 +579,11 @@ case class DataSource(
       checkEmptyGlobPath, checkFilesExist, enableGlobbing = globPaths)
   }
 
-  private def disallowWritingIntervals(dataTypes: Seq[DataType]): Unit = {
-    dataTypes.foreach(TypeUtils.invokeOnceForInterval(_) {
+  private def disallowWritingIntervals(
+      dataTypes: Seq[DataType],
+      forbidAnsiIntervals: Boolean): Unit = {
+    val isParquet = providingClass == classOf[ParquetFileFormat]
+    dataTypes.foreach(TypeUtils.invokeOnceForInterval(_, forbidAnsiIntervals 
|| !isParquet) {
       throw QueryCompilationErrors.cannotSaveIntervalIntoExternalStorageError()
     })
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
index d3ac077..586952a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
@@ -373,8 +373,6 @@ class ParquetFileFormat
   }
 
   override def supportDataType(dataType: DataType): Boolean = dataType match {
-    case _: AnsiIntervalType => false
-
     case _: AtomicType => true
 
     case st: StructType => st.forall { f => supportDataType(f.dataType) }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
index 583b4ba..1967066 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
@@ -265,7 +265,8 @@ private[parquet] class ParquetRowConverter(
           override def addInt(value: Int): Unit =
             updater.setLong(Integer.toUnsignedLong(value))
         }
-      case BooleanType | IntegerType | LongType | FloatType | DoubleType | 
BinaryType =>
+      case BooleanType | IntegerType | LongType | FloatType | DoubleType | 
BinaryType |
+        _: AnsiIntervalType =>
         new ParquetPrimitiveConverter(updater)
 
       case ByteType =>
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
index f3ecd79..e91a3ce 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
@@ -374,10 +374,10 @@ class SparkToParquetSchemaConverter(
         Types.primitive(INT32, repetition)
           .as(LogicalTypeAnnotation.intType(16, true)).named(field.name)
 
-      case IntegerType =>
+      case IntegerType | _: YearMonthIntervalType =>
         Types.primitive(INT32, repetition).named(field.name)
 
-      case LongType =>
+      case LongType | _: DayTimeIntervalType =>
         Types.primitive(INT64, repetition).named(field.name)
 
       case FloatType =>
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala
index d0cd02f..e4e0078 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala
@@ -180,11 +180,11 @@ class ParquetWriteSupport extends 
WriteSupport[InternalRow] with Logging {
         (row: SpecializedGetters, ordinal: Int) =>
           recordConsumer.addInteger(dateRebaseFunc(row.getInt(ordinal)))
 
-      case IntegerType =>
+      case IntegerType | _: YearMonthIntervalType =>
         (row: SpecializedGetters, ordinal: Int) =>
           recordConsumer.addInteger(row.getInt(ordinal))
 
-      case LongType =>
+      case LongType | _: DayTimeIntervalType =>
         (row: SpecializedGetters, ordinal: Int) =>
           recordConsumer.addLong(row.getLong(ordinal))
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
index 910f159..3f2f12d 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
@@ -382,7 +382,9 @@ class FileBasedDataSourceSuite extends QueryTest
             msg.toLowerCase(Locale.ROOT).contains(msg2))
         }
 
-        withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> useV1List) {
+        withSQLConf(
+          SQLConf.USE_V1_SOURCE_LIST.key -> useV1List,
+          SQLConf.LEGACY_INTERVAL_ENABLED.key -> "true") {
           // write path
           Seq("csv", "json", "parquet", "orc").foreach { format =>
             val msg = intercept[AnalysisException] {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CommonFileDataSourceSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CommonFileDataSourceSuite.scala
index e59bc05..39e00e2 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CommonFileDataSourceSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CommonFileDataSourceSuite.scala
@@ -36,22 +36,25 @@ trait CommonFileDataSourceSuite extends SQLHelper { self: 
AnyFunSuite =>
   protected def inputDataset: Dataset[_] = 
spark.createDataset(Seq("abc"))(Encoders.STRING)
 
   test(s"SPARK-36349: disallow saving of ANSI intervals to $dataSourceFormat") 
{
-    Seq("INTERVAL '1' DAY", "INTERVAL '1' YEAR").foreach { i =>
-      withTempPath { dir =>
-        val errMsg = intercept[AnalysisException] {
-          spark.sql(s"SELECT 
$i").write.format(dataSourceFormat).save(dir.getAbsolutePath)
-        }.getMessage
-        assert(errMsg.contains("Cannot save interval data type into external 
storage"))
+    if (!Set("parquet").contains(dataSourceFormat.toLowerCase(Locale.ROOT))) {
+      Seq("INTERVAL '1' DAY", "INTERVAL '1' YEAR").foreach { i =>
+        withTempPath { dir =>
+          val errMsg = intercept[AnalysisException] {
+            spark.sql(s"SELECT 
$i").write.format(dataSourceFormat).save(dir.getAbsolutePath)
+          }.getMessage
+          assert(errMsg.contains("Cannot save interval data type into external 
storage"))
+        }
       }
-    }
 
-    // Check all built-in file-based datasources except of libsvm which 
requires particular schema.
-    if (!Set("libsvm").contains(dataSourceFormat.toLowerCase(Locale.ROOT))) {
-      Seq("INTERVAL DAY TO SECOND", "INTERVAL YEAR TO MONTH").foreach { it =>
-        val errMsg = intercept[AnalysisException] {
-          spark.sql(s"CREATE TABLE t (i $it) USING $dataSourceFormat")
-        }.getMessage
-        assert(errMsg.contains("data source does not support"))
+      // Check all built-in file-based datasources except of libsvm which
+      // requires particular schema.
+      if (!Set("libsvm").contains(dataSourceFormat.toLowerCase(Locale.ROOT))) {
+        Seq("INTERVAL DAY TO SECOND", "INTERVAL YEAR TO MONTH").foreach { it =>
+          val errMsg = intercept[AnalysisException] {
+            spark.sql(s"CREATE TABLE t (i $it) USING $dataSourceFormat")
+          }.getMessage
+          assert(errMsg.contains("data source does not support"))
+        }
       }
     }
   }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
index e03a50b..e59b499 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.execution.datasources.parquet
 
+import java.time.{Duration, Period}
 import java.util.Locale
 
 import scala.collection.JavaConverters._
@@ -1060,6 +1061,25 @@ class ParquetIOSuite extends QueryTest with ParquetTest 
with SharedSparkSession
       }
     }
   }
+
+  test("SPARK-36825: year-month/day-time intervals written and read as 
INT32/INT64") {
+    Seq(
+      YearMonthIntervalType() -> ((i: Int) => Period.of(i, i, 0)),
+      DayTimeIntervalType() -> ((i: Int) => Duration.ofDays(i).plusSeconds(i))
+    ).foreach { case (it, f) =>
+      val data = (1 to 10).map(i => Row(i, f(i)))
+      val schema = StructType(Array(StructField("d", IntegerType, false),
+        StructField("i", it, false)))
+      withTempPath { file =>
+        val df = spark.createDataFrame(sparkContext.parallelize(data), schema)
+        df.write.parquet(file.getCanonicalPath)
+        withAllParquetReaders {
+          val df2 = spark.read.parquet(file.getCanonicalPath)
+          checkAnswer(df2, df.collect().toSeq)
+        }
+      }
+    }
+  }
 }
 
 class JobCommitFailureParquetOutputCommitter(outputPath: Path, context: 
TaskAttemptContext)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala
index fcc08ee..1da8574 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala
@@ -1029,6 +1029,24 @@ class ParquetSchemaSuite extends ParquetSchemaTest {
     writeLegacyParquetFormat = true,
     outputTimestampType = SQLConf.ParquetOutputTimestampType.TIMESTAMP_MICROS)
 
+  testCatalystToParquet(
+    "SPARK-36825: Year-month interval written and read as INT32",
+    StructType(Seq(StructField("f1", YearMonthIntervalType()))),
+    """message root {
+      |  optional INT32 f1;
+      |}
+    """.stripMargin,
+    writeLegacyParquetFormat = false)
+
+  testCatalystToParquet(
+    "SPARK-36825: Day-time interval written and read as INT64",
+    StructType(Seq(StructField("f1", DayTimeIntervalType()))),
+    """message root {
+      |  optional INT64 f1;
+      |}
+    """.stripMargin,
+    writeLegacyParquetFormat = false)
+
   private def testSchemaClipping(
       testName: String,
       parquetSchema: String,

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to