This is an automated email from the ASF dual-hosted git repository.

kou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new 22384d8ad9 GH-49753: [C++][Gandiva] Fix overflow in string functions 
(#49813)
22384d8ad9 is described below

commit 22384d8ad9d11a2a9f9bb32cf9acb341099335be
Author: Abel Thomas <[email protected]>
AuthorDate: Wed Jun 17 04:02:00 2026 +0200

    GH-49753: [C++][Gandiva] Fix overflow in string functions (#49813)
    
    ### Rationale for this change
    Fix the overflow in functions where strings are used
    
    ### What changes are included in this PR?
    Fixes overflow and handles negative lengths
    
    ### Are these changes tested?
    Yes, modified relevant google tests and ran them locally.
    
    ### Are there any user-facing changes?
    No
    
    **This PR contains a "Critical Fix".** (If the changes fix either (a) a 
security vulnerability, (b) a bug that caused incorrect or invalid data to be 
produced, or (c) a bug that causes a crash (even when the API contract is 
upheld), please provide explanation. If not, you can remove this.)
    
    * GitHub Issue: #49753
    
    Lead-authored-by: Abel Tom <[email protected]>
    Co-authored-by: Sutou Kouhei <[email protected]>
    Signed-off-by: Sutou Kouhei <[email protected]>
---
 cpp/src/gandiva/gdv_function_stubs_test.cc     |  73 +++++
 cpp/src/gandiva/gdv_string_function_stubs.cc   |  93 ++++--
 cpp/src/gandiva/precompiled/string_ops.cc      | 380 ++++++++++++-------------
 cpp/src/gandiva/precompiled/string_ops_test.cc |  94 ++++++
 4 files changed, 422 insertions(+), 218 deletions(-)

diff --git a/cpp/src/gandiva/gdv_function_stubs_test.cc 
b/cpp/src/gandiva/gdv_function_stubs_test.cc
index 2eb43689d8..3067a5f275 100644
--- a/cpp/src/gandiva/gdv_function_stubs_test.cc
+++ b/cpp/src/gandiva/gdv_function_stubs_test.cc
@@ -589,6 +589,10 @@ TEST(TestGdvFnStubs, TestSubstringIndex) {
                                    std::numeric_limits<int32_t>::min(), 
&out_len);
   EXPECT_EQ(std::string(out_str, out_len), "a.b.c");
   EXPECT_FALSE(ctx.has_error());
+
+  out_str = gdv_fn_substring_index(ctx_ptr, "a", -2, ".", -1, -50, &out_len);
+  EXPECT_STREQ(out_str, "");
+  EXPECT_EQ(out_len, 0);
 }
 
 TEST(TestGdvFnStubs, TestUpper) {
@@ -642,6 +646,26 @@ TEST(TestGdvFnStubs, TestUpper) {
   EXPECT_THAT(ctx.get_error(),
               ::testing::HasSubstr(
                   "unexpected byte \\c3 encountered while decoding utf8 
string"));
+
+  ctx.Reset();
+
+  // Max Len Test
+  out_len = -1;
+  int32_t bad_len = std::numeric_limits<int32_t>::max() / 2 + 1;
+  const char* out = gdv_fn_upper_utf8(ctx_ptr, "dummy", bad_len, &out_len);
+  // Expect failure
+  EXPECT_EQ(out_len, 0);
+  EXPECT_STREQ(out, "");
+  EXPECT_THAT(ctx.get_error(),
+              ::testing::HasSubstr("Would overflow maximum output size"));
+  ctx.Reset();
+
+  // Negative length test
+  out_len = -1;
+  out = gdv_fn_upper_utf8(ctx_ptr, "abc", -105, &out_len);
+  EXPECT_EQ(out_len, 0);
+  EXPECT_STREQ(out, "");
+  EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid (negative) data 
length"));
   ctx.Reset();
 
   std::string e(
@@ -699,6 +723,26 @@ TEST(TestGdvFnStubs, TestLower) {
   out_str = gdv_fn_lower_utf8(ctx_ptr, "", 0, &out_len);
   EXPECT_EQ(std::string(out_str, out_len), "");
   EXPECT_FALSE(ctx.has_error());
+  ctx.Reset();
+
+  // Max Len Test
+  out_len = -1;
+  int32_t bad_len = std::numeric_limits<int32_t>::max() / 2 + 1;
+  const char* out = gdv_fn_lower_utf8(ctx_ptr, "dummy", bad_len, &out_len);
+  // Expect failure
+  EXPECT_EQ(out_len, 0);
+  EXPECT_STREQ(out, "");
+  EXPECT_THAT(ctx.get_error(),
+              ::testing::HasSubstr("Would overflow maximum output size"));
+  ctx.Reset();
+
+  // Negative length test
+  out_len = -1;
+  out = gdv_fn_lower_utf8(ctx_ptr, "abc", -105, &out_len);
+  EXPECT_EQ(out_len, 0);
+  EXPECT_STREQ(out, "");
+  EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid (negative) data 
length"));
+  ctx.Reset();
 
   std::string d("AbOJjÜoß\xc3");
   out_str = gdv_fn_lower_utf8(ctx_ptr, d.data(), static_cast<int>(d.length()), 
&out_len);
@@ -798,6 +842,25 @@ TEST(TestGdvFnStubs, TestInitCap) {
                   "unexpected byte \\c3 encountered while decoding utf8 
string"));
   ctx.Reset();
 
+  // Max Len Test
+  out_len = -1;
+  int32_t bad_len = std::numeric_limits<int32_t>::max() / 2 + 1;
+  const char* out = gdv_fn_initcap_utf8(ctx_ptr, "dummy", bad_len, &out_len);
+  // Expect failure
+  EXPECT_EQ(out_len, 0);
+  EXPECT_STREQ(out, "");
+  EXPECT_THAT(ctx.get_error(),
+              ::testing::HasSubstr("Would overflow maximum output size"));
+  ctx.Reset();
+
+  // Negative length test
+  out_len = -1;
+  out = gdv_fn_initcap_utf8(ctx_ptr, "abc", -105, &out_len);
+  EXPECT_EQ(out_len, 0);
+  EXPECT_STREQ(out, "");
+  EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid (negative) data 
length"));
+  ctx.Reset();
+
   std::string e(
       "åbÑg\xe0\xa0"
       "åBUå");
@@ -1129,6 +1192,16 @@ TEST(TestGdvFnStubs, TestTranslate) {
   result = translate_utf8_utf8_utf8(ctx_ptr, "987654321", 9, "123456789", 9, 
"0123456789",
                                     10, &out_len);
   EXPECT_EQ(expected, std::string(result, out_len));
+
+  int32_t bad_in_len = std::numeric_limits<int32_t>::max() / 4 + 1;
+  out_len = -1;
+  const unsigned char bad_in_array[] = {0x80, 0x12, 0x13, 0x14};
+  result = translate_utf8_utf8_utf8(ctx_ptr, reinterpret_cast<const 
char*>(bad_in_array),
+                                    bad_in_len, "B", 1, "C", 1, &out_len);
+  EXPECT_EQ(out_len, 0);
+  EXPECT_STREQ(result, "");
+  EXPECT_THAT(ctx.get_error(),
+              ::testing::HasSubstr("Would overflow maximum output size"));
 }
 
 TEST(TestGdvFnStubs, TestToUtcTimezone) {
diff --git a/cpp/src/gandiva/gdv_string_function_stubs.cc 
b/cpp/src/gandiva/gdv_string_function_stubs.cc
index d271834fb4..b3159a2d74 100644
--- a/cpp/src/gandiva/gdv_string_function_stubs.cc
+++ b/cpp/src/gandiva/gdv_string_function_stubs.cc
@@ -213,6 +213,25 @@ int32_t gdv_fn_utf8_char_length(char c) {
   return 0;
 }
 
+static inline bool is_datalen_valid(int64_t context, int32_t data_len, 
int32_t* alloc_len,
+                                    int32_t* out_len) {
+  // Reject negative lengths
+  if (ARROW_PREDICT_FALSE(data_len < 0)) {
+    gdv_fn_context_set_error_msg(context, "Invalid (negative) data length");
+    *out_len = 0;
+    return false;
+  }
+
+  // Check overflow: 2 * data_len
+  if (ARROW_PREDICT_FALSE(
+          arrow::internal::MultiplyWithOverflow(2, data_len, alloc_len))) {
+    gdv_fn_context_set_error_msg(context, "Would overflow maximum output 
size");
+    *out_len = 0;
+    return false;
+  }
+  return true;
+}
+
 // Convert an utf8 string to its corresponding lowercase string
 GANDIVA_EXPORT
 const char* gdv_fn_lower_utf8(int64_t context, const char* data, int32_t 
data_len,
@@ -222,10 +241,15 @@ const char* gdv_fn_lower_utf8(int64_t context, const 
char* data, int32_t data_le
     return "";
   }
 
+  int32_t alloc_length = 0;
+  if (ARROW_PREDICT_FALSE(!is_datalen_valid(context, data_len, &alloc_length, 
out_len))) {
+    return "";
+  }
+
   // If it is a single-byte character (ASCII), corresponding lowercase is 
always 1-byte
   // long; if it is >= 2 bytes long, lowercase can be at most 4 bytes long, so 
length of
   // the output can be at most twice the length of the input
-  char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, 2 * 
data_len));
+  char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, 
alloc_length));
   if (out == nullptr) {
     gdv_fn_context_set_error_msg(context, "Could not allocate memory for 
output string");
     *out_len = 0;
@@ -294,10 +318,15 @@ const char* gdv_fn_upper_utf8(int64_t context, const 
char* data, int32_t data_le
     return "";
   }
 
+  int32_t alloc_length = 0;
+  if (ARROW_PREDICT_FALSE(!is_datalen_valid(context, data_len, &alloc_length, 
out_len))) {
+    return "";
+  }
+
   // If it is a single-byte character (ASCII), corresponding uppercase is 
always 1-byte
   // long; if it is >= 2 bytes long, uppercase can be at most 4 bytes long, so 
length of
   // the output can be at most twice the length of the input
-  char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, 2 * 
data_len));
+  char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, 
alloc_length));
   if (out == nullptr) {
     gdv_fn_context_set_error_msg(context, "Could not allocate memory for 
output string");
     *out_len = 0;
@@ -367,6 +396,17 @@ const char* gdv_fn_substring_index(int64_t context, const 
char* txt, int32_t txt
     return "";
   }
 
+  if (ARROW_PREDICT_FALSE(txt_len < 0)) {
+    gdv_fn_context_set_error_msg(context, "Input string length cannot be 
negative");
+    *out_len = 0;
+    return "";
+  }
+  if (ARROW_PREDICT_FALSE(pat_len < 0)) {
+    gdv_fn_context_set_error_msg(context, "Pattern string length cannot be 
negative");
+    *out_len = 0;
+    return "";
+  }
+
   char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, 
txt_len));
   if (out == nullptr) {
     gdv_fn_context_set_error_msg(context, "Could not allocate memory for 
output string");
@@ -445,8 +485,8 @@ const char* gdv_fn_substring_index(int64_t context, const 
char* txt, int32_t txt
     return out;
 
   } else {
+    memcpy(out, txt, static_cast<size_t>(txt_len));
     *out_len = txt_len;
-    memcpy(out, txt, txt_len);
     return out;
   }
 }
@@ -480,10 +520,15 @@ const char* gdv_fn_initcap_utf8(int64_t context, const 
char* data, int32_t data_
     return "";
   }
 
+  int32_t alloc_length = 0;
+  if (ARROW_PREDICT_FALSE(!is_datalen_valid(context, data_len, &alloc_length, 
out_len))) {
+    return "";
+  }
+
   // If it is a single-byte character (ASCII), corresponding uppercase is 
always 1-byte
   // long; if it is >= 2 bytes long, uppercase can be at most 4 bytes long, so 
length of
   // the output can be at most twice the length of the input
-  char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, 2 * 
data_len));
+  char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, 
alloc_length));
   if (out == nullptr) {
     gdv_fn_context_set_error_msg(context, "Could not allocate memory for 
output string");
     *out_len = 0;
@@ -579,15 +624,17 @@ const char* translate_utf8_utf8_utf8(int64_t context, 
const char* in, int32_t in
     return in;
   }
 
+  int32_t alloc_length = 0;
+
   // This variable is to control if there are multi-byte utf8 entries
   bool has_multi_byte = false;
 
   // This variable is to store the final result
   char* result;
-  int result_len;
+  int32_t result_len;
 
   // Searching multi-bytes in In
-  for (int i = 0; i < in_len; i++) {
+  for (int32_t i = 0; i < in_len; i++) {
     unsigned char char_single_byte = in[i];
     if (char_single_byte > 127) {
       // found a multi-byte utf-8 char
@@ -598,7 +645,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const 
char* in, int32_t in
 
   // Searching multi-bytes in From
   if (!has_multi_byte) {
-    for (int i = 0; i < from_len; i++) {
+    for (int32_t i = 0; i < from_len; i++) {
       unsigned char char_single_byte = from[i];
       if (char_single_byte > 127) {
         // found a multi-byte utf-8 char
@@ -610,7 +657,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const 
char* in, int32_t in
 
   // Searching multi-bytes in To
   if (!has_multi_byte) {
-    for (int i = 0; i < to_len; i++) {
+    for (int32_t i = 0; i < to_len; i++) {
       unsigned char char_single_byte = to[i];
       if (char_single_byte > 127) {
         // found a multi-byte utf-8 char
@@ -638,7 +685,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const 
char* in, int32_t in
 
     // This variable is for controlling the position in entry TO, for never 
repeat the
     // changes
-    int start_compare;
+    int32_t start_compare;
 
     if (to_len > 0) {
       start_compare = 0;
@@ -650,7 +697,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const 
char* in, int32_t in
     // list, to mark deletion positions
     const char empty = '\0';
 
-    for (int in_for = 0; in_for < in_len; in_for++) {
+    for (int32_t in_for = 0; in_for < in_len; in_for++) {
       if (subs_list.find(in[in_for]) != subs_list.end()) {
         if (subs_list[in[in_for]] != empty) {
           // If exist in map, only add the correspondent value in result
@@ -658,7 +705,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const 
char* in, int32_t in
           result_len++;
         }
       } else {
-        for (int from_for = 0; from_for <= from_len; from_for++) {
+        for (int32_t from_for = 0; from_for <= from_len; from_for++) {
           if (from_for == from_len) {
             // If it's not in the FROM list, just add it to the map and the 
result.
             subs_list.insert(std::pair<char, char>(in[in_for], in[in_for]));
@@ -686,10 +733,18 @@ const char* translate_utf8_utf8_utf8(int64_t context, 
const char* in, int32_t in
         }
       }
     }
-  } else {  // If there are no multibytes in the input, work with std::strings
+  } else {
+    // Check overflow: 4 * in_len
+    if (ARROW_PREDICT_FALSE(
+            arrow::internal::MultiplyWithOverflow(4, in_len, &alloc_length))) {
+      gdv_fn_context_set_error_msg(context, "Would overflow maximum output 
size");
+      *out_len = 0;
+      return "";
+    }
+    // If there are multibytes in the input, work with std::strings
     // This variable is for receive the substitutions, malloc is in_len * 4 to 
receive
     // possible inputs with 4 bytes
-    result = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, 
in_len * 4));
+    result = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, 
alloc_length));
 
     if (result == nullptr) {
       gdv_fn_context_set_error_msg(context,
@@ -704,7 +759,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const 
char* in, int32_t in
 
     // This variable is for controlling the position in entry TO, for never 
repeat the
     // changes
-    int start_compare;
+    int32_t start_compare;
 
     if (to_len > 0) {
       start_compare = 0;
@@ -717,11 +772,11 @@ const char* translate_utf8_utf8_utf8(int64_t context, 
const char* in, int32_t in
     const std::string empty = "";
 
     // This variables is to control len of multi-bytes entries
-    int len_char_in = 0;
-    int len_char_from = 0;
-    int len_char_to = 0;
+    int32_t len_char_in = 0;
+    int32_t len_char_from = 0;
+    int32_t len_char_to = 0;
 
-    for (int in_for = 0; in_for < in_len; in_for += len_char_in) {
+    for (int32_t in_for = 0; in_for < in_len; in_for += len_char_in) {
       // Updating len to char in this position
       len_char_in = gdv_fn_utf8_char_length(in[in_for]);
       // Making copy to std::string with length for this char position
@@ -734,7 +789,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const 
char* in, int32_t in
           result_len += static_cast<int>(subs_list[insert_copy_key].length());
         }
       } else {
-        for (int from_for = 0; from_for <= from_len; from_for += 
len_char_from) {
+        for (int32_t from_for = 0; from_for <= from_len; from_for += 
len_char_from) {
           // Updating len to char in this position
           len_char_from = gdv_fn_utf8_char_length(from[from_for]);
           // Making copy to std::string with length for this char position
diff --git a/cpp/src/gandiva/precompiled/string_ops.cc 
b/cpp/src/gandiva/precompiled/string_ops.cc
index ff0fa026c8..6c9c392a82 100644
--- a/cpp/src/gandiva/precompiled/string_ops.cc
+++ b/cpp/src/gandiva/precompiled/string_ops.cc
@@ -1961,11 +1961,29 @@ const char* quote_utf8(gdv_int64 context, const char* 
in, gdv_int32 in_len,
     *out_len = 0;
     return "";
   }
+
+  int32_t double_len = 0;
+  // Test multiply overflow for in_len
+  if (ARROW_PREDICT_FALSE(
+          arrow::internal::MultiplyWithOverflow(2, in_len, &double_len))) {
+    gdv_fn_context_set_error_msg(context, "Memory allocation size too large.");
+    *out_len = 0;
+    return "";
+  }
+
+  int32_t alloc_length = 0;
+  // Test add overflow for in_len
+  if (ARROW_PREDICT_FALSE(
+          arrow::internal::AddWithOverflow(2, double_len, &alloc_length))) {
+    gdv_fn_context_set_error_msg(context, "Memory allocation size too large.");
+    *out_len = 0;
+    return "";
+  }
+
   // try to allocate double size output string (worst case)
-  auto out =
-      reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, (in_len * 
2) + 2));
+  auto out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, 
alloc_length));
   if (out == nullptr) {
-    gdv_fn_context_set_error_msg(context, "Could not allocate memory for 
output string");
+    gdv_fn_context_set_error_msg(context, "Could not allocate memory for 
output string.");
     *out_len = 0;
     return "";
   }
@@ -2494,118 +2512,166 @@ void concat_word(char* out_buf, int* out_idx, const 
char* in_buf, int in_len,
   *out_idx += in_len;
 }
 
-FORCE_INLINE
-const char* concat_ws_utf8_utf8(int64_t context, const char* separator,
-                                int32_t separator_len, bool separator_validity,
-                                const char* word1, int32_t word1_len, bool 
word1_validity,
-                                const char* word2, int32_t word2_len, bool 
word2_validity,
-                                bool* out_valid, int32_t* out_len) {
-  *out_len = 0;
-  int numValidInput = 0;
-  // If separator is null, always return null
-  if (!separator_validity) {
-    *out_len = 0;
-    *out_valid = false;
-    return "";
-  }
+// Helper structure to maintain state during safe length accumulation
+struct SafeLengthState {
+  int32_t total_len = 0;
+  int32_t num_valid = 0;
+  bool overflow = false;
+};
+
+// Helper to safely add a word length
+static inline bool safe_accumulate_word(int64_t context, SafeLengthState& 
state,
+                                        int32_t word_len, bool word_validity) {
+  if (!word_validity) return true;
 
-  if (word1_validity) {
-    *out_len += word1_len;
-    numValidInput++;
+  if (word_len < 0) {
+    gdv_fn_context_set_error_msg(context, "Invalid word length.");
+    return false;
   }
-  if (word2_validity) {
-    *out_len += word2_len;
-    numValidInput++;
+
+  int32_t temp = 0;
+  if (ARROW_PREDICT_FALSE(
+          arrow::internal::AddWithOverflow(state.total_len, word_len, &temp))) 
{
+    gdv_fn_context_set_error_msg(context, "Overflow in addition detected.");
+    state.overflow = true;
+    return false;
   }
+  state.total_len = temp;
+  state.num_valid++;
+  return true;
+}
 
-  *out_len += separator_len * (numValidInput > 1 ? numValidInput - 1 : 0);
-  if (*out_len == 0) {
-    *out_valid = true;
-    return "";
+// Helper to safely add separators based on number of valid words
+static inline bool safe_add_separators(int64_t context, SafeLengthState* state,
+                                       int32_t separator_len) {
+  if (state->num_valid <= 1) return true;
+
+  int32_t sep_total = 0;
+  int32_t temp = 0;
+
+  if (ARROW_PREDICT_FALSE(arrow::internal::MultiplyWithOverflow(
+          separator_len, state->num_valid - 1, &sep_total))) {
+    gdv_fn_context_set_error_msg(context, "Overflow in multiplication 
detected.");
+    state->overflow = true;
+    return false;
   }
 
-  char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, 
*out_len));
-  if (out == nullptr) {
-    gdv_fn_context_set_error_msg(context, "Could not allocate memory for 
output string");
-    *out_len = 0;
-    *out_valid = false;
-    return "";
+  if (ARROW_PREDICT_FALSE(
+          arrow::internal::AddWithOverflow(state->total_len, sep_total, 
&temp))) {
+    gdv_fn_context_set_error_msg(context, "Overflow in addition detected.");
+    state->overflow = true;
+    return false;
   }
 
-  char* tmp = out;
-  int out_idx = 0;
-  bool seenAnyValidInput = false;
+  state->total_len = temp;
+  return true;
+}
 
-  concat_word(tmp, &out_idx, word1, word1_len, word1_validity, separator, 
separator_len,
-              &seenAnyValidInput);
-  concat_word(tmp, &out_idx, word2, word2_len, word2_validity, separator, 
separator_len,
-              &seenAnyValidInput);
+// Helper to handle overflow failure (sets output parameters and returns empty 
string)
+static inline const char* handle_overflow_failure(bool* out_valid, int32_t* 
out_len) {
+  *out_len = 0;
+  *out_valid = false;
+  return "";
+}
 
+// Helper to handle empty result (all words invalid)
+static inline const char* handle_empty_result(bool* out_valid, int32_t* 
out_len) {
+  *out_len = 0;
   *out_valid = true;
-  *out_len = out_idx;
-  return out;
+  return "";
 }
 
-FORCE_INLINE
-const char* concat_ws_utf8_utf8_utf8(
-    int64_t context, const char* separator, int32_t separator_len,
-    bool separator_validity, const char* word1, int32_t word1_len, bool 
word1_validity,
-    const char* word2, int32_t word2_len, bool word2_validity, const char* 
word3,
-    int32_t word3_len, bool word3_validity, bool* out_valid, int32_t* out_len) 
{
+struct WordArg {
+  const char* data;
+  int32_t len;
+  bool valid;
+};
+
+static inline const char* concat_ws_impl(int64_t context, const char* 
separator,
+                                         int32_t separator_len, bool 
separator_validity,
+                                         bool* out_valid, int32_t* out_len,
+                                         std::initializer_list<WordArg> words) 
{
   *out_len = 0;
-  int numValidInput = 0;
-  // If separator is null, always return null
+
+  // Separator validity check
   if (!separator_validity) {
-    *out_len = 0;
     *out_valid = false;
     return "";
   }
-
-  if (word1_validity) {
-    *out_len += word1_len;
-    numValidInput++;
-  }
-  if (word2_validity) {
-    *out_len += word2_len;
-    numValidInput++;
+  if (separator_len < 0) {
+    gdv_fn_context_set_error_msg(context, "Separator length cannot be 
negative");
+    *out_valid = false;
+    return "";
   }
-  if (word3_validity) {
-    *out_len += word3_len;
-    numValidInput++;
+
+  SafeLengthState state;
+
+  // Accumulate all word lengths safely
+  for (const WordArg& w : words) {
+    if (!safe_accumulate_word(context, state, w.len, w.valid)) {
+      *out_len = 0;
+      *out_valid = false;
+      return "";
+    }
   }
 
-  *out_len += separator_len * (numValidInput > 1 ? numValidInput - 1 : 0);
+  // Add separator lengths
+  if (!safe_add_separators(context, &state, separator_len)) {
+    return handle_overflow_failure(out_valid, out_len);
+  }
 
-  if (*out_len == 0) {
-    *out_len = 0;
-    *out_valid = true;
-    return "";
+  // Empty result
+  if (state.total_len == 0) {
+    return handle_empty_result(out_valid, out_len);
   }
 
-  char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, 
*out_len));
+  // Allocate memory
+  char* out =
+      reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, 
state.total_len));
   if (out == nullptr) {
     gdv_fn_context_set_error_msg(context, "Could not allocate memory for 
output string");
-    *out_len = 0;
-    *out_valid = false;
-    return "";
+    return handle_overflow_failure(out_valid, out_len);
   }
 
+  // Concatenate all words
   char* tmp = out;
   int out_idx = 0;
   bool seenAnyValidInput = false;
 
-  concat_word(tmp, &out_idx, word1, word1_len, word1_validity, separator, 
separator_len,
-              &seenAnyValidInput);
-  concat_word(tmp, &out_idx, word2, word2_len, word2_validity, separator, 
separator_len,
-              &seenAnyValidInput);
-  concat_word(tmp, &out_idx, word3, word3_len, word3_validity, separator, 
separator_len,
-              &seenAnyValidInput);
+  for (const WordArg& w : words) {
+    concat_word(tmp, &out_idx, w.data, w.len, w.valid, separator, 
separator_len,
+                &seenAnyValidInput);
+  }
 
   *out_valid = true;
   *out_len = out_idx;
   return out;
 }
 
+FORCE_INLINE
+const char* concat_ws_utf8_utf8(int64_t context, const char* separator,
+                                int32_t separator_len, bool separator_validity,
+                                const char* word1, int32_t word1_len, bool 
word1_validity,
+                                const char* word2, int32_t word2_len, bool 
word2_validity,
+                                bool* out_valid, int32_t* out_len) {
+  return concat_ws_impl(
+      context, separator, separator_len, separator_validity, out_valid, 
out_len,
+      {{word1, word1_len, word1_validity}, {word2, word2_len, 
word2_validity}});
+}
+
+FORCE_INLINE
+const char* concat_ws_utf8_utf8_utf8(
+    int64_t context, const char* separator, int32_t separator_len,
+    bool separator_validity, const char* word1, int32_t word1_len, bool 
word1_validity,
+    const char* word2, int32_t word2_len, bool word2_validity, const char* 
word3,
+    int32_t word3_len, bool word3_validity, bool* out_valid, int32_t* out_len) 
{
+  return concat_ws_impl(context, separator, separator_len, separator_validity, 
out_valid,
+                        out_len,
+                        {{word1, word1_len, word1_validity},
+                         {word2, word2_len, word2_validity},
+                         {word3, word3_len, word3_validity}});
+}
+
 FORCE_INLINE
 const char* concat_ws_utf8_utf8_utf8_utf8(
     int64_t context, const char* separator, int32_t separator_len,
@@ -2613,63 +2679,12 @@ const char* concat_ws_utf8_utf8_utf8_utf8(
     const char* word2, int32_t word2_len, bool word2_validity, const char* 
word3,
     int32_t word3_len, bool word3_validity, const char* word4, int32_t 
word4_len,
     bool word4_validity, bool* out_valid, int32_t* out_len) {
-  *out_len = 0;
-  int numValidInput = 0;
-  // If separator is null, always return null
-  if (!separator_validity) {
-    *out_len = 0;
-    *out_valid = false;
-    return "";
-  }
-  if (word1_validity) {
-    *out_len += word1_len;
-    numValidInput++;
-  }
-  if (word2_validity) {
-    *out_len += word2_len;
-    numValidInput++;
-  }
-  if (word3_validity) {
-    *out_len += word3_len;
-    numValidInput++;
-  }
-  if (word4_validity) {
-    *out_len += word4_len;
-    numValidInput++;
-  }
-
-  *out_len += separator_len * (numValidInput > 1 ? numValidInput - 1 : 0);
-
-  if (*out_len == 0) {
-    *out_len = 0;
-    *out_valid = true;
-    return "";
-  }
-
-  char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, 
*out_len));
-  if (out == nullptr) {
-    gdv_fn_context_set_error_msg(context, "Could not allocate memory for 
output string");
-    *out_valid = false;
-    *out_len = 0;
-    return "";
-  }
-
-  char* tmp = out;
-  int out_idx = 0;
-  bool seenAnyValidInput = false;
-
-  concat_word(tmp, &out_idx, word1, word1_len, word1_validity, separator, 
separator_len,
-              &seenAnyValidInput);
-  concat_word(tmp, &out_idx, word2, word2_len, word2_validity, separator, 
separator_len,
-              &seenAnyValidInput);
-  concat_word(tmp, &out_idx, word3, word3_len, word3_validity, separator, 
separator_len,
-              &seenAnyValidInput);
-  concat_word(tmp, &out_idx, word4, word4_len, word4_validity, separator, 
separator_len,
-              &seenAnyValidInput);
-
-  *out_valid = true;
-  *out_len = out_idx;
-  return out;
+  return concat_ws_impl(context, separator, separator_len, separator_validity, 
out_valid,
+                        out_len,
+                        {{word1, word1_len, word1_validity},
+                         {word2, word2_len, word2_validity},
+                         {word3, word3_len, word3_validity},
+                         {word4, word4_len, word4_validity}});
 }
 
 FORCE_INLINE
@@ -2680,69 +2695,13 @@ const char* concat_ws_utf8_utf8_utf8_utf8_utf8(
     int32_t word3_len, bool word3_validity, const char* word4, int32_t 
word4_len,
     bool word4_validity, const char* word5, int32_t word5_len, bool 
word5_validity,
     bool* out_valid, int32_t* out_len) {
-  *out_len = 0;
-  int numValidInput = 0;
-  // If separator is null, always return null
-  if (!separator_validity) {
-    *out_len = 0;
-    *out_valid = false;
-    return "";
-  }
-  if (word1_validity) {
-    *out_len += word1_len;
-    numValidInput++;
-  }
-  if (word2_validity) {
-    *out_len += word2_len;
-    numValidInput++;
-  }
-  if (word3_validity) {
-    *out_len += word3_len;
-    numValidInput++;
-  }
-  if (word4_validity) {
-    *out_len += word4_len;
-    numValidInput++;
-  }
-  if (word5_validity) {
-    *out_len += word5_len;
-    numValidInput++;
-  }
-
-  *out_len += separator_len * (numValidInput > 1 ? numValidInput - 1 : 0);
-
-  if (*out_len == 0) {
-    *out_len = 0;
-    *out_valid = true;
-    return "";
-  }
-
-  char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, 
*out_len));
-  if (out == nullptr) {
-    gdv_fn_context_set_error_msg(context, "Could not allocate memory for 
output string");
-    *out_len = 0;
-    *out_valid = false;
-    return "";
-  }
-
-  char* tmp = out;
-  int out_idx = 0;
-  bool seenAnyValidInput = false;
-
-  concat_word(tmp, &out_idx, word1, word1_len, word1_validity, separator, 
separator_len,
-              &seenAnyValidInput);
-  concat_word(tmp, &out_idx, word2, word2_len, word2_validity, separator, 
separator_len,
-              &seenAnyValidInput);
-  concat_word(tmp, &out_idx, word3, word3_len, word3_validity, separator, 
separator_len,
-              &seenAnyValidInput);
-  concat_word(tmp, &out_idx, word4, word4_len, word4_validity, separator, 
separator_len,
-              &seenAnyValidInput);
-  concat_word(tmp, &out_idx, word5, word5_len, word5_validity, separator, 
separator_len,
-              &seenAnyValidInput);
-
-  *out_valid = true;
-  *out_len = out_idx;
-  return out;
+  return concat_ws_impl(context, separator, separator_len, separator_validity, 
out_valid,
+                        out_len,
+                        {{word1, word1_len, word1_validity},
+                         {word2, word2_len, word2_validity},
+                         {word3, word3_len, word3_validity},
+                         {word4, word4_len, word4_validity},
+                         {word5, word5_len, word5_validity}});
 }
 
 FORCE_INLINE
@@ -2879,8 +2838,31 @@ const char* to_hex_binary(int64_t context, const char* 
text, int32_t text_len,
     return "";
   }
 
-  auto ret =
-      reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, text_len * 
2 + 1));
+  if (ARROW_PREDICT_FALSE(text_len < 0)) {
+    gdv_fn_context_set_error_msg(context, "Text length invalid (negative).");
+    *out_len = 0;
+    return "";
+  }
+
+  int32_t double_len = 0;
+  // Check multiply overflow for text_len
+  if (ARROW_PREDICT_FALSE(
+          arrow::internal::MultiplyWithOverflow(2, text_len, &double_len))) {
+    gdv_fn_context_set_error_msg(context, "Memory allocation size too large.");
+    *out_len = 0;
+    return "";
+  }
+
+  int32_t alloc_length = 0;
+  // Check add overflow for text_len
+  if (ARROW_PREDICT_FALSE(
+          arrow::internal::AddWithOverflow(1, double_len, &alloc_length))) {
+    gdv_fn_context_set_error_msg(context, "Memory allocation size too large.");
+    *out_len = 0;
+    return "";
+  }
+
+  auto ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, 
alloc_length));
 
   if (ret == nullptr) {
     gdv_fn_context_set_error_msg(context, "Could not allocate memory for 
output string");
diff --git a/cpp/src/gandiva/precompiled/string_ops_test.cc 
b/cpp/src/gandiva/precompiled/string_ops_test.cc
index a31683c65a..5a317d4595 100644
--- a/cpp/src/gandiva/precompiled/string_ops_test.cc
+++ b/cpp/src/gandiva/precompiled/string_ops_test.cc
@@ -1145,26 +1145,46 @@ TEST(TestStringOps, TestQuote) {
   out_str = quote_utf8(ctx_ptr, "dont", 4, &out_len);
   EXPECT_EQ(std::string(out_str, out_len), "\'dont\'");
   EXPECT_FALSE(ctx.has_error());
+  ctx.Reset();
 
   out_str = quote_utf8(ctx_ptr, "abc", 3, &out_len);
   EXPECT_EQ(std::string(out_str, out_len), "\'abc\'");
   EXPECT_FALSE(ctx.has_error());
+  ctx.Reset();
 
   out_str = quote_utf8(ctx_ptr, "don't", 5, &out_len);
   EXPECT_EQ(std::string(out_str, out_len), "\'don\\'t\'");
   EXPECT_FALSE(ctx.has_error());
+  ctx.Reset();
 
   out_str = quote_utf8(ctx_ptr, "", 0, &out_len);
   EXPECT_EQ(std::string(out_str, out_len), "");
   EXPECT_FALSE(ctx.has_error());
+  ctx.Reset();
 
   out_str = quote_utf8(ctx_ptr, "'", 1, &out_len);
   EXPECT_EQ(std::string(out_str, out_len), "'\\''");
   EXPECT_FALSE(ctx.has_error());
+  ctx.Reset();
 
   out_str = quote_utf8(ctx_ptr, "'''''''''", 9, &out_len);
   EXPECT_EQ(std::string(out_str, out_len), "'\\'\\'\\'\\'\\'\\'\\'\\'\\''");
   EXPECT_FALSE(ctx.has_error());
+  ctx.Reset();
+
+  int32_t bad_in_len = std::numeric_limits<int32_t>::max() / 2 + 1;
+  out_str = quote_utf8(ctx_ptr, "YYZ", bad_in_len, &out_len);
+  EXPECT_EQ(ctx.get_error(), "Memory allocation size too large.");
+  EXPECT_EQ(out_len, 0);
+  EXPECT_STREQ(out_str, "");
+  ctx.Reset();
+
+  bad_in_len = std::numeric_limits<int32_t>::max() / 2 + 20;
+  out_str = quote_utf8(ctx_ptr, "ABCDE", bad_in_len, &out_len);
+  EXPECT_EQ(ctx.get_error(), "Memory allocation size too large.");
+  EXPECT_EQ(out_len, 0);
+  EXPECT_STREQ(out_str, "");
+  ctx.Reset();
 }
 
 TEST(TestStringOps, TestLtrim) {
@@ -2298,11 +2318,42 @@ TEST(TestStringOps, TestConcatWs) {
   EXPECT_EQ(std::string(out, out_len), "hey");
   EXPECT_EQ(out_result, true);
 
+  // Max word1 len
+  out = concat_ws_utf8_utf8(ctx_ptr, separator, sep_len, true, word1,
+                            std::numeric_limits<int32_t>::max(), true, word2, 
word2_len,
+                            true, &out_result, &out_len);
+  EXPECT_STREQ(out, "");
+  EXPECT_EQ(out_len, 0);
+  EXPECT_EQ(out_result, false);
+
+  // Max word2 len
+  out = concat_ws_utf8_utf8(ctx_ptr, separator, sep_len, true, word1, 
word1_len, true,
+                            word2, std::numeric_limits<int32_t>::max(), true, 
&out_result,
+                            &out_len);
+  EXPECT_STREQ(out, "");
+  EXPECT_EQ(out_len, 0);
+  EXPECT_EQ(out_result, false);
+
+  // Max separator len
+  out = concat_ws_utf8_utf8(ctx_ptr, separator, 
std::numeric_limits<int32_t>::max(), true,
+                            word1, word1_len, true, word2, word2_len, true, 
&out_result,
+                            &out_len);
+  EXPECT_STREQ(out, "");
+  EXPECT_EQ(out_len, 0);
+  EXPECT_EQ(out_result, false);
+
   separator = "#";
   sep_len = static_cast<int32_t>(strlen(separator));
   const char* word3 = "wow";
   int32_t word3_len = static_cast<int32_t>(strlen(word3));
 
+  out = concat_ws_utf8_utf8_utf8(ctx_ptr, separator, 
std::numeric_limits<int32_t>::max(),
+                                 true, word1, word1_len, true, word2, 
word2_len, true,
+                                 word3, word3_len, true, &out_result, 
&out_len);
+  EXPECT_STREQ(out, "");
+  EXPECT_EQ(out_len, 0);
+  EXPECT_EQ(out_result, false);
+
   out = concat_ws_utf8_utf8_utf8(ctx_ptr, separator, sep_len, true, word1, 
word1_len,
                                  true, word2, word2_len, true, word3, 
word3_len, true,
                                  &out_result, &out_len);
@@ -2344,6 +2395,14 @@ TEST(TestStringOps, TestConcatWs) {
   const char* word4 = "awesome";
   int32_t word4_len = static_cast<int32_t>(strlen(word4));
 
+  out = concat_ws_utf8_utf8_utf8_utf8(ctx_ptr, separator, sep_len, true, word1,
+                                      std::numeric_limits<int32_t>::max(), 
true, word2,
+                                      word2_len, true, word3, word3_len, true, 
word4,
+                                      word4_len, true, &out_result, &out_len);
+  EXPECT_STREQ(out, "");
+  EXPECT_EQ(out_len, 0);
+  EXPECT_EQ(out_result, false);
+
   out = concat_ws_utf8_utf8_utf8_utf8(
       ctx_ptr, separator, sep_len, true, word1, word1_len, true, word2, 
word2_len, true,
       word3, word3_len, true, word4, word4_len, true, &out_result, &out_len);
@@ -2355,6 +2414,20 @@ TEST(TestStringOps, TestConcatWs) {
   const char* word5 = "super";
   int32_t word5_len = static_cast<int32_t>(strlen(word5));
 
+  out = concat_ws_utf8_utf8_utf8_utf8_utf8(
+      ctx_ptr, separator, sep_len, true, word1, word1_len, true, word2, 
word2_len, true,
+      word3, word3_len, true, word4, std::numeric_limits<int32_t>::max(), 
true, word5,
+      std::numeric_limits<int32_t>::max(), true, &out_result, &out_len);
+  EXPECT_STREQ(out, "");
+  EXPECT_EQ(out_result, false);
+
+  out = concat_ws_utf8_utf8_utf8_utf8_utf8(ctx_ptr, separator, sep_len, true, 
word1,
+                                           word1_len, true, word2, word2_len, 
true, word3,
+                                           word3_len, true, word4, -25, true, 
word5,
+                                           word5_len, true, &out_result, 
&out_len);
+  EXPECT_STREQ(out, "");
+  EXPECT_EQ(out_result, false);
+
   out = concat_ws_utf8_utf8_utf8_utf8_utf8(ctx_ptr, separator, sep_len, true, 
word1,
                                            word1_len, true, word2, word2_len, 
true, word3,
                                            word3_len, true, word4, word4_len, 
true, word5,
@@ -2498,6 +2571,27 @@ TEST(TestStringOps, TestToHex) {
   output = std::string(out_str, out_len);
   EXPECT_EQ(out_len, 2 * in_len);
   EXPECT_EQ(output, "090A090A090A090A0A0A092061206C657474405D6572");
+  ctx.Reset();
+
+  int32_t bad_text_len = std::numeric_limits<int32_t>::max() / 2 + 20;
+  out_str = to_hex_binary(ctx_ptr, binary_string, bad_text_len, &out_len);
+  EXPECT_EQ(out_len, 0);
+  EXPECT_STREQ(out_str, "");
+  ctx.Reset();
+
+  bad_text_len = (std::numeric_limits<int32_t>::max() / 2) + 1;
+  out_str = to_hex_binary(ctx_ptr, binary_string, bad_text_len, &out_len);
+  EXPECT_EQ(out_len, 0);
+  EXPECT_STREQ(out_str, "");
+  EXPECT_EQ(ctx.get_error(), "Memory allocation size too large.");
+  ctx.Reset();
+
+  int32_t neg_in_len = -20;
+  out_str = to_hex_binary(ctx_ptr, binary_string, neg_in_len, &out_len);
+  EXPECT_EQ(out_len, 0);
+  EXPECT_STREQ(out_str, "");
+  EXPECT_EQ(ctx.get_error(), "Text length invalid (negative).");
+  ctx.Reset();
 }
 
 TEST(TestStringOps, TestToHexInt64) {


Reply via email to