This is an automated email from the ASF dual-hosted git repository.
kou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 22384d8ad9 GH-49753: [C++][Gandiva] Fix overflow in string functions
(#49813)
22384d8ad9 is described below
commit 22384d8ad9d11a2a9f9bb32cf9acb341099335be
Author: Abel Thomas <[email protected]>
AuthorDate: Wed Jun 17 04:02:00 2026 +0200
GH-49753: [C++][Gandiva] Fix overflow in string functions (#49813)
### Rationale for this change
Fix the overflow in functions where strings are used
### What changes are included in this PR?
Fixes overflow and handles negative lengths
### Are these changes tested?
Yes, modified relevant google tests and ran them locally.
### Are there any user-facing changes?
No
**This PR contains a "Critical Fix".** (If the changes fix either (a) a
security vulnerability, (b) a bug that caused incorrect or invalid data to be
produced, or (c) a bug that causes a crash (even when the API contract is
upheld), please provide explanation. If not, you can remove this.)
* GitHub Issue: #49753
Lead-authored-by: Abel Tom <[email protected]>
Co-authored-by: Sutou Kouhei <[email protected]>
Signed-off-by: Sutou Kouhei <[email protected]>
---
cpp/src/gandiva/gdv_function_stubs_test.cc | 73 +++++
cpp/src/gandiva/gdv_string_function_stubs.cc | 93 ++++--
cpp/src/gandiva/precompiled/string_ops.cc | 380 ++++++++++++-------------
cpp/src/gandiva/precompiled/string_ops_test.cc | 94 ++++++
4 files changed, 422 insertions(+), 218 deletions(-)
diff --git a/cpp/src/gandiva/gdv_function_stubs_test.cc
b/cpp/src/gandiva/gdv_function_stubs_test.cc
index 2eb43689d8..3067a5f275 100644
--- a/cpp/src/gandiva/gdv_function_stubs_test.cc
+++ b/cpp/src/gandiva/gdv_function_stubs_test.cc
@@ -589,6 +589,10 @@ TEST(TestGdvFnStubs, TestSubstringIndex) {
std::numeric_limits<int32_t>::min(),
&out_len);
EXPECT_EQ(std::string(out_str, out_len), "a.b.c");
EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_substring_index(ctx_ptr, "a", -2, ".", -1, -50, &out_len);
+ EXPECT_STREQ(out_str, "");
+ EXPECT_EQ(out_len, 0);
}
TEST(TestGdvFnStubs, TestUpper) {
@@ -642,6 +646,26 @@ TEST(TestGdvFnStubs, TestUpper) {
EXPECT_THAT(ctx.get_error(),
::testing::HasSubstr(
"unexpected byte \\c3 encountered while decoding utf8
string"));
+
+ ctx.Reset();
+
+ // Max Len Test
+ out_len = -1;
+ int32_t bad_len = std::numeric_limits<int32_t>::max() / 2 + 1;
+ const char* out = gdv_fn_upper_utf8(ctx_ptr, "dummy", bad_len, &out_len);
+ // Expect failure
+ EXPECT_EQ(out_len, 0);
+ EXPECT_STREQ(out, "");
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr("Would overflow maximum output size"));
+ ctx.Reset();
+
+ // Negative length test
+ out_len = -1;
+ out = gdv_fn_upper_utf8(ctx_ptr, "abc", -105, &out_len);
+ EXPECT_EQ(out_len, 0);
+ EXPECT_STREQ(out, "");
+ EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid (negative) data
length"));
ctx.Reset();
std::string e(
@@ -699,6 +723,26 @@ TEST(TestGdvFnStubs, TestLower) {
out_str = gdv_fn_lower_utf8(ctx_ptr, "", 0, &out_len);
EXPECT_EQ(std::string(out_str, out_len), "");
EXPECT_FALSE(ctx.has_error());
+ ctx.Reset();
+
+ // Max Len Test
+ out_len = -1;
+ int32_t bad_len = std::numeric_limits<int32_t>::max() / 2 + 1;
+ const char* out = gdv_fn_lower_utf8(ctx_ptr, "dummy", bad_len, &out_len);
+ // Expect failure
+ EXPECT_EQ(out_len, 0);
+ EXPECT_STREQ(out, "");
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr("Would overflow maximum output size"));
+ ctx.Reset();
+
+ // Negative length test
+ out_len = -1;
+ out = gdv_fn_lower_utf8(ctx_ptr, "abc", -105, &out_len);
+ EXPECT_EQ(out_len, 0);
+ EXPECT_STREQ(out, "");
+ EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid (negative) data
length"));
+ ctx.Reset();
std::string d("AbOJjÜoß\xc3");
out_str = gdv_fn_lower_utf8(ctx_ptr, d.data(), static_cast<int>(d.length()),
&out_len);
@@ -798,6 +842,25 @@ TEST(TestGdvFnStubs, TestInitCap) {
"unexpected byte \\c3 encountered while decoding utf8
string"));
ctx.Reset();
+ // Max Len Test
+ out_len = -1;
+ int32_t bad_len = std::numeric_limits<int32_t>::max() / 2 + 1;
+ const char* out = gdv_fn_initcap_utf8(ctx_ptr, "dummy", bad_len, &out_len);
+ // Expect failure
+ EXPECT_EQ(out_len, 0);
+ EXPECT_STREQ(out, "");
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr("Would overflow maximum output size"));
+ ctx.Reset();
+
+ // Negative length test
+ out_len = -1;
+ out = gdv_fn_initcap_utf8(ctx_ptr, "abc", -105, &out_len);
+ EXPECT_EQ(out_len, 0);
+ EXPECT_STREQ(out, "");
+ EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid (negative) data
length"));
+ ctx.Reset();
+
std::string e(
"åbÑg\xe0\xa0"
"åBUå");
@@ -1129,6 +1192,16 @@ TEST(TestGdvFnStubs, TestTranslate) {
result = translate_utf8_utf8_utf8(ctx_ptr, "987654321", 9, "123456789", 9,
"0123456789",
10, &out_len);
EXPECT_EQ(expected, std::string(result, out_len));
+
+ int32_t bad_in_len = std::numeric_limits<int32_t>::max() / 4 + 1;
+ out_len = -1;
+ const unsigned char bad_in_array[] = {0x80, 0x12, 0x13, 0x14};
+ result = translate_utf8_utf8_utf8(ctx_ptr, reinterpret_cast<const
char*>(bad_in_array),
+ bad_in_len, "B", 1, "C", 1, &out_len);
+ EXPECT_EQ(out_len, 0);
+ EXPECT_STREQ(result, "");
+ EXPECT_THAT(ctx.get_error(),
+ ::testing::HasSubstr("Would overflow maximum output size"));
}
TEST(TestGdvFnStubs, TestToUtcTimezone) {
diff --git a/cpp/src/gandiva/gdv_string_function_stubs.cc
b/cpp/src/gandiva/gdv_string_function_stubs.cc
index d271834fb4..b3159a2d74 100644
--- a/cpp/src/gandiva/gdv_string_function_stubs.cc
+++ b/cpp/src/gandiva/gdv_string_function_stubs.cc
@@ -213,6 +213,25 @@ int32_t gdv_fn_utf8_char_length(char c) {
return 0;
}
+static inline bool is_datalen_valid(int64_t context, int32_t data_len,
int32_t* alloc_len,
+ int32_t* out_len) {
+ // Reject negative lengths
+ if (ARROW_PREDICT_FALSE(data_len < 0)) {
+ gdv_fn_context_set_error_msg(context, "Invalid (negative) data length");
+ *out_len = 0;
+ return false;
+ }
+
+ // Check overflow: 2 * data_len
+ if (ARROW_PREDICT_FALSE(
+ arrow::internal::MultiplyWithOverflow(2, data_len, alloc_len))) {
+ gdv_fn_context_set_error_msg(context, "Would overflow maximum output
size");
+ *out_len = 0;
+ return false;
+ }
+ return true;
+}
+
// Convert an utf8 string to its corresponding lowercase string
GANDIVA_EXPORT
const char* gdv_fn_lower_utf8(int64_t context, const char* data, int32_t
data_len,
@@ -222,10 +241,15 @@ const char* gdv_fn_lower_utf8(int64_t context, const
char* data, int32_t data_le
return "";
}
+ int32_t alloc_length = 0;
+ if (ARROW_PREDICT_FALSE(!is_datalen_valid(context, data_len, &alloc_length,
out_len))) {
+ return "";
+ }
+
// If it is a single-byte character (ASCII), corresponding lowercase is
always 1-byte
// long; if it is >= 2 bytes long, lowercase can be at most 4 bytes long, so
length of
// the output can be at most twice the length of the input
- char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, 2 *
data_len));
+ char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context,
alloc_length));
if (out == nullptr) {
gdv_fn_context_set_error_msg(context, "Could not allocate memory for
output string");
*out_len = 0;
@@ -294,10 +318,15 @@ const char* gdv_fn_upper_utf8(int64_t context, const
char* data, int32_t data_le
return "";
}
+ int32_t alloc_length = 0;
+ if (ARROW_PREDICT_FALSE(!is_datalen_valid(context, data_len, &alloc_length,
out_len))) {
+ return "";
+ }
+
// If it is a single-byte character (ASCII), corresponding uppercase is
always 1-byte
// long; if it is >= 2 bytes long, uppercase can be at most 4 bytes long, so
length of
// the output can be at most twice the length of the input
- char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, 2 *
data_len));
+ char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context,
alloc_length));
if (out == nullptr) {
gdv_fn_context_set_error_msg(context, "Could not allocate memory for
output string");
*out_len = 0;
@@ -367,6 +396,17 @@ const char* gdv_fn_substring_index(int64_t context, const
char* txt, int32_t txt
return "";
}
+ if (ARROW_PREDICT_FALSE(txt_len < 0)) {
+ gdv_fn_context_set_error_msg(context, "Input string length cannot be
negative");
+ *out_len = 0;
+ return "";
+ }
+ if (ARROW_PREDICT_FALSE(pat_len < 0)) {
+ gdv_fn_context_set_error_msg(context, "Pattern string length cannot be
negative");
+ *out_len = 0;
+ return "";
+ }
+
char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context,
txt_len));
if (out == nullptr) {
gdv_fn_context_set_error_msg(context, "Could not allocate memory for
output string");
@@ -445,8 +485,8 @@ const char* gdv_fn_substring_index(int64_t context, const
char* txt, int32_t txt
return out;
} else {
+ memcpy(out, txt, static_cast<size_t>(txt_len));
*out_len = txt_len;
- memcpy(out, txt, txt_len);
return out;
}
}
@@ -480,10 +520,15 @@ const char* gdv_fn_initcap_utf8(int64_t context, const
char* data, int32_t data_
return "";
}
+ int32_t alloc_length = 0;
+ if (ARROW_PREDICT_FALSE(!is_datalen_valid(context, data_len, &alloc_length,
out_len))) {
+ return "";
+ }
+
// If it is a single-byte character (ASCII), corresponding uppercase is
always 1-byte
// long; if it is >= 2 bytes long, uppercase can be at most 4 bytes long, so
length of
// the output can be at most twice the length of the input
- char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, 2 *
data_len));
+ char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context,
alloc_length));
if (out == nullptr) {
gdv_fn_context_set_error_msg(context, "Could not allocate memory for
output string");
*out_len = 0;
@@ -579,15 +624,17 @@ const char* translate_utf8_utf8_utf8(int64_t context,
const char* in, int32_t in
return in;
}
+ int32_t alloc_length = 0;
+
// This variable is to control if there are multi-byte utf8 entries
bool has_multi_byte = false;
// This variable is to store the final result
char* result;
- int result_len;
+ int32_t result_len;
// Searching multi-bytes in In
- for (int i = 0; i < in_len; i++) {
+ for (int32_t i = 0; i < in_len; i++) {
unsigned char char_single_byte = in[i];
if (char_single_byte > 127) {
// found a multi-byte utf-8 char
@@ -598,7 +645,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const
char* in, int32_t in
// Searching multi-bytes in From
if (!has_multi_byte) {
- for (int i = 0; i < from_len; i++) {
+ for (int32_t i = 0; i < from_len; i++) {
unsigned char char_single_byte = from[i];
if (char_single_byte > 127) {
// found a multi-byte utf-8 char
@@ -610,7 +657,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const
char* in, int32_t in
// Searching multi-bytes in To
if (!has_multi_byte) {
- for (int i = 0; i < to_len; i++) {
+ for (int32_t i = 0; i < to_len; i++) {
unsigned char char_single_byte = to[i];
if (char_single_byte > 127) {
// found a multi-byte utf-8 char
@@ -638,7 +685,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const
char* in, int32_t in
// This variable is for controlling the position in entry TO, for never
repeat the
// changes
- int start_compare;
+ int32_t start_compare;
if (to_len > 0) {
start_compare = 0;
@@ -650,7 +697,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const
char* in, int32_t in
// list, to mark deletion positions
const char empty = '\0';
- for (int in_for = 0; in_for < in_len; in_for++) {
+ for (int32_t in_for = 0; in_for < in_len; in_for++) {
if (subs_list.find(in[in_for]) != subs_list.end()) {
if (subs_list[in[in_for]] != empty) {
// If exist in map, only add the correspondent value in result
@@ -658,7 +705,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const
char* in, int32_t in
result_len++;
}
} else {
- for (int from_for = 0; from_for <= from_len; from_for++) {
+ for (int32_t from_for = 0; from_for <= from_len; from_for++) {
if (from_for == from_len) {
// If it's not in the FROM list, just add it to the map and the
result.
subs_list.insert(std::pair<char, char>(in[in_for], in[in_for]));
@@ -686,10 +733,18 @@ const char* translate_utf8_utf8_utf8(int64_t context,
const char* in, int32_t in
}
}
}
- } else { // If there are no multibytes in the input, work with std::strings
+ } else {
+ // Check overflow: 4 * in_len
+ if (ARROW_PREDICT_FALSE(
+ arrow::internal::MultiplyWithOverflow(4, in_len, &alloc_length))) {
+ gdv_fn_context_set_error_msg(context, "Would overflow maximum output
size");
+ *out_len = 0;
+ return "";
+ }
+ // If there are multibytes in the input, work with std::strings
// This variable is for receive the substitutions, malloc is in_len * 4 to
receive
// possible inputs with 4 bytes
- result = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context,
in_len * 4));
+ result = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context,
alloc_length));
if (result == nullptr) {
gdv_fn_context_set_error_msg(context,
@@ -704,7 +759,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const
char* in, int32_t in
// This variable is for controlling the position in entry TO, for never
repeat the
// changes
- int start_compare;
+ int32_t start_compare;
if (to_len > 0) {
start_compare = 0;
@@ -717,11 +772,11 @@ const char* translate_utf8_utf8_utf8(int64_t context,
const char* in, int32_t in
const std::string empty = "";
// This variables is to control len of multi-bytes entries
- int len_char_in = 0;
- int len_char_from = 0;
- int len_char_to = 0;
+ int32_t len_char_in = 0;
+ int32_t len_char_from = 0;
+ int32_t len_char_to = 0;
- for (int in_for = 0; in_for < in_len; in_for += len_char_in) {
+ for (int32_t in_for = 0; in_for < in_len; in_for += len_char_in) {
// Updating len to char in this position
len_char_in = gdv_fn_utf8_char_length(in[in_for]);
// Making copy to std::string with length for this char position
@@ -734,7 +789,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const
char* in, int32_t in
result_len += static_cast<int>(subs_list[insert_copy_key].length());
}
} else {
- for (int from_for = 0; from_for <= from_len; from_for +=
len_char_from) {
+ for (int32_t from_for = 0; from_for <= from_len; from_for +=
len_char_from) {
// Updating len to char in this position
len_char_from = gdv_fn_utf8_char_length(from[from_for]);
// Making copy to std::string with length for this char position
diff --git a/cpp/src/gandiva/precompiled/string_ops.cc
b/cpp/src/gandiva/precompiled/string_ops.cc
index ff0fa026c8..6c9c392a82 100644
--- a/cpp/src/gandiva/precompiled/string_ops.cc
+++ b/cpp/src/gandiva/precompiled/string_ops.cc
@@ -1961,11 +1961,29 @@ const char* quote_utf8(gdv_int64 context, const char*
in, gdv_int32 in_len,
*out_len = 0;
return "";
}
+
+ int32_t double_len = 0;
+ // Test multiply overflow for in_len
+ if (ARROW_PREDICT_FALSE(
+ arrow::internal::MultiplyWithOverflow(2, in_len, &double_len))) {
+ gdv_fn_context_set_error_msg(context, "Memory allocation size too large.");
+ *out_len = 0;
+ return "";
+ }
+
+ int32_t alloc_length = 0;
+ // Test add overflow for in_len
+ if (ARROW_PREDICT_FALSE(
+ arrow::internal::AddWithOverflow(2, double_len, &alloc_length))) {
+ gdv_fn_context_set_error_msg(context, "Memory allocation size too large.");
+ *out_len = 0;
+ return "";
+ }
+
// try to allocate double size output string (worst case)
- auto out =
- reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, (in_len *
2) + 2));
+ auto out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context,
alloc_length));
if (out == nullptr) {
- gdv_fn_context_set_error_msg(context, "Could not allocate memory for
output string");
+ gdv_fn_context_set_error_msg(context, "Could not allocate memory for
output string.");
*out_len = 0;
return "";
}
@@ -2494,118 +2512,166 @@ void concat_word(char* out_buf, int* out_idx, const
char* in_buf, int in_len,
*out_idx += in_len;
}
-FORCE_INLINE
-const char* concat_ws_utf8_utf8(int64_t context, const char* separator,
- int32_t separator_len, bool separator_validity,
- const char* word1, int32_t word1_len, bool
word1_validity,
- const char* word2, int32_t word2_len, bool
word2_validity,
- bool* out_valid, int32_t* out_len) {
- *out_len = 0;
- int numValidInput = 0;
- // If separator is null, always return null
- if (!separator_validity) {
- *out_len = 0;
- *out_valid = false;
- return "";
- }
+// Helper structure to maintain state during safe length accumulation
+struct SafeLengthState {
+ int32_t total_len = 0;
+ int32_t num_valid = 0;
+ bool overflow = false;
+};
+
+// Helper to safely add a word length
+static inline bool safe_accumulate_word(int64_t context, SafeLengthState&
state,
+ int32_t word_len, bool word_validity) {
+ if (!word_validity) return true;
- if (word1_validity) {
- *out_len += word1_len;
- numValidInput++;
+ if (word_len < 0) {
+ gdv_fn_context_set_error_msg(context, "Invalid word length.");
+ return false;
}
- if (word2_validity) {
- *out_len += word2_len;
- numValidInput++;
+
+ int32_t temp = 0;
+ if (ARROW_PREDICT_FALSE(
+ arrow::internal::AddWithOverflow(state.total_len, word_len, &temp)))
{
+ gdv_fn_context_set_error_msg(context, "Overflow in addition detected.");
+ state.overflow = true;
+ return false;
}
+ state.total_len = temp;
+ state.num_valid++;
+ return true;
+}
- *out_len += separator_len * (numValidInput > 1 ? numValidInput - 1 : 0);
- if (*out_len == 0) {
- *out_valid = true;
- return "";
+// Helper to safely add separators based on number of valid words
+static inline bool safe_add_separators(int64_t context, SafeLengthState* state,
+ int32_t separator_len) {
+ if (state->num_valid <= 1) return true;
+
+ int32_t sep_total = 0;
+ int32_t temp = 0;
+
+ if (ARROW_PREDICT_FALSE(arrow::internal::MultiplyWithOverflow(
+ separator_len, state->num_valid - 1, &sep_total))) {
+ gdv_fn_context_set_error_msg(context, "Overflow in multiplication
detected.");
+ state->overflow = true;
+ return false;
}
- char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context,
*out_len));
- if (out == nullptr) {
- gdv_fn_context_set_error_msg(context, "Could not allocate memory for
output string");
- *out_len = 0;
- *out_valid = false;
- return "";
+ if (ARROW_PREDICT_FALSE(
+ arrow::internal::AddWithOverflow(state->total_len, sep_total,
&temp))) {
+ gdv_fn_context_set_error_msg(context, "Overflow in addition detected.");
+ state->overflow = true;
+ return false;
}
- char* tmp = out;
- int out_idx = 0;
- bool seenAnyValidInput = false;
+ state->total_len = temp;
+ return true;
+}
- concat_word(tmp, &out_idx, word1, word1_len, word1_validity, separator,
separator_len,
- &seenAnyValidInput);
- concat_word(tmp, &out_idx, word2, word2_len, word2_validity, separator,
separator_len,
- &seenAnyValidInput);
+// Helper to handle overflow failure (sets output parameters and returns empty
string)
+static inline const char* handle_overflow_failure(bool* out_valid, int32_t*
out_len) {
+ *out_len = 0;
+ *out_valid = false;
+ return "";
+}
+// Helper to handle empty result (all words invalid)
+static inline const char* handle_empty_result(bool* out_valid, int32_t*
out_len) {
+ *out_len = 0;
*out_valid = true;
- *out_len = out_idx;
- return out;
+ return "";
}
-FORCE_INLINE
-const char* concat_ws_utf8_utf8_utf8(
- int64_t context, const char* separator, int32_t separator_len,
- bool separator_validity, const char* word1, int32_t word1_len, bool
word1_validity,
- const char* word2, int32_t word2_len, bool word2_validity, const char*
word3,
- int32_t word3_len, bool word3_validity, bool* out_valid, int32_t* out_len)
{
+struct WordArg {
+ const char* data;
+ int32_t len;
+ bool valid;
+};
+
+static inline const char* concat_ws_impl(int64_t context, const char*
separator,
+ int32_t separator_len, bool
separator_validity,
+ bool* out_valid, int32_t* out_len,
+ std::initializer_list<WordArg> words)
{
*out_len = 0;
- int numValidInput = 0;
- // If separator is null, always return null
+
+ // Separator validity check
if (!separator_validity) {
- *out_len = 0;
*out_valid = false;
return "";
}
-
- if (word1_validity) {
- *out_len += word1_len;
- numValidInput++;
- }
- if (word2_validity) {
- *out_len += word2_len;
- numValidInput++;
+ if (separator_len < 0) {
+ gdv_fn_context_set_error_msg(context, "Separator length cannot be
negative");
+ *out_valid = false;
+ return "";
}
- if (word3_validity) {
- *out_len += word3_len;
- numValidInput++;
+
+ SafeLengthState state;
+
+ // Accumulate all word lengths safely
+ for (const WordArg& w : words) {
+ if (!safe_accumulate_word(context, state, w.len, w.valid)) {
+ *out_len = 0;
+ *out_valid = false;
+ return "";
+ }
}
- *out_len += separator_len * (numValidInput > 1 ? numValidInput - 1 : 0);
+ // Add separator lengths
+ if (!safe_add_separators(context, &state, separator_len)) {
+ return handle_overflow_failure(out_valid, out_len);
+ }
- if (*out_len == 0) {
- *out_len = 0;
- *out_valid = true;
- return "";
+ // Empty result
+ if (state.total_len == 0) {
+ return handle_empty_result(out_valid, out_len);
}
- char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context,
*out_len));
+ // Allocate memory
+ char* out =
+ reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context,
state.total_len));
if (out == nullptr) {
gdv_fn_context_set_error_msg(context, "Could not allocate memory for
output string");
- *out_len = 0;
- *out_valid = false;
- return "";
+ return handle_overflow_failure(out_valid, out_len);
}
+ // Concatenate all words
char* tmp = out;
int out_idx = 0;
bool seenAnyValidInput = false;
- concat_word(tmp, &out_idx, word1, word1_len, word1_validity, separator,
separator_len,
- &seenAnyValidInput);
- concat_word(tmp, &out_idx, word2, word2_len, word2_validity, separator,
separator_len,
- &seenAnyValidInput);
- concat_word(tmp, &out_idx, word3, word3_len, word3_validity, separator,
separator_len,
- &seenAnyValidInput);
+ for (const WordArg& w : words) {
+ concat_word(tmp, &out_idx, w.data, w.len, w.valid, separator,
separator_len,
+ &seenAnyValidInput);
+ }
*out_valid = true;
*out_len = out_idx;
return out;
}
+FORCE_INLINE
+const char* concat_ws_utf8_utf8(int64_t context, const char* separator,
+ int32_t separator_len, bool separator_validity,
+ const char* word1, int32_t word1_len, bool
word1_validity,
+ const char* word2, int32_t word2_len, bool
word2_validity,
+ bool* out_valid, int32_t* out_len) {
+ return concat_ws_impl(
+ context, separator, separator_len, separator_validity, out_valid,
out_len,
+ {{word1, word1_len, word1_validity}, {word2, word2_len,
word2_validity}});
+}
+
+FORCE_INLINE
+const char* concat_ws_utf8_utf8_utf8(
+ int64_t context, const char* separator, int32_t separator_len,
+ bool separator_validity, const char* word1, int32_t word1_len, bool
word1_validity,
+ const char* word2, int32_t word2_len, bool word2_validity, const char*
word3,
+ int32_t word3_len, bool word3_validity, bool* out_valid, int32_t* out_len)
{
+ return concat_ws_impl(context, separator, separator_len, separator_validity,
out_valid,
+ out_len,
+ {{word1, word1_len, word1_validity},
+ {word2, word2_len, word2_validity},
+ {word3, word3_len, word3_validity}});
+}
+
FORCE_INLINE
const char* concat_ws_utf8_utf8_utf8_utf8(
int64_t context, const char* separator, int32_t separator_len,
@@ -2613,63 +2679,12 @@ const char* concat_ws_utf8_utf8_utf8_utf8(
const char* word2, int32_t word2_len, bool word2_validity, const char*
word3,
int32_t word3_len, bool word3_validity, const char* word4, int32_t
word4_len,
bool word4_validity, bool* out_valid, int32_t* out_len) {
- *out_len = 0;
- int numValidInput = 0;
- // If separator is null, always return null
- if (!separator_validity) {
- *out_len = 0;
- *out_valid = false;
- return "";
- }
- if (word1_validity) {
- *out_len += word1_len;
- numValidInput++;
- }
- if (word2_validity) {
- *out_len += word2_len;
- numValidInput++;
- }
- if (word3_validity) {
- *out_len += word3_len;
- numValidInput++;
- }
- if (word4_validity) {
- *out_len += word4_len;
- numValidInput++;
- }
-
- *out_len += separator_len * (numValidInput > 1 ? numValidInput - 1 : 0);
-
- if (*out_len == 0) {
- *out_len = 0;
- *out_valid = true;
- return "";
- }
-
- char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context,
*out_len));
- if (out == nullptr) {
- gdv_fn_context_set_error_msg(context, "Could not allocate memory for
output string");
- *out_valid = false;
- *out_len = 0;
- return "";
- }
-
- char* tmp = out;
- int out_idx = 0;
- bool seenAnyValidInput = false;
-
- concat_word(tmp, &out_idx, word1, word1_len, word1_validity, separator,
separator_len,
- &seenAnyValidInput);
- concat_word(tmp, &out_idx, word2, word2_len, word2_validity, separator,
separator_len,
- &seenAnyValidInput);
- concat_word(tmp, &out_idx, word3, word3_len, word3_validity, separator,
separator_len,
- &seenAnyValidInput);
- concat_word(tmp, &out_idx, word4, word4_len, word4_validity, separator,
separator_len,
- &seenAnyValidInput);
-
- *out_valid = true;
- *out_len = out_idx;
- return out;
+ return concat_ws_impl(context, separator, separator_len, separator_validity,
out_valid,
+ out_len,
+ {{word1, word1_len, word1_validity},
+ {word2, word2_len, word2_validity},
+ {word3, word3_len, word3_validity},
+ {word4, word4_len, word4_validity}});
}
FORCE_INLINE
@@ -2680,69 +2695,13 @@ const char* concat_ws_utf8_utf8_utf8_utf8_utf8(
int32_t word3_len, bool word3_validity, const char* word4, int32_t
word4_len,
bool word4_validity, const char* word5, int32_t word5_len, bool
word5_validity,
bool* out_valid, int32_t* out_len) {
- *out_len = 0;
- int numValidInput = 0;
- // If separator is null, always return null
- if (!separator_validity) {
- *out_len = 0;
- *out_valid = false;
- return "";
- }
- if (word1_validity) {
- *out_len += word1_len;
- numValidInput++;
- }
- if (word2_validity) {
- *out_len += word2_len;
- numValidInput++;
- }
- if (word3_validity) {
- *out_len += word3_len;
- numValidInput++;
- }
- if (word4_validity) {
- *out_len += word4_len;
- numValidInput++;
- }
- if (word5_validity) {
- *out_len += word5_len;
- numValidInput++;
- }
-
- *out_len += separator_len * (numValidInput > 1 ? numValidInput - 1 : 0);
-
- if (*out_len == 0) {
- *out_len = 0;
- *out_valid = true;
- return "";
- }
-
- char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context,
*out_len));
- if (out == nullptr) {
- gdv_fn_context_set_error_msg(context, "Could not allocate memory for
output string");
- *out_len = 0;
- *out_valid = false;
- return "";
- }
-
- char* tmp = out;
- int out_idx = 0;
- bool seenAnyValidInput = false;
-
- concat_word(tmp, &out_idx, word1, word1_len, word1_validity, separator,
separator_len,
- &seenAnyValidInput);
- concat_word(tmp, &out_idx, word2, word2_len, word2_validity, separator,
separator_len,
- &seenAnyValidInput);
- concat_word(tmp, &out_idx, word3, word3_len, word3_validity, separator,
separator_len,
- &seenAnyValidInput);
- concat_word(tmp, &out_idx, word4, word4_len, word4_validity, separator,
separator_len,
- &seenAnyValidInput);
- concat_word(tmp, &out_idx, word5, word5_len, word5_validity, separator,
separator_len,
- &seenAnyValidInput);
-
- *out_valid = true;
- *out_len = out_idx;
- return out;
+ return concat_ws_impl(context, separator, separator_len, separator_validity,
out_valid,
+ out_len,
+ {{word1, word1_len, word1_validity},
+ {word2, word2_len, word2_validity},
+ {word3, word3_len, word3_validity},
+ {word4, word4_len, word4_validity},
+ {word5, word5_len, word5_validity}});
}
FORCE_INLINE
@@ -2879,8 +2838,31 @@ const char* to_hex_binary(int64_t context, const char*
text, int32_t text_len,
return "";
}
- auto ret =
- reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, text_len *
2 + 1));
+ if (ARROW_PREDICT_FALSE(text_len < 0)) {
+ gdv_fn_context_set_error_msg(context, "Text length invalid (negative).");
+ *out_len = 0;
+ return "";
+ }
+
+ int32_t double_len = 0;
+ // Check multiply overflow for text_len
+ if (ARROW_PREDICT_FALSE(
+ arrow::internal::MultiplyWithOverflow(2, text_len, &double_len))) {
+ gdv_fn_context_set_error_msg(context, "Memory allocation size too large.");
+ *out_len = 0;
+ return "";
+ }
+
+ int32_t alloc_length = 0;
+ // Check add overflow for text_len
+ if (ARROW_PREDICT_FALSE(
+ arrow::internal::AddWithOverflow(1, double_len, &alloc_length))) {
+ gdv_fn_context_set_error_msg(context, "Memory allocation size too large.");
+ *out_len = 0;
+ return "";
+ }
+
+ auto ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context,
alloc_length));
if (ret == nullptr) {
gdv_fn_context_set_error_msg(context, "Could not allocate memory for
output string");
diff --git a/cpp/src/gandiva/precompiled/string_ops_test.cc
b/cpp/src/gandiva/precompiled/string_ops_test.cc
index a31683c65a..5a317d4595 100644
--- a/cpp/src/gandiva/precompiled/string_ops_test.cc
+++ b/cpp/src/gandiva/precompiled/string_ops_test.cc
@@ -1145,26 +1145,46 @@ TEST(TestStringOps, TestQuote) {
out_str = quote_utf8(ctx_ptr, "dont", 4, &out_len);
EXPECT_EQ(std::string(out_str, out_len), "\'dont\'");
EXPECT_FALSE(ctx.has_error());
+ ctx.Reset();
out_str = quote_utf8(ctx_ptr, "abc", 3, &out_len);
EXPECT_EQ(std::string(out_str, out_len), "\'abc\'");
EXPECT_FALSE(ctx.has_error());
+ ctx.Reset();
out_str = quote_utf8(ctx_ptr, "don't", 5, &out_len);
EXPECT_EQ(std::string(out_str, out_len), "\'don\\'t\'");
EXPECT_FALSE(ctx.has_error());
+ ctx.Reset();
out_str = quote_utf8(ctx_ptr, "", 0, &out_len);
EXPECT_EQ(std::string(out_str, out_len), "");
EXPECT_FALSE(ctx.has_error());
+ ctx.Reset();
out_str = quote_utf8(ctx_ptr, "'", 1, &out_len);
EXPECT_EQ(std::string(out_str, out_len), "'\\''");
EXPECT_FALSE(ctx.has_error());
+ ctx.Reset();
out_str = quote_utf8(ctx_ptr, "'''''''''", 9, &out_len);
EXPECT_EQ(std::string(out_str, out_len), "'\\'\\'\\'\\'\\'\\'\\'\\'\\''");
EXPECT_FALSE(ctx.has_error());
+ ctx.Reset();
+
+ int32_t bad_in_len = std::numeric_limits<int32_t>::max() / 2 + 1;
+ out_str = quote_utf8(ctx_ptr, "YYZ", bad_in_len, &out_len);
+ EXPECT_EQ(ctx.get_error(), "Memory allocation size too large.");
+ EXPECT_EQ(out_len, 0);
+ EXPECT_STREQ(out_str, "");
+ ctx.Reset();
+
+ bad_in_len = std::numeric_limits<int32_t>::max() / 2 + 20;
+ out_str = quote_utf8(ctx_ptr, "ABCDE", bad_in_len, &out_len);
+ EXPECT_EQ(ctx.get_error(), "Memory allocation size too large.");
+ EXPECT_EQ(out_len, 0);
+ EXPECT_STREQ(out_str, "");
+ ctx.Reset();
}
TEST(TestStringOps, TestLtrim) {
@@ -2298,11 +2318,42 @@ TEST(TestStringOps, TestConcatWs) {
EXPECT_EQ(std::string(out, out_len), "hey");
EXPECT_EQ(out_result, true);
+ // Max word1 len
+ out = concat_ws_utf8_utf8(ctx_ptr, separator, sep_len, true, word1,
+ std::numeric_limits<int32_t>::max(), true, word2,
word2_len,
+ true, &out_result, &out_len);
+ EXPECT_STREQ(out, "");
+ EXPECT_EQ(out_len, 0);
+ EXPECT_EQ(out_result, false);
+
+ // Max word2 len
+ out = concat_ws_utf8_utf8(ctx_ptr, separator, sep_len, true, word1,
word1_len, true,
+ word2, std::numeric_limits<int32_t>::max(), true,
&out_result,
+ &out_len);
+ EXPECT_STREQ(out, "");
+ EXPECT_EQ(out_len, 0);
+ EXPECT_EQ(out_result, false);
+
+ // Max separator len
+ out = concat_ws_utf8_utf8(ctx_ptr, separator,
std::numeric_limits<int32_t>::max(), true,
+ word1, word1_len, true, word2, word2_len, true,
&out_result,
+ &out_len);
+ EXPECT_STREQ(out, "");
+ EXPECT_EQ(out_len, 0);
+ EXPECT_EQ(out_result, false);
+
separator = "#";
sep_len = static_cast<int32_t>(strlen(separator));
const char* word3 = "wow";
int32_t word3_len = static_cast<int32_t>(strlen(word3));
+ out = concat_ws_utf8_utf8_utf8(ctx_ptr, separator,
std::numeric_limits<int32_t>::max(),
+ true, word1, word1_len, true, word2,
word2_len, true,
+ word3, word3_len, true, &out_result,
&out_len);
+ EXPECT_STREQ(out, "");
+ EXPECT_EQ(out_len, 0);
+ EXPECT_EQ(out_result, false);
+
out = concat_ws_utf8_utf8_utf8(ctx_ptr, separator, sep_len, true, word1,
word1_len,
true, word2, word2_len, true, word3,
word3_len, true,
&out_result, &out_len);
@@ -2344,6 +2395,14 @@ TEST(TestStringOps, TestConcatWs) {
const char* word4 = "awesome";
int32_t word4_len = static_cast<int32_t>(strlen(word4));
+ out = concat_ws_utf8_utf8_utf8_utf8(ctx_ptr, separator, sep_len, true, word1,
+ std::numeric_limits<int32_t>::max(),
true, word2,
+ word2_len, true, word3, word3_len, true,
word4,
+ word4_len, true, &out_result, &out_len);
+ EXPECT_STREQ(out, "");
+ EXPECT_EQ(out_len, 0);
+ EXPECT_EQ(out_result, false);
+
out = concat_ws_utf8_utf8_utf8_utf8(
ctx_ptr, separator, sep_len, true, word1, word1_len, true, word2,
word2_len, true,
word3, word3_len, true, word4, word4_len, true, &out_result, &out_len);
@@ -2355,6 +2414,20 @@ TEST(TestStringOps, TestConcatWs) {
const char* word5 = "super";
int32_t word5_len = static_cast<int32_t>(strlen(word5));
+ out = concat_ws_utf8_utf8_utf8_utf8_utf8(
+ ctx_ptr, separator, sep_len, true, word1, word1_len, true, word2,
word2_len, true,
+ word3, word3_len, true, word4, std::numeric_limits<int32_t>::max(),
true, word5,
+ std::numeric_limits<int32_t>::max(), true, &out_result, &out_len);
+ EXPECT_STREQ(out, "");
+ EXPECT_EQ(out_result, false);
+
+ out = concat_ws_utf8_utf8_utf8_utf8_utf8(ctx_ptr, separator, sep_len, true,
word1,
+ word1_len, true, word2, word2_len,
true, word3,
+ word3_len, true, word4, -25, true,
word5,
+ word5_len, true, &out_result,
&out_len);
+ EXPECT_STREQ(out, "");
+ EXPECT_EQ(out_result, false);
+
out = concat_ws_utf8_utf8_utf8_utf8_utf8(ctx_ptr, separator, sep_len, true,
word1,
word1_len, true, word2, word2_len,
true, word3,
word3_len, true, word4, word4_len,
true, word5,
@@ -2498,6 +2571,27 @@ TEST(TestStringOps, TestToHex) {
output = std::string(out_str, out_len);
EXPECT_EQ(out_len, 2 * in_len);
EXPECT_EQ(output, "090A090A090A090A0A0A092061206C657474405D6572");
+ ctx.Reset();
+
+ int32_t bad_text_len = std::numeric_limits<int32_t>::max() / 2 + 20;
+ out_str = to_hex_binary(ctx_ptr, binary_string, bad_text_len, &out_len);
+ EXPECT_EQ(out_len, 0);
+ EXPECT_STREQ(out_str, "");
+ ctx.Reset();
+
+ bad_text_len = (std::numeric_limits<int32_t>::max() / 2) + 1;
+ out_str = to_hex_binary(ctx_ptr, binary_string, bad_text_len, &out_len);
+ EXPECT_EQ(out_len, 0);
+ EXPECT_STREQ(out_str, "");
+ EXPECT_EQ(ctx.get_error(), "Memory allocation size too large.");
+ ctx.Reset();
+
+ int32_t neg_in_len = -20;
+ out_str = to_hex_binary(ctx_ptr, binary_string, neg_in_len, &out_len);
+ EXPECT_EQ(out_len, 0);
+ EXPECT_STREQ(out_str, "");
+ EXPECT_EQ(ctx.get_error(), "Text length invalid (negative).");
+ ctx.Reset();
}
TEST(TestStringOps, TestToHexInt64) {