Repository: spark
Updated Branches:
  refs/heads/master e352de0db -> 2692bdb7d


[SPARK-11455][SQL] fix case sensitivity of partition by

depend on `caseSensitive` to do column name equality check, instead of just `==`

Author: Wenchen Fan <wenc...@databricks.com>

Closes #9410 from cloud-fan/partition.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2692bdb7
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2692bdb7
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2692bdb7

Branch: refs/heads/master
Commit: 2692bdb7dbf36d6247f595d5fd0cb9cda89e1fdd
Parents: e352de0
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Tue Nov 3 20:25:58 2015 -0800
Committer: Yin Huai <yh...@databricks.com>
Committed: Tue Nov 3 20:25:58 2015 -0800

----------------------------------------------------------------------
 .../datasources/PartitioningUtils.scala         |  7 ++---
 .../datasources/ResolvedDataSource.scala        | 27 +++++++++++++++-----
 .../spark/sql/execution/datasources/rules.scala |  6 +++--
 .../org/apache/spark/sql/DataFrameSuite.scala   | 10 ++++++++
 4 files changed, 39 insertions(+), 11 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2692bdb7/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
index 628c5e1..16dc236 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
@@ -287,10 +287,11 @@ private[sql] object PartitioningUtils {
 
   def validatePartitionColumnDataTypes(
       schema: StructType,
-      partitionColumns: Array[String]): Unit = {
+      partitionColumns: Array[String],
+      caseSensitive: Boolean): Unit = {
 
-    ResolvedDataSource.partitionColumnsSchema(schema, 
partitionColumns).foreach { field =>
-      field.dataType match {
+    ResolvedDataSource.partitionColumnsSchema(schema, partitionColumns, 
caseSensitive).foreach {
+      field => field.dataType match {
         case _: AtomicType => // OK
         case _ => throw new AnalysisException(s"Cannot use ${field.dataType} 
for partition column")
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/2692bdb7/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala
index 54beabb..86a306b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala
@@ -99,7 +99,8 @@ object ResolvedDataSource extends Logging {
           val maybePartitionsSchema = if (partitionColumns.isEmpty) {
             None
           } else {
-            Some(partitionColumnsSchema(schema, partitionColumns))
+            Some(partitionColumnsSchema(
+              schema, partitionColumns, sqlContext.conf.caseSensitiveAnalysis))
           }
 
           val caseInsensitiveOptions = new CaseInsensitiveMap(options)
@@ -172,14 +173,24 @@ object ResolvedDataSource extends Logging {
 
   def partitionColumnsSchema(
       schema: StructType,
-      partitionColumns: Array[String]): StructType = {
+      partitionColumns: Array[String],
+      caseSensitive: Boolean): StructType = {
+    val equality = columnNameEquality(caseSensitive)
     StructType(partitionColumns.map { col =>
-      schema.find(_.name == col).getOrElse {
+      schema.find(f => equality(f.name, col)).getOrElse {
         throw new RuntimeException(s"Partition column $col not found in schema 
$schema")
       }
     }).asNullable
   }
 
+  private def columnNameEquality(caseSensitive: Boolean): (String, String) => 
Boolean = {
+    if (caseSensitive) {
+      org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
+    } else {
+      org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution
+    }
+  }
+
   /** Create a [[ResolvedDataSource]] for saving the content of the given 
DataFrame. */
   def apply(
       sqlContext: SQLContext,
@@ -207,14 +218,18 @@ object ResolvedDataSource extends Logging {
           path.makeQualified(fs.getUri, fs.getWorkingDirectory)
         }
 
-        PartitioningUtils.validatePartitionColumnDataTypes(data.schema, 
partitionColumns)
+        val caseSensitive = sqlContext.conf.caseSensitiveAnalysis
+        PartitioningUtils.validatePartitionColumnDataTypes(
+          data.schema, partitionColumns, caseSensitive)
 
-        val dataSchema = StructType(data.schema.filterNot(f => 
partitionColumns.contains(f.name)))
+        val equality = columnNameEquality(caseSensitive)
+        val dataSchema = StructType(
+          data.schema.filterNot(f => partitionColumns.exists(equality(_, 
f.name))))
         val r = dataSource.createRelation(
           sqlContext,
           Array(outputPath.toString),
           Some(dataSchema.asNullable),
-          Some(partitionColumnsSchema(data.schema, partitionColumns)),
+          Some(partitionColumnsSchema(data.schema, partitionColumns, 
caseSensitive)),
           caseInsensitiveOptions)
 
         // For partitioned relation r, r.schema's column ordering can be 
different from the column

http://git-wip-us.apache.org/repos/asf/spark/blob/2692bdb7/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
index abc016b..1a8e7ab 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
@@ -140,7 +140,8 @@ private[sql] case class PreWriteCheck(catalog: Catalog) 
extends (LogicalPlan =>
           // OK
         }
 
-        PartitioningUtils.validatePartitionColumnDataTypes(r.schema, 
part.keySet.toArray)
+        PartitioningUtils.validatePartitionColumnDataTypes(
+          r.schema, part.keySet.toArray, catalog.conf.caseSensitiveAnalysis)
 
         // Get all input data source relations of the query.
         val srcRelations = query.collect {
@@ -190,7 +191,8 @@ private[sql] case class PreWriteCheck(catalog: Catalog) 
extends (LogicalPlan =>
           // OK
         }
 
-        PartitioningUtils.validatePartitionColumnDataTypes(query.schema, 
partitionColumns)
+        PartitioningUtils.validatePartitionColumnDataTypes(
+          query.schema, partitionColumns, catalog.conf.caseSensitiveAnalysis)
 
       case _ => // OK
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/2692bdb7/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index a883bcb..a9e6413 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -1118,4 +1118,14 @@ class DataFrameSuite extends QueryTest with 
SharedSQLContext {
       if (!allSequential) throw new SparkException("Partition should contain 
all sequential values")
     })
   }
+
+  test("fix case sensitivity of partition by") {
+    withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
+      withTempPath { path =>
+        val p = path.getAbsolutePath
+        Seq(2012 -> "a").toDF("year", 
"val").write.partitionBy("yEAr").parquet(p)
+        checkAnswer(sqlContext.read.parquet(p).select("YeaR"), Row(2012))
+      }
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to