This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 646315361 fix: Fix shuffle writing rows containing null struct fields 
(#1845)
646315361 is described below

commit 6463153612f24d4d8e9d5f546341014bfecc16ca
Author: Kristin Cowalcijk <b...@wherobots.com>
AuthorDate: Thu Jun 5 08:58:25 2025 +0800

    fix: Fix shuffle writing rows containing null struct fields (#1845)
---
 native/core/src/execution/shuffle/row.rs           | 67 ++++++++++++++++------
 .../comet/exec/CometColumnarShuffleSuite.scala     | 30 ++++++++++
 2 files changed, 81 insertions(+), 16 deletions(-)

diff --git a/native/core/src/execution/shuffle/row.rs 
b/native/core/src/execution/shuffle/row.rs
index bb1401e26..c98cc5438 100644
--- a/native/core/src/execution/shuffle/row.rs
+++ b/native/core/src/execution/shuffle/row.rs
@@ -444,25 +444,18 @@ pub(crate) fn append_field(
             // Appending value into struct field builder of Arrow struct 
builder.
             let field_builder = 
struct_builder.field_builder::<StructBuilder>(idx).unwrap();
 
-            if row.is_null_row() {
-                // The row is null.
+            let nested_row = if row.is_null_row() || row.is_null_at(idx) {
+                // The row is null, or the field in the row is null, i.e., a 
null nested row.
+                // Append a null value to the row builder.
                 field_builder.append_null();
+                SparkUnsafeRow::default()
             } else {
-                let is_null = row.is_null_at(idx);
+                field_builder.append(true);
+                row.get_struct(idx, fields.len())
+            };
 
-                let nested_row = if is_null {
-                    // The field in the row is null, i.e., a null nested row.
-                    // Append a null value to the row builder.
-                    field_builder.append_null();
-                    SparkUnsafeRow::default()
-                } else {
-                    field_builder.append(true);
-                    row.get_struct(idx, fields.len())
-                };
-
-                for (field_idx, field) in fields.into_iter().enumerate() {
-                    append_field(field.data_type(), field_builder, 
&nested_row, field_idx)?;
-                }
+            for (field_idx, field) in fields.into_iter().enumerate() {
+                append_field(field.data_type(), field_builder, &nested_row, 
field_idx)?;
             }
         }
         DataType::Map(field, _) => {
@@ -3302,3 +3295,45 @@ fn make_batch(arrays: Vec<ArrayRef>, row_count: usize) 
-> Result<RecordBatch, Ar
     let options = 
RecordBatchOptions::new().with_row_count(Option::from(row_count));
     RecordBatch::try_new_with_options(schema, arrays, &options)
 }
+
+#[cfg(test)]
+mod test {
+    use arrow::datatypes::Fields;
+
+    use super::*;
+
+    #[test]
+    fn test_append_null_row_to_struct_builder() {
+        let data_type = DataType::Struct(Fields::from(vec![
+            Field::new("a", DataType::Boolean, true),
+            Field::new("b", DataType::Boolean, true),
+        ]));
+        let fields = Fields::from(vec![Field::new("st", data_type.clone(), 
true)]);
+        let mut struct_builder = StructBuilder::from_fields(fields, 1);
+        let row = SparkUnsafeRow::default();
+        append_field(&data_type, &mut struct_builder, &row, 0).expect("append 
field");
+        struct_builder.append_null();
+        let struct_array = struct_builder.finish();
+        assert_eq!(struct_array.len(), 1);
+        assert!(struct_array.is_null(0));
+    }
+
+    #[test]
+    #[cfg_attr(miri, ignore)] // Unaligned memory access in SparkUnsafeRow
+    fn test_append_null_struct_field_to_struct_builder() {
+        let data_type = DataType::Struct(Fields::from(vec![
+            Field::new("a", DataType::Boolean, true),
+            Field::new("b", DataType::Boolean, true),
+        ]));
+        let fields = Fields::from(vec![Field::new("st", data_type.clone(), 
true)]);
+        let mut struct_builder = StructBuilder::from_fields(fields, 1);
+        let mut row = SparkUnsafeRow::new_with_num_fields(1);
+        let data = [0; 8];
+        row.point_to_slice(&data);
+        append_field(&data_type, &mut struct_builder, &row, 0).expect("append 
field");
+        struct_builder.append_null();
+        let struct_array = struct_builder.finish();
+        assert_eq!(struct_array.len(), 1);
+        assert!(struct_array.is_null(0));
+    }
+}
diff --git 
a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala 
b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala
index 02801f8bc..2de3620b8 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala
@@ -19,6 +19,9 @@
 
 package org.apache.comet.exec
 
+import java.nio.file.Files
+import java.nio.file.Paths
+
 import scala.reflect.runtime.universe._
 import scala.util.Random
 
@@ -820,6 +823,33 @@ abstract class CometColumnarShuffleSuite extends 
CometTestBase with AdaptiveSpar
     }
   }
 
+  test("columnar shuffle on null struct fields") {
+    withTempDir { dir =>
+      val testData = "{}\n"
+      val path = Paths.get(dir.toString, "test.json")
+      Files.write(path, testData.getBytes)
+
+      // Define the nested struct schema
+      val readSchema = StructType(
+        Array(
+          StructField(
+            "metaData",
+            StructType(
+              Array(StructField(
+                "format",
+                StructType(Array(StructField("provider", StringType, nullable 
= true))),
+                nullable = true))),
+            nullable = true)))
+
+      // Read JSON with custom schema and repartition, this will repartition 
rows that contain
+      // null struct fields.
+      val df = 
spark.read.format("json").schema(readSchema).load(path.toString).repartition(2)
+      assert(df.count() == 1)
+      val row = df.collect()(0)
+      assert(row.getAs[org.apache.spark.sql.Row]("metaData") == null)
+    }
+  }
+
   /**
    * Checks that `df` produces the same answer as Spark does, and has the 
`expectedNum` Comet
    * exchange operators.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org

Reply via email to