This is an automated email from the ASF dual-hosted git repository.
github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new ea9f354123 fix: derive custom nullability for spark `bit_shift`
(#19222)
ea9f354123 is described below
commit ea9f354123715cccef84fe0e706a5f09f3eae21c
Author: Kumar Ujjawal <[email protected]>
AuthorDate: Tue Dec 9 18:04:11 2025 +0530
fix: derive custom nullability for spark `bit_shift` (#19222)
## Which issue does this PR close?
<!--
We generally require a GitHub issue to be filed for all bug fixes and
enhancements and this helps us generate change logs for our releases.
You can link an issue to this PR using the GitHub syntax. For example
`Closes #123` indicates that this PR will close issue #123.
-->
- Closes #19149.
- Part of #19144
## Rationale for this change
As stated in the original issue the UDF uses the default is_nullable
which is always true which is not the case.
<!--
Why are you proposing this change? If this is already explained clearly
in the issue then this section is not needed.
Explaining clearly why changes are proposed helps reviewers understand
your changes and offer better suggestions for fixes.
-->
## What changes are included in this PR?
- `bit_shift` now reports schema using `return_field_from_args`
- Added unit tests
<!--
There is no need to duplicate the description in the issue here but it
is sometimes worth providing a summary of the individual changes in this
PR.
-->
## Are these changes tested?
- All original tests pass
- Added new unit tests for the changes
<!--
We typically require tests for all PRs in order to:
1. Prevent the code from being accidentally broken by subsequent changes
2. Serve as another way to document the expected behavior of the code
If tests are not included in your PR, please explain why (for example,
are they covered by existing tests)?
-->
## Are there any user-facing changes?
<!--
If there are user-facing changes then we may require documentation to be
updated before approving the PR.
-->
<!--
If there are any breaking changes to public APIs, please add the `api
change` label.
-->
---
datafusion/spark/src/function/bitwise/bit_shift.rs | 72 ++++++++++++++++++++--
1 file changed, 67 insertions(+), 5 deletions(-)
diff --git a/datafusion/spark/src/function/bitwise/bit_shift.rs
b/datafusion/spark/src/function/bitwise/bit_shift.rs
index 65df048580..ff7f7662ec 100644
--- a/datafusion/spark/src/function/bitwise/bit_shift.rs
+++ b/datafusion/spark/src/function/bitwise/bit_shift.rs
@@ -21,7 +21,8 @@ use std::sync::Arc;
use arrow::array::{ArrayRef, ArrowPrimitiveType, AsArray, Int32Array,
PrimitiveArray};
use arrow::compute;
use arrow::datatypes::{
- ArrowNativeType, DataType, Int32Type, Int64Type, UInt32Type, UInt64Type,
+ ArrowNativeType, DataType, Field, FieldRef, Int32Type, Int64Type,
UInt32Type,
+ UInt64Type,
};
use datafusion_common::types::{
logical_int16, logical_int32, logical_int64, logical_int8, logical_uint16,
@@ -30,8 +31,8 @@ use datafusion_common::types::{
use datafusion_common::utils::take_function_args;
use datafusion_common::{internal_err, Result};
use datafusion_expr::{
- Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature,
TypeSignature,
- TypeSignatureClass, Volatility,
+ Coercion, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs,
ScalarUDFImpl,
+ Signature, TypeSignature, TypeSignatureClass, Volatility,
};
use datafusion_functions::utils::make_scalar_function;
@@ -275,8 +276,14 @@ impl ScalarUDFImpl for SparkBitShift {
&self.signature
}
- fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
- Ok(arg_types[0].clone())
+ fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
+ internal_err!("return_field_from_args should be used instead")
+ }
+
+ fn return_field_from_args(&self, args: ReturnFieldArgs) ->
Result<FieldRef> {
+ let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
+ let data_type = args.arg_fields[0].data_type().clone();
+ Ok(Arc::new(Field::new(self.name(), data_type, nullable)))
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) ->
Result<ColumnarValue> {
@@ -286,3 +293,58 @@ impl ScalarUDFImpl for SparkBitShift {
make_scalar_function(inner, vec![])(&args.args)
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use arrow::datatypes::Field;
+ use datafusion_expr::ReturnFieldArgs;
+
+ #[test]
+ fn test_bit_shift_nullability() -> Result<()> {
+ let func = SparkBitShift::left();
+
+ let non_nullable_value: FieldRef =
+ Arc::new(Field::new("value", DataType::Int64, false));
+ let non_nullable_shift: FieldRef =
+ Arc::new(Field::new("shift", DataType::Int32, false));
+
+ let out = func.return_field_from_args(ReturnFieldArgs {
+ arg_fields: &[
+ Arc::clone(&non_nullable_value),
+ Arc::clone(&non_nullable_shift),
+ ],
+ scalar_arguments: &[None, None],
+ })?;
+
+ assert_eq!(out.data_type(), non_nullable_value.data_type());
+ assert!(
+ !out.is_nullable(),
+ "shift result should be non-nullable when both inputs are
non-nullable"
+ );
+
+ let nullable_value: FieldRef =
+ Arc::new(Field::new("value", DataType::Int64, true));
+ let out_nullable_value = func.return_field_from_args(ReturnFieldArgs {
+ arg_fields: &[Arc::clone(&nullable_value),
Arc::clone(&non_nullable_shift)],
+ scalar_arguments: &[None, None],
+ })?;
+ assert!(
+ out_nullable_value.is_nullable(),
+ "shift result should be nullable when value is nullable"
+ );
+
+ let nullable_shift: FieldRef =
+ Arc::new(Field::new("shift", DataType::Int32, true));
+ let out_nullable_shift = func.return_field_from_args(ReturnFieldArgs {
+ arg_fields: &[non_nullable_value, nullable_shift],
+ scalar_arguments: &[None, None],
+ })?;
+ assert!(
+ out_nullable_shift.is_nullable(),
+ "shift result should be nullable when shift is nullable"
+ );
+
+ Ok(())
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]