Jefffrey commented on code in PR #17729: URL: https://github.com/apache/datafusion/pull/17729#discussion_r2377197488
########## datafusion/spark/src/function/string/elt.rs: ########## @@ -0,0 +1,257 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ + Array, ArrayRef, AsArray, PrimitiveArray, StringArray, StringBuilder, +}; +use arrow::compute::cast; +use arrow::datatypes::DataType::Utf8; +use arrow::datatypes::{DataType, Int32Type, Int64Type}; +use datafusion_common::{plan_datafusion_err, DataFusionError, Result}; +use datafusion_expr::Volatility::Immutable; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature}; +use datafusion_functions::utils::make_scalar_function; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkElt { + signature: Signature, +} + +impl Default for SparkElt { + fn default() -> Self { + SparkElt::new() + } +} + +impl SparkElt { + pub fn new() -> Self { + Self { + signature: Signature::variadic_any(Immutable), + } + } +} + +impl ScalarUDFImpl for SparkElt { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "elt" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> { + Ok(Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> { + make_scalar_function(elt, vec![])(&args.args) + } +} + +fn elt(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> { + if args.len() < 2 { + plan_datafusion_err!("elt expects at least 2 arguments: index, value1"); + } + + let n_rows = args[0].len(); + + let idx_i32: Option<&PrimitiveArray<Int32Type>> = + args[0].as_primitive_opt::<Int32Type>(); + let idx_i64: Option<&PrimitiveArray<Int64Type>> = + args[0].as_primitive_opt::<Int64Type>(); + + if idx_i32.is_none() && idx_i64.is_none() { + plan_datafusion_err!( + "elt: first argument must be Int32 or Int64 (got {:?})", + args[0].data_type() + ); + } Review Comment: I suggest defining the signature as `Signature::user_defined` and using `coerce_types()` implementation to ensure the first argument gets coerced to a `Int64Array` so we don't need this logic to handle if it is either an `Int32Array` or `Int64Array` ########## datafusion/spark/src/function/string/elt.rs: ########## @@ -0,0 +1,257 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ + Array, ArrayRef, AsArray, PrimitiveArray, StringArray, StringBuilder, +}; +use arrow::compute::cast; +use arrow::datatypes::DataType::Utf8; +use arrow::datatypes::{DataType, Int32Type, Int64Type}; +use datafusion_common::{plan_datafusion_err, DataFusionError, Result}; +use datafusion_expr::Volatility::Immutable; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature}; +use datafusion_functions::utils::make_scalar_function; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkElt { + signature: Signature, +} + +impl Default for SparkElt { + fn default() -> Self { + SparkElt::new() + } +} + +impl SparkElt { + pub fn new() -> Self { + Self { + signature: Signature::variadic_any(Immutable), + } + } +} + +impl ScalarUDFImpl for SparkElt { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "elt" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> { + Ok(Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> { + make_scalar_function(elt, vec![])(&args.args) + } +} + +fn elt(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> { + if args.len() < 2 { + plan_datafusion_err!("elt expects at least 2 arguments: index, value1"); + } + + let n_rows = args[0].len(); + + let idx_i32: Option<&PrimitiveArray<Int32Type>> = + args[0].as_primitive_opt::<Int32Type>(); + let idx_i64: Option<&PrimitiveArray<Int64Type>> = + args[0].as_primitive_opt::<Int64Type>(); + + if idx_i32.is_none() && idx_i64.is_none() { + plan_datafusion_err!( + "elt: first argument must be Int32 or Int64 (got {:?})", + args[0].data_type() + ); + } + + let k: usize = args.len() - 1; + let mut cols: Vec<Arc<StringArray>> = Vec::with_capacity(k); + for a in args.iter().skip(1) { + let casted = cast(a, &Utf8)?; + let sa = casted + .as_any() + .downcast_ref::<StringArray>() + .ok_or_else(|| DataFusionError::Internal("downcast Utf8 failed".into()))? + .clone(); + cols.push(Arc::new(sa)); + } + + let mut builder = StringBuilder::new(); + + for i in 0..n_rows { + let n_opt: Option<i64> = if let Some(idx) = idx_i32 { + if idx.is_null(i) { + None + } else { + Some(idx.value(i) as i64) + } + } else { + let idx = idx_i64.unwrap(); + if idx.is_null(i) { + None + } else { + Some(idx.value(i)) + } + }; + + let Some(n) = n_opt else { + builder.append_null(); + continue; + }; + + let ansi_enable: bool = false; + + if n < 1 || (n as usize) > k { + if !ansi_enable { + builder.append_null(); + continue; + } else { + plan_datafusion_err!("ArrayIndexOutOfBoundsException"); + } + } + + let j = (n as usize) - 1; + let col = &cols[j]; + + if col.is_null(i) { + builder.append_null(); + } else { + builder.append_value(col.value(i)); + } + } + + Ok(Arc::new(builder.finish()) as ArrayRef) +} + +#[cfg(test)] +mod tests { + use arrow::array::{Int32Array, Int64Array}; + use datafusion_common::Result; + + use super::*; + + fn run_elt_arrays(arrs: Vec<ArrayRef>) -> Result<ArrayRef> { + elt(&arrs) + } + + #[test] + fn elt_utf8_basic() -> Result<()> { + let idx = Arc::new(Int32Array::from(vec![ + Some(1), + Some(2), + Some(3), + Some(4), + Some(0), + None, + ])); + let v1 = Arc::new(StringArray::from(vec![ + Some("a1"), + Some("a2"), + Some("a3"), + Some("a4"), + Some("a5"), + Some("a6"), + ])); + let v2 = Arc::new(StringArray::from(vec![ + Some("b1"), + Some("b2"), + None, + Some("b4"), + Some("b5"), + Some("b6"), + ])); + let v3 = Arc::new(StringArray::from(vec![ + Some("c1"), + Some("c2"), + Some("c3"), + None, + Some("c5"), + Some("c6"), + ])); + + let out = run_elt_arrays(vec![idx, v1, v2, v3])?; + let out = out + .as_any() + .downcast_ref::<StringArray>() + .ok_or_else(|| DataFusionError::Internal("expected Utf8".into()))?; + assert_eq!(out.len(), 6); + assert_eq!(out.value(0), "a1"); + assert_eq!(out.value(1), "b2"); + assert_eq!(out.value(2), "c3"); + assert!(out.is_null(3)); + assert!(out.is_null(4)); + assert!(out.is_null(5)); + Ok(()) + } + + #[test] + fn elt_int64_basic() -> Result<()> { + let idx = Arc::new(Int32Array::from(vec![Some(2), Some(1), Some(2)])); + let v1 = Arc::new(Int64Array::from(vec![Some(10), Some(20), Some(30)])); + let v2 = Arc::new(Int64Array::from(vec![Some(100), None, Some(300)])); + + let out = run_elt_arrays(vec![idx, v1, v2])?; + let out = out + .as_any() + .downcast_ref::<StringArray>() + .ok_or_else(|| DataFusionError::Internal("expected Utf8".into()))?; + assert_eq!(out.len(), 3); + assert_eq!(out.value(0), "100"); + assert_eq!(out.value(1), "20"); + assert_eq!(out.value(2), "300"); + Ok(()) + } + + #[test] + fn elt_out_of_range_all_null() -> Result<()> { + let idx = Arc::new(Int32Array::from(vec![Some(5), Some(-1), Some(0)])); + let v1 = Arc::new(StringArray::from(vec![Some("x"), Some("y"), Some("z")])); + let v2 = Arc::new(StringArray::from(vec![Some("a"), Some("b"), Some("c")])); + + let out = run_elt_arrays(vec![idx, v1, v2])?; + let out = out + .as_any() + .downcast_ref::<StringArray>() + .ok_or_else(|| DataFusionError::Internal("expected Utf8".into()))?; + assert!(out.is_null(0)); + assert!(out.is_null(1)); + assert!(out.is_null(2)); + Ok(()) + } + + #[test] + fn elt_utf8_returns_utf8view() -> Result<()> { Review Comment: This test name seems incorrect as it is not returning a `Utf8View` ########## datafusion/spark/src/function/string/elt.rs: ########## @@ -0,0 +1,257 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ + Array, ArrayRef, AsArray, PrimitiveArray, StringArray, StringBuilder, +}; +use arrow::compute::cast; +use arrow::datatypes::DataType::Utf8; +use arrow::datatypes::{DataType, Int32Type, Int64Type}; +use datafusion_common::{plan_datafusion_err, DataFusionError, Result}; +use datafusion_expr::Volatility::Immutable; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature}; +use datafusion_functions::utils::make_scalar_function; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkElt { + signature: Signature, +} + +impl Default for SparkElt { + fn default() -> Self { + SparkElt::new() + } +} + +impl SparkElt { + pub fn new() -> Self { + Self { + signature: Signature::variadic_any(Immutable), + } + } +} + +impl ScalarUDFImpl for SparkElt { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "elt" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> { + Ok(Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> { + make_scalar_function(elt, vec![])(&args.args) + } +} + +fn elt(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> { + if args.len() < 2 { + plan_datafusion_err!("elt expects at least 2 arguments: index, value1"); + } + + let n_rows = args[0].len(); + + let idx_i32: Option<&PrimitiveArray<Int32Type>> = + args[0].as_primitive_opt::<Int32Type>(); + let idx_i64: Option<&PrimitiveArray<Int64Type>> = + args[0].as_primitive_opt::<Int64Type>(); + + if idx_i32.is_none() && idx_i64.is_none() { + plan_datafusion_err!( + "elt: first argument must be Int32 or Int64 (got {:?})", + args[0].data_type() + ); + } + + let k: usize = args.len() - 1; + let mut cols: Vec<Arc<StringArray>> = Vec::with_capacity(k); + for a in args.iter().skip(1) { + let casted = cast(a, &Utf8)?; + let sa = casted + .as_any() + .downcast_ref::<StringArray>() + .ok_or_else(|| DataFusionError::Internal("downcast Utf8 failed".into()))? + .clone(); + cols.push(Arc::new(sa)); + } + + let mut builder = StringBuilder::new(); + + for i in 0..n_rows { + let n_opt: Option<i64> = if let Some(idx) = idx_i32 { + if idx.is_null(i) { + None + } else { + Some(idx.value(i) as i64) + } + } else { + let idx = idx_i64.unwrap(); + if idx.is_null(i) { + None + } else { + Some(idx.value(i)) + } + }; + + let Some(n) = n_opt else { + builder.append_null(); + continue; + }; + + let ansi_enable: bool = false; + + if n < 1 || (n as usize) > k { + if !ansi_enable { + builder.append_null(); + continue; + } else { + plan_datafusion_err!("ArrayIndexOutOfBoundsException"); + } + } + + let j = (n as usize) - 1; + let col = &cols[j]; + + if col.is_null(i) { + builder.append_null(); + } else { + builder.append_value(col.value(i)); + } + } + + Ok(Arc::new(builder.finish()) as ArrayRef) +} + +#[cfg(test)] +mod tests { + use arrow::array::{Int32Array, Int64Array}; + use datafusion_common::Result; + + use super::*; + + fn run_elt_arrays(arrs: Vec<ArrayRef>) -> Result<ArrayRef> { + elt(&arrs) + } Review Comment: Should downcast to `StringArray` here to simplify the tests below ########## datafusion/spark/src/function/string/elt.rs: ########## @@ -0,0 +1,257 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ + Array, ArrayRef, AsArray, PrimitiveArray, StringArray, StringBuilder, +}; +use arrow::compute::cast; +use arrow::datatypes::DataType::Utf8; +use arrow::datatypes::{DataType, Int32Type, Int64Type}; +use datafusion_common::{plan_datafusion_err, DataFusionError, Result}; +use datafusion_expr::Volatility::Immutable; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature}; +use datafusion_functions::utils::make_scalar_function; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkElt { + signature: Signature, +} + +impl Default for SparkElt { + fn default() -> Self { + SparkElt::new() + } +} + +impl SparkElt { + pub fn new() -> Self { + Self { + signature: Signature::variadic_any(Immutable), + } + } +} + +impl ScalarUDFImpl for SparkElt { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "elt" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> { + Ok(Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> { + make_scalar_function(elt, vec![])(&args.args) + } +} + +fn elt(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> { + if args.len() < 2 { + plan_datafusion_err!("elt expects at least 2 arguments: index, value1"); + } + + let n_rows = args[0].len(); + + let idx_i32: Option<&PrimitiveArray<Int32Type>> = + args[0].as_primitive_opt::<Int32Type>(); + let idx_i64: Option<&PrimitiveArray<Int64Type>> = + args[0].as_primitive_opt::<Int64Type>(); + + if idx_i32.is_none() && idx_i64.is_none() { + plan_datafusion_err!( + "elt: first argument must be Int32 or Int64 (got {:?})", + args[0].data_type() + ); + } + + let k: usize = args.len() - 1; + let mut cols: Vec<Arc<StringArray>> = Vec::with_capacity(k); + for a in args.iter().skip(1) { + let casted = cast(a, &Utf8)?; + let sa = casted + .as_any() + .downcast_ref::<StringArray>() + .ok_or_else(|| DataFusionError::Internal("downcast Utf8 failed".into()))? Review Comment: We could use `as_string_array()` to avoid needing error here: ```rust let sa = casted.as_string::<i32>().clone(); ``` Though I do wonder if it is possible to achieve this function without needing to downcast all input arrays first 🤔 ########## datafusion/spark/src/function/string/elt.rs: ########## @@ -0,0 +1,257 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ + Array, ArrayRef, AsArray, PrimitiveArray, StringArray, StringBuilder, +}; +use arrow::compute::cast; +use arrow::datatypes::DataType::Utf8; +use arrow::datatypes::{DataType, Int32Type, Int64Type}; +use datafusion_common::{plan_datafusion_err, DataFusionError, Result}; +use datafusion_expr::Volatility::Immutable; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature}; +use datafusion_functions::utils::make_scalar_function; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkElt { + signature: Signature, +} + +impl Default for SparkElt { + fn default() -> Self { + SparkElt::new() + } +} + +impl SparkElt { + pub fn new() -> Self { + Self { + signature: Signature::variadic_any(Immutable), + } + } +} + +impl ScalarUDFImpl for SparkElt { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "elt" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> { + Ok(Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> { + make_scalar_function(elt, vec![])(&args.args) + } +} + +fn elt(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> { + if args.len() < 2 { + plan_datafusion_err!("elt expects at least 2 arguments: index, value1"); + } + + let n_rows = args[0].len(); + + let idx_i32: Option<&PrimitiveArray<Int32Type>> = + args[0].as_primitive_opt::<Int32Type>(); + let idx_i64: Option<&PrimitiveArray<Int64Type>> = + args[0].as_primitive_opt::<Int64Type>(); + + if idx_i32.is_none() && idx_i64.is_none() { + plan_datafusion_err!( + "elt: first argument must be Int32 or Int64 (got {:?})", + args[0].data_type() + ); + } + + let k: usize = args.len() - 1; + let mut cols: Vec<Arc<StringArray>> = Vec::with_capacity(k); + for a in args.iter().skip(1) { + let casted = cast(a, &Utf8)?; + let sa = casted + .as_any() + .downcast_ref::<StringArray>() + .ok_or_else(|| DataFusionError::Internal("downcast Utf8 failed".into()))? + .clone(); + cols.push(Arc::new(sa)); + } + + let mut builder = StringBuilder::new(); + + for i in 0..n_rows { + let n_opt: Option<i64> = if let Some(idx) = idx_i32 { + if idx.is_null(i) { + None + } else { + Some(idx.value(i) as i64) + } + } else { + let idx = idx_i64.unwrap(); + if idx.is_null(i) { + None + } else { + Some(idx.value(i)) + } + }; + + let Some(n) = n_opt else { + builder.append_null(); + continue; + }; + + let ansi_enable: bool = false; + + if n < 1 || (n as usize) > k { + if !ansi_enable { + builder.append_null(); + continue; + } else { + plan_datafusion_err!("ArrayIndexOutOfBoundsException"); + } + } + + let j = (n as usize) - 1; + let col = &cols[j]; Review Comment: We need more descriptive variable names than `n` and `j` ########## datafusion/spark/src/function/string/mod.rs: ########## @@ -46,6 +48,11 @@ pub mod expr_fn { "Returns the ASCII character having the binary equivalent to col. If col is larger than 256 the result is equivalent to char(col % 256).", arg1 )); + export_functions!(( + elt, + "Returns the n-th input, e.g., returns input2 when n is 2. The function returns NULL if the index exceeds the length of the array and spark.sql.ansi.enabled is set to false.", Review Comment: Do we have a way to set `spark.sql.ansi.enabled`? If not then I don't think we include that detail in the description -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
