This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-3.0 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.0 by this push: new 6c88d7c [SPARK-32646][SQL][3.0][TEST-HADOOP2.7][TEST-HIVE1.2] ORC predicate pushdown should work with case-insensitive analysis 6c88d7c is described below commit 6c88d7c1259ea9fe89f5c8190c683bba506d528e Author: Liang-Chi Hsieh <vii...@gmail.com> AuthorDate: Tue Aug 25 04:42:39 2020 +0000 [SPARK-32646][SQL][3.0][TEST-HADOOP2.7][TEST-HIVE1.2] ORC predicate pushdown should work with case-insensitive analysis ### What changes were proposed in this pull request? This PR proposes to fix ORC predicate pushdown under case-insensitive analysis case. The field names in pushed down predicates don't need to match in exact letter case with physical field names in ORC files, if we enable case-insensitive analysis. ### Why are the changes needed? Currently ORC predicate pushdown doesn't work with case-insensitive analysis. A predicate "a < 0" cannot pushdown to ORC file with field name "A" under case-insensitive analysis. But Parquet predicate pushdown works with this case. We should make ORC predicate pushdown work with case-insensitive analysis too. ### Does this PR introduce _any_ user-facing change? Yes, after this PR, under case-insensitive analysis, ORC predicate pushdown will work. ### How was this patch tested? Unit tests. Closes #29513 from viirya/fix-orc-pushdown-3.0. Authored-by: Liang-Chi Hsieh <vii...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../execution/datasources/orc/OrcFileFormat.scala | 16 +++-- .../execution/datasources/orc/OrcFiltersBase.scala | 35 ++++++++++- .../sql/execution/datasources/orc/OrcUtils.scala | 14 +++++ .../v2/orc/OrcPartitionReaderFactory.scala | 22 ++++++- .../sql/execution/datasources/v2/orc/OrcScan.scala | 2 +- .../datasources/v2/orc/OrcScanBuilder.scala | 9 +-- .../sql/execution/datasources/orc/OrcFilters.scala | 72 ++++++++++++---------- .../execution/datasources/orc/OrcFilterSuite.scala | 49 ++++++++++++++- .../sql/execution/datasources/orc/OrcFilters.scala | 70 +++++++++++---------- .../execution/datasources/orc/OrcFilterSuite.scala | 49 ++++++++++++++- 10 files changed, 253 insertions(+), 85 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index 4dff1ec..69badb4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -153,11 +153,6 @@ class OrcFileFormat filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { - if (sparkSession.sessionState.conf.orcFilterPushDown) { - OrcFilters.createFilter(dataSchema, filters).foreach { f => - OrcInputFormat.setSearchArgument(hadoopConf, f, dataSchema.fieldNames) - } - } val resultSchema = StructType(requiredSchema.fields ++ partitionSchema.fields) val sqlConf = sparkSession.sessionState.conf @@ -169,6 +164,8 @@ class OrcFileFormat val broadcastedConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + val orcFilterPushDown = sparkSession.sessionState.conf.orcFilterPushDown + val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles (file: PartitionedFile) => { val conf = broadcastedConf.value.value @@ -186,6 +183,15 @@ class OrcFileFormat if (resultedColPruneInfo.isEmpty) { Iterator.empty } else { + // ORC predicate pushdown + if (orcFilterPushDown) { + OrcUtils.readCatalystSchema(filePath, conf, ignoreCorruptFiles).map { fileSchema => + OrcFilters.createFilter(fileSchema, filters).foreach { f => + OrcInputFormat.setSearchArgument(conf, f, fileSchema.fieldNames) + } + } + } + val (requestedColIds, canPruneCols) = resultedColPruneInfo.get val resultSchemaString = OrcUtils.orcResultSchemaString(canPruneCols, dataSchema, resultSchema, partitionSchema, conf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala index e673309..4554899 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala @@ -17,14 +17,45 @@ package org.apache.spark.sql.execution.datasources.orc +import java.util.Locale + +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.quoteIfNeeded import org.apache.spark.sql.sources.{And, Filter} -import org.apache.spark.sql.types.{AtomicType, BinaryType, DataType} +import org.apache.spark.sql.types.{AtomicType, BinaryType, DataType, StructType} /** * Methods that can be shared when upgrading the built-in Hive. */ trait OrcFiltersBase { + case class OrcPrimitiveField(fieldName: String, fieldType: DataType) + + protected[sql] def getDataTypeMap( + schema: StructType, + caseSensitive: Boolean): Map[String, OrcPrimitiveField] = { + val fields = schema.flatMap { f => + if (isSearchableType(f.dataType)) { + Some(quoteIfNeeded(f.name) -> OrcPrimitiveField(quoteIfNeeded(f.name), f.dataType)) + } else { + None + } + } + + if (caseSensitive) { + fields.toMap + } else { + // Don't consider ambiguity here, i.e. more than one field are matched in case insensitive + // mode, just skip pushdown for these fields, they will trigger Exception when reading, + // See: SPARK-25175. + val dedupPrimitiveFields = fields + .groupBy(_._1.toLowerCase(Locale.ROOT)) + .filter(_._2.size == 1) + .mapValues(_.head._2) + CaseInsensitiveMap(dedupPrimitiveFields) + } + } + private[sql] def buildTree(filters: Seq[Filter]): Option[Filter] = { filters match { case Seq() => None @@ -40,7 +71,7 @@ trait OrcFiltersBase { * Return true if this is a searchable type in ORC. * Both CharType and VarcharType are cleaned at AstBuilder. */ - protected[sql] def isSearchableType(dataType: DataType) = dataType match { + private def isSearchableType(dataType: DataType) = dataType match { case BinaryType => false case _: AtomicType => true case _ => false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index e102539..be36432 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -92,6 +92,20 @@ object OrcUtils extends Logging { } } + def readCatalystSchema( + file: Path, + conf: Configuration, + ignoreCorruptFiles: Boolean): Option[StructType] = { + readSchema(file, conf, ignoreCorruptFiles) match { + case Some(schema) => + Some(CatalystSqlParser.parseDataType(schema.toString).asInstanceOf[StructType]) + + case None => + // Field names is empty or `FileFormatException` was thrown but ignoreCorruptFiles is true. + None + } + } + /** * Reads ORC file schemas in multi-threaded manner, using native version of ORC. * This is visible for testing. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala index 7f25f7bd..1f38128 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala @@ -31,9 +31,10 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader} import org.apache.spark.sql.execution.datasources.PartitionedFile -import org.apache.spark.sql.execution.datasources.orc.{OrcColumnarBatchReader, OrcDeserializer, OrcUtils} +import org.apache.spark.sql.execution.datasources.orc.{OrcColumnarBatchReader, OrcDeserializer, OrcFilters, OrcUtils} import org.apache.spark.sql.execution.datasources.v2._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.{AtomicType, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -52,10 +53,13 @@ case class OrcPartitionReaderFactory( broadcastedConf: Broadcast[SerializableConfiguration], dataSchema: StructType, readDataSchema: StructType, - partitionSchema: StructType) extends FilePartitionReaderFactory { + partitionSchema: StructType, + filters: Array[Filter]) extends FilePartitionReaderFactory { private val resultSchema = StructType(readDataSchema.fields ++ partitionSchema.fields) private val isCaseSensitive = sqlConf.caseSensitiveAnalysis private val capacity = sqlConf.orcVectorizedReaderBatchSize + private val orcFilterPushDown = sqlConf.orcFilterPushDown + private val ignoreCorruptFiles = sqlConf.ignoreCorruptFiles override def supportColumnarReads(partition: InputPartition): Boolean = { sqlConf.orcVectorizedReaderEnabled && sqlConf.wholeStageEnabled && @@ -63,6 +67,16 @@ case class OrcPartitionReaderFactory( resultSchema.forall(_.dataType.isInstanceOf[AtomicType]) } + private def pushDownPredicates(filePath: Path, conf: Configuration): Unit = { + if (orcFilterPushDown) { + OrcUtils.readCatalystSchema(filePath, conf, ignoreCorruptFiles).map { fileSchema => + OrcFilters.createFilter(fileSchema, filters).foreach { f => + OrcInputFormat.setSearchArgument(conf, f, fileSchema.fieldNames) + } + } + } + } + override def buildReader(file: PartitionedFile): PartitionReader[InternalRow] = { val conf = broadcastedConf.value.value @@ -70,6 +84,8 @@ case class OrcPartitionReaderFactory( val filePath = new Path(new URI(file.filePath)) + pushDownPredicates(filePath, conf) + val fs = filePath.getFileSystem(conf) val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) val resultedColPruneInfo = @@ -116,6 +132,8 @@ case class OrcPartitionReaderFactory( val filePath = new Path(new URI(file.filePath)) + pushDownPredicates(filePath, conf) + val fs = filePath.getFileSystem(conf) val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) val resultedColPruneInfo = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala index 62894fa..35e3b1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala @@ -48,7 +48,7 @@ case class OrcScan( // The partition values are already truncated in `FileScan.partitions`. // We should use `readPartitionSchema` as the partition schema here. OrcPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, - dataSchema, readDataSchema, readPartitionSchema) + dataSchema, readDataSchema, readPartitionSchema, pushedFilters) } override def equals(obj: Any): Boolean = obj match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala index 9f40f5f..6a9cb25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala @@ -22,11 +22,11 @@ import scala.collection.JavaConverters._ import org.apache.orc.mapreduce.OrcInputFormat import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.quoteIfNeeded import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters} import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.orc.OrcFilters import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -55,12 +55,7 @@ case class OrcScanBuilder( override def pushFilters(filters: Array[Filter]): Array[Filter] = { if (sparkSession.sessionState.conf.orcFilterPushDown) { - OrcFilters.createFilter(schema, filters).foreach { f => - // The pushed filters will be set in `hadoopConf`. After that, we can simply use the - // changed `hadoopConf` in executors. - OrcInputFormat.setSearchArgument(hadoopConf, f, schema.fieldNames) - } - val dataTypeMap = schema.map(f => quoteIfNeeded(f.name) -> f.dataType).toMap + val dataTypeMap = OrcFilters.getDataTypeMap(schema, SQLConf.get.caseSensitiveAnalysis) // TODO (SPARK-25557): ORC doesn't support nested predicate pushdown, so they are removed. val newFilters = filters.filter(!_.containsNestedColumn) _pushedFilters = OrcFilters.convertibleFilters(schema, dataTypeMap, newFilters).toArray diff --git a/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index b685639..a068347 100644 --- a/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -27,7 +27,7 @@ import org.apache.orc.storage.serde2.io.HiveDecimalWritable import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateToDays, toJavaDate, toJavaTimestamp} -import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.quoteIfNeeded +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types._ @@ -68,7 +68,7 @@ private[sql] object OrcFilters extends OrcFiltersBase { * Create ORC filter as a SearchArgument instance. */ def createFilter(schema: StructType, filters: Seq[Filter]): Option[SearchArgument] = { - val dataTypeMap = schema.map(f => quoteIfNeeded(f.name) -> f.dataType).toMap + val dataTypeMap = getDataTypeMap(schema, SQLConf.get.caseSensitiveAnalysis) // Combines all convertible filters using `And` to produce a single conjunction // TODO (SPARK-25557): ORC doesn't support nested predicate pushdown, so they are removed. val newFilters = filters.filter(!_.containsNestedColumn) @@ -83,7 +83,7 @@ private[sql] object OrcFilters extends OrcFiltersBase { def convertibleFilters( schema: StructType, - dataTypeMap: Map[String, DataType], + dataTypeMap: Map[String, OrcPrimitiveField], filters: Seq[Filter]): Seq[Filter] = { import org.apache.spark.sql.sources._ @@ -141,7 +141,7 @@ private[sql] object OrcFilters extends OrcFiltersBase { /** * Get PredicateLeafType which is corresponding to the given DataType. */ - private def getPredicateLeafType(dataType: DataType) = dataType match { + private[sql] def getPredicateLeafType(dataType: DataType) = dataType match { case BooleanType => PredicateLeaf.Type.BOOLEAN case ByteType | ShortType | IntegerType | LongType => PredicateLeaf.Type.LONG case FloatType | DoubleType => PredicateLeaf.Type.FLOAT @@ -181,7 +181,7 @@ private[sql] object OrcFilters extends OrcFiltersBase { * @return the builder so far. */ private def buildSearchArgument( - dataTypeMap: Map[String, DataType], + dataTypeMap: Map[String, OrcPrimitiveField], expression: Filter, builder: Builder): Builder = { import org.apache.spark.sql.sources._ @@ -217,11 +217,11 @@ private[sql] object OrcFilters extends OrcFiltersBase { * @return the builder so far. */ private def buildLeafSearchArgument( - dataTypeMap: Map[String, DataType], + dataTypeMap: Map[String, OrcPrimitiveField], expression: Filter, builder: Builder): Option[Builder] = { def getType(attribute: String): PredicateLeaf.Type = - getPredicateLeafType(dataTypeMap(attribute)) + getPredicateLeafType(dataTypeMap(attribute).fieldType) import org.apache.spark.sql.sources._ @@ -231,39 +231,47 @@ private[sql] object OrcFilters extends OrcFiltersBase { // Since ORC 1.5.0 (ORC-323), we need to quote for column names with `.` characters // in order to distinguish predicate pushdown for nested columns. expression match { - case EqualTo(name, value) if isSearchableType(dataTypeMap(name)) => - val castedValue = castLiteralValue(value, dataTypeMap(name)) - Some(builder.startAnd().equals(name, getType(name), castedValue).end()) + case EqualTo(name, value) if dataTypeMap.contains(name) => + val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) + Some(builder.startAnd() + .equals(dataTypeMap(name).fieldName, getType(name), castedValue).end()) - case EqualNullSafe(name, value) if isSearchableType(dataTypeMap(name)) => - val castedValue = castLiteralValue(value, dataTypeMap(name)) - Some(builder.startAnd().nullSafeEquals(name, getType(name), castedValue).end()) + case EqualNullSafe(name, value) if dataTypeMap.contains(name) => + val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) + Some(builder.startAnd() + .nullSafeEquals(dataTypeMap(name).fieldName, getType(name), castedValue).end()) - case LessThan(name, value) if isSearchableType(dataTypeMap(name)) => - val castedValue = castLiteralValue(value, dataTypeMap(name)) - Some(builder.startAnd().lessThan(name, getType(name), castedValue).end()) + case LessThan(name, value) if dataTypeMap.contains(name) => + val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) + Some(builder.startAnd() + .lessThan(dataTypeMap(name).fieldName, getType(name), castedValue).end()) - case LessThanOrEqual(name, value) if isSearchableType(dataTypeMap(name)) => - val castedValue = castLiteralValue(value, dataTypeMap(name)) - Some(builder.startAnd().lessThanEquals(name, getType(name), castedValue).end()) + case LessThanOrEqual(name, value) if dataTypeMap.contains(name) => + val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) + Some(builder.startAnd() + .lessThanEquals(dataTypeMap(name).fieldName, getType(name), castedValue).end()) - case GreaterThan(name, value) if isSearchableType(dataTypeMap(name)) => - val castedValue = castLiteralValue(value, dataTypeMap(name)) - Some(builder.startNot().lessThanEquals(name, getType(name), castedValue).end()) + case GreaterThan(name, value) if dataTypeMap.contains(name) => + val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) + Some(builder.startNot() + .lessThanEquals(dataTypeMap(name).fieldName, getType(name), castedValue).end()) - case GreaterThanOrEqual(name, value) if isSearchableType(dataTypeMap(name)) => - val castedValue = castLiteralValue(value, dataTypeMap(name)) - Some(builder.startNot().lessThan(name, getType(name), castedValue).end()) + case GreaterThanOrEqual(name, value) if dataTypeMap.contains(name) => + val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) + Some(builder.startNot() + .lessThan(dataTypeMap(name).fieldName, getType(name), castedValue).end()) - case IsNull(name) if isSearchableType(dataTypeMap(name)) => - Some(builder.startAnd().isNull(name, getType(name)).end()) + case IsNull(name) if dataTypeMap.contains(name) => + Some(builder.startAnd() + .isNull(dataTypeMap(name).fieldName, getType(name)).end()) - case IsNotNull(name) if isSearchableType(dataTypeMap(name)) => - Some(builder.startNot().isNull(name, getType(name)).end()) + case IsNotNull(name) if dataTypeMap.contains(name) => + Some(builder.startNot() + .isNull(dataTypeMap(name).fieldName, getType(name)).end()) - case In(name, values) if isSearchableType(dataTypeMap(name)) => - val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(name))) - Some(builder.startAnd().in(name, getType(name), + case In(name, values) if dataTypeMap.contains(name) => + val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(name).fieldType)) + Some(builder.startAnd().in(dataTypeMap(name).fieldName, getType(name), castedValues.map(_.asInstanceOf[AnyRef]): _*).end()) case _ => None diff --git a/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala index beb7232..a3c2343 100644 --- a/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala +++ b/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala @@ -24,6 +24,7 @@ import java.sql.{Date, Timestamp} import scala.collection.JavaConverters._ import org.apache.orc.storage.ql.io.sarg.{PredicateLeaf, SearchArgument} +import org.apache.orc.storage.ql.io.sarg.SearchArgumentFactory.newBuilder import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Row} @@ -542,8 +543,7 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { checkAnswer(sql(s"select a from $tableName"), (0 until count).map(c => Row(c - 1))) val actual = stripSparkFilter(sql(s"select a from $tableName where a < 0")) - // TODO: ORC predicate pushdown should work under case-insensitive analysis. - // assert(actual.count() == 1) + assert(actual.count() == 1) } } @@ -562,5 +562,50 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { } } } + + test("SPARK-32646: Case-insensitive field resolution for pushdown when reading ORC") { + import org.apache.spark.sql.sources._ + + def getOrcFilter( + schema: StructType, + filters: Seq[Filter], + caseSensitive: String): Option[SearchArgument] = { + var orcFilter: Option[SearchArgument] = None + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive) { + orcFilter = + OrcFilters.createFilter(schema, filters) + } + orcFilter + } + + def testFilter( + schema: StructType, + filters: Seq[Filter], + expected: SearchArgument): Unit = { + val caseSensitiveFilters = getOrcFilter(schema, filters, "true") + val caseInsensitiveFilters = getOrcFilter(schema, filters, "false") + + assert(caseSensitiveFilters.isEmpty) + assert(caseInsensitiveFilters.isDefined) + + assert(caseInsensitiveFilters.get.getLeaves().size() > 0) + assert(caseInsensitiveFilters.get.getLeaves().size() == expected.getLeaves().size()) + (0 until expected.getLeaves().size()).foreach { index => + assert(caseInsensitiveFilters.get.getLeaves().get(index) == expected.getLeaves().get(index)) + } + } + + val schema = StructType(Seq(StructField("cint", IntegerType))) + testFilter(schema, Seq(GreaterThan("CINT", 1)), + newBuilder.startNot() + .lessThanEquals("cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`().build()) + testFilter(schema, Seq( + And(GreaterThan("CINT", 1), EqualTo("Cint", 2))), + newBuilder.startAnd() + .startNot() + .lessThanEquals("cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`() + .equals("cint", OrcFilters.getPredicateLeafType(IntegerType), 2L) + .`end`().build()) + } } diff --git a/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index 4b64208..9f1927e 100644 --- a/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -27,7 +27,7 @@ import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateToDays, toJavaDate, toJavaTimestamp} -import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.quoteIfNeeded +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types._ @@ -68,7 +68,7 @@ private[sql] object OrcFilters extends OrcFiltersBase { * Create ORC filter as a SearchArgument instance. */ def createFilter(schema: StructType, filters: Seq[Filter]): Option[SearchArgument] = { - val dataTypeMap = schema.map(f => quoteIfNeeded(f.name) -> f.dataType).toMap + val dataTypeMap = getDataTypeMap(schema, SQLConf.get.caseSensitiveAnalysis) // Combines all convertible filters using `And` to produce a single conjunction // TODO (SPARK-25557): ORC doesn't support nested predicate pushdown, so they are removed. val newFilters = filters.filter(!_.containsNestedColumn) @@ -83,7 +83,7 @@ private[sql] object OrcFilters extends OrcFiltersBase { def convertibleFilters( schema: StructType, - dataTypeMap: Map[String, DataType], + dataTypeMap: Map[String, OrcPrimitiveField], filters: Seq[Filter]): Seq[Filter] = { import org.apache.spark.sql.sources._ @@ -141,7 +141,7 @@ private[sql] object OrcFilters extends OrcFiltersBase { /** * Get PredicateLeafType which is corresponding to the given DataType. */ - private def getPredicateLeafType(dataType: DataType) = dataType match { + private[sql] def getPredicateLeafType(dataType: DataType) = dataType match { case BooleanType => PredicateLeaf.Type.BOOLEAN case ByteType | ShortType | IntegerType | LongType => PredicateLeaf.Type.LONG case FloatType | DoubleType => PredicateLeaf.Type.FLOAT @@ -181,7 +181,7 @@ private[sql] object OrcFilters extends OrcFiltersBase { * @return the builder so far. */ private def buildSearchArgument( - dataTypeMap: Map[String, DataType], + dataTypeMap: Map[String, OrcPrimitiveField], expression: Filter, builder: Builder): Builder = { import org.apache.spark.sql.sources._ @@ -217,11 +217,11 @@ private[sql] object OrcFilters extends OrcFiltersBase { * @return the builder so far. */ private def buildLeafSearchArgument( - dataTypeMap: Map[String, DataType], + dataTypeMap: Map[String, OrcPrimitiveField], expression: Filter, builder: Builder): Option[Builder] = { def getType(attribute: String): PredicateLeaf.Type = - getPredicateLeafType(dataTypeMap(attribute)) + getPredicateLeafType(dataTypeMap(attribute).fieldType) import org.apache.spark.sql.sources._ @@ -231,39 +231,45 @@ private[sql] object OrcFilters extends OrcFiltersBase { // Since ORC 1.5.0 (ORC-323), we need to quote for column names with `.` characters // in order to distinguish predicate pushdown for nested columns. expression match { - case EqualTo(name, value) if isSearchableType(dataTypeMap(name)) => - val castedValue = castLiteralValue(value, dataTypeMap(name)) - Some(builder.startAnd().equals(name, getType(name), castedValue).end()) + case EqualTo(name, value) if dataTypeMap.contains(name) => + val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) + Some(builder.startAnd() + .equals(dataTypeMap(name).fieldName, getType(name), castedValue).end()) - case EqualNullSafe(name, value) if isSearchableType(dataTypeMap(name)) => - val castedValue = castLiteralValue(value, dataTypeMap(name)) - Some(builder.startAnd().nullSafeEquals(name, getType(name), castedValue).end()) + case EqualNullSafe(name, value) if dataTypeMap.contains(name) => + val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) + Some(builder.startAnd() + .nullSafeEquals(dataTypeMap(name).fieldName, getType(name), castedValue).end()) - case LessThan(name, value) if isSearchableType(dataTypeMap(name)) => - val castedValue = castLiteralValue(value, dataTypeMap(name)) - Some(builder.startAnd().lessThan(name, getType(name), castedValue).end()) + case LessThan(name, value) if dataTypeMap.contains(name) => + val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) + Some(builder.startAnd() + .lessThan(dataTypeMap(name).fieldName, getType(name), castedValue).end()) - case LessThanOrEqual(name, value) if isSearchableType(dataTypeMap(name)) => - val castedValue = castLiteralValue(value, dataTypeMap(name)) - Some(builder.startAnd().lessThanEquals(name, getType(name), castedValue).end()) + case LessThanOrEqual(name, value) if dataTypeMap.contains(name) => + val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) + Some(builder.startAnd() + .lessThanEquals(dataTypeMap(name).fieldName, getType(name), castedValue).end()) - case GreaterThan(name, value) if isSearchableType(dataTypeMap(name)) => - val castedValue = castLiteralValue(value, dataTypeMap(name)) - Some(builder.startNot().lessThanEquals(name, getType(name), castedValue).end()) + case GreaterThan(name, value) if dataTypeMap.contains(name) => + val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) + Some(builder.startNot() + .lessThanEquals(dataTypeMap(name).fieldName, getType(name), castedValue).end()) - case GreaterThanOrEqual(name, value) if isSearchableType(dataTypeMap(name)) => - val castedValue = castLiteralValue(value, dataTypeMap(name)) - Some(builder.startNot().lessThan(name, getType(name), castedValue).end()) + case GreaterThanOrEqual(name, value) if dataTypeMap.contains(name) => + val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) + Some(builder.startNot() + .lessThan(dataTypeMap(name).fieldName, getType(name), castedValue).end()) - case IsNull(name) if isSearchableType(dataTypeMap(name)) => - Some(builder.startAnd().isNull(name, getType(name)).end()) + case IsNull(name) if dataTypeMap.contains(name) => + Some(builder.startAnd().isNull(dataTypeMap(name).fieldName, getType(name)).end()) - case IsNotNull(name) if isSearchableType(dataTypeMap(name)) => - Some(builder.startNot().isNull(name, getType(name)).end()) + case IsNotNull(name) if dataTypeMap.contains(name) => + Some(builder.startNot().isNull(dataTypeMap(name).fieldName, getType(name)).end()) - case In(name, values) if isSearchableType(dataTypeMap(name)) => - val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(name))) - Some(builder.startAnd().in(name, getType(name), + case In(name, values) if dataTypeMap.contains(name) => + val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(name).fieldType)) + Some(builder.startAnd().in(dataTypeMap(name).fieldName, getType(name), castedValues.map(_.asInstanceOf[AnyRef]): _*).end()) case _ => None diff --git a/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala index a3e450c..cb69413 100644 --- a/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala +++ b/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala @@ -24,6 +24,7 @@ import java.sql.{Date, Timestamp} import scala.collection.JavaConverters._ import org.apache.hadoop.hive.ql.io.sarg.{PredicateLeaf, SearchArgument} +import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentFactory.newBuilder import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Row} @@ -543,8 +544,7 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { checkAnswer(sql(s"select a from $tableName"), (0 until count).map(c => Row(c - 1))) val actual = stripSparkFilter(sql(s"select a from $tableName where a < 0")) - // TODO: ORC predicate pushdown should work under case-insensitive analysis. - // assert(actual.count() == 1) + assert(actual.count() == 1) } } @@ -563,5 +563,50 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { } } } + + test("SPARK-32646: Case-insensitive field resolution for pushdown when reading ORC") { + import org.apache.spark.sql.sources._ + + def getOrcFilter( + schema: StructType, + filters: Seq[Filter], + caseSensitive: String): Option[SearchArgument] = { + var orcFilter: Option[SearchArgument] = None + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive) { + orcFilter = + OrcFilters.createFilter(schema, filters) + } + orcFilter + } + + def testFilter( + schema: StructType, + filters: Seq[Filter], + expected: SearchArgument): Unit = { + val caseSensitiveFilters = getOrcFilter(schema, filters, "true") + val caseInsensitiveFilters = getOrcFilter(schema, filters, "false") + + assert(caseSensitiveFilters.isEmpty) + assert(caseInsensitiveFilters.isDefined) + + assert(caseInsensitiveFilters.get.getLeaves().size() > 0) + assert(caseInsensitiveFilters.get.getLeaves().size() == expected.getLeaves().size()) + (0 until expected.getLeaves().size()).foreach { index => + assert(caseInsensitiveFilters.get.getLeaves().get(index) == expected.getLeaves().get(index)) + } + } + + val schema = StructType(Seq(StructField("cint", IntegerType))) + testFilter(schema, Seq(GreaterThan("CINT", 1)), + newBuilder.startNot() + .lessThanEquals("cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`().build()) + testFilter(schema, Seq( + And(GreaterThan("CINT", 1), EqualTo("Cint", 2))), + newBuilder.startAnd() + .startNot() + .lessThanEquals("cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`() + .equals("cint", OrcFilters.getPredicateLeafType(IntegerType), 2L) + .`end`().build()) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org