This is an automated email from the ASF dual-hosted git repository.

apitrou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new 2b4a703202 GH-39231: [C++][Compute] Add binary_slice kernel for fixed 
size binary (#39245)
2b4a703202 is described below

commit 2b4a70320232647f730b19d2fea5746c3baec752
Author: Jin Shang <[email protected]>
AuthorDate: Fri Jan 12 01:56:46 2024 +0800

    GH-39231: [C++][Compute] Add binary_slice kernel for fixed size binary 
(#39245)
    
    
    
    ### Rationale for this change
    Add binary_slice kernel for fixed size binary
    
    ### What changes are included in this PR?
    Add binary_slice kernel for fixed size binary
    
    ### Are these changes tested?
    Yes
    
    ### Are there any user-facing changes?
    No
    
    * Closes: #39231
    
    Lead-authored-by: Jin Shang <[email protected]>
    Co-authored-by: Antoine Pitrou <[email protected]>
    Signed-off-by: Antoine Pitrou <[email protected]>
---
 .../arrow/compute/kernels/scalar_string_ascii.cc   | 117 ++++++++++++-----
 .../arrow/compute/kernels/scalar_string_internal.h |   2 +
 .../arrow/compute/kernels/scalar_string_test.cc    | 146 +++++++++++++++++++--
 python/pyarrow/tests/test_compute.py               |  10 +-
 4 files changed, 233 insertions(+), 42 deletions(-)

diff --git a/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc 
b/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc
index 6764845dfc..8fdc6172aa 100644
--- a/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc
@@ -95,7 +95,7 @@ struct FixedSizeBinaryTransformExecBase {
                           ctx->Allocate(output_width * input_nstrings));
     uint8_t* output_str = values_buffer->mutable_data();
 
-    const uint8_t* input_data = input.GetValues<uint8_t>(1);
+    const uint8_t* input_data = input.GetValues<uint8_t>(1, input.offset * 
input_width);
     for (int64_t i = 0; i < input_nstrings; i++) {
       if (!input.IsNull(i)) {
         const uint8_t* input_string = input_data + i * input_width;
@@ -132,7 +132,8 @@ struct FixedSizeBinaryTransformExecWithState
     DCHECK_EQ(1, types.size());
     const auto& options = State::Get(ctx);
     const int32_t input_width = types[0].type->byte_width();
-    const int32_t output_width = StringTransform::FixedOutputSize(options, 
input_width);
+    ARROW_ASSIGN_OR_RAISE(const int32_t output_width,
+                          StringTransform::FixedOutputSize(options, 
input_width));
     return fixed_size_binary(output_width);
   }
 };
@@ -2377,7 +2378,8 @@ struct BinaryReplaceSliceTransform : 
ReplaceStringSliceTransformBase {
     return output - output_start;
   }
 
-  static int32_t FixedOutputSize(const ReplaceSliceOptions& opts, int32_t 
input_width) {
+  static Result<int32_t> FixedOutputSize(const ReplaceSliceOptions& opts,
+                                         int32_t input_width) {
     int32_t before_slice = 0;
     int32_t after_slice = 0;
     const int32_t start = static_cast<int32_t>(opts.start);
@@ -2436,6 +2438,7 @@ void AddAsciiStringReplaceSlice(FunctionRegistry* 
registry) {
 
 namespace {
 struct SliceBytesTransform : StringSliceTransformBase {
+  using StringSliceTransformBase::StringSliceTransformBase;
   int64_t MaxCodeunits(int64_t ninputs, int64_t input_bytes) override {
     const SliceOptions& opt = *this->options;
     if ((opt.start >= 0) != (opt.stop >= 0)) {
@@ -2454,22 +2457,15 @@ struct SliceBytesTransform : StringSliceTransformBase {
     return SliceBackward(input, input_string_bytes, output);
   }
 
-  int64_t SliceForward(const uint8_t* input, int64_t input_string_bytes,
-                       uint8_t* output) {
-    // Slice in forward order (step > 0)
-    const SliceOptions& opt = *this->options;
-    const uint8_t* begin = input;
-    const uint8_t* end = input + input_string_bytes;
-    const uint8_t* begin_sliced;
-    const uint8_t* end_sliced;
-
-    if (!input_string_bytes) {
-      return 0;
-    }
-    // First, compute begin_sliced and end_sliced
+  static std::pair<int64_t, int64_t> SliceForwardRange(const SliceOptions& opt,
+                                                       int64_t 
input_string_bytes) {
+    int64_t begin = 0;
+    int64_t end = input_string_bytes;
+    int64_t begin_sliced = 0;
+    int64_t end_sliced = 0;
     if (opt.start >= 0) {
       // start counting from the left
-      begin_sliced = std::min(begin + opt.start, end);
+      begin_sliced = std::min(opt.start, end);
       if (opt.stop > opt.start) {
         // continue counting from begin_sliced
         const int64_t length = opt.stop - opt.start;
@@ -2479,7 +2475,7 @@ struct SliceBytesTransform : StringSliceTransformBase {
         end_sliced = std::max(end + opt.stop, begin_sliced);
       } else {
         // zero length slice
-        return 0;
+        return {0, 0};
       }
     } else {
       // start counting from the right
@@ -2491,7 +2487,7 @@ struct SliceBytesTransform : StringSliceTransformBase {
         // and therefore we also need this
         if (end_sliced <= begin_sliced) {
           // zero length slice
-          return 0;
+          return {0, 0};
         }
       } else if ((opt.stop < 0) && (opt.stop > opt.start)) {
         // stop is negative, but larger than start, so we count again from the 
right
@@ -2501,12 +2497,30 @@ struct SliceBytesTransform : StringSliceTransformBase {
         end_sliced = std::max(end + opt.stop, begin_sliced);
       } else {
         // zero length slice
-        return 0;
+        return {0, 0};
       }
     }
+    return {begin_sliced, end_sliced};
+  }
+
+  int64_t SliceForward(const uint8_t* input, int64_t input_string_bytes,
+                       uint8_t* output) {
+    // Slice in forward order (step > 0)
+    if (!input_string_bytes) {
+      return 0;
+    }
+
+    const SliceOptions& opt = *this->options;
+    auto [begin_index, end_index] = SliceForwardRange(opt, input_string_bytes);
+    const uint8_t* begin_sliced = input + begin_index;
+    const uint8_t* end_sliced = input + end_index;
+
+    if (begin_sliced == end_sliced) {
+      return 0;
+    }
 
     // Second, copy computed slice to output
-    DCHECK(begin_sliced <= end_sliced);
+    DCHECK(begin_sliced < end_sliced);
     if (opt.step == 1) {
       // fast case, where we simply can finish with a memcpy
       std::copy(begin_sliced, end_sliced, output);
@@ -2525,18 +2539,13 @@ struct SliceBytesTransform : StringSliceTransformBase {
     return dest - output;
   }
 
-  int64_t SliceBackward(const uint8_t* input, int64_t input_string_bytes,
-                        uint8_t* output) {
+  static std::pair<int64_t, int64_t> SliceBackwardRange(const SliceOptions& 
opt,
+                                                        int64_t 
input_string_bytes) {
     // Slice in reverse order (step < 0)
-    const SliceOptions& opt = *this->options;
-    const uint8_t* begin = input;
-    const uint8_t* end = input + input_string_bytes;
-    const uint8_t* begin_sliced = begin;
-    const uint8_t* end_sliced = end;
-
-    if (!input_string_bytes) {
-      return 0;
-    }
+    int64_t begin = 0;
+    int64_t end = input_string_bytes;
+    int64_t begin_sliced = begin;
+    int64_t end_sliced = end;
 
     if (opt.start >= 0) {
       // +1 because begin_sliced acts as as the end of a reverse iterator
@@ -2555,6 +2564,28 @@ struct SliceBytesTransform : StringSliceTransformBase {
     }
     end_sliced--;
 
+    if (begin_sliced <= end_sliced) {
+      // zero length slice
+      return {0, 0};
+    }
+
+    return {begin_sliced, end_sliced};
+  }
+
+  int64_t SliceBackward(const uint8_t* input, int64_t input_string_bytes,
+                        uint8_t* output) {
+    if (!input_string_bytes) {
+      return 0;
+    }
+
+    const SliceOptions& opt = *this->options;
+    auto [begin_index, end_index] = SliceBackwardRange(opt, 
input_string_bytes);
+    const uint8_t* begin_sliced = input + begin_index;
+    const uint8_t* end_sliced = input + end_index;
+
+    if (begin_sliced == end_sliced) {
+      return 0;
+    }
     // Copy computed slice to output
     uint8_t* dest = output;
     const uint8_t* i = begin_sliced;
@@ -2568,6 +2599,22 @@ struct SliceBytesTransform : StringSliceTransformBase {
 
     return dest - output;
   }
+
+  static Result<int32_t> FixedOutputSize(SliceOptions options, int32_t 
input_width_32) {
+    auto step = options.step;
+    if (step == 0) {
+      return Status::Invalid("Slice step cannot be zero");
+    }
+    if (step > 0) {
+      // forward slice
+      auto [begin_index, end_index] = SliceForwardRange(options, 
input_width_32);
+      return static_cast<int32_t>((end_index - begin_index + step - 1) / step);
+    } else {
+      // backward slice
+      auto [begin_index, end_index] = SliceBackwardRange(options, 
input_width_32);
+      return static_cast<int32_t>((end_index - begin_index + step + 1) / step);
+    }
+  }
 };
 
 template <typename Type>
@@ -2594,6 +2641,12 @@ void AddAsciiStringSlice(FunctionRegistry* registry) {
     DCHECK_OK(
         func->AddKernel({ty}, ty, std::move(exec), 
SliceBytesTransform::State::Init));
   }
+  using TransformExec = 
FixedSizeBinaryTransformExecWithState<SliceBytesTransform>;
+  ScalarKernel fsb_kernel({InputType(Type::FIXED_SIZE_BINARY)},
+                          OutputType(TransformExec::OutputType), 
TransformExec::Exec,
+                          StringSliceTransformBase::State::Init);
+  fsb_kernel.mem_allocation = MemAllocation::NO_PREALLOCATE;
+  DCHECK_OK(func->AddKernel(std::move(fsb_kernel)));
   DCHECK_OK(registry->AddFunction(std::move(func)));
 }
 
diff --git a/cpp/src/arrow/compute/kernels/scalar_string_internal.h 
b/cpp/src/arrow/compute/kernels/scalar_string_internal.h
index 7a5d5a7c86..6723d11c8d 100644
--- a/cpp/src/arrow/compute/kernels/scalar_string_internal.h
+++ b/cpp/src/arrow/compute/kernels/scalar_string_internal.h
@@ -250,6 +250,8 @@ struct StringSliceTransformBase : public 
StringTransformBase {
   using State = OptionsWrapper<SliceOptions>;
 
   const SliceOptions* options;
+  StringSliceTransformBase() = default;
+  explicit StringSliceTransformBase(const SliceOptions& options) : 
options{&options} {}
 
   Status PreExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) 
override {
     options = &State::Get(ctx);
diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc 
b/cpp/src/arrow/compute/kernels/scalar_string_test.cc
index 5dec16d89e..d7e35d0733 100644
--- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc
@@ -33,10 +33,10 @@
 #include "arrow/compute/kernels/test_util.h"
 #include "arrow/testing/gtest_util.h"
 #include "arrow/type.h"
+#include "arrow/type_fwd.h"
 #include "arrow/util/value_parsing.h"
 
-namespace arrow {
-namespace compute {
+namespace arrow::compute {
 
 // interesting utf8 characters for testing (lower case / upper case):
 //  * ῦ / Υ͂ (3 to 4 code units) (Note, we don't support this yet, utf8proc 
does not use
@@ -712,11 +712,140 @@ TEST_F(TestFixedSizeBinaryKernels, BinaryLength) {
              "[6, null, 6]");
 }
 
+TEST_F(TestFixedSizeBinaryKernels, BinarySliceEmpty) {
+  SliceOptions options{2, 4};
+  CheckScalarUnary("binary_slice", ArrayFromJSON(fixed_size_binary(0), 
R"([""])"),
+                   ArrayFromJSON(fixed_size_binary(0), R"([""])"), &options);
+
+  CheckScalarUnary("binary_slice",
+                   ArrayFromJSON(fixed_size_binary(0), R"(["", null, ""])"),
+                   ArrayFromJSON(fixed_size_binary(0), R"(["", null, ""])"), 
&options);
+
+  CheckUnary("binary_slice", R"([null, null])", fixed_size_binary(2), 
R"([null, null])",
+             &options);
+}
+
+TEST_F(TestFixedSizeBinaryKernels, BinarySliceBasic) {
+  SliceOptions options{2, 4};
+  CheckUnary("binary_slice", R"(["abcdef", null, "foobaz"])", 
fixed_size_binary(2),
+             R"(["cd", null, "ob"])", &options);
+
+  SliceOptions options_edgecase_1{-3, 1};
+  CheckUnary("binary_slice", R"(["abcdef", "foobaz"])", fixed_size_binary(0),
+             R"(["", ""])", &options_edgecase_1);
+
+  SliceOptions options_edgecase_2{-10, -3};
+  CheckUnary("binary_slice", R"(["abcdef", "foobaz", null])", 
fixed_size_binary(3),
+             R"(["abc", "foo", null])", &options_edgecase_2);
+
+  auto input = ArrayFromJSON(this->type(), R"(["foobaz"])");
+  EXPECT_RAISES_WITH_MESSAGE_THAT(
+      Invalid,
+      testing::HasSubstr("Function 'binary_slice' cannot be called without 
options"),
+      CallFunction("binary_slice", {input}));
+
+  SliceOptions options_invalid{2, 4, 0};
+  EXPECT_RAISES_WITH_MESSAGE_THAT(
+      Invalid, testing::HasSubstr("Slice step cannot be zero"),
+      CallFunction("binary_slice", {input}, &options_invalid));
+}
+
+TEST_F(TestFixedSizeBinaryKernels, BinarySlicePosPos) {
+  SliceOptions options_step{1, 5, 2};
+  CheckUnary("binary_slice", R"([null, "abcdef", "foobaz"])", 
fixed_size_binary(2),
+             R"([null, "bd", "ob"])", &options_step);
+
+  SliceOptions options_step_neg{5, 0, -2};
+  CheckUnary("binary_slice", R"(["abcdef", "foobaz"])", fixed_size_binary(3),
+             R"(["fdb", "zbo"])", &options_step_neg);
+}
+
+TEST_F(TestFixedSizeBinaryKernels, BinarySlicePosNeg) {
+  SliceOptions options{2, -1};
+  CheckUnary("binary_slice", R"(["abcdef", "foobaz"])", fixed_size_binary(3),
+             R"(["cde", "oba"])", &options);
+
+  SliceOptions options_step{1, -1, 2};
+  CheckUnary("binary_slice", R"(["abcdef", null, "foobaz"])", 
fixed_size_binary(2),
+             R"(["bd", null, "ob"])", &options_step);
+
+  SliceOptions options_step_neg{5, -4, -2};
+  CheckUnary("binary_slice", R"(["abcdef", "foobaz"])", fixed_size_binary(2),
+             R"(["fd", "zb"])", &options_step_neg);
+
+  options_step_neg.stop = -6;
+  CheckUnary("binary_slice", R"(["abcdef", "foobaz"])", fixed_size_binary(3),
+             R"(["fdb", "zbo"])", &options_step_neg);
+}
+
+TEST_F(TestFixedSizeBinaryKernels, BinarySliceNegNeg) {
+  SliceOptions options{-2, -1};
+  CheckUnary("binary_slice", R"(["abcdef", "foobaz"])", fixed_size_binary(1),
+             R"(["e", "a"])", &options);
+
+  SliceOptions options_step{-4, -1, 2};
+  CheckUnary("binary_slice", R"(["abcdef", "foobaz", null, null])", 
fixed_size_binary(2),
+             R"(["ce", "oa", null, null])", &options_step);
+
+  SliceOptions options_step_neg{-1, -3, -2};
+  CheckUnary("binary_slice", R"([null, "abcdef", null, "foobaz"])", 
fixed_size_binary(1),
+             R"([null, "f", null, "z"])", &options_step_neg);
+
+  options_step_neg.stop = -4;
+  CheckUnary("binary_slice", R"(["abcdef", "foobaz"])", fixed_size_binary(2),
+             R"(["fd", "zb"])", &options_step_neg);
+}
+
+TEST_F(TestFixedSizeBinaryKernels, BinarySliceNegPos) {
+  SliceOptions options{-2, 4};
+  CheckUnary("binary_slice", R"(["abcdef", "foobaz"])", fixed_size_binary(0),
+             R"(["", ""])", &options);
+
+  SliceOptions options_step{-4, 5, 2};
+  CheckUnary("binary_slice", R"(["abcdef", "foobaz"])", fixed_size_binary(2),
+             R"(["ce", "oa"])", &options_step);
+
+  SliceOptions options_step_neg{-1, 1, -2};
+  CheckUnary("binary_slice", R"([null, "abcdef", "foobaz", null])", 
fixed_size_binary(2),
+             R"([null, "fd", "zb", null])", &options_step_neg);
+
+  options_step_neg.stop = 0;
+  CheckUnary("binary_slice", R"(["abcdef", "foobaz"])", fixed_size_binary(3),
+             R"(["fdb", "zbo"])", &options_step_neg);
+}
+
+TEST_F(TestFixedSizeBinaryKernels, BinarySliceConsistentyWithVarLenBinary) {
+  std::string source_str = "abcdef";
+  for (size_t str_len = 0; str_len < source_str.size(); ++str_len) {
+    auto input_str = source_str.substr(0, str_len);
+    auto fixed_input = 
ArrayFromJSON(fixed_size_binary(static_cast<int32_t>(str_len)),
+                                     R"([")" + input_str + R"("])");
+    auto varlen_input = ArrayFromJSON(binary(), R"([")" + input_str + R"("])");
+    for (auto start = -6; start <= 6; ++start) {
+      for (auto stop = -6; stop <= 6; ++stop) {
+        for (auto step = -3; step <= 4; ++step) {
+          if (step == 0) {
+            continue;
+          }
+          SliceOptions options{start, stop, step};
+          auto expected =
+              CallFunction("binary_slice", {varlen_input}, 
&options).ValueOrDie();
+          auto actual =
+              CallFunction("binary_slice", {fixed_input}, 
&options).ValueOrDie();
+          actual = Cast(actual, binary()).ValueOrDie();
+          ASSERT_OK(actual.make_array()->ValidateFull());
+          AssertDatumsEqual(expected, actual);
+        }
+      }
+    }
+  }
+}
+
 TEST_F(TestFixedSizeBinaryKernels, BinaryReplaceSlice) {
   ReplaceSliceOptions options{0, 1, "XX"};
   CheckUnary("binary_replace_slice", "[]", fixed_size_binary(7), "[]", 
&options);
-  CheckUnary("binary_replace_slice", R"([null, "abcdef"])", 
fixed_size_binary(7),
-             R"([null, "XXbcdef"])", &options);
+  CheckUnary("binary_replace_slice", R"(["foobaz", null, "abcdef"])",
+             fixed_size_binary(7), R"(["XXoobaz", null, "XXbcdef"])", 
&options);
 
   ReplaceSliceOptions options_shrink{0, 2, ""};
   CheckUnary("binary_replace_slice", R"([null, "abcdef"])", 
fixed_size_binary(4),
@@ -731,8 +860,8 @@ TEST_F(TestFixedSizeBinaryKernels, BinaryReplaceSlice) {
              R"([null, "abXXef"])", &options_middle);
 
   ReplaceSliceOptions options_neg_start{-3, -2, "XX"};
-  CheckUnary("binary_replace_slice", R"([null, "abcdef"])", 
fixed_size_binary(7),
-             R"([null, "abcXXef"])", &options_neg_start);
+  CheckUnary("binary_replace_slice", R"(["foobaz", null, "abcdef"])",
+             fixed_size_binary(7), R"(["fooXXaz", null, "abcXXef"])", 
&options_neg_start);
 
   ReplaceSliceOptions options_neg_end{2, -2, "XX"};
   CheckUnary("binary_replace_slice", R"([null, "abcdef"])", 
fixed_size_binary(6),
@@ -807,7 +936,7 @@ TEST_F(TestFixedSizeBinaryKernels, 
CountSubstringIgnoreCase) {
       offset_type(), "[0, null, 0, 1, 1, 1, 2, 2, 1]", &options);
 
   MatchSubstringOptions options_empty{"", /*ignore_case=*/true};
-  CheckUnary("count_substring", R"(["      ", null, "abcABc"])", offset_type(),
+  CheckUnary("count_substring", R"(["      ", null, "abcdef"])", offset_type(),
              "[7, null, 7]", &options_empty);
 }
 
@@ -2382,5 +2511,4 @@ TEST(TestStringKernels, UnicodeLibraryAssumptions) {
 }
 #endif
 
-}  // namespace compute
-}  // namespace arrow
+}  // namespace arrow::compute
diff --git a/python/pyarrow/tests/test_compute.py 
b/python/pyarrow/tests/test_compute.py
index 7c5a134d33..d1eb605c71 100644
--- a/python/pyarrow/tests/test_compute.py
+++ b/python/pyarrow/tests/test_compute.py
@@ -561,7 +561,8 @@ def test_slice_compatibility():
 
 
 def test_binary_slice_compatibility():
-    arr = pa.array([b"", b"a", b"a\xff", b"ab\x00", b"abc\xfb", b"ab\xf2de"])
+    data = [b"", b"a", b"a\xff", b"ab\x00", b"abc\xfb", b"ab\xf2de"]
+    arr = pa.array(data)
     for start, stop, step in itertools.product(range(-6, 6),
                                                range(-6, 6),
                                                range(-3, 4)):
@@ -574,6 +575,13 @@ def test_binary_slice_compatibility():
         assert expected.equals(result)
         # Positional options
         assert pc.binary_slice(arr, start, stop, step) == result
+        # Fixed size binary input / output
+        for item in data:
+            fsb_scalar = pa.scalar(item, type=pa.binary(len(item)))
+            expected = item[start:stop:step]
+            actual = pc.binary_slice(fsb_scalar, start, stop, step)
+            assert actual.type == pa.binary(len(expected))
+            assert actual.as_py() == expected
 
 
 def test_split_pattern():

Reply via email to