Jefffrey commented on code in PR #18837:
URL: https://github.com/apache/datafusion/pull/18837#discussion_r2548304856


##########
datafusion/functions-aggregate/src/percentile_cont.rs:
##########
@@ -367,6 +377,151 @@ impl AggregateUDFImpl for PercentileCont {
     }
 }
 
+const PERCENTILE_LITERAL_EPSILON: f64 = 1e-12;
+
+#[derive(Clone, Copy)]
+enum PercentileRewriteTarget {
+    Min,
+    Max,
+}
+
+#[expect(clippy::needless_pass_by_value)]
+fn simplify_percentile_cont_aggregate(
+    aggregate_function: AggregateFunction,
+    info: &dyn SimplifyInfo,
+) -> Result<Expr> {
+    let original_expr = Expr::AggregateFunction(aggregate_function.clone());
+    let params = &aggregate_function.params;
+
+    if params.args.len() != 2 {
+        return Ok(original_expr);
+    }
+
+    let percentile_value = match extract_percentile_literal(&params.args[1]) {
+        Some(value) if (0.0..=1.0).contains(&value) => value,
+        _ => return Ok(original_expr),
+    };
+
+    let is_descending = params
+        .order_by
+        .first()
+        .map(|sort| !sort.asc)
+        .unwrap_or(false);
+
+    let rewrite_target = match classify_rewrite_target(percentile_value, 
is_descending) {
+        Some(target) => target,
+        None => return Ok(original_expr),
+    };
+
+    let value_expr = params.args[0].clone();
+    let input_type = match info.get_data_type(&value_expr) {
+        Ok(data_type) => data_type,
+        Err(_) => return Ok(original_expr),
+    };
+
+    let expected_return_type = match percentile_cont_result_type(&input_type) {
+        Some(data_type) => data_type,
+        None => return Ok(original_expr),
+    };
+
+    let udaf = match rewrite_target {
+        PercentileRewriteTarget::Min => min_udaf(),
+        PercentileRewriteTarget::Max => max_udaf(),
+    };
+
+    let mut agg_arg = value_expr;
+    if expected_return_type != input_type {
+        agg_arg = Expr::Cast(Cast::new(Box::new(agg_arg), 
expected_return_type.clone()));
+    }
+
+    let rewritten = Expr::AggregateFunction(AggregateFunction::new_udf(
+        udaf,
+        vec![agg_arg],
+        params.distinct,
+        params.filter.clone(),
+        vec![],
+        params.null_treatment,
+    ));
+    Ok(rewritten)
+}
+
+fn classify_rewrite_target(
+    percentile_value: f64,
+    is_descending: bool,
+) -> Option<PercentileRewriteTarget> {
+    if nearly_equals_fraction(percentile_value, 0.0) {
+        Some(if is_descending {
+            PercentileRewriteTarget::Max
+        } else {
+            PercentileRewriteTarget::Min
+        })
+    } else if nearly_equals_fraction(percentile_value, 1.0) {
+        Some(if is_descending {
+            PercentileRewriteTarget::Min
+        } else {
+            PercentileRewriteTarget::Max
+        })
+    } else {
+        None
+    }
+}
+
+fn nearly_equals_fraction(value: f64, target: f64) -> bool {
+    (value - target).abs() < PERCENTILE_LITERAL_EPSILON
+}

Review Comment:
   I'm personally of the mind to check directly against 0.0 and 1.0 instead of 
doing an epsilon check; I think it's more likely a user would input an expr 
like `SELECT percentile_cont(column1, 0.0)` than doing something like `SELECT 
percentile_cont(column1, expr)` where `expr` might be some math that could make 
it `0.0000001` 🤔 



##########
datafusion/functions-aggregate/src/percentile_cont.rs:
##########
@@ -367,6 +377,151 @@ impl AggregateUDFImpl for PercentileCont {
     }
 }
 
+const PERCENTILE_LITERAL_EPSILON: f64 = 1e-12;
+
+#[derive(Clone, Copy)]
+enum PercentileRewriteTarget {
+    Min,
+    Max,
+}
+
+#[expect(clippy::needless_pass_by_value)]
+fn simplify_percentile_cont_aggregate(
+    aggregate_function: AggregateFunction,
+    info: &dyn SimplifyInfo,
+) -> Result<Expr> {
+    let original_expr = Expr::AggregateFunction(aggregate_function.clone());
+    let params = &aggregate_function.params;
+
+    if params.args.len() != 2 {
+        return Ok(original_expr);
+    }
+
+    let percentile_value = match extract_percentile_literal(&params.args[1]) {
+        Some(value) if (0.0..=1.0).contains(&value) => value,
+        _ => return Ok(original_expr),
+    };
+
+    let is_descending = params
+        .order_by
+        .first()
+        .map(|sort| !sort.asc)
+        .unwrap_or(false);
+
+    let rewrite_target = match classify_rewrite_target(percentile_value, 
is_descending) {
+        Some(target) => target,
+        None => return Ok(original_expr),
+    };
+
+    let value_expr = params.args[0].clone();
+    let input_type = match info.get_data_type(&value_expr) {
+        Ok(data_type) => data_type,
+        Err(_) => return Ok(original_expr),
+    };
+
+    let expected_return_type = match percentile_cont_result_type(&input_type) {
+        Some(data_type) => data_type,
+        None => return Ok(original_expr),
+    };
+
+    let udaf = match rewrite_target {
+        PercentileRewriteTarget::Min => min_udaf(),
+        PercentileRewriteTarget::Max => max_udaf(),
+    };
+
+    let mut agg_arg = value_expr;
+    if expected_return_type != input_type {
+        agg_arg = Expr::Cast(Cast::new(Box::new(agg_arg), 
expected_return_type.clone()));
+    }

Review Comment:
   Can we explain why this is necessary in a comment here?



##########
datafusion/functions-aggregate/src/percentile_cont.rs:
##########
@@ -367,6 +377,151 @@ impl AggregateUDFImpl for PercentileCont {
     }
 }
 
+const PERCENTILE_LITERAL_EPSILON: f64 = 1e-12;
+
+#[derive(Clone, Copy)]
+enum PercentileRewriteTarget {
+    Min,
+    Max,
+}
+
+#[expect(clippy::needless_pass_by_value)]
+fn simplify_percentile_cont_aggregate(
+    aggregate_function: AggregateFunction,
+    info: &dyn SimplifyInfo,
+) -> Result<Expr> {
+    let original_expr = Expr::AggregateFunction(aggregate_function.clone());
+    let params = &aggregate_function.params;
+
+    if params.args.len() != 2 {
+        return Ok(original_expr);
+    }

Review Comment:
   ```suggestion
       let [value, percentile] = take_function_args("percentile_cont", 
&params.args)?;
   ```
   
   More ergonomic this way; technically this error path should never occur as 
the signature should already guard us by now.



##########
datafusion/functions-aggregate/src/percentile_cont.rs:
##########
@@ -367,6 +377,151 @@ impl AggregateUDFImpl for PercentileCont {
     }
 }
 
+const PERCENTILE_LITERAL_EPSILON: f64 = 1e-12;
+
+#[derive(Clone, Copy)]
+enum PercentileRewriteTarget {
+    Min,
+    Max,
+}
+
+#[expect(clippy::needless_pass_by_value)]
+fn simplify_percentile_cont_aggregate(
+    aggregate_function: AggregateFunction,
+    info: &dyn SimplifyInfo,
+) -> Result<Expr> {
+    let original_expr = Expr::AggregateFunction(aggregate_function.clone());
+    let params = &aggregate_function.params;
+
+    if params.args.len() != 2 {
+        return Ok(original_expr);
+    }
+
+    let percentile_value = match extract_percentile_literal(&params.args[1]) {
+        Some(value) if (0.0..=1.0).contains(&value) => value,
+        _ => return Ok(original_expr),
+    };
+
+    let is_descending = params
+        .order_by
+        .first()
+        .map(|sort| !sort.asc)
+        .unwrap_or(false);
+
+    let rewrite_target = match classify_rewrite_target(percentile_value, 
is_descending) {
+        Some(target) => target,
+        None => return Ok(original_expr),
+    };
+
+    let value_expr = params.args[0].clone();
+    let input_type = match info.get_data_type(&value_expr) {
+        Ok(data_type) => data_type,
+        Err(_) => return Ok(original_expr),
+    };
+
+    let expected_return_type = match percentile_cont_result_type(&input_type) {
+        Some(data_type) => data_type,
+        None => return Ok(original_expr),
+    };
+
+    let udaf = match rewrite_target {
+        PercentileRewriteTarget::Min => min_udaf(),
+        PercentileRewriteTarget::Max => max_udaf(),
+    };
+
+    let mut agg_arg = value_expr;
+    if expected_return_type != input_type {
+        agg_arg = Expr::Cast(Cast::new(Box::new(agg_arg), 
expected_return_type.clone()));
+    }
+
+    let rewritten = Expr::AggregateFunction(AggregateFunction::new_udf(
+        udaf,
+        vec![agg_arg],
+        params.distinct,
+        params.filter.clone(),
+        vec![],
+        params.null_treatment,
+    ));
+    Ok(rewritten)
+}
+
+fn classify_rewrite_target(
+    percentile_value: f64,
+    is_descending: bool,
+) -> Option<PercentileRewriteTarget> {
+    if nearly_equals_fraction(percentile_value, 0.0) {
+        Some(if is_descending {
+            PercentileRewriteTarget::Max
+        } else {
+            PercentileRewriteTarget::Min
+        })
+    } else if nearly_equals_fraction(percentile_value, 1.0) {
+        Some(if is_descending {
+            PercentileRewriteTarget::Min
+        } else {
+            PercentileRewriteTarget::Max
+        })
+    } else {
+        None
+    }
+}
+
+fn nearly_equals_fraction(value: f64, target: f64) -> bool {
+    (value - target).abs() < PERCENTILE_LITERAL_EPSILON
+}
+
+fn percentile_cont_result_type(input_type: &DataType) -> Option<DataType> {
+    if !input_type.is_numeric() {
+        return None;
+    }
+
+    let result_type = match input_type {
+        DataType::Float16 | DataType::Float32 | DataType::Float64 => 
input_type.clone(),
+        DataType::Decimal32(_, _)
+        | DataType::Decimal64(_, _)
+        | DataType::Decimal128(_, _)
+        | DataType::Decimal256(_, _) => input_type.clone(),
+        DataType::UInt8
+        | DataType::UInt16
+        | DataType::UInt32
+        | DataType::UInt64
+        | DataType::Int8
+        | DataType::Int16
+        | DataType::Int32
+        | DataType::Int64 => DataType::Float64,
+        _ => return None,
+    };
+
+    Some(result_type)
+}
+
+fn extract_percentile_literal(expr: &Expr) -> Option<f64> {
+    match expr {
+        Expr::Literal(value, _) => literal_scalar_to_f64(value),
+        Expr::Alias(alias) => extract_percentile_literal(alias.expr.as_ref()),
+        Expr::Cast(cast) => extract_percentile_literal(cast.expr.as_ref()),
+        Expr::TryCast(cast) => extract_percentile_literal(cast.expr.as_ref()),

Review Comment:
   How strictly necessary are these other arms? Is checking only for `Literal` 
not sufficient?



##########
datafusion/functions-aggregate/src/percentile_cont.rs:
##########
@@ -367,6 +377,151 @@ impl AggregateUDFImpl for PercentileCont {
     }
 }
 
+const PERCENTILE_LITERAL_EPSILON: f64 = 1e-12;
+
+#[derive(Clone, Copy)]
+enum PercentileRewriteTarget {
+    Min,
+    Max,
+}
+
+#[expect(clippy::needless_pass_by_value)]
+fn simplify_percentile_cont_aggregate(
+    aggregate_function: AggregateFunction,
+    info: &dyn SimplifyInfo,
+) -> Result<Expr> {
+    let original_expr = Expr::AggregateFunction(aggregate_function.clone());
+    let params = &aggregate_function.params;
+
+    if params.args.len() != 2 {
+        return Ok(original_expr);
+    }
+
+    let percentile_value = match extract_percentile_literal(&params.args[1]) {
+        Some(value) if (0.0..=1.0).contains(&value) => value,
+        _ => return Ok(original_expr),
+    };
+
+    let is_descending = params
+        .order_by
+        .first()
+        .map(|sort| !sort.asc)
+        .unwrap_or(false);
+
+    let rewrite_target = match classify_rewrite_target(percentile_value, 
is_descending) {
+        Some(target) => target,
+        None => return Ok(original_expr),
+    };
+
+    let value_expr = params.args[0].clone();
+    let input_type = match info.get_data_type(&value_expr) {
+        Ok(data_type) => data_type,
+        Err(_) => return Ok(original_expr),
+    };
+
+    let expected_return_type = match percentile_cont_result_type(&input_type) {
+        Some(data_type) => data_type,
+        None => return Ok(original_expr),
+    };
+
+    let udaf = match rewrite_target {
+        PercentileRewriteTarget::Min => min_udaf(),
+        PercentileRewriteTarget::Max => max_udaf(),
+    };
+
+    let mut agg_arg = value_expr;
+    if expected_return_type != input_type {
+        agg_arg = Expr::Cast(Cast::new(Box::new(agg_arg), 
expected_return_type.clone()));
+    }
+
+    let rewritten = Expr::AggregateFunction(AggregateFunction::new_udf(
+        udaf,
+        vec![agg_arg],
+        params.distinct,
+        params.filter.clone(),
+        vec![],
+        params.null_treatment,
+    ));
+    Ok(rewritten)
+}
+
+fn classify_rewrite_target(
+    percentile_value: f64,
+    is_descending: bool,
+) -> Option<PercentileRewriteTarget> {
+    if nearly_equals_fraction(percentile_value, 0.0) {
+        Some(if is_descending {
+            PercentileRewriteTarget::Max
+        } else {
+            PercentileRewriteTarget::Min
+        })
+    } else if nearly_equals_fraction(percentile_value, 1.0) {
+        Some(if is_descending {
+            PercentileRewriteTarget::Min
+        } else {
+            PercentileRewriteTarget::Max
+        })
+    } else {
+        None
+    }
+}
+
+fn nearly_equals_fraction(value: f64, target: f64) -> bool {
+    (value - target).abs() < PERCENTILE_LITERAL_EPSILON
+}
+
+fn percentile_cont_result_type(input_type: &DataType) -> Option<DataType> {

Review Comment:
   We should reuse the code from `return_type` if possible instead of 
duplicating it here
   
   
https://github.com/apache/datafusion/blob/f1ecaccd183367086ecb5b7736d93b3aba109e01/datafusion/functions-aggregate/src/percentile_cont.rs#L232-L261



##########
datafusion/functions-aggregate/src/percentile_cont.rs:
##########
@@ -367,6 +377,151 @@ impl AggregateUDFImpl for PercentileCont {
     }
 }
 
+const PERCENTILE_LITERAL_EPSILON: f64 = 1e-12;
+
+#[derive(Clone, Copy)]
+enum PercentileRewriteTarget {
+    Min,
+    Max,
+}
+
+#[expect(clippy::needless_pass_by_value)]
+fn simplify_percentile_cont_aggregate(
+    aggregate_function: AggregateFunction,
+    info: &dyn SimplifyInfo,
+) -> Result<Expr> {
+    let original_expr = Expr::AggregateFunction(aggregate_function.clone());
+    let params = &aggregate_function.params;
+
+    if params.args.len() != 2 {
+        return Ok(original_expr);
+    }
+
+    let percentile_value = match extract_percentile_literal(&params.args[1]) {
+        Some(value) if (0.0..=1.0).contains(&value) => value,
+        _ => return Ok(original_expr),
+    };
+
+    let is_descending = params
+        .order_by
+        .first()
+        .map(|sort| !sort.asc)
+        .unwrap_or(false);
+
+    let rewrite_target = match classify_rewrite_target(percentile_value, 
is_descending) {
+        Some(target) => target,
+        None => return Ok(original_expr),
+    };
+
+    let value_expr = params.args[0].clone();
+    let input_type = match info.get_data_type(&value_expr) {
+        Ok(data_type) => data_type,
+        Err(_) => return Ok(original_expr),
+    };
+
+    let expected_return_type = match percentile_cont_result_type(&input_type) {
+        Some(data_type) => data_type,
+        None => return Ok(original_expr),
+    };
+
+    let udaf = match rewrite_target {
+        PercentileRewriteTarget::Min => min_udaf(),
+        PercentileRewriteTarget::Max => max_udaf(),
+    };
+
+    let mut agg_arg = value_expr;
+    if expected_return_type != input_type {
+        agg_arg = Expr::Cast(Cast::new(Box::new(agg_arg), 
expected_return_type.clone()));
+    }
+
+    let rewritten = Expr::AggregateFunction(AggregateFunction::new_udf(
+        udaf,
+        vec![agg_arg],
+        params.distinct,
+        params.filter.clone(),
+        vec![],
+        params.null_treatment,
+    ));
+    Ok(rewritten)
+}
+
+fn classify_rewrite_target(
+    percentile_value: f64,
+    is_descending: bool,
+) -> Option<PercentileRewriteTarget> {
+    if nearly_equals_fraction(percentile_value, 0.0) {
+        Some(if is_descending {
+            PercentileRewriteTarget::Max
+        } else {
+            PercentileRewriteTarget::Min
+        })
+    } else if nearly_equals_fraction(percentile_value, 1.0) {
+        Some(if is_descending {
+            PercentileRewriteTarget::Min
+        } else {
+            PercentileRewriteTarget::Max
+        })
+    } else {
+        None
+    }
+}
+
+fn nearly_equals_fraction(value: f64, target: f64) -> bool {
+    (value - target).abs() < PERCENTILE_LITERAL_EPSILON
+}
+
+fn percentile_cont_result_type(input_type: &DataType) -> Option<DataType> {
+    if !input_type.is_numeric() {
+        return None;
+    }
+
+    let result_type = match input_type {
+        DataType::Float16 | DataType::Float32 | DataType::Float64 => 
input_type.clone(),
+        DataType::Decimal32(_, _)
+        | DataType::Decimal64(_, _)
+        | DataType::Decimal128(_, _)
+        | DataType::Decimal256(_, _) => input_type.clone(),
+        DataType::UInt8
+        | DataType::UInt16
+        | DataType::UInt32
+        | DataType::UInt64
+        | DataType::Int8
+        | DataType::Int16
+        | DataType::Int32
+        | DataType::Int64 => DataType::Float64,
+        _ => return None,
+    };
+
+    Some(result_type)
+}
+
+fn extract_percentile_literal(expr: &Expr) -> Option<f64> {
+    match expr {
+        Expr::Literal(value, _) => literal_scalar_to_f64(value),
+        Expr::Alias(alias) => extract_percentile_literal(alias.expr.as_ref()),
+        Expr::Cast(cast) => extract_percentile_literal(cast.expr.as_ref()),
+        Expr::TryCast(cast) => extract_percentile_literal(cast.expr.as_ref()),
+        _ => None,
+    }
+}
+
+fn literal_scalar_to_f64(value: &ScalarValue) -> Option<f64> {

Review Comment:
   Can we have percentiles that are not of type `Flaot64`? I thought the 
signature guarded us against this
   
   
https://github.com/apache/datafusion/blob/f1ecaccd183367086ecb5b7736d93b3aba109e01/datafusion/functions-aggregate/src/percentile_cont.rs#L142-L154



##########
datafusion/functions-aggregate/src/percentile_cont.rs:
##########
@@ -367,6 +377,151 @@ impl AggregateUDFImpl for PercentileCont {
     }
 }
 
+const PERCENTILE_LITERAL_EPSILON: f64 = 1e-12;
+
+#[derive(Clone, Copy)]
+enum PercentileRewriteTarget {
+    Min,
+    Max,
+}
+
+#[expect(clippy::needless_pass_by_value)]
+fn simplify_percentile_cont_aggregate(
+    aggregate_function: AggregateFunction,
+    info: &dyn SimplifyInfo,
+) -> Result<Expr> {
+    let original_expr = Expr::AggregateFunction(aggregate_function.clone());
+    let params = &aggregate_function.params;
+
+    if params.args.len() != 2 {
+        return Ok(original_expr);
+    }
+
+    let percentile_value = match extract_percentile_literal(&params.args[1]) {
+        Some(value) if (0.0..=1.0).contains(&value) => value,
+        _ => return Ok(original_expr),
+    };
+
+    let is_descending = params
+        .order_by
+        .first()
+        .map(|sort| !sort.asc)
+        .unwrap_or(false);
+
+    let rewrite_target = match classify_rewrite_target(percentile_value, 
is_descending) {
+        Some(target) => target,
+        None => return Ok(original_expr),
+    };
+
+    let value_expr = params.args[0].clone();
+    let input_type = match info.get_data_type(&value_expr) {
+        Ok(data_type) => data_type,
+        Err(_) => return Ok(original_expr),
+    };

Review Comment:
   ```suggestion
       let input_type = match info.get_data_type(&value_expr)?;
   ```



##########
datafusion/functions-aggregate/src/percentile_cont.rs:
##########
@@ -760,3 +914,80 @@ fn calculate_percentile<T: ArrowNumericType>(
         }
     }
 }
+
+#[cfg(test)]
+mod tests {

Review Comment:
   We should remove the unit tests if they duplicate the sqllogictests



##########
datafusion/functions-aggregate/src/percentile_cont.rs:
##########
@@ -367,6 +377,151 @@ impl AggregateUDFImpl for PercentileCont {
     }
 }
 
+const PERCENTILE_LITERAL_EPSILON: f64 = 1e-12;
+
+#[derive(Clone, Copy)]
+enum PercentileRewriteTarget {
+    Min,
+    Max,
+}
+
+#[expect(clippy::needless_pass_by_value)]
+fn simplify_percentile_cont_aggregate(
+    aggregate_function: AggregateFunction,
+    info: &dyn SimplifyInfo,
+) -> Result<Expr> {
+    let original_expr = Expr::AggregateFunction(aggregate_function.clone());
+    let params = &aggregate_function.params;
+
+    if params.args.len() != 2 {
+        return Ok(original_expr);
+    }
+
+    let percentile_value = match extract_percentile_literal(&params.args[1]) {
+        Some(value) if (0.0..=1.0).contains(&value) => value,
+        _ => return Ok(original_expr),
+    };
+
+    let is_descending = params
+        .order_by
+        .first()
+        .map(|sort| !sort.asc)
+        .unwrap_or(false);
+
+    let rewrite_target = match classify_rewrite_target(percentile_value, 
is_descending) {
+        Some(target) => target,
+        None => return Ok(original_expr),
+    };

Review Comment:
   I feel this should be folded directly into line 400 above, instead of 
splitting it like this



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to