This is an automated email from the ASF dual-hosted git repository.

lzljs3620320 pushed a commit to branch release-1.0
in repository https://gitbox.apache.org/repos/asf/paimon.git

commit a767a103dc400b506b1c53545f01809fedbe6e2a
Author: Zouxxyy <[email protected]>
AuthorDate: Thu Dec 26 22:45:58 2024 +0800

    [spark] Fix writing null struct col (#4787)
---
 .../spark/catalyst/analysis/PaimonAnalysis.scala   | 24 ++++++++++++++--------
 .../apache/paimon/spark/PaimonSparkTestBase.scala  | 11 ++++++++--
 .../spark/sql/InsertOverwriteTableTestBase.scala   | 22 ++++++++++++++++++++
 .../apache/paimon/spark/sql/WithTableOptions.scala |  1 +
 4 files changed, 47 insertions(+), 11 deletions(-)

diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala
index f567d925ea..7909838668 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala
@@ -26,7 +26,7 @@ import org.apache.paimon.table.FileStoreTable
 
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.analysis.{NamedRelation, ResolvedTable}
-import org.apache.spark.sql.catalyst.expressions.{Alias, ArrayTransform, 
Attribute, CreateStruct, Expression, GetArrayItem, GetStructField, 
LambdaFunction, Literal, NamedExpression, NamedLambdaVariable}
+import org.apache.spark.sql.catalyst.expressions.{Alias, ArrayTransform, 
Attribute, CreateNamedStruct, CreateStruct, Expression, GetArrayItem, 
GetStructField, If, IsNull, LambdaFunction, Literal, NamedExpression, 
NamedLambdaVariable}
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.util.CharVarcharUtils
@@ -206,10 +206,7 @@ class PaimonAnalysis(session: SparkSession) extends 
Rule[LogicalPlan] {
         val sourceField = source(sourceIndex)
         castStructField(parent, sourceIndex, sourceField.name, targetField)
     }
-    Alias(CreateStruct(fields), parent.name)(
-      parent.exprId,
-      parent.qualifier,
-      Option(parent.metadata))
+    structAlias(fields, parent)
   }
 
   private def addCastToStructByPosition(
@@ -234,10 +231,19 @@ class PaimonAnalysis(session: SparkSession) extends 
Rule[LogicalPlan] {
         val sourceField = source(i)
         castStructField(parent, i, sourceField.name, targetField)
     }
-    Alias(CreateStruct(fields), parent.name)(
-      parent.exprId,
-      parent.qualifier,
-      Option(parent.metadata))
+    structAlias(fields, parent)
+  }
+
+  private def structAlias(
+      fields: Seq[NamedExpression],
+      parent: NamedExpression): NamedExpression = {
+    val struct = CreateStruct(fields)
+    val res = if (parent.nullable) {
+      If(IsNull(parent), Literal(null, struct.dataType), struct)
+    } else {
+      struct
+    }
+    Alias(res, parent.name)(parent.exprId, parent.qualifier, 
Option(parent.metadata))
   }
 
   private def castStructField(
diff --git 
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/PaimonSparkTestBase.scala
 
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/PaimonSparkTestBase.scala
index 867b3e5e33..9a6719010e 100644
--- 
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/PaimonSparkTestBase.scala
+++ 
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/PaimonSparkTestBase.scala
@@ -19,6 +19,8 @@
 package org.apache.paimon.spark
 
 import org.apache.paimon.catalog.{Catalog, Identifier}
+import org.apache.paimon.fs.FileIO
+import org.apache.paimon.fs.local.LocalFileIO
 import org.apache.paimon.spark.catalog.WithPaimonCatalog
 import org.apache.paimon.spark.extensions.PaimonSparkSessionExtensions
 import org.apache.paimon.spark.sql.{SparkVersionSupport, WithTableOptions}
@@ -46,6 +48,8 @@ class PaimonSparkTestBase
   with WithTableOptions
   with SparkVersionSupport {
 
+  protected lazy val fileIO: FileIO = LocalFileIO.create
+
   protected lazy val tempDBDir: File = Utils.createTempDir
 
   protected def paimonCatalog: Catalog = {
@@ -64,6 +68,7 @@ class PaimonSparkTestBase
       "org.apache.spark.serializer.JavaSerializer"
     }
     super.sparkConf
+      .set("spark.sql.warehouse.dir", tempDBDir.getCanonicalPath)
       .set("spark.sql.catalog.paimon", classOf[SparkCatalog].getName)
       .set("spark.sql.catalog.paimon.warehouse", tempDBDir.getCanonicalPath)
       .set("spark.sql.extensions", 
classOf[PaimonSparkSessionExtensions].getName)
@@ -152,8 +157,10 @@ class PaimonSparkTestBase
 
   override def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit
       pos: Position): Unit = {
-    println(testName)
-    super.test(testName, testTags: _*)(testFun)(pos)
+    super.test(testName, testTags: _*) {
+      println(testName)
+      testFun
+    }(pos)
   }
 
   def loadTable(tableName: String): FileStoreTable = {
diff --git 
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/InsertOverwriteTableTestBase.scala
 
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/InsertOverwriteTableTestBase.scala
index 977b747070..38cca371f0 100644
--- 
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/InsertOverwriteTableTestBase.scala
+++ 
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/InsertOverwriteTableTestBase.scala
@@ -560,4 +560,26 @@ abstract class InsertOverwriteTableTestBase extends 
PaimonSparkTestBase {
     }
     checkAnswer(sql("SELECT * FROM T ORDER BY name"), Row("g", null, 
"Shanghai"))
   }
+
+  test("Paimon Insert: read and write struct with null") {
+    fileFormats {
+      format =>
+        withTable("t") {
+          sql(
+            s"CREATE TABLE t (i INT, s STRUCT<f1: INT, f2: INT>) TBLPROPERTIES 
('file.format' = '$format')")
+          sql(
+            "INSERT INTO t VALUES (1, STRUCT(1, 1)), (2, null), (3, STRUCT(1, 
null)), (4, STRUCT(null, null))")
+          if (format.equals("parquet")) {
+            // todo: fix it, see https://github.com/apache/paimon/issues/4785
+            checkAnswer(
+              sql("SELECT * FROM t ORDER BY i"),
+              Seq(Row(1, Row(1, 1)), Row(2, null), Row(3, Row(1, null)), 
Row(4, null)))
+          } else {
+            checkAnswer(
+              sql("SELECT * FROM t ORDER BY i"),
+              Seq(Row(1, Row(1, 1)), Row(2, null), Row(3, Row(1, null)), 
Row(4, Row(null, null))))
+          }
+        }
+    }
+  }
 }
diff --git 
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/WithTableOptions.scala
 
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/WithTableOptions.scala
index e390058baf..d5866a31b1 100644
--- 
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/WithTableOptions.scala
+++ 
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/WithTableOptions.scala
@@ -27,4 +27,5 @@ trait WithTableOptions {
 
   protected val withPk: Seq[Boolean] = Seq(true, false)
 
+  protected def fileFormats(fn: String => Unit): Unit = Seq("parquet", "orc", 
"avro").foreach(fn)
 }

Reply via email to