Github user fjh100456 commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20087#discussion_r162779218
  
    --- Diff: 
sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala 
---
    @@ -0,0 +1,354 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.sql.hive
    +
    +import java.io.File
    +
    +import scala.collection.JavaConverters._
    +
    +import org.apache.hadoop.fs.Path
    +import org.apache.orc.OrcConf.COMPRESS
    +import org.apache.parquet.hadoop.ParquetOutputFormat
    +import org.scalatest.BeforeAndAfterAll
    +
    +import org.apache.spark.sql.execution.datasources.orc.OrcOptions
    +import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, 
ParquetTest}
    +import org.apache.spark.sql.hive.orc.OrcFileOperator
    +import org.apache.spark.sql.hive.test.TestHiveSingleton
    +import org.apache.spark.sql.internal.SQLConf
    +
    +class CompressionCodecSuite extends TestHiveSingleton with ParquetTest 
with BeforeAndAfterAll {
    +  import spark.implicits._
    +
    +  override def beforeAll(): Unit = {
    +    super.beforeAll()
    +    (0 until 
maxRecordNum).toDF("a").createOrReplaceTempView("table_source")
    +  }
    +
    +  override def afterAll(): Unit = {
    +    try {
    +      spark.catalog.dropTempView("table_source")
    +    } finally {
    +      super.afterAll()
    +    }
    +  }
    +
    +  private val maxRecordNum = 50
    +
    +  private def getConvertMetastoreConfName(format: String): String = 
format.toLowerCase match {
    +    case "parquet" => HiveUtils.CONVERT_METASTORE_PARQUET.key
    +    case "orc" => HiveUtils.CONVERT_METASTORE_ORC.key
    +  }
    +
    +  private def getSparkCompressionConfName(format: String): String = 
format.toLowerCase match {
    +    case "parquet" => SQLConf.PARQUET_COMPRESSION.key
    +    case "orc" => SQLConf.ORC_COMPRESSION.key
    +  }
    +
    +  private def getHiveCompressPropName(format: String): String = 
format.toLowerCase match {
    +    case "parquet" => ParquetOutputFormat.COMPRESSION
    +    case "orc" => COMPRESS.getAttribute
    +  }
    +
    +  private def normalizeCodecName(format: String, name: String): String = {
    +    format.toLowerCase match {
    +      case "parquet" => ParquetOptions.getParquetCompressionCodecName(name)
    +      case "orc" => OrcOptions.getORCCompressionCodecName(name)
    +    }
    +  }
    +
    +  private def getTableCompressionCodec(path: String, format: String): 
Seq[String] = {
    +    val hadoopConf = spark.sessionState.newHadoopConf()
    +    val codecs = format.toLowerCase match {
    +      case "parquet" => for {
    +        footer <- readAllFootersWithoutSummaryFiles(new Path(path), 
hadoopConf)
    +        block <- footer.getParquetMetadata.getBlocks.asScala
    +        column <- block.getColumns.asScala
    +      } yield column.getCodec.name()
    +      case "orc" => new File(path).listFiles().filter { file =>
    +        file.isFile && !file.getName.endsWith(".crc") && file.getName != 
"_SUCCESS"
    +      }.map { orcFile =>
    +        
OrcFileOperator.getFileReader(orcFile.toPath.toString).get.getCompression.toString
    +      }.toSeq
    +    }
    +    codecs.distinct
    +  }
    +
    +  private def createTable(
    +      rootDir: File,
    +      tableName: String,
    +      isPartitioned: Boolean,
    +      format: String,
    +      compressionCodec: Option[String]): Unit = {
    +    val tblProperties = compressionCodec match {
    +      case Some(prop) => 
s"TBLPROPERTIES('${getHiveCompressPropName(format)}'='$prop')"
    +      case _ => ""
    +    }
    +    val partitionCreate = if (isPartitioned) "PARTITIONED BY (p string)" 
else ""
    +    sql(
    +      s"""
    +        |CREATE TABLE $tableName(a int)
    +        |$partitionCreate
    +        |STORED AS $format
    +        |LOCATION '${rootDir.toURI.toString.stripSuffix("/")}/$tableName'
    +        |$tblProperties
    +      """.stripMargin)
    +  }
    +
    +  private def writeDataToTable(
    +      tableName: String,
    +      partitionValue: Option[String]): Unit = {
    +    val partitionInsert = partitionValue.map(p => s"partition 
(p='$p')").mkString
    +    sql(
    +      s"""
    +        |INSERT INTO TABLE $tableName
    +        |$partitionInsert
    +        |SELECT * FROM table_source
    +      """.stripMargin)
    +  }
    +
    +  private def writeDateToTableUsingCTAS(
    +      rootDir: File,
    +      tableName: String,
    +      partitionValue: Option[String],
    +      format: String,
    +      compressionCodec: Option[String]): Unit = {
    +    val partitionCreate = partitionValue.map(p => s"PARTITIONED BY 
(p)").mkString
    +    val compressionOption = compressionCodec.map { codec =>
    +      s",'${getHiveCompressPropName(format)}'='$codec'"
    +    }.mkString
    +    val partitionSelect = partitionValue.map(p => s",'$p' AS p").mkString
    +    sql(
    +      s"""
    +        |CREATE TABLE $tableName
    +        |USING $format
    +        
|OPTIONS('path'='${rootDir.toURI.toString.stripSuffix("/")}/$tableName' 
$compressionOption)
    +        |$partitionCreate
    +        |AS SELECT * $partitionSelect FROM table_source
    +      """.stripMargin)
    +  }
    +
    +  private def getPreparedTablePath(
    +      tmpDir: File,
    +      tableName: String,
    +      isPartitioned: Boolean,
    +      format: String,
    +      compressionCodec: Option[String],
    +      usingCTAS: Boolean): String = {
    +    val partitionValue = if (isPartitioned) Some("test") else None
    +    if (usingCTAS) {
    +      writeDateToTableUsingCTAS(tmpDir, tableName, partitionValue, format, 
compressionCodec)
    +    } else {
    +      createTable(tmpDir, tableName, isPartitioned, format, 
compressionCodec)
    +      writeDataToTable(tableName, partitionValue)
    +    }
    +    getTablePartitionPath(tmpDir, tableName, partitionValue)
    +  }
    +
    +  private def getTableSize(path: String): Long = {
    +    val dir = new File(path)
    +    val files = dir.listFiles().filter(_.getName.startsWith("part-"))
    +    files.map(_.length()).sum
    +  }
    +
    +  private def getTablePartitionPath(
    +      dir: File,
    +      tableName: String,
    +      partitionValue: Option[String]) = {
    +    val partitionPath = partitionValue.map(p => s"p=$p").mkString
    +    s"${dir.getPath.stripSuffix("/")}/$tableName/$partitionPath"
    +  }
    +
    +  private def getUncompressedDataSizeByFormat(
    +      format: String, isPartitioned: Boolean, usingCTAS: Boolean): Long = {
    +    var totalSize = 0L
    +    val tableName = s"tbl_$format"
    +    val codecName = normalizeCodecName(format, "uncompressed")
    +    withSQLConf(getSparkCompressionConfName(format) -> codecName) {
    +      withTempDir { tmpDir =>
    +        withTable(tableName) {
    +          val compressionCodec = Option(codecName)
    +          val path = getPreparedTablePath(
    +            tmpDir, tableName, isPartitioned, format, compressionCodec, 
usingCTAS)
    +          totalSize = getTableSize(path)
    +        }
    +      }
    +    }
    +    assert(totalSize > 0L)
    +    totalSize
    +  }
    +
    +  private def checkCompressionCodecForTable(
    +      format: String,
    +      isPartitioned: Boolean,
    +      compressionCodec: Option[String],
    +      usingCTAS: Boolean)
    +      (assertion: (String, Long) => Unit): Unit = {
    +    val tableName =
    +      if (usingCTAS) s"tbl_$format$isPartitioned" else 
s"tbl_$format${isPartitioned}_CAST"
    +    withTempDir { tmpDir =>
    +      withTable(tableName) {
    +        val path = getPreparedTablePath(
    +          tmpDir, tableName, isPartitioned, format, compressionCodec, 
usingCTAS)
    +        val relCompressionCodecs = getTableCompressionCodec(path, format)
    +        assert(relCompressionCodecs.length == 1)
    +        val tableSize = getTableSize(path)
    +        assertion(relCompressionCodecs.head, tableSize)
    +      }
    +    }
    +  }
    +
    +  private def checkTableCompressionCodecForCodecs(
    +      format: String,
    +      isPartitioned: Boolean,
    +      convertMetastore: Boolean,
    +      usingCTAS: Boolean,
    +      compressionCodecs: List[String],
    +      tableCompressionCodecs: List[String])
    +      (assertionCompressionCodec: (Option[String], String, String, Long) 
=> Unit): Unit = {
    +    withSQLConf(getConvertMetastoreConfName(format) -> 
convertMetastore.toString) {
    +      tableCompressionCodecs.foreach { tableCompression =>
    +        compressionCodecs.foreach { sessionCompressionCodec =>
    +          withSQLConf(getSparkCompressionConfName(format) -> 
sessionCompressionCodec) {
    +            // 'tableCompression = null' means no table-level compression
    +            val compression = Option(tableCompression)
    +            checkCompressionCodecForTable(format, isPartitioned, 
compression, usingCTAS) {
    +              case (realCompressionCodec, tableSize) =>
    +                assertionCompressionCodec(
    +                  compression, sessionCompressionCodec, 
realCompressionCodec, tableSize)
    +            }
    +          }
    +        }
    +      }
    +    }
    +  }
    +
    +  // When the amount of data is small, compressed data size may be larger 
than uncompressed one,
    +  // so we just check the difference when compressionCodec is not NONE or 
UNCOMPRESSED.
    +  private def checkTableSize(
    +      format: String,
    +      compressionCodec: String,
    +      isPartitioned: Boolean,
    +      convertMetastore: Boolean,
    +      usingCTAS: Boolean,
    +      tableSize: Long): Boolean = {
    +    val uncompressedSize = getUncompressedDataSizeByFormat(format, 
isPartitioned, usingCTAS)
    +    compressionCodec match {
    +      case "UNCOMPRESSED" if format == "parquet" => tableSize == 
uncompressedSize
    +      case "NONE" if format == "orc" => tableSize == uncompressedSize
    +      case _ => tableSize != uncompressedSize
    +    }
    +  }
    +
    +  def checkForTableWithCompressProp(format: String, compressCodecs: 
List[String]): Unit = {
    +    Seq(true, false).foreach { isPartitioned =>
    +      Seq(true, false).foreach { convertMetastore =>
    +        // TODO: Also verify CTAS(usingCTAS=true) cases when the 
bug(SPARK-22926) is fixed.
    +        Seq(false).foreach { usingCTAS =>
    +          checkTableCompressionCodecForCodecs(
    +            format,
    +            isPartitioned,
    +            convertMetastore,
    +            usingCTAS,
    +            compressionCodecs = compressCodecs,
    +            tableCompressionCodecs = compressCodecs) {
    +            case (tableCodec, sessionCodec, realCodec, tableSize) =>
    +              // For non-partitioned table and when convertMetastore is 
true, Expect session-level
    +              // take effect, and in other cases expect table-level take 
effect
    +              // TODO: It should always be table-level taking effect when 
the bug(SPARK-22926)
    +              // is fixed
    +              val expectCodec =
    +                if (convertMetastore && !isPartitioned) sessionCodec else 
tableCodec.get
    +              assert(expectCodec == realCodec)
    +              assert(checkTableSize(
    +                format, expectCodec, isPartitioned, convertMetastore, 
usingCTAS, tableSize))
    +          }
    +        }
    +      }
    +    }
    +  }
    +
    +  def checkForTableWithoutCompressProp(format: String, compressCodecs: 
List[String]): Unit = {
    +    Seq(true, false).foreach { isPartitioned =>
    +      Seq(true, false).foreach { convertMetastore =>
    +        // TODO: Also verify CTAS(usingCTAS=true) cases when the 
bug(SPARK-22926) is fixed.
    +        Seq(false).foreach { usingCTAS =>
    +          checkTableCompressionCodecForCodecs(
    +            format,
    +            isPartitioned,
    +            convertMetastore,
    +            usingCTAS,
    +            compressionCodecs = compressCodecs,
    +            tableCompressionCodecs = List(null)) {
    +            case
    +              (tableCodec, sessionCodec, realCodec, tableSize) =>
    --- End diff --
    
    Oops, I made a mistake. Thank you !


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to