pitrou commented on a change in pull request #10494: URL: https://github.com/apache/arrow/pull/10494#discussion_r648955618
########## File path: cpp/src/arrow/compute/kernels/scalar_string.cc ########## @@ -2288,6 +2288,164 @@ const FunctionDoc replace_substring_regex_doc( {"strings"}, "ReplaceSubstringOptions"); #endif +// ---------------------------------------------------------------------- +// Replace slice + +struct ReplaceSliceTransformBase : public StringTransformBase { + using State = OptionsWrapper<ReplaceSliceOptions>; + + const ReplaceSliceOptions* options; + + explicit ReplaceSliceTransformBase(const ReplaceSliceOptions& options) + : options{&options} {} + + int64_t MaxCodeunits(int64_t ninputs, int64_t input_ncodeunits) override { + return ninputs * options->replacement.size() + input_ncodeunits; + } +}; + +struct AsciiReplaceSliceTransform : ReplaceSliceTransformBase { + using ReplaceSliceTransformBase::ReplaceSliceTransformBase; + int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits, + uint8_t* output) { + const auto& opts = *options; + int64_t before_slice = 0; + int64_t after_slice = 0; + uint8_t* output_start = output; + + if (opts.start >= 0) { + // Count from left + before_slice = std::min<int64_t>(input_string_ncodeunits, opts.start); + } else { + // Count from right + before_slice = std::max<int64_t>(0, input_string_ncodeunits + opts.start); + } + // Mimic Pandas: if stop would be before start, treat as 0-length slice + if (opts.stop >= 0) { + // Count from left + after_slice = + std::min<int64_t>(input_string_ncodeunits, std::max(opts.start, opts.stop)); + } else { + // Count from right + after_slice = std::max<int64_t>(before_slice, input_string_ncodeunits + opts.stop); + } + output = std::copy(input, input + before_slice, output); + output = std::copy(opts.replacement.begin(), opts.replacement.end(), output); + output = std::copy(input + after_slice, input + input_string_ncodeunits, output); + return std::distance(output_start, output); + } +}; + +struct Utf8ReplaceSliceTransform : ReplaceSliceTransformBase { + using ReplaceSliceTransformBase::ReplaceSliceTransformBase; + int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits, + uint8_t* output) { + const auto& opts = *options; + const uint8_t* begin = input; + const uint8_t* end = input + input_string_ncodeunits; + const uint8_t *begin_sliced, *end_sliced; + uint8_t* output_start = output; + + // Mimic Pandas: if stop would be before start, treat as 0-length slice + if (opts.start >= 0) { + // Count from left + if (!arrow::util::UTF8AdvanceCodepoints(begin, end, &begin_sliced, opts.start)) { + return kTransformError; + } + if (opts.stop > options->start) { + // Continue counting from left + const int64_t length = opts.stop - options->start; + if (!arrow::util::UTF8AdvanceCodepoints(begin_sliced, end, &end_sliced, length)) { + return kTransformError; + } + } else if (opts.stop < 0) { + // Count from right + if (!arrow::util::UTF8AdvanceCodepointsReverse(begin_sliced, end, &end_sliced, + -opts.stop)) { + return kTransformError; + } + } else { + // Zero-length slice + end_sliced = begin_sliced; + } + } else { + // Count from right + if (!arrow::util::UTF8AdvanceCodepointsReverse(begin, end, &begin_sliced, + -opts.start)) { + return kTransformError; + } + if (opts.stop >= 0) { + // Restart counting from left + if (!arrow::util::UTF8AdvanceCodepoints(begin, end, &end_sliced, opts.stop)) { + return kTransformError; + } + if (end_sliced <= begin_sliced) { + // Zero-length slice + end_sliced = begin_sliced; + } + } else if ((opts.stop < 0) && (options->stop > options->start)) { + // Count from right + if (!arrow::util::UTF8AdvanceCodepointsReverse(begin_sliced, end, &end_sliced, + -opts.stop)) { + return kTransformError; + } + } else { + // zero-length slice + end_sliced = begin_sliced; + } + } + output = std::copy(begin, begin_sliced, output); + output = std::copy(opts.replacement.begin(), options->replacement.end(), output); + output = std::copy(end_sliced, end, output); + return std::distance(output_start, output); + } +}; + +template <typename Type> +using AsciiReplaceSlice = StringTransformExecWithState<Type, AsciiReplaceSliceTransform>; +template <typename Type> +using Utf8ReplaceSlice = StringTransformExecWithState<Type, Utf8ReplaceSliceTransform>; + +const FunctionDoc ascii_replace_slice_doc( + "Replace a slice of a string with `replacement`", + ("For each string in `strings`, replace a slice of the string defined by `start`" + "and `stop` with `replacement`. `start` is inclusive and `stop` is exclusive, " + "and both are measured in bytes.\n" + "Null values emit null."), + {"strings"}, "ReplaceSliceOptions"); + +const FunctionDoc utf8_replace_slice_doc( + "Replace a slice of a string with `replacement`", + ("For each string in `strings`, replace a slice of the string defined by `start`" + "and `stop` with `replacement`. `start` is inclusive and `stop` is exclusive, " + "and both are measured in codeunits.\n" + "Null values emit null."), + {"strings"}, "ReplaceSliceOptions"); + +void AddReplaceSlice(FunctionRegistry* registry) { + { + auto func = std::make_shared<ScalarFunction>("ascii_replace_slice", Arity::Unary(), + &ascii_replace_slice_doc); Review comment: Should probably be called `binary_replace_slice` since it works on non-Ascii input as well (it just slices in byte units, not codeunits). ########## File path: cpp/src/arrow/compute/kernels/scalar_string.cc ########## @@ -2288,6 +2288,164 @@ const FunctionDoc replace_substring_regex_doc( {"strings"}, "ReplaceSubstringOptions"); #endif +// ---------------------------------------------------------------------- +// Replace slice + +struct ReplaceSliceTransformBase : public StringTransformBase { + using State = OptionsWrapper<ReplaceSliceOptions>; + + const ReplaceSliceOptions* options; + + explicit ReplaceSliceTransformBase(const ReplaceSliceOptions& options) + : options{&options} {} + + int64_t MaxCodeunits(int64_t ninputs, int64_t input_ncodeunits) override { + return ninputs * options->replacement.size() + input_ncodeunits; + } +}; + +struct AsciiReplaceSliceTransform : ReplaceSliceTransformBase { + using ReplaceSliceTransformBase::ReplaceSliceTransformBase; + int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits, + uint8_t* output) { + const auto& opts = *options; + int64_t before_slice = 0; + int64_t after_slice = 0; + uint8_t* output_start = output; + + if (opts.start >= 0) { + // Count from left + before_slice = std::min<int64_t>(input_string_ncodeunits, opts.start); + } else { + // Count from right + before_slice = std::max<int64_t>(0, input_string_ncodeunits + opts.start); + } + // Mimic Pandas: if stop would be before start, treat as 0-length slice + if (opts.stop >= 0) { + // Count from left + after_slice = + std::min<int64_t>(input_string_ncodeunits, std::max(opts.start, opts.stop)); + } else { + // Count from right + after_slice = std::max<int64_t>(before_slice, input_string_ncodeunits + opts.stop); + } + output = std::copy(input, input + before_slice, output); + output = std::copy(opts.replacement.begin(), opts.replacement.end(), output); + output = std::copy(input + after_slice, input + input_string_ncodeunits, output); + return std::distance(output_start, output); Review comment: This looks a bit pedantic. Just `output - output_start`? ########## File path: cpp/src/arrow/compute/kernels/scalar_string.cc ########## @@ -2288,6 +2288,164 @@ const FunctionDoc replace_substring_regex_doc( {"strings"}, "ReplaceSubstringOptions"); #endif +// ---------------------------------------------------------------------- +// Replace slice + +struct ReplaceSliceTransformBase : public StringTransformBase { + using State = OptionsWrapper<ReplaceSliceOptions>; + + const ReplaceSliceOptions* options; + + explicit ReplaceSliceTransformBase(const ReplaceSliceOptions& options) + : options{&options} {} + + int64_t MaxCodeunits(int64_t ninputs, int64_t input_ncodeunits) override { + return ninputs * options->replacement.size() + input_ncodeunits; + } +}; + +struct AsciiReplaceSliceTransform : ReplaceSliceTransformBase { + using ReplaceSliceTransformBase::ReplaceSliceTransformBase; + int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits, + uint8_t* output) { + const auto& opts = *options; + int64_t before_slice = 0; + int64_t after_slice = 0; + uint8_t* output_start = output; + + if (opts.start >= 0) { + // Count from left + before_slice = std::min<int64_t>(input_string_ncodeunits, opts.start); + } else { + // Count from right + before_slice = std::max<int64_t>(0, input_string_ncodeunits + opts.start); + } + // Mimic Pandas: if stop would be before start, treat as 0-length slice + if (opts.stop >= 0) { + // Count from left + after_slice = + std::min<int64_t>(input_string_ncodeunits, std::max(opts.start, opts.stop)); + } else { + // Count from right + after_slice = std::max<int64_t>(before_slice, input_string_ncodeunits + opts.stop); + } + output = std::copy(input, input + before_slice, output); + output = std::copy(opts.replacement.begin(), opts.replacement.end(), output); + output = std::copy(input + after_slice, input + input_string_ncodeunits, output); + return std::distance(output_start, output); + } +}; + +struct Utf8ReplaceSliceTransform : ReplaceSliceTransformBase { + using ReplaceSliceTransformBase::ReplaceSliceTransformBase; + int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits, + uint8_t* output) { + const auto& opts = *options; + const uint8_t* begin = input; + const uint8_t* end = input + input_string_ncodeunits; + const uint8_t *begin_sliced, *end_sliced; + uint8_t* output_start = output; + + // Mimic Pandas: if stop would be before start, treat as 0-length slice + if (opts.start >= 0) { + // Count from left + if (!arrow::util::UTF8AdvanceCodepoints(begin, end, &begin_sliced, opts.start)) { + return kTransformError; + } + if (opts.stop > options->start) { + // Continue counting from left + const int64_t length = opts.stop - options->start; + if (!arrow::util::UTF8AdvanceCodepoints(begin_sliced, end, &end_sliced, length)) { + return kTransformError; + } + } else if (opts.stop < 0) { + // Count from right + if (!arrow::util::UTF8AdvanceCodepointsReverse(begin_sliced, end, &end_sliced, + -opts.stop)) { + return kTransformError; + } + } else { + // Zero-length slice + end_sliced = begin_sliced; + } + } else { + // Count from right + if (!arrow::util::UTF8AdvanceCodepointsReverse(begin, end, &begin_sliced, + -opts.start)) { + return kTransformError; + } + if (opts.stop >= 0) { + // Restart counting from left + if (!arrow::util::UTF8AdvanceCodepoints(begin, end, &end_sliced, opts.stop)) { + return kTransformError; + } + if (end_sliced <= begin_sliced) { + // Zero-length slice + end_sliced = begin_sliced; + } + } else if ((opts.stop < 0) && (options->stop > options->start)) { + // Count from right + if (!arrow::util::UTF8AdvanceCodepointsReverse(begin_sliced, end, &end_sliced, + -opts.stop)) { + return kTransformError; + } + } else { + // zero-length slice + end_sliced = begin_sliced; + } + } + output = std::copy(begin, begin_sliced, output); + output = std::copy(opts.replacement.begin(), options->replacement.end(), output); + output = std::copy(end_sliced, end, output); + return std::distance(output_start, output); + } +}; + +template <typename Type> +using AsciiReplaceSlice = StringTransformExecWithState<Type, AsciiReplaceSliceTransform>; +template <typename Type> +using Utf8ReplaceSlice = StringTransformExecWithState<Type, Utf8ReplaceSliceTransform>; + +const FunctionDoc ascii_replace_slice_doc( + "Replace a slice of a string with `replacement`", Review comment: "binary string"? ########## File path: cpp/src/arrow/compute/api_scalar.h ########## @@ -77,6 +77,18 @@ struct ARROW_EXPORT SplitPatternOptions : public SplitOptions { std::string pattern; }; +struct ARROW_EXPORT ReplaceSliceOptions : public FunctionOptions { + explicit ReplaceSliceOptions(int64_t start, int64_t stop, std::string replacement) + : start(start), stop(stop), replacement(std::move(replacement)) {} + + /// Index to start slicing at + int64_t start = 0; + /// Index to stop slicing at + int64_t stop = std::numeric_limits<int64_t>::max(); Review comment: Hmm... I'm not sure the default values will be picked up. Is it just for documentation? ########## File path: python/pyarrow/tests/test_compute.py ########## @@ -693,6 +693,24 @@ def test_string_py_compat_boolean(function_name, variant): assert arrow_func(ar)[0].as_py() == getattr(c, py_name)() +def test_replace_slice(): + arr = pa.array([None, '', 'a', 'ab', 'abc', 'abcd']) + res = pc.ascii_replace_slice(arr, start=1, stop=3, replacement='XX') + assert res.tolist() == [None, 'XX', 'aXX', 'aXX', 'aXX', 'aXXd'] + res = pc.ascii_replace_slice(arr, start=-2, stop=3, replacement='XX') + assert res.tolist() == [None, 'XX', 'XX', 'XX', 'aXX', 'abXXd'] + res = pc.ascii_replace_slice(arr, start=-3, stop=-2, replacement='XX') + assert res.tolist() == [None, 'XX', 'XXa', 'XXab', 'XXbc', 'aXXcd'] + + arr = pa.array([None, '', 'π', 'πb', 'πbθ', 'πbθd']) + res = pc.utf8_replace_slice(arr, start=1, stop=3, replacement='χχ') + assert res.tolist() == [None, 'χχ', 'πχχ', 'πχχ', 'πχχ', 'πχχd'] + res = pc.utf8_replace_slice(arr, start=-2, stop=3, replacement='χχ') + assert res.tolist() == [None, 'χχ', 'χχ', 'χχ', 'πχχ', 'πbχχd'] + res = pc.utf8_replace_slice(arr, start=-3, stop=-2, replacement='χχ') + assert res.tolist() == [None, 'χχ', 'χχπ', 'χχπb', 'χχbθ', 'πχχθd'] Review comment: It's not really useful to re-write the same tests in Python. What you could do is generate slices as in `test_slice_compatibility` to check Pandas compatibility. ########## File path: cpp/src/arrow/compute/kernels/scalar_string.cc ########## @@ -2288,6 +2288,164 @@ const FunctionDoc replace_substring_regex_doc( {"strings"}, "ReplaceSubstringOptions"); #endif +// ---------------------------------------------------------------------- +// Replace slice + +struct ReplaceSliceTransformBase : public StringTransformBase { + using State = OptionsWrapper<ReplaceSliceOptions>; + + const ReplaceSliceOptions* options; + + explicit ReplaceSliceTransformBase(const ReplaceSliceOptions& options) + : options{&options} {} + + int64_t MaxCodeunits(int64_t ninputs, int64_t input_ncodeunits) override { + return ninputs * options->replacement.size() + input_ncodeunits; + } +}; + +struct AsciiReplaceSliceTransform : ReplaceSliceTransformBase { + using ReplaceSliceTransformBase::ReplaceSliceTransformBase; + int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits, + uint8_t* output) { + const auto& opts = *options; + int64_t before_slice = 0; + int64_t after_slice = 0; + uint8_t* output_start = output; + + if (opts.start >= 0) { + // Count from left + before_slice = std::min<int64_t>(input_string_ncodeunits, opts.start); + } else { + // Count from right + before_slice = std::max<int64_t>(0, input_string_ncodeunits + opts.start); + } + // Mimic Pandas: if stop would be before start, treat as 0-length slice + if (opts.stop >= 0) { + // Count from left + after_slice = + std::min<int64_t>(input_string_ncodeunits, std::max(opts.start, opts.stop)); Review comment: `opts.start` can be negative, so you should perhaps use `before_slice` instead. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org