This is an automated email from the ASF dual-hosted git repository.

kou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new d9aeaa112b GH-49614: [C++] Report an error instead of silent 
truncation in base64_decode on invalid input (#49660)
d9aeaa112b is described below

commit d9aeaa112b15bf9201f2094eccbf7f59d9abe16d
Author: Aaditya Srinivasan <[email protected]>
AuthorDate: Tue Apr 14 13:24:49 2026 +0530

    GH-49614: [C++] Report an error instead of silent truncation in 
base64_decode on invalid input (#49660)
    
    ### Rationale for this change
    
    `arrow::util::base64_decode` previously allowed invalid input to be 
processed, which could result in silently truncated or incorrect output without 
signaling an error. This can lead to unintended data corruption.
    
    ### What changes are included in this PR?
    
    - Change `base64_decode` to return `arrow::Result<std::string>` instead of 
`std::string`
    - Add validation for:
      - invalid input length
      - invalid base64 characters
      - incorrect padding
    - Return an error (`Status::Invalid`) for invalid input instead of 
producing partial output
    - Update all call sites to handle `Result<std::string>`
    - Add unit tests covering valid and invalid inputs
    
    ### Are these changes tested?
    
    Yes. Unit tests have been added to verify:
    
    - valid decoding behavior
    - invalid input length
    - invalid characters
    - incorrect padding handling
    
    ### Are there any user-facing changes?
    
    - The API now returns `arrow::Result<std::string>` instead of `std::string`
    - Invalid base64 input now results in an error (`Status::Invalid`) instead 
of returning partial or incorrect output
    
    * GitHub Issue: #49614
    
    Authored-by: Aaditya Srinivasan <[email protected]>
    Signed-off-by: Sutou Kouhei <[email protected]>
---
 cpp/src/arrow/flight/flight_test.cc                |  3 +-
 cpp/src/arrow/util/CMakeLists.txt                  |  1 +
 cpp/src/arrow/util/base64.h                        |  3 +-
 cpp/src/arrow/util/base64_test.cc                  | 98 ++++++++++++++++++++++
 cpp/src/arrow/vendored/base64.cpp                  | 74 ++++++++++------
 cpp/src/gandiva/gdv_function_stubs.cc              | 10 ++-
 cpp/src/parquet/arrow/fuzz_internal.cc             |  5 +-
 cpp/src/parquet/arrow/schema.cc                    |  3 +-
 cpp/src/parquet/encryption/file_key_unwrapper.cc   |  6 +-
 cpp/src/parquet/encryption/key_toolkit_internal.cc |  3 +-
 10 files changed, 173 insertions(+), 33 deletions(-)

diff --git a/cpp/src/arrow/flight/flight_test.cc 
b/cpp/src/arrow/flight/flight_test.cc
index 16a4909828..eb93101190 100644
--- a/cpp/src/arrow/flight/flight_test.cc
+++ b/cpp/src/arrow/flight/flight_test.cc
@@ -620,7 +620,8 @@ void ParseBasicHeader(const CallHeaders& incoming_headers, 
std::string& username
                       std::string& password) {
   std::string encoded_credentials =
       FindKeyValPrefixInCallHeaders(incoming_headers, kAuthHeader, 
kBasicPrefix);
-  std::stringstream 
decoded_stream(arrow::util::base64_decode(encoded_credentials));
+  ASSERT_OK_AND_ASSIGN(auto decoded, 
arrow::util::base64_decode(encoded_credentials));
+  std::stringstream decoded_stream(decoded);
   std::getline(decoded_stream, username, ':');
   std::getline(decoded_stream, password, ':');
 }
diff --git a/cpp/src/arrow/util/CMakeLists.txt 
b/cpp/src/arrow/util/CMakeLists.txt
index 4352716ebd..deb3e9e3fb 100644
--- a/cpp/src/arrow/util/CMakeLists.txt
+++ b/cpp/src/arrow/util/CMakeLists.txt
@@ -49,6 +49,7 @@ add_arrow_test(utility-test
                SOURCES
                align_util_test.cc
                atfork_test.cc
+               base64_test.cc
                byte_size_test.cc
                byte_stream_split_test.cc
                cache_test.cc
diff --git a/cpp/src/arrow/util/base64.h b/cpp/src/arrow/util/base64.h
index 5b80e19d89..a575fee451 100644
--- a/cpp/src/arrow/util/base64.h
+++ b/cpp/src/arrow/util/base64.h
@@ -20,6 +20,7 @@
 #include <string>
 #include <string_view>
 
+#include "arrow/result.h"
 #include "arrow/util/visibility.h"
 
 namespace arrow {
@@ -29,7 +30,7 @@ ARROW_EXPORT
 std::string base64_encode(std::string_view s);
 
 ARROW_EXPORT
-std::string base64_decode(std::string_view s);
+arrow::Result<std::string> base64_decode(std::string_view s);
 
 }  // namespace util
 }  // namespace arrow
diff --git a/cpp/src/arrow/util/base64_test.cc 
b/cpp/src/arrow/util/base64_test.cc
new file mode 100644
index 0000000000..cd7f9bab31
--- /dev/null
+++ b/cpp/src/arrow/util/base64_test.cc
@@ -0,0 +1,98 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/util/base64.h"
+#include "arrow/testing/gtest_util.h"
+
+namespace arrow {
+namespace util {
+
+TEST(Base64DecodeTest, ValidInputs) {
+  ASSERT_OK_AND_ASSIGN(auto empty, base64_decode(""));
+  EXPECT_EQ(empty, "");
+
+  ASSERT_OK_AND_ASSIGN(auto two_paddings, base64_decode("Zg=="));
+  EXPECT_EQ(two_paddings, "f");
+
+  ASSERT_OK_AND_ASSIGN(auto one_padding, base64_decode("Zm8="));
+  EXPECT_EQ(one_padding, "fo");
+
+  ASSERT_OK_AND_ASSIGN(auto no_padding, base64_decode("Zm9v"));
+  EXPECT_EQ(no_padding, "foo");
+
+  ASSERT_OK_AND_ASSIGN(auto multiblock, base64_decode("SGVsbG8gd29ybGQ="));
+  EXPECT_EQ(multiblock, "Hello world");
+}
+
+TEST(Base64DecodeTest, BinaryOutput) {
+  // 'A' maps to index 0 — same zero value used for padding slots
+  // verifies the 'A' bug is not present
+  ASSERT_OK_AND_ASSIGN(auto all_A, base64_decode("AAAA"));
+  EXPECT_EQ(all_A, std::string("\x00\x00\x00", 3));
+
+  // Arbitrary non-ASCII output bytes
+  ASSERT_OK_AND_ASSIGN(auto binary, base64_decode("AP8A"));
+  EXPECT_EQ(binary, std::string("\x00\xff\x00", 3));
+}
+
+TEST(Base64DecodeTest, InvalidLength) {
+  ASSERT_RAISES_WITH_MESSAGE(
+      Invalid, "Invalid: Invalid base64 input: length is not a multiple of 4",
+      base64_decode("abc"));
+}
+
+TEST(Base64DecodeTest, InvalidCharacters) {
+  ASSERT_RAISES_WITH_MESSAGE(
+      Invalid, "Invalid: Invalid base64 input: character is not valid base64 
character",
+      base64_decode("ab$="));
+
+  // Non-ASCII byte
+  std::string non_ascii = std::string("abc") + static_cast<char>(0xFF);
+  ASSERT_RAISES_WITH_MESSAGE(
+      Invalid, "Invalid: Invalid base64 input: character is not valid base64 
character",
+      base64_decode(non_ascii));
+
+  // Corruption mid-string across multiple blocks
+  ASSERT_RAISES_WITH_MESSAGE(
+      Invalid, "Invalid: Invalid base64 input: character is not valid base64 
character",
+      base64_decode("aGVs$G8gd29ybGQ="));
+}
+
+TEST(Base64DecodeTest, InvalidPadding) {
+  // Padding in wrong position within block
+  ASSERT_RAISES_WITH_MESSAGE(Invalid,
+                             "Invalid: Invalid base64 input: padding in wrong 
position",
+                             base64_decode("ab=c"));
+
+  // 3 padding characters — exceeds maximum of 2
+  ASSERT_RAISES_WITH_MESSAGE(Invalid,
+                             "Invalid: Invalid base64 input: too many padding 
characters",
+                             base64_decode("a==="));
+
+  // 4 padding characters
+  ASSERT_RAISES_WITH_MESSAGE(Invalid,
+                             "Invalid: Invalid base64 input: too many padding 
characters",
+                             base64_decode("===="));
+
+  // Padding in non-final block across multiple blocks
+  ASSERT_RAISES_WITH_MESSAGE(Invalid,
+                             "Invalid: Invalid base64 input: padding in wrong 
position",
+                             base64_decode("Zm8=Zm8="));
+}
+
+}  // namespace util
+}  // namespace arrow
diff --git a/cpp/src/arrow/vendored/base64.cpp 
b/cpp/src/arrow/vendored/base64.cpp
index 6f53c0524e..db2f74ed98 100644
--- a/cpp/src/arrow/vendored/base64.cpp
+++ b/cpp/src/arrow/vendored/base64.cpp
@@ -40,11 +40,6 @@ static const std::string base64_chars =
              "abcdefghijklmnopqrstuvwxyz"
              "0123456789+/";
 
-
-static inline bool is_base64(unsigned char c) {
-  return (isalnum(c) || (c == '+') || (c == '/'));
-}
-
 static std::string base64_encode(unsigned char const* bytes_to_encode, 
unsigned int in_len) {
   std::string ret;
   int i = 0;
@@ -93,38 +88,65 @@ std::string base64_encode(std::string_view 
string_to_encode) {
   return base64_encode(bytes_to_encode, in_len);
 }
 
-std::string base64_decode(std::string_view encoded_string) {
+Result<std::string> base64_decode(std::string_view encoded_string) {
   size_t in_len = encoded_string.size();
   int i = 0;
-  int j = 0;
-  int in_ = 0;
+  std::string_view::size_type in_ = 0;
+  int padding_count = 0;
+  int block_padding = 0;
+  bool padding_started = false;
   unsigned char char_array_4[4], char_array_3[3];
   std::string ret;
 
-  while (in_len-- && ( encoded_string[in_] != '=') && 
is_base64(encoded_string[in_])) {
-    char_array_4[i++] = encoded_string[in_]; in_++;
-    if (i ==4) {
-      for (i = 0; i <4; i++)
-        char_array_4[i] = base64_chars.find(char_array_4[i]) & 0xff;
+  if (encoded_string.size() % 4 != 0) {
+    return Status::Invalid("Invalid base64 input: length is not a multiple of 
4");
+  }
 
-      char_array_3[0] = ( char_array_4[0] << 2       ) + ((char_array_4[1] & 
0x30) >> 4);
-      char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 
0x3c) >> 2);
-      char_array_3[2] = ((char_array_4[2] & 0x3) << 6) +   char_array_4[3];
+  while (in_len--) {
+    unsigned char c = encoded_string[in_];
 
-      for (i = 0; (i < 3); i++)
-        ret += char_array_3[i];
-      i = 0;
+    if (c == '=') {
+      padding_started = true;
+      padding_count++;
+
+      if (padding_count > 2) {
+        return Status::Invalid("Invalid base64 input: too many padding 
characters");
+      }
+
+      char_array_4[i++] = 0;
+    } else {
+      if (padding_started) {
+        return Status::Invalid("Invalid base64 input: padding in wrong 
position");
+      }
+
+      if (base64_chars.find(c) == std::string::npos) {
+        return Status::Invalid("Invalid base64 input: character is not valid 
base64 character");
+      }
+
+      char_array_4[i++] = c;
     }
-  }
 
-  if (i) {
-    for (j = 0; j < i; j++)
-      char_array_4[j] = base64_chars.find(char_array_4[j]) & 0xff;
+    in_++;
+
+    if (i == 4) {
+      for (i = 0; i < 4; i++) {
+        if (char_array_4[i] != 0) {
+          char_array_4[i] = base64_chars.find(char_array_4[i]) & 0xff;
+        }
+      }
+
+      char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 
4);
+      char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 
0x3c) >> 2);
+      char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
+
+      block_padding = padding_count;
 
-    char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4);
-    char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 
0x3c) >> 2);
+      for (i = 0; i < 3 - block_padding; i++) {
+        ret += char_array_3[i];
+      }
 
-    for (j = 0; (j < i - 1); j++) ret += char_array_3[j];
+      i = 0;
+    }
   }
 
   return ret;
diff --git a/cpp/src/gandiva/gdv_function_stubs.cc 
b/cpp/src/gandiva/gdv_function_stubs.cc
index 3eda4afadb..6fe3fa9a57 100644
--- a/cpp/src/gandiva/gdv_function_stubs.cc
+++ b/cpp/src/gandiva/gdv_function_stubs.cc
@@ -269,7 +269,15 @@ const char* gdv_fn_base64_decode_utf8(int64_t context, 
const char* in, int32_t i
     return "";
   }
   // use arrow method to decode base64 string
-  std::string decoded_str = arrow::util::base64_decode(std::string_view(in, 
in_len));
+  auto result = arrow::util::base64_decode(std::string_view(in, in_len));
+  if (!result.ok()) {
+    gdv_fn_context_set_error_msg(context, result.status().message().c_str());
+    *out_len = 0;
+    return "";
+  }
+
+  std::string decoded_str = *result;
+
   *out_len = static_cast<int32_t>(decoded_str.length());
   // allocate memory for response
   char* ret = reinterpret_cast<char*>(
diff --git a/cpp/src/parquet/arrow/fuzz_internal.cc 
b/cpp/src/parquet/arrow/fuzz_internal.cc
index dfbb8ae161..f07b8fc7c3 100644
--- a/cpp/src/parquet/arrow/fuzz_internal.cc
+++ b/cpp/src/parquet/arrow/fuzz_internal.cc
@@ -83,8 +83,11 @@ class FuzzDecryptionKeyRetriever : public 
DecryptionKeyRetriever {
     }
     // Is it a key generated by MakeEncryptionKey?
     if (key_id.starts_with(kInlineKeyPrefix)) {
-      return SecureString(
+      PARQUET_ASSIGN_OR_THROW(
+          auto decoded_key,
           
::arrow::util::base64_decode(key_id.substr(kInlineKeyPrefix.length())));
+
+      return SecureString(std::move(decoded_key));
     }
     throw ParquetException("Unknown fuzz encryption key_id");
   }
diff --git a/cpp/src/parquet/arrow/schema.cc b/cpp/src/parquet/arrow/schema.cc
index 11d5d13e4b..9c4c462c6b 100644
--- a/cpp/src/parquet/arrow/schema.cc
+++ b/cpp/src/parquet/arrow/schema.cc
@@ -953,7 +953,8 @@ Status GetOriginSchema(const std::shared_ptr<const 
KeyValueMetadata>& metadata,
   // The original Arrow schema was serialized using the store_schema option.
   // We deserialize it here and use it to inform read options such as
   // dictionary-encoded fields.
-  auto decoded = ::arrow::util::base64_decode(metadata->value(schema_index));
+  ARROW_ASSIGN_OR_RAISE(auto decoded,
+                        
::arrow::util::base64_decode(metadata->value(schema_index)));
   auto schema_buf = std::make_shared<Buffer>(decoded);
 
   ::arrow::ipc::DictionaryMemo dict_memo;
diff --git a/cpp/src/parquet/encryption/file_key_unwrapper.cc 
b/cpp/src/parquet/encryption/file_key_unwrapper.cc
index 4dc1492a0b..1cc0320137 100644
--- a/cpp/src/parquet/encryption/file_key_unwrapper.cc
+++ b/cpp/src/parquet/encryption/file_key_unwrapper.cc
@@ -19,6 +19,7 @@
 
 #include "arrow/util/utf8.h"
 
+#include "arrow/util/base64.h"
 #include "parquet/encryption/file_key_unwrapper.h"
 #include "parquet/encryption/key_metadata.h"
 
@@ -122,7 +123,10 @@ KeyWithMasterId 
FileKeyUnwrapper::GetDataEncryptionKey(const KeyMaterial& key_ma
         });
 
     // Decrypt the data key
-    std::string aad = ::arrow::util::base64_decode(encoded_kek_id);
+    PARQUET_ASSIGN_OR_THROW(auto decoded_kek,
+                            ::arrow::util::base64_decode(encoded_kek_id));
+
+    std::string aad = std::move(decoded_kek);
     data_key = internal::DecryptKeyLocally(encoded_wrapped_dek, kek_bytes, 
aad);
   }
 
diff --git a/cpp/src/parquet/encryption/key_toolkit_internal.cc 
b/cpp/src/parquet/encryption/key_toolkit_internal.cc
index d304041e3e..dd2a7e9f7d 100644
--- a/cpp/src/parquet/encryption/key_toolkit_internal.cc
+++ b/cpp/src/parquet/encryption/key_toolkit_internal.cc
@@ -52,7 +52,8 @@ std::string EncryptKeyLocally(const SecureString& key_bytes,
 
 SecureString DecryptKeyLocally(const std::string& encoded_encrypted_key,
                                const SecureString& master_key, const 
std::string& aad) {
-  std::string encrypted_key = 
::arrow::util::base64_decode(encoded_encrypted_key);
+  PARQUET_ASSIGN_OR_THROW(auto encrypted_key,
+                          ::arrow::util::base64_decode(encoded_encrypted_key));
 
   AesDecryptor key_decryptor(ParquetCipher::AES_GCM_V1,
                              static_cast<int>(master_key.size()), false,

Reply via email to