This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 49f9e74973f [SPARK-45481][SPARK-45664][SPARK-45711][SQL][FOLLOWUP] 
Avoid magic strings copy from parquet|orc|avro compression codes
49f9e74973f is described below

commit 49f9e74973faadeddfab944d822dd3bcd6365c5b
Author: Jiaan Geng <belie...@163.com>
AuthorDate: Tue Oct 31 11:44:59 2023 -0700

    [SPARK-45481][SPARK-45664][SPARK-45711][SQL][FOLLOWUP] Avoid magic strings 
copy from parquet|orc|avro compression codes
    
    ### What changes were proposed in this pull request?
    This PR follows up https://github.com/apache/spark/pull/43562, 
https://github.com/apache/spark/pull/43528 and 
https://github.com/apache/spark/pull/43308.
    The aim of this PR is to avoid magic strings copy from `parquet|orc|avro` 
compression codes.
    
    This PR also simplify some test cases.
    
    ### Why are the changes needed?
    Avoid magic strings copy from parquet|orc|avro compression codes
    
    ### Does this PR introduce _any_ user-facing change?
    'No'.
    
    ### How was this patch tested?
    Exists test cases.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    'No'.
    
    Closes #43604 from beliefer/parquet_orc_avro.
    
    Authored-by: Jiaan Geng <belie...@163.com>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../org/apache/spark/sql/avro/AvroSuite.scala      | 29 +++++++----------
 .../execution/datasources/orc/OrcSourceSuite.scala | 36 +++++++++-------------
 .../apache/spark/sql/internal/SQLConfSuite.scala   | 13 ++++----
 .../spark/sql/hive/execution/HiveDDLSuite.scala    |  2 +-
 4 files changed, 34 insertions(+), 46 deletions(-)

diff --git 
a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala 
b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
index d618c0035fb..f4a88bd0db2 100644
--- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
+++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
@@ -38,6 +38,7 @@ import org.apache.spark.{SPARK_VERSION_SHORT, SparkConf, 
SparkException, SparkUp
 import org.apache.spark.TestUtils.assertExceptionMsg
 import org.apache.spark.sql._
 import org.apache.spark.sql.TestingUDT.IntervalData
+import org.apache.spark.sql.avro.AvroCompressionCodec._
 import org.apache.spark.sql.catalyst.expressions.AttributeReference
 import org.apache.spark.sql.catalyst.plans.logical.Filter
 import org.apache.spark.sql.catalyst.util.DateTimeTestUtils
@@ -680,24 +681,18 @@ abstract class AvroSuite
       val zstandardDir = s"$dir/zstandard"
 
       val df = spark.read.format("avro").load(testAvro)
-      spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key,
-        AvroCompressionCodec.UNCOMPRESSED.lowerCaseName())
+      spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key, 
UNCOMPRESSED.lowerCaseName())
       df.write.format("avro").save(uncompressDir)
-      spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key,
-        AvroCompressionCodec.BZIP2.lowerCaseName())
+      spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key, BZIP2.lowerCaseName())
       df.write.format("avro").save(bzip2Dir)
-      spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key,
-        AvroCompressionCodec.XZ.lowerCaseName())
+      spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key, XZ.lowerCaseName())
       df.write.format("avro").save(xzDir)
-      spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key,
-        AvroCompressionCodec.DEFLATE.lowerCaseName())
+      spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key, 
DEFLATE.lowerCaseName())
       spark.conf.set(SQLConf.AVRO_DEFLATE_LEVEL.key, "9")
       df.write.format("avro").save(deflateDir)
-      spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key,
-        AvroCompressionCodec.SNAPPY.lowerCaseName())
+      spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key, 
SNAPPY.lowerCaseName())
       df.write.format("avro").save(snappyDir)
-      spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key,
-        AvroCompressionCodec.ZSTANDARD.lowerCaseName())
+      spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key, 
ZSTANDARD.lowerCaseName())
       df.write.format("avro").save(zstandardDir)
 
       val uncompressSize = FileUtils.sizeOfDirectory(new File(uncompressDir))
@@ -2132,7 +2127,7 @@ abstract class AvroSuite
         val reader = new DataFileReader(file, new GenericDatumReader[Any]())
         val r = reader.getMetaString("avro.codec")
         r
-      }.map(v => if (v == "null") "uncompressed" else v).headOption
+      }.map(v => if (v == "null") UNCOMPRESSED.lowerCaseName() else 
v).headOption
     }
     def checkCodec(df: DataFrame, dir: String, codec: String): Unit = {
       val subdir = s"$dir/$codec"
@@ -2143,11 +2138,9 @@ abstract class AvroSuite
       val path = dir.toString
       val df = spark.read.format("avro").load(testAvro)
 
-      checkCodec(df, path, "uncompressed")
-      checkCodec(df, path, "deflate")
-      checkCodec(df, path, "snappy")
-      checkCodec(df, path, "bzip2")
-      checkCodec(df, path, "xz")
+      AvroCompressionCodec.values().foreach { codec =>
+        checkCodec(df, path, codec.lowerCaseName())
+      }
     }
   }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala
index 4abcb4a7ef1..1e98099361d 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala
@@ -36,6 +36,7 @@ import org.scalatest.BeforeAndAfterAll
 import org.apache.spark.{SPARK_VERSION_SHORT, SparkConf, SparkException}
 import org.apache.spark.sql.{Row, SPARK_VERSION_METADATA_KEY}
 import org.apache.spark.sql.execution.datasources.{CommonFileDataSourceSuite, 
SchemaMergeUtils}
+import org.apache.spark.sql.execution.datasources.orc.OrcCompressionCodec._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtilsBase}
 import org.apache.spark.sql.types._
@@ -324,38 +325,31 @@ abstract class OrcSuite
 
   test("SPARK-18433: Improve DataSource option keys to be more 
case-insensitive") {
     val conf = spark.sessionState.conf
-    val option = new OrcOptions(
-      Map(COMPRESS.getAttribute.toUpperCase(Locale.ROOT) -> 
OrcCompressionCodec.NONE.name()), conf)
+    val option =
+      new OrcOptions(Map(COMPRESS.getAttribute.toUpperCase(Locale.ROOT) -> 
NONE.name()), conf)
     assert(option.compressionCodec == OrcCompressionCodec.NONE.name())
   }
 
   test("SPARK-21839: Add SQL config for ORC compression") {
     val conf = spark.sessionState.conf
     // Test if the default of spark.sql.orc.compression.codec is snappy
-    assert(new OrcOptions(
-      Map.empty[String, String], conf).compressionCodec == 
OrcCompressionCodec.SNAPPY.name())
+    assert(new OrcOptions(Map.empty[String, String], conf).compressionCodec == 
SNAPPY.name())
 
     // OrcOptions's parameters have a higher priority than SQL configuration.
     // `compression` -> `orc.compression` -> `spark.sql.orc.compression.codec`
-    withSQLConf(SQLConf.ORC_COMPRESSION.key -> "uncompressed") {
-      assert(new OrcOptions(
-        Map.empty[String, String], conf).compressionCodec == 
OrcCompressionCodec.NONE.name())
-      val zlibCodec = OrcCompressionCodec.ZLIB.lowerCaseName()
-      val lzoCodec = OrcCompressionCodec.LZO.lowerCaseName()
+    withSQLConf(SQLConf.ORC_COMPRESSION.key -> UNCOMPRESSED.lowerCaseName()) {
+      assert(new OrcOptions(Map.empty[String, String], conf).compressionCodec 
== NONE.name())
+      val zlibCodec = ZLIB.lowerCaseName()
       val map1 = Map(COMPRESS.getAttribute -> zlibCodec)
-      val map2 = Map(COMPRESS.getAttribute -> zlibCodec, "compression" -> 
lzoCodec)
-      assert(new OrcOptions(map1, conf).compressionCodec ==  
OrcCompressionCodec.ZLIB.name())
-      assert(new OrcOptions(map2, conf).compressionCodec == 
OrcCompressionCodec.LZO.name())
+      val map2 = Map(COMPRESS.getAttribute -> zlibCodec, "compression" -> 
LZO.lowerCaseName())
+      assert(new OrcOptions(map1, conf).compressionCodec == ZLIB.name())
+      assert(new OrcOptions(map2, conf).compressionCodec == LZO.name())
     }
 
     // Test all the valid options of spark.sql.orc.compression.codec
     OrcCompressionCodec.values().map(_.name()).foreach { c =>
       withSQLConf(SQLConf.ORC_COMPRESSION.key -> c) {
-        val expected = if (c == OrcCompressionCodec.UNCOMPRESSED.name()) {
-          OrcCompressionCodec.NONE.name()
-        } else {
-          c
-        }
+        val expected = OrcCompressionCodec.valueOf(c).getCompressionKind.name()
         assert(new OrcOptions(Map.empty[String, String], 
conf).compressionCodec == expected)
       }
     }
@@ -556,20 +550,20 @@ abstract class OrcSuite
   test("SPARK-35612: Support LZ4 compression in ORC data source") {
     withTempPath { dir =>
       val path = dir.getAbsolutePath
-      spark.range(3).write.option("compression", "lz4").orc(path)
+      spark.range(3).write.option("compression", LZ4.lowerCaseName()).orc(path)
       checkAnswer(spark.read.orc(path), Seq(Row(0), Row(1), Row(2)))
       val files = OrcUtils.listOrcFiles(path, 
spark.sessionState.newHadoopConf())
-      assert(files.nonEmpty && files.forall(_.getName.contains("lz4")))
+      assert(files.nonEmpty && 
files.forall(_.getName.contains(LZ4.lowerCaseName())))
     }
   }
 
   test("SPARK-33978: Write and read a file with ZSTD compression") {
     withTempPath { dir =>
       val path = dir.getAbsolutePath
-      spark.range(3).write.option("compression", "zstd").orc(path)
+      spark.range(3).write.option("compression", 
ZSTD.lowerCaseName()).orc(path)
       checkAnswer(spark.read.orc(path), Seq(Row(0), Row(1), Row(2)))
       val files = OrcUtils.listOrcFiles(path, 
spark.sessionState.newHadoopConf())
-      assert(files.nonEmpty && files.forall(_.getName.contains("zstd")))
+      assert(files.nonEmpty && 
files.forall(_.getName.contains(ZSTD.lowerCaseName())))
     }
   }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
index 822c0642f2b..cc4669641a2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
@@ -26,6 +26,7 @@ import org.apache.spark.{SPARK_DOC_ROOT, 
SparkNoSuchElementException}
 import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.parser.ParseException
 import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.MIT
+import 
org.apache.spark.sql.execution.datasources.parquet.ParquetCompressionCodec.{GZIP,
 LZO}
 import org.apache.spark.sql.internal.StaticSQLConf._
 import org.apache.spark.sql.test.{SharedSparkSession, TestSQLContext}
 import org.apache.spark.util.Utils
@@ -368,7 +369,7 @@ class SQLConfSuite extends QueryTest with 
SharedSparkSession {
 
     assert(spark.conf.get(fallback.key) ===
       SQLConf.PARQUET_COMPRESSION.defaultValue.get)
-    assert(spark.conf.get(fallback.key, "lzo") === "lzo")
+    assert(spark.conf.get(fallback.key, LZO.lowerCaseName()) === 
LZO.lowerCaseName())
 
     val displayValue = spark.sessionState.conf.getAllDefinedConfs
       .find { case (key, _, _, _) => key == fallback.key }
@@ -376,17 +377,17 @@ class SQLConfSuite extends QueryTest with 
SharedSparkSession {
       .get
     assert(displayValue === fallback.defaultValueString)
 
-    spark.conf.set(SQLConf.PARQUET_COMPRESSION, "gzip")
-    assert(spark.conf.get(fallback.key) === "gzip")
+    spark.conf.set(SQLConf.PARQUET_COMPRESSION, GZIP.lowerCaseName())
+    assert(spark.conf.get(fallback.key) === GZIP.lowerCaseName())
 
-    spark.conf.set(fallback, "lzo")
-    assert(spark.conf.get(fallback.key) === "lzo")
+    spark.conf.set(fallback, LZO.lowerCaseName())
+    assert(spark.conf.get(fallback.key) === LZO.lowerCaseName())
 
     val newDisplayValue = spark.sessionState.conf.getAllDefinedConfs
       .find { case (key, _, _, _) => key == fallback.key }
       .map { case (_, v, _, _) => v }
       .get
-    assert(newDisplayValue === "lzo")
+    assert(newDisplayValue === LZO.lowerCaseName())
 
     SQLConf.unregister(fallback)
   }
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala
index 55cbf591303..91ac21652e1 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala
@@ -1926,7 +1926,7 @@ class HiveDDLSuite
         checkAnswer(spark.table("t"), Row(1))
         // Check if this is compressed as ZLIB.
         val maybeOrcFile = path.listFiles().find(_.getName.startsWith("part"))
-        assertCompression(maybeOrcFile, "orc", "ZLIB")
+        assertCompression(maybeOrcFile, "orc", OrcCompressionCodec.ZLIB.name())
 
         sql("CREATE TABLE t2 USING HIVE AS SELECT 1 AS c1, 'a' AS c2")
         val table2 = 
spark.sessionState.catalog.getTableMetadata(TableIdentifier("t2"))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to