This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ee9e36dd58c [SPARK-43390][SQL] DSv2 allows CTAS/RTAS to reserve schema 
nullability
ee9e36dd58c is described below

commit ee9e36dd58c824e4c965a5e2c850f64c59fbc278
Author: Cheng Pan <cheng...@apache.org>
AuthorDate: Tue May 9 17:10:01 2023 +0800

    [SPARK-43390][SQL] DSv2 allows CTAS/RTAS to reserve schema nullability
    
    ### What changes were proposed in this pull request?
    
    Add a new method `useNullableQuerySchema` in `Table`, to allow the 
DataSource implementation to declare whether they need to reserve schema 
nullability on CTAS/RTAS.
    
    ### Why are the changes needed?
    
    SPARK-28837 forcibly uses the nullable schema on CTAS/RTAS, which seems too 
aggressive:
    
    1. The existing matured RDBMSs have different behaviors for reserving 
schema nullability on CTAS/RTAS, as mentioned in #25536, PostgreSQL forcibly 
uses the nullable schema, but MySQL respects the query's output schema 
nullability.
    2. Some OLAP systems(e.g. ClickHouse) are perf-sensitive for nullable, and 
have strict restrictions on table schema, e.g. the primary keys are not allowed 
to be nullable.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, this PR adds a new DSv2 API, but the default implementation reserves 
backward compatibility.
    
    ### How was this patch tested?
    
    UTs are updated.
    
    Closes #41070 from pan3793/SPARK-43390.
    
    Authored-by: Cheng Pan <cheng...@apache.org>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../spark/sql/connector/catalog/TableCatalog.java  |  8 +++++
 .../datasources/v2/WriteToDataSourceV2Exec.scala   | 18 ++++++----
 .../spark/sql/connector/DataSourceV2SQLSuite.scala | 39 +++++++++++++++++-----
 .../spark/sql/connector/DatasourceV2SQLBase.scala  | 21 ++++++------
 4 files changed, 60 insertions(+), 26 deletions(-)

diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java
index eb442ad38bd..6cfd5ab1b6b 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java
@@ -199,6 +199,14 @@ public interface TableCatalog extends CatalogPlugin {
     return createTable(ident, CatalogV2Util.v2ColumnsToStructType(columns), 
partitions, properties);
   }
 
+  /**
+   * If true, mark all the fields of the query schema as nullable when 
executing
+   * CREATE/REPLACE TABLE ... AS SELECT ... and creating the table.
+   */
+  default boolean useNullableQuerySchema() {
+    return true;
+  }
+
   /**
    * Apply a set of {@link TableChange changes} to a table in the catalog.
    * <p>
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
index 426f33129a6..4a9b85450a1 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
@@ -82,7 +82,8 @@ case class CreateTableAsSelectExec(
       throw QueryCompilationErrors.tableAlreadyExistsError(ident)
     }
     val table = catalog.createTable(
-      ident, getV2Columns(query.schema), partitioning.toArray, 
properties.asJava)
+      ident, getV2Columns(query.schema, catalog.useNullableQuerySchema),
+      partitioning.toArray, properties.asJava)
     writeToTable(catalog, table, writeOptions, ident, query)
   }
 }
@@ -115,7 +116,8 @@ case class AtomicCreateTableAsSelectExec(
       throw QueryCompilationErrors.tableAlreadyExistsError(ident)
     }
     val stagedTable = catalog.stageCreate(
-      ident, getV2Columns(query.schema), partitioning.toArray, 
properties.asJava)
+      ident, getV2Columns(query.schema, catalog.useNullableQuerySchema),
+      partitioning.toArray, properties.asJava)
     writeToTable(catalog, stagedTable, writeOptions, ident, query)
   }
 }
@@ -160,7 +162,8 @@ case class ReplaceTableAsSelectExec(
       throw QueryCompilationErrors.cannotReplaceMissingTableError(ident)
     }
     val table = catalog.createTable(
-      ident, getV2Columns(query.schema), partitioning.toArray, 
properties.asJava)
+      ident, getV2Columns(query.schema, catalog.useNullableQuerySchema),
+      partitioning.toArray, properties.asJava)
     writeToTable(catalog, table, writeOptions, ident, query)
   }
 }
@@ -191,7 +194,7 @@ case class AtomicReplaceTableAsSelectExec(
   val properties = CatalogV2Util.convertTableProperties(tableSpec)
 
   override protected def run(): Seq[InternalRow] = {
-    val columns = getV2Columns(query.schema)
+    val columns = getV2Columns(query.schema, catalog.useNullableQuerySchema)
     if (catalog.tableExists(ident)) {
       val table = catalog.loadTable(ident)
       invalidateCache(catalog, table, ident)
@@ -555,9 +558,10 @@ case class DeltaWithMetadataWritingSparkTask(
 private[v2] trait V2CreateTableAsSelectBaseExec extends LeafV2CommandExec {
   override def output: Seq[Attribute] = Nil
 
-  protected def getV2Columns(schema: StructType): Array[Column] = {
-    CatalogV2Util.structTypeToV2Columns(CharVarcharUtils.getRawSchema(
-      removeInternalMetadata(schema), conf).asNullable)
+  protected def getV2Columns(schema: StructType, forceNullable: Boolean): 
Array[Column] = {
+    val rawSchema = 
CharVarcharUtils.getRawSchema(removeInternalMetadata(schema), conf)
+    val tableSchema = if (forceNullable) rawSchema.asNullable else rawSchema
+    CatalogV2Util.structTypeToV2Columns(tableSchema)
   }
 
   protected def writeToTable(
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
index 053afb84c10..6f14b0971ca 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
@@ -781,13 +781,20 @@ class DataSourceV2SQLSuiteV1Filter
   }
 
   test("CreateTableAsSelect: nullable schema") {
+    registerCatalog("testcat_nullability", 
classOf[ReserveSchemaNullabilityCatalog])
+
     val basicCatalog = catalog("testcat").asTableCatalog
     val atomicCatalog = catalog("testcat_atomic").asTableCatalog
+    val reserveNullabilityCatalog = 
catalog("testcat_nullability").asTableCatalog
     val basicIdentifier = "testcat.table_name"
     val atomicIdentifier = "testcat_atomic.table_name"
+    val reserveNullabilityIdentifier = "testcat_nullability.table_name"
 
-    Seq((basicCatalog, basicIdentifier), (atomicCatalog, 
atomicIdentifier)).foreach {
-      case (catalog, identifier) =>
+    Seq(
+      (basicCatalog, basicIdentifier, true),
+      (atomicCatalog, atomicIdentifier, true),
+      (reserveNullabilityCatalog, reserveNullabilityIdentifier, 
false)).foreach {
+      case (catalog, identifier, nullable) =>
         spark.sql(s"CREATE TABLE $identifier USING foo AS SELECT 1 i")
 
         val table = catalog.loadTable(Identifier.of(Array(), "table_name"))
@@ -795,14 +802,24 @@ class DataSourceV2SQLSuiteV1Filter
         assert(table.name == identifier)
         assert(table.partitioning.isEmpty)
         assert(table.properties == withDefaultOwnership(Map("provider" -> 
"foo")).asJava)
-        assert(table.schema == new StructType().add("i", "int"))
+        assert(table.schema == new StructType().add("i", "int", nullable))
 
         val rdd = 
spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows)
         checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), Row(1))
 
-        sql(s"INSERT INTO $identifier SELECT CAST(null AS INT)")
-        val rdd2 = 
spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows)
-        checkAnswer(spark.internalCreateDataFrame(rdd2, table.schema), 
Seq(Row(1), Row(null)))
+        def insertNullValueAndCheck(): Unit = {
+          sql(s"INSERT INTO $identifier SELECT CAST(null AS INT)")
+          val rdd2 = 
spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows)
+          checkAnswer(spark.internalCreateDataFrame(rdd2, table.schema), 
Seq(Row(1), Row(null)))
+        }
+        if (nullable) {
+          insertNullValueAndCheck()
+        } else {
+          val e = intercept[Exception] {
+            insertNullValueAndCheck()
+          }
+          assert(e.getMessage.contains("Null value appeared in non-nullable 
field"))
+        }
     }
   }
 
@@ -2311,7 +2328,7 @@ class DataSourceV2SQLSuiteV1Filter
 
   test("global temp view should not be masked by v2 catalog") {
     val globalTempDB = spark.conf.get(StaticSQLConf.GLOBAL_TEMP_DATABASE)
-    spark.conf.set(s"spark.sql.catalog.$globalTempDB", 
classOf[InMemoryTableCatalog].getName)
+    registerCatalog(globalTempDB, classOf[InMemoryTableCatalog])
 
     try {
       sql("create global temp view v as select 1")
@@ -2336,7 +2353,7 @@ class DataSourceV2SQLSuiteV1Filter
 
   test("SPARK-30104: v2 catalog named global_temp will be masked") {
     val globalTempDB = spark.conf.get(StaticSQLConf.GLOBAL_TEMP_DATABASE)
-    spark.conf.set(s"spark.sql.catalog.$globalTempDB", 
classOf[InMemoryTableCatalog].getName)
+    registerCatalog(globalTempDB, classOf[InMemoryTableCatalog])
     checkError(
       exception = intercept[AnalysisException] {
         // Since the following multi-part name starts with `globalTempDB`, it 
is resolved to
@@ -2543,7 +2560,7 @@ class DataSourceV2SQLSuiteV1Filter
       context = ExpectedContext(fragment = "testcat.abc", start = 17, stop = 
27))
 
     val globalTempDB = spark.conf.get(StaticSQLConf.GLOBAL_TEMP_DATABASE)
-    spark.conf.set(s"spark.sql.catalog.$globalTempDB", 
classOf[InMemoryTableCatalog].getName)
+    registerCatalog(globalTempDB, classOf[InMemoryTableCatalog])
     withTempView("v") {
       sql("create global temp view v as select 1")
       checkError(
@@ -3261,3 +3278,7 @@ class FakeV2Provider extends SimpleTableProvider {
     throw new UnsupportedOperationException("Unnecessary for DDL tests")
   }
 }
+
+class ReserveSchemaNullabilityCatalog extends InMemoryCatalog {
+  override def useNullableQuerySchema(): Boolean = false
+}
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DatasourceV2SQLBase.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DatasourceV2SQLBase.scala
index 9c308515025..4ccff44fa06 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DatasourceV2SQLBase.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DatasourceV2SQLBase.scala
@@ -21,26 +21,27 @@ import org.scalatest.BeforeAndAfter
 
 import org.apache.spark.sql.QueryTest
 import org.apache.spark.sql.connector.catalog.{CatalogPlugin, InMemoryCatalog, 
InMemoryPartitionTableCatalog, InMemoryTableWithV2FilterCatalog, 
StagingInMemoryTableCatalog}
-import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION
+import 
org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME
 import org.apache.spark.sql.test.SharedSparkSession
 
 trait DatasourceV2SQLBase
   extends QueryTest with SharedSparkSession with BeforeAndAfter {
 
+  protected def registerCatalog[T <: CatalogPlugin](name: String, clazz: 
Class[T]): Unit = {
+    spark.conf.set(s"spark.sql.catalog.$name", clazz.getName)
+  }
+
   protected def catalog(name: String): CatalogPlugin = {
     spark.sessionState.catalogManager.catalog(name)
   }
 
   before {
-    spark.conf.set("spark.sql.catalog.testcat", 
classOf[InMemoryCatalog].getName)
-    spark.conf.set("spark.sql.catalog.testv2filter",
-      classOf[InMemoryTableWithV2FilterCatalog].getName)
-    spark.conf.set("spark.sql.catalog.testpart", 
classOf[InMemoryPartitionTableCatalog].getName)
-    spark.conf.set(
-      "spark.sql.catalog.testcat_atomic", 
classOf[StagingInMemoryTableCatalog].getName)
-    spark.conf.set("spark.sql.catalog.testcat2", 
classOf[InMemoryCatalog].getName)
-    spark.conf.set(
-      V2_SESSION_CATALOG_IMPLEMENTATION.key, 
classOf[InMemoryTableSessionCatalog].getName)
+    registerCatalog("testcat", classOf[InMemoryCatalog])
+    registerCatalog("testv2filter", classOf[InMemoryTableWithV2FilterCatalog])
+    registerCatalog("testpart", classOf[InMemoryPartitionTableCatalog])
+    registerCatalog("testcat_atomic", classOf[StagingInMemoryTableCatalog])
+    registerCatalog("testcat2", classOf[InMemoryCatalog])
+    registerCatalog(SESSION_CATALOG_NAME, classOf[InMemoryTableSessionCatalog])
 
     val df = spark.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, 
"c"))).toDF("id", "data")
     df.createOrReplaceTempView("source")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to