This is an automated email from the ASF dual-hosted git repository.

comphead pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new f51cd6e789 Add DataFrame fill_null (#14769)
f51cd6e789 is described below

commit f51cd6e7893aa7221dac338bac0e6983ed8d6141
Author: kosiew <[email protected]>
AuthorDate: Thu Feb 27 00:35:48 2025 +0800

    Add DataFrame fill_null (#14769)
    
    * feat: add fill_null  methods to DataFrame for handling null values
    
    * test: refactor fill_null tests and create helper function for null table
    
    * style: reorder imports in mod.rs for better organization
    
    * clippy lint
    
    * test: add comment to clarify test
    
    * refactor: columns Vec<String>
    
    * docs: enhance fill_null documentation with example usage
    
    * test: columns Vec<String>
    
    * docs: update fill_null documentation with detailed usage examples
---
 datafusion/core/src/dataframe/mod.rs   | 97 ++++++++++++++++++++++++++++++++--
 datafusion/core/tests/dataframe/mod.rs | 94 ++++++++++++++++++++++++++++++++
 2 files changed, 186 insertions(+), 5 deletions(-)

diff --git a/datafusion/core/src/dataframe/mod.rs 
b/datafusion/core/src/dataframe/mod.rs
index b6949d2eea..d2aee0a161 100644
--- a/datafusion/core/src/dataframe/mod.rs
+++ b/datafusion/core/src/dataframe/mod.rs
@@ -51,14 +51,18 @@ use arrow::compute::{cast, concat};
 use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
 use datafusion_common::config::{CsvOptions, JsonOptions};
 use datafusion_common::{
-    exec_err, not_impl_err, plan_err, Column, DFSchema, DataFusionError, 
ParamValues,
-    SchemaError, UnnestOptions,
+    exec_err, not_impl_err, plan_datafusion_err, plan_err, Column, DFSchema,
+    DataFusionError, ParamValues, ScalarValue, SchemaError, UnnestOptions,
 };
-use datafusion_expr::dml::InsertOp;
-use datafusion_expr::{case, is_null, lit, SortExpr};
 use datafusion_expr::{
-    utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE,
+    case,
+    dml::InsertOp,
+    expr::{Alias, ScalarFunction},
+    is_null, lit,
+    utils::COUNT_STAR_EXPANSION,
+    SortExpr, TableProviderFilterPushDown, UNNAMED_TABLE,
 };
+use datafusion_functions::core::coalesce;
 use datafusion_functions_aggregate::expr_fn::{
     avg, count, max, median, min, stddev, sum,
 };
@@ -1930,6 +1934,89 @@ impl DataFrame {
             plan,
         })
     }
+
+    /// Fill null values in specified columns with a given value
+    /// If no columns are specified (empty vector), applies to all columns
+    /// Only fills if the value can be cast to the column's type
+    ///
+    /// # Arguments
+    /// * `value` - Value to fill nulls with
+    /// * `columns` - List of column names to fill. If empty, fills all 
columns.
+    ///
+    /// # Example
+    /// ```
+    /// # use datafusion::prelude::*;
+    /// # use datafusion::error::Result;
+    /// # use datafusion_common::ScalarValue;
+    /// # #[tokio::main]
+    /// # async fn main() -> Result<()> {
+    /// let ctx = SessionContext::new();
+    /// let df = ctx.read_csv("tests/data/example.csv", 
CsvReadOptions::new()).await?;
+    /// // Fill nulls in only columns "a" and "c":
+    /// let df = df.fill_null(ScalarValue::from(0), vec!["a".to_owned(), 
"c".to_owned()])?;
+    /// // Fill nulls across all columns:
+    /// let df = df.fill_null(ScalarValue::from(0), vec![])?;
+    /// # Ok(())
+    /// # }
+    /// ```
+    pub fn fill_null(
+        &self,
+        value: ScalarValue,
+        columns: Vec<String>,
+    ) -> Result<DataFrame> {
+        let cols = if columns.is_empty() {
+            self.logical_plan()
+                .schema()
+                .fields()
+                .iter()
+                .map(|f| f.as_ref().clone())
+                .collect()
+        } else {
+            self.find_columns(&columns)?
+        };
+
+        // Create projections for each column
+        let projections = self
+            .logical_plan()
+            .schema()
+            .fields()
+            .iter()
+            .map(|field| {
+                if cols.contains(field) {
+                    // Try to cast fill value to column type. If the cast 
fails, fallback to the original column.
+                    match value.clone().cast_to(field.data_type()) {
+                        Ok(fill_value) => Expr::Alias(Alias {
+                            expr: Box::new(Expr::ScalarFunction(ScalarFunction 
{
+                                func: coalesce(),
+                                args: vec![col(field.name()), lit(fill_value)],
+                            })),
+                            relation: None,
+                            name: field.name().to_string(),
+                        }),
+                        Err(_) => col(field.name()),
+                    }
+                } else {
+                    col(field.name())
+                }
+            })
+            .collect::<Vec<_>>();
+
+        self.clone().select(projections)
+    }
+
+    // Helper to find columns from names
+    fn find_columns(&self, names: &[String]) -> Result<Vec<Field>> {
+        let schema = self.logical_plan().schema();
+        names
+            .iter()
+            .map(|name| {
+                schema
+                    .field_with_name(None, name)
+                    .cloned()
+                    .map_err(|_| plan_datafusion_err!("Column '{}' not found", 
name))
+            })
+            .collect()
+    }
 }
 
 #[derive(Debug)]
diff --git a/datafusion/core/tests/dataframe/mod.rs 
b/datafusion/core/tests/dataframe/mod.rs
index 2b0be37d78..b134ec54b1 100644
--- a/datafusion/core/tests/dataframe/mod.rs
+++ b/datafusion/core/tests/dataframe/mod.rs
@@ -5342,3 +5342,97 @@ async fn test_insert_into_checking() -> Result<()> {
 
     Ok(())
 }
+
+async fn create_null_table() -> Result<DataFrame> {
+    // create a DataFrame with null values
+    //    "+---+----+",
+    //    "| a | b |",
+    //    "+---+---+",
+    //    "| 1 | x |",
+    //    "|   |   |",
+    //    "| 3 | z |",
+    //    "+---+---+",
+    let schema = Arc::new(Schema::new(vec![
+        Field::new("a", DataType::Int32, true),
+        Field::new("b", DataType::Utf8, true),
+    ]));
+    let a_values = Int32Array::from(vec![Some(1), None, Some(3)]);
+    let b_values = StringArray::from(vec![Some("x"), None, Some("z")]);
+    let batch = RecordBatch::try_new(
+        schema.clone(),
+        vec![Arc::new(a_values), Arc::new(b_values)],
+    )?;
+
+    let ctx = SessionContext::new();
+    let table = MemTable::try_new(schema.clone(), vec![vec![batch]])?;
+    ctx.register_table("t_null", Arc::new(table))?;
+    let df = ctx.table("t_null").await?;
+    Ok(df)
+}
+
+#[tokio::test]
+async fn test_fill_null() -> Result<()> {
+    let df = create_null_table().await?;
+
+    // Use fill_null to replace nulls on each column.
+    let df_filled = df
+        .fill_null(ScalarValue::Int32(Some(0)), vec!["a".to_string()])?
+        .fill_null(
+            ScalarValue::Utf8(Some("default".to_string())),
+            vec!["b".to_string()],
+        )?;
+
+    let results = df_filled.collect().await?;
+    let expected = [
+        "+---+---------+",
+        "| a | b       |",
+        "+---+---------+",
+        "| 1 | x       |",
+        "| 0 | default |",
+        "| 3 | z       |",
+        "+---+---------+",
+    ];
+    assert_batches_sorted_eq!(expected, &results);
+    Ok(())
+}
+
+#[tokio::test]
+async fn test_fill_null_all_columns() -> Result<()> {
+    let df = create_null_table().await?;
+
+    // Use fill_null to replace nulls on all columns.
+    // Only column "b" will be replaced since 
ScalarValue::Utf8(Some("default".to_string()))
+    // can be cast to Utf8.
+    let df_filled =
+        df.fill_null(ScalarValue::Utf8(Some("default".to_string())), vec![])?;
+
+    let results = df_filled.clone().collect().await?;
+
+    let expected = [
+        "+---+---------+",
+        "| a | b       |",
+        "+---+---------+",
+        "| 1 | x       |",
+        "|   | default |",
+        "| 3 | z       |",
+        "+---+---------+",
+    ];
+
+    assert_batches_sorted_eq!(expected, &results);
+
+    // Fill column "a" null values with a value that cannot be cast to Int32.
+    let df_filled = df_filled.fill_null(ScalarValue::Int32(Some(0)), vec![])?;
+
+    let results = df_filled.collect().await?;
+    let expected = [
+        "+---+---------+",
+        "| a | b       |",
+        "+---+---------+",
+        "| 1 | x       |",
+        "| 0 | default |",
+        "| 3 | z       |",
+        "+---+---------+",
+    ];
+    assert_batches_sorted_eq!(expected, &results);
+    Ok(())
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to