pepijnve commented on code in PR #19994:
URL: https://github.com/apache/datafusion/pull/19994#discussion_r2750218242
##########
datafusion/physical-expr/src/expressions/case.rs:
##########
@@ -1389,6 +1487,78 @@ fn replace_with_null(
Ok(with_null)
}
+fn safe_divide(
+ numerator: &ColumnarValue,
+ divisor: &ColumnarValue,
+) -> Result<ColumnarValue> {
+ if let ColumnarValue::Scalar(div_scalar) = divisor
+ && is_scalar_zero(div_scalar)
+ {
+ let data_type = numerator.data_type();
+ return match numerator {
+ ColumnarValue::Array(arr) => {
+ Ok(ColumnarValue::Array(new_null_array(&data_type, arr.len())))
+ }
+ ColumnarValue::Scalar(_) => Ok(ColumnarValue::Scalar(
+ ScalarValue::try_new_null(&data_type)?,
+ )),
+ };
+ }
+
+ let num_rows = match (numerator, divisor) {
+ (ColumnarValue::Array(arr), _) => arr.len(),
+ (_, ColumnarValue::Array(arr)) => arr.len(),
+ _ => 1,
+ };
+
+ let num_array = numerator.clone().into_array(num_rows)?;
+ let div_array = divisor.clone().into_array(num_rows)?;
+
+ let result = safe_divide_arrays(&num_array, &div_array)?;
+
+ if matches!(numerator, ColumnarValue::Scalar(_))
+ && matches!(divisor, ColumnarValue::Scalar(_))
+ {
+ Ok(ColumnarValue::Scalar(ScalarValue::try_from_array(
+ &result, 0,
+ )?))
+ } else {
+ Ok(ColumnarValue::Array(result))
+ }
+}
+
+fn safe_divide_arrays(numerator: &ArrayRef, divisor: &ArrayRef) ->
Result<ArrayRef> {
+ use arrow::compute::kernels::cmp::eq;
+ use arrow::compute::kernels::numeric::div;
+
+ let zero = ScalarValue::new_zero(divisor.data_type())?.to_scalar()?;
+ let zero_mask = eq(divisor, &zero)?;
Review Comment:
I think you need to make sure you use the condition from the when expression
here in order to get the correct result. I added these SLTs and they were not
all passing.
```
query I
SELECT CASE WHEN d != 0 THEN n / d ELSE NULL END FROM (VALUES (1, 1), (1,
0), (1, -1)) t(n,d)
----
1
NULL
-1
query I
SELECT CASE WHEN d > 0 THEN n / d ELSE NULL END FROM (VALUES (1, 1), (1, 0),
(1, -1)) t(n,d)
----
1
NULL
NULL
query I
SELECT CASE WHEN d < 0 THEN n / d ELSE NULL END FROM (VALUES (1, 1), (1, 0),
(1, -1)) t(n,d)
----
NULL
NULL
-1
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]