martin-g commented on code in PR #18837:
URL: https://github.com/apache/datafusion/pull/18837#discussion_r2546311350


##########
datafusion/functions-aggregate/src/percentile_cont.rs:
##########
@@ -367,6 +377,150 @@ impl AggregateUDFImpl for PercentileCont {
     }
 }
 
+const PERCENTILE_LITERAL_EPSILON: f64 = 1e-12;
+
+#[derive(Clone, Copy)]
+enum PercentileRewriteTarget {
+    Min,
+    Max,
+}
+
+#[allow(clippy::needless_pass_by_value)]

Review Comment:
   ```suggestion
   #[expect(clippy::needless_pass_by_value)]
   ```



##########
datafusion/functions-aggregate/src/percentile_cont.rs:
##########
@@ -367,6 +377,150 @@ impl AggregateUDFImpl for PercentileCont {
     }
 }
 
+const PERCENTILE_LITERAL_EPSILON: f64 = 1e-12;
+
+#[derive(Clone, Copy)]
+enum PercentileRewriteTarget {
+    Min,
+    Max,
+}
+
+#[allow(clippy::needless_pass_by_value)]
+fn simplify_percentile_cont_aggregate(
+    aggregate_function: AggregateFunction,
+    info: &dyn SimplifyInfo,
+) -> Result<Expr> {
+    let original_expr = Expr::AggregateFunction(aggregate_function.clone());
+    let params = &aggregate_function.params;
+
+    if params.args.len() != 2 {
+        return Ok(original_expr);
+    }
+
+    let percentile_value = match extract_percentile_literal(&params.args[1]) {
+        Some(value) => value,
+        None => return Ok(original_expr),
+    };
+
+    let is_descending = params
+        .order_by
+        .first()
+        .map(|sort| !sort.asc)
+        .unwrap_or(false);
+
+    let rewrite_target = match classify_rewrite_target(percentile_value, 
is_descending) {
+        Some(target) => target,
+        None => return Ok(original_expr),
+    };
+
+    let value_expr = params.args[0].clone();
+    let input_type = match info.get_data_type(&value_expr) {
+        Ok(data_type) => data_type,
+        Err(_) => return Ok(original_expr),
+    };
+
+    let expected_return_type = match percentile_cont_result_type(&input_type) {
+        Some(data_type) => data_type,
+        None => return Ok(original_expr),
+    };
+
+    let udaf = match rewrite_target {
+        PercentileRewriteTarget::Min => min_udaf(),
+        PercentileRewriteTarget::Max => max_udaf(),
+    };
+
+    let mut agg_arg = value_expr;
+    if expected_return_type != input_type {
+        agg_arg = Expr::Cast(Cast::new(Box::new(agg_arg), 
expected_return_type.clone()));
+    }
+
+    let rewritten = Expr::AggregateFunction(AggregateFunction::new_udf(
+        udaf,
+        vec![agg_arg],
+        params.distinct,
+        params.filter.clone(),
+        vec![],
+        params.null_treatment,
+    ));
+    Ok(rewritten)
+}
+
+fn classify_rewrite_target(
+    percentile_value: f64,
+    is_descending: bool,
+) -> Option<PercentileRewriteTarget> {
+    if nearly_equals_fraction(percentile_value, 0.0) {
+        Some(if is_descending {
+            PercentileRewriteTarget::Max
+        } else {
+            PercentileRewriteTarget::Min
+        })
+    } else if nearly_equals_fraction(percentile_value, 1.0) {
+        Some(if is_descending {
+            PercentileRewriteTarget::Min
+        } else {
+            PercentileRewriteTarget::Max
+        })
+    } else {
+        None
+    }
+}
+
+fn nearly_equals_fraction(value: f64, target: f64) -> bool {
+    (value - target).abs() < PERCENTILE_LITERAL_EPSILON
+}
+
+fn percentile_cont_result_type(input_type: &DataType) -> Option<DataType> {
+    if !input_type.is_numeric() {
+        return None;
+    }
+
+    let result_type = match input_type {
+        DataType::Float16 | DataType::Float32 | DataType::Float64 => 
input_type.clone(),
+        DataType::Decimal32(_, _)
+        | DataType::Decimal64(_, _)
+        | DataType::Decimal128(_, _)
+        | DataType::Decimal256(_, _) => input_type.clone(),
+        DataType::UInt8
+        | DataType::UInt16
+        | DataType::UInt32
+        | DataType::UInt64
+        | DataType::Int8
+        | DataType::Int16
+        | DataType::Int32
+        | DataType::Int64 => DataType::Float64,
+        _ => return None,
+    };
+
+    Some(result_type)
+}
+
+fn extract_percentile_literal(expr: &Expr) -> Option<f64> {
+    match expr {
+        Expr::Literal(value, _) => literal_scalar_to_f64(value),
+        Expr::Alias(alias) => extract_percentile_literal(alias.expr.as_ref()),
+        Expr::Cast(cast) => extract_percentile_literal(cast.expr.as_ref()),
+        Expr::TryCast(cast) => extract_percentile_literal(cast.expr.as_ref()),
+        _ => None,
+    }
+}
+
+fn literal_scalar_to_f64(value: &ScalarValue) -> Option<f64> {
+    match value {
+        ScalarValue::Float64(Some(v)) => Some(*v),
+        ScalarValue::Float32(Some(v)) => Some(*v as f64),

Review Comment:
   ```suggestion
           ScalarValue::Float32(Some(v)) => Some(*v as f64),
           ScalarValue::Float16(Some(v)) => Some(v.to_f64()),
   ```



##########
datafusion/functions-aggregate/src/percentile_cont.rs:
##########
@@ -367,6 +377,150 @@ impl AggregateUDFImpl for PercentileCont {
     }
 }
 
+const PERCENTILE_LITERAL_EPSILON: f64 = 1e-12;
+
+#[derive(Clone, Copy)]
+enum PercentileRewriteTarget {
+    Min,
+    Max,
+}
+
+#[allow(clippy::needless_pass_by_value)]
+fn simplify_percentile_cont_aggregate(
+    aggregate_function: AggregateFunction,
+    info: &dyn SimplifyInfo,
+) -> Result<Expr> {
+    let original_expr = Expr::AggregateFunction(aggregate_function.clone());
+    let params = &aggregate_function.params;
+
+    if params.args.len() != 2 {
+        return Ok(original_expr);
+    }
+
+    let percentile_value = match extract_percentile_literal(&params.args[1]) {
+        Some(value) => value,

Review Comment:
   ```suggestion
           Some(value) if value >= 0.0 && value <= 1.0 => value,
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to