Jefffrey commented on code in PR #19976:
URL: https://github.com/apache/datafusion/pull/19976#discussion_r2725736559


##########
datafusion/functions/src/string/repeat.rs:
##########
@@ -99,39 +100,121 @@ impl ScalarUDFImpl for RepeatFunc {
     }
 
     fn invoke_with_args(&self, args: ScalarFunctionArgs) -> 
Result<ColumnarValue> {
-        make_scalar_function(repeat, vec![])(&args.args)
+        let [string_arg, count_arg] = take_function_args(self.name(), 
args.args)?;
+
+        // Helper to create null result with correct type (follows 
utf8_to_str_type)
+        let null_result = |dt: &DataType| -> ColumnarValue {
+            let scalar = if matches!(dt, LargeUtf8) {
+                ScalarValue::LargeUtf8(None)
+            } else {
+                ScalarValue::Utf8(None)
+            };
+            ColumnarValue::Scalar(scalar)
+        };
+
+        // Early return if either argument is a scalar null
+        if let ColumnarValue::Scalar(s) = &string_arg
+            && s.is_null()
+        {
+            return Ok(null_result(&s.data_type()));
+        }
+        if let ColumnarValue::Scalar(c) = &count_arg
+            && c.is_null()
+        {
+            let dt = match &string_arg {
+                ColumnarValue::Scalar(s) => s.data_type(),
+                ColumnarValue::Array(a) => a.data_type().clone(),
+            };

Review Comment:
   We can more easily get return type from `ScalarFunctionArgs`



##########
datafusion/functions/src/string/repeat.rs:
##########
@@ -99,39 +100,121 @@ impl ScalarUDFImpl for RepeatFunc {
     }
 
     fn invoke_with_args(&self, args: ScalarFunctionArgs) -> 
Result<ColumnarValue> {
-        make_scalar_function(repeat, vec![])(&args.args)
+        let [string_arg, count_arg] = take_function_args(self.name(), 
args.args)?;
+
+        // Helper to create null result with correct type (follows 
utf8_to_str_type)
+        let null_result = |dt: &DataType| -> ColumnarValue {
+            let scalar = if matches!(dt, LargeUtf8) {
+                ScalarValue::LargeUtf8(None)
+            } else {
+                ScalarValue::Utf8(None)
+            };
+            ColumnarValue::Scalar(scalar)
+        };
+
+        // Early return if either argument is a scalar null
+        if let ColumnarValue::Scalar(s) = &string_arg
+            && s.is_null()
+        {
+            return Ok(null_result(&s.data_type()));
+        }
+        if let ColumnarValue::Scalar(c) = &count_arg
+            && c.is_null()
+        {
+            let dt = match &string_arg {
+                ColumnarValue::Scalar(s) => s.data_type(),
+                ColumnarValue::Array(a) => a.data_type().clone(),
+            };
+            return Ok(null_result(&dt));
+        }
+
+        match (&string_arg, &count_arg) {
+            (
+                ColumnarValue::Scalar(string_scalar),
+                ColumnarValue::Scalar(count_scalar),
+            ) => {
+                let count = match count_scalar {
+                    ScalarValue::Int64(Some(n)) => *n,
+                    _ => {
+                        return internal_err!(
+                            "Unexpected data type {:?} for repeat count",
+                            count_scalar.data_type()
+                        );
+                    }
+                };
+
+                let result = match string_scalar {
+                    ScalarValue::Utf8(Some(s)) | 
ScalarValue::Utf8View(Some(s)) => {
+                        ScalarValue::Utf8(Some(compute_repeat(s, count)?))
+                    }
+                    ScalarValue::LargeUtf8(Some(s)) => {
+                        ScalarValue::LargeUtf8(Some(compute_repeat(s, count)?))
+                    }
+                    _ => {
+                        return internal_err!(
+                            "Unexpected data type {:?} for function repeat",
+                            string_scalar.data_type()
+                        );
+                    }
+                };
+
+                Ok(ColumnarValue::Scalar(result))
+            }
+            _ => {
+                let string_array = string_arg.to_array(args.number_rows)?;
+                let count_array = count_arg.to_array(args.number_rows)?;
+                Ok(ColumnarValue::Array(repeat(&string_array, &count_array)?))
+            }
+        }
     }
 
     fn documentation(&self) -> Option<&Documentation> {
         self.doc()
     }
 }
 
+/// Computes repeat for a single string value
+#[inline]
+fn compute_repeat(s: &str, count: i64) -> Result<String> {
+    if count <= 0 {
+        return Ok(String::new());
+    }
+    let result_len = s.len().saturating_mul(count as usize);
+    if result_len > i32::MAX as usize {
+        return exec_err!(
+            "string size overflow on repeat, max size is {}, but got {}",
+            i32::MAX,

Review Comment:
   Technically this holds only for utf8/utf8view; for largeutf8 we'd check 
against i64 I believe. Can look into making this function generic across offset 
type



##########
datafusion/functions/src/string/repeat.rs:
##########
@@ -181,37 +264,52 @@ where
     // Reusable buffer to avoid allocations in string.repeat()
     let mut buffer = Vec::<u8>::with_capacity(max_item_capacity);
 
-    string_array
-        .iter()
-        .zip(number_array.iter())
-        .for_each(|(string, number)| {
+    // Helper function to repeat a string into a buffer using doubling strategy
+    // count must be > 0
+    #[inline]
+    fn repeat_to_buffer(buffer: &mut Vec<u8>, string: &str, count: usize) {
+        buffer.clear();
+        if !string.is_empty() {
+            let src = string.as_bytes();
+            buffer.extend_from_slice(src);
+            while buffer.len() < src.len() * count {
+                let copy_len = buffer.len().min(src.len() * count - 
buffer.len());
+                buffer.extend_from_within(..copy_len);
+            }
+        }
+    }
+
+    // Fast path: no nulls in either array
+    if string_array.null_count() == 0 && number_array.null_count() == 0 {
+        for i in 0..string_array.len() {
+            // SAFETY: i is within bounds (0..len) and null_count() == 0 
guarantees valid value
+            let string = unsafe { string_array.value_unchecked(i) };
+            let count = number_array.value(i);
+            if count > 0 {
+                repeat_to_buffer(&mut buffer, string, count as usize);
+                // SAFETY: buffer contains valid UTF-8 since we only copy from 
a valid &str
+                builder.append_value(unsafe { 
std::str::from_utf8_unchecked(&buffer) });
+            } else {
+                builder.append_value("");
+            }
+        }
+    } else {
+        // Slow path: handle nulls
+        for (string, number) in string_array.iter().zip(number_array.iter()) {
             match (string, number) {
-                (Some(string), Some(number)) if number >= 0 => {
-                    buffer.clear();
-                    let count = number as usize;
-                    if count > 0 && !string.is_empty() {
-                        let src = string.as_bytes();
-                        // Initial copy
-                        buffer.extend_from_slice(src);
-                        // Doubling strategy: copy what we have so far until 
we reach the target
-                        while buffer.len() < src.len() * count {
-                            let copy_len =
-                                buffer.len().min(src.len() * count - 
buffer.len());
-                            // SAFETY: we're copying valid UTF-8 bytes that we 
already verified

Review Comment:
   It would be nice not to lose these original comments



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to