This is an automated email from the ASF dual-hosted git repository.
apitrou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 049411524a GH-44615: [C++][Compute] Add extract_regex_span function
(#45577)
049411524a is described below
commit 049411524adfc5be0d8b5a25986e920394cc7be4
Author: Arash Andishgar <[email protected]>
AuthorDate: Tue Mar 11 20:31:04 2025 +0330
GH-44615: [C++][Compute] Add extract_regex_span function (#45577)
### Rationale for this change
While the `extract_regex` function returns substrings of the matching regex
captures, `extract_regex_span` returns (index, length) pairs of these
substrings relative to the original string values.
### Are these changes tested?
Yes, by dedicated unit tests.
### Are there any user-facing changes?
No, except a new compute function.
* GitHub Issue: #44615
Lead-authored-by: arash andishgar <[email protected]>
Co-authored-by: Antoine Pitrou <[email protected]>
Signed-off-by: Antoine Pitrou <[email protected]>
---
cpp/src/arrow/compute/api_scalar.cc | 10 +
cpp/src/arrow/compute/api_scalar.h | 10 +
.../arrow/compute/kernels/scalar_string_ascii.cc | 205 ++++++++++++++++++---
.../arrow/compute/kernels/scalar_string_test.cc | 83 +++++++++
docs/source/cpp/compute.rst | 19 +-
python/pyarrow/_compute.pyx | 19 ++
python/pyarrow/compute.py | 1 +
python/pyarrow/includes/libarrow.pxd | 5 +
python/pyarrow/tests/test_compute.py | 11 ++
9 files changed, 329 insertions(+), 34 deletions(-)
diff --git a/cpp/src/arrow/compute/api_scalar.cc
b/cpp/src/arrow/compute/api_scalar.cc
index 61a16f5f5e..e36a7acabd 100644
--- a/cpp/src/arrow/compute/api_scalar.cc
+++ b/cpp/src/arrow/compute/api_scalar.cc
@@ -325,6 +325,9 @@ static auto kElementWiseAggregateOptionsType =
DataMember("skip_nulls", &ElementWiseAggregateOptions::skip_nulls));
static auto kExtractRegexOptionsType =
GetFunctionOptionsType<ExtractRegexOptions>(
DataMember("pattern", &ExtractRegexOptions::pattern));
+static auto kExtractRegexSpanOptionsType =
+ GetFunctionOptionsType<ExtractRegexSpanOptions>(
+ DataMember("pattern", &ExtractRegexSpanOptions::pattern));
static auto kJoinOptionsType = GetFunctionOptionsType<JoinOptions>(
DataMember("null_handling", &JoinOptions::null_handling),
DataMember("null_replacement", &JoinOptions::null_replacement));
@@ -438,6 +441,12 @@ ExtractRegexOptions::ExtractRegexOptions(std::string
pattern)
ExtractRegexOptions::ExtractRegexOptions() : ExtractRegexOptions("") {}
constexpr char ExtractRegexOptions::kTypeName[];
+ExtractRegexSpanOptions::ExtractRegexSpanOptions(std::string pattern)
+ : FunctionOptions(internal::kExtractRegexSpanOptionsType),
+ pattern(std::move(pattern)) {}
+ExtractRegexSpanOptions::ExtractRegexSpanOptions() :
ExtractRegexSpanOptions("") {}
+constexpr char ExtractRegexSpanOptions::kTypeName[];
+
JoinOptions::JoinOptions(NullHandlingBehavior null_handling, std::string
null_replacement)
: FunctionOptions(internal::kJoinOptionsType),
null_handling(null_handling),
@@ -684,6 +693,7 @@ void RegisterScalarOptions(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunctionOptionsType(kDayOfWeekOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kElementWiseAggregateOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kExtractRegexOptionsType));
+ DCHECK_OK(registry->AddFunctionOptionsType(kExtractRegexSpanOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kJoinOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kListSliceOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kMakeStructOptionsType));
diff --git a/cpp/src/arrow/compute/api_scalar.h
b/cpp/src/arrow/compute/api_scalar.h
index 0e5a388b10..492ea05f6d 100644
--- a/cpp/src/arrow/compute/api_scalar.h
+++ b/cpp/src/arrow/compute/api_scalar.h
@@ -265,6 +265,16 @@ class ARROW_EXPORT ExtractRegexOptions : public
FunctionOptions {
std::string pattern;
};
+class ARROW_EXPORT ExtractRegexSpanOptions : public FunctionOptions {
+ public:
+ explicit ExtractRegexSpanOptions(std::string pattern);
+ ExtractRegexSpanOptions();
+ static constexpr char const kTypeName[] = "ExtractRegexSpanOptions";
+
+ /// Regular expression with named capture fields
+ std::string pattern;
+};
+
/// Options for IsIn and IndexIn functions
class ARROW_EXPORT SetLookupOptions : public FunctionOptions {
public:
diff --git a/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc
b/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc
index e58f7b065a..6f02432d3d 100644
--- a/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc
@@ -22,6 +22,7 @@
#include <string>
#include "arrow/array/builder_nested.h"
+#include "arrow/array/builder_primitive.h"
#include "arrow/compute/kernels/scalar_string_internal.h"
#include "arrow/result.h"
#include "arrow/util/config.h"
@@ -2184,20 +2185,12 @@ void AddAsciiStringReplaceSubstring(FunctionRegistry*
registry) {
using ExtractRegexState = OptionsWrapper<ExtractRegexOptions>;
-// TODO cache this once per ExtractRegexOptions
-struct ExtractRegexData {
- // Use unique_ptr<> because RE2 is non-movable (for ARROW_ASSIGN_OR_RAISE)
- std::unique_ptr<RE2> regex;
- std::vector<std::string> group_names;
-
- static Result<ExtractRegexData> Make(const ExtractRegexOptions& options,
- bool is_utf8 = true) {
- ExtractRegexData data(options.pattern, is_utf8);
- RETURN_NOT_OK(RegexStatus(*data.regex));
-
- const int group_count = data.regex->NumberOfCapturingGroups();
- const auto& name_map = data.regex->CapturingGroupNames();
- data.group_names.reserve(group_count);
+struct BaseExtractRegexData {
+ Status Init() {
+ RETURN_NOT_OK(RegexStatus(*regex));
+ const int group_count = regex->NumberOfCapturingGroups();
+ const auto& name_map = regex->CapturingGroupNames();
+ group_names.reserve(group_count);
for (int i = 0; i < group_count; i++) {
auto item = name_map.find(i + 1); // re2 starts counting from 1
@@ -2205,8 +2198,27 @@ struct ExtractRegexData {
// XXX should we instead just create fields with an empty name?
return Status::Invalid("Regular expression contains unnamed groups");
}
- data.group_names.emplace_back(item->second);
+ group_names.emplace_back(item->second);
}
+ return Status::OK();
+ }
+
+ int64_t num_groups() const { return
static_cast<int64_t>(group_names.size()); }
+
+ std::unique_ptr<RE2> regex;
+ std::vector<std::string> group_names;
+
+ protected:
+ explicit BaseExtractRegexData(const std::string& pattern, bool is_utf8 =
true)
+ : regex(new RE2(pattern, MakeRE2Options(is_utf8))) {}
+};
+
+// TODO cache this once per ExtractRegexOptions
+struct ExtractRegexData : public BaseExtractRegexData {
+ static Result<ExtractRegexData> Make(const ExtractRegexOptions& options,
+ bool is_utf8 = true) {
+ ExtractRegexData data(options.pattern, is_utf8);
+ ARROW_RETURN_NOT_OK(data.Init());
return data;
}
@@ -2220,7 +2232,7 @@ struct ExtractRegexData {
// of each field in the output struct type.
DCHECK(is_base_binary_like(input_type->id()));
FieldVector fields;
- fields.reserve(group_names.size());
+ fields.reserve(num_groups());
std::shared_ptr<DataType> owned_type = input_type->GetSharedPtr();
std::transform(group_names.begin(), group_names.end(),
std::back_inserter(fields),
[&](const std::string& name) { return field(name,
owned_type); });
@@ -2229,7 +2241,7 @@ struct ExtractRegexData {
private:
explicit ExtractRegexData(const std::string& pattern, bool is_utf8 = true)
- : regex(new RE2(pattern, MakeRE2Options(is_utf8))) {}
+ : BaseExtractRegexData(pattern, is_utf8) {}
};
Result<TypeHolder> ResolveExtractRegexOutput(KernelContext* ctx,
@@ -2240,7 +2252,7 @@ Result<TypeHolder>
ResolveExtractRegexOutput(KernelContext* ctx,
}
struct ExtractRegexBase {
- const ExtractRegexData& data;
+ const BaseExtractRegexData& data;
const int group_count;
std::vector<re2::StringPiece> found_values;
std::vector<RE2::Arg> args;
@@ -2248,9 +2260,9 @@ struct ExtractRegexBase {
const RE2::Arg** args_pointers_start;
const RE2::Arg* null_arg = nullptr;
- explicit ExtractRegexBase(const ExtractRegexData& data)
+ explicit ExtractRegexBase(const BaseExtractRegexData& data)
: data(data),
- group_count(static_cast<int>(data.group_names.size())),
+ group_count(static_cast<int>(data.num_groups())),
found_values(group_count) {
args.reserve(group_count);
args_pointers.reserve(group_count);
@@ -2280,25 +2292,23 @@ struct ExtractRegex : public ExtractRegexBase {
static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult*
out) {
ExtractRegexOptions options = ExtractRegexState::Get(ctx);
ARROW_ASSIGN_OR_RAISE(auto data, ExtractRegexData::Make(options,
Type::is_utf8));
- return ExtractRegex{data}.Extract(ctx, batch, out);
+ return ExtractRegex(data).Extract(ctx, batch, out);
}
Status Extract(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
- // TODO: why is this needed? Type resolution should already be
- // done and the output type set in the output variable
- ARROW_ASSIGN_OR_RAISE(TypeHolder out_type,
data.ResolveOutputType(batch.GetTypes()));
- DCHECK_NE(out_type.type, nullptr);
- std::shared_ptr<DataType> type = out_type.GetSharedPtr();
-
- std::unique_ptr<ArrayBuilder> array_builder;
- RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), type, &array_builder));
+ DCHECK_NE(out->array_data(), nullptr);
+ std::shared_ptr<DataType> type = out->array_data()->type;
+ ARROW_ASSIGN_OR_RAISE(std::unique_ptr<ArrayBuilder> array_builder,
+ MakeBuilder(type, ctx->memory_pool()));
StructBuilder* struct_builder =
checked_cast<StructBuilder*>(array_builder.get());
+ ARROW_RETURN_NOT_OK(struct_builder->Reserve(batch[0].length()));
std::vector<BuilderType*> field_builders;
field_builders.reserve(group_count);
for (int i = 0; i < group_count; i++) {
field_builders.push_back(
checked_cast<BuilderType*>(struct_builder->field_builder(i)));
+ RETURN_NOT_OK(field_builders.back()->Reserve(batch[0].length()));
}
auto visit_null = [&]() { return struct_builder->AppendNull(); };
@@ -2347,6 +2357,142 @@ void AddAsciiStringExtractRegex(FunctionRegistry*
registry) {
}
DCHECK_OK(registry->AddFunction(std::move(func)));
}
+
+struct ExtractRegexSpanData : public BaseExtractRegexData {
+ static Result<ExtractRegexSpanData> Make(const std::string& pattern,
+ bool is_utf8 = true) {
+ auto data = ExtractRegexSpanData(pattern, is_utf8);
+ ARROW_RETURN_NOT_OK(data.Init());
+ return data;
+ }
+
+ Result<TypeHolder> ResolveOutputType(const std::vector<TypeHolder>& types)
const {
+ const DataType* input_type = types[0].type;
+ if (input_type == nullptr) {
+ return nullptr;
+ }
+ DCHECK(is_base_binary_like(input_type->id()));
+ FieldVector fields;
+ fields.reserve(num_groups());
+ auto index_type = is_binary_like(input_type->id()) ? int32() : int64();
+ for (const auto& group_name : group_names) {
+ // list size is 2 as every span contains position and length
+ fields.push_back(field(group_name, fixed_size_list(index_type, 2)));
+ }
+ return struct_(std::move(fields));
+ }
+
+ private:
+ ExtractRegexSpanData(const std::string& pattern, const bool is_utf8)
+ : BaseExtractRegexData(pattern, is_utf8) {}
+};
+
+template <typename Type>
+struct ExtractRegexSpan : ExtractRegexBase {
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+ using BuilderType = typename TypeTraits<Type>::BuilderType;
+ using offset_type = typename Type::offset_type;
+ using OffsetBuilderType =
+ typename TypeTraits<typename
CTypeTraits<offset_type>::ArrowType>::BuilderType;
+ using OffsetCType =
+ typename TypeTraits<typename CTypeTraits<offset_type>::ArrowType>::CType;
+
+ using ExtractRegexBase::ExtractRegexBase;
+
+ static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult*
out) {
+ auto options = OptionsWrapper<ExtractRegexSpanOptions>::Get(ctx);
+ ARROW_ASSIGN_OR_RAISE(auto data,
+ ExtractRegexSpanData::Make(options.pattern,
Type::is_utf8));
+ return ExtractRegexSpan{data}.Extract(ctx, batch, out);
+ }
+
+ Status Extract(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
+ DCHECK_NE(out->array_data(), nullptr);
+ std::shared_ptr<DataType> out_type = out->array_data()->type;
+ ARROW_ASSIGN_OR_RAISE(auto out_builder, MakeBuilder(out_type,
ctx->memory_pool()));
+ StructBuilder* struct_builder =
checked_cast<StructBuilder*>(out_builder.get());
+ ARROW_RETURN_NOT_OK(struct_builder->Reserve(batch[0].array.length));
+
+ std::vector<FixedSizeListBuilder*> span_builders;
+ std::vector<OffsetBuilderType*> array_builders;
+ span_builders.reserve(group_count);
+ array_builders.reserve(group_count);
+ for (int i = 0; i < group_count; i++) {
+ span_builders.push_back(
+
checked_cast<FixedSizeListBuilder*>(struct_builder->field_builder(i)));
+ array_builders.push_back(
+
checked_cast<OffsetBuilderType*>(span_builders.back()->value_builder()));
+ RETURN_NOT_OK(span_builders.back()->Reserve(batch[0].length()));
+ RETURN_NOT_OK(array_builders.back()->Reserve(2 * batch[0].length()));
+ }
+
+ auto visit_null = [&]() { return struct_builder->AppendNull(); };
+ auto visit_value = [&](std::string_view element) -> Status {
+ if (Match(element)) {
+ for (int i = 0; i < group_count; i++) {
+ // https://github.com/google/re2/issues/24#issuecomment-97653183
+ if (found_values[i].data() != nullptr) {
+ int64_t begin = found_values[i].data() - element.data();
+ int64_t size = found_values[i].size();
+ array_builders[i]->UnsafeAppend(static_cast<OffsetCType>(begin));
+ array_builders[i]->UnsafeAppend(static_cast<OffsetCType>(size));
+ ARROW_RETURN_NOT_OK(span_builders[i]->Append());
+ } else {
+ ARROW_RETURN_NOT_OK(span_builders[i]->AppendNull());
+ }
+ }
+ ARROW_RETURN_NOT_OK(struct_builder->Append());
+ } else {
+ ARROW_RETURN_NOT_OK(struct_builder->AppendNull());
+ }
+ return Status::OK();
+ };
+ ARROW_RETURN_NOT_OK(
+ VisitArraySpanInline<Type>(batch[0].array, visit_value, visit_null));
+
+ ARROW_ASSIGN_OR_RAISE(auto out_array, struct_builder->Finish());
+ out->value = std::move(out_array->data());
+ return Status::OK();
+ }
+};
+
+const FunctionDoc extract_regex_span_doc(
+ "Extract string spans captured by a regex pattern",
+ ("For each string in strings, match the regular expression and, if\n"
+ "successful, emit a struct with field names and values coming from the\n"
+ "regular expression's named capture groups. Each struct field value\n"
+ "will be a fixed_size_list(offset_type, 2) where offset_type is int32\n"
+ "or int64, depending on the input string type. The two elements in\n"
+ "each fixed-size list are the index and the length of the substring\n"
+ "matched by the corresponding named capture group.\n"
+ "\n"
+ "If the input is null or the regular expression fails matching,\n"
+ "a null output value is emitted.\n"
+ "\n"
+ "Regular expression matching is done using the Google RE2 library."),
+ {"strings"}, "ExtractRegexSpanOptions", /*options_required=*/true);
+
+Result<TypeHolder> ResolveExtractRegexSpanOutputType(
+ KernelContext* ctx, const std::vector<TypeHolder>& types) {
+ auto options = OptionsWrapper<ExtractRegexSpanOptions>::Get(*ctx->state());
+ ARROW_ASSIGN_OR_RAISE(auto span,
ExtractRegexSpanData::Make(options.pattern));
+ return span.ResolveOutputType(types);
+}
+
+void AddAsciiStringExtractRegexSpan(FunctionRegistry* registry) {
+ auto func = std::make_shared<ScalarFunction>("extract_regex_span",
Arity::Unary(),
+ extract_regex_span_doc);
+ OutputType output_type(ResolveExtractRegexSpanOutputType);
+ for (const auto& type : BaseBinaryTypes()) {
+ ScalarKernel kernel({type}, output_type,
+ GenerateVarBinaryToVarBinary<ExtractRegexSpan>(type),
+ OptionsWrapper<ExtractRegexSpanOptions>::Init);
+ kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
+ kernel.mem_allocation = MemAllocation::NO_PREALLOCATE;
+ DCHECK_OK(func->AddKernel(std::move(kernel)));
+ }
+ DCHECK_OK(registry->AddFunction(func));
+}
#endif // ARROW_WITH_RE2
// ----------------------------------------------------------------------
@@ -3457,6 +3603,7 @@ void RegisterScalarStringAscii(FunctionRegistry*
registry) {
AddAsciiStringSplitWhitespace(registry);
#ifdef ARROW_WITH_RE2
AddAsciiStringSplitRegex(registry);
+ AddAsciiStringExtractRegexSpan(registry);
#endif
AddAsciiStringJoin(registry);
AddAsciiStringRepeat(registry);
diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc
b/cpp/src/arrow/compute/kernels/scalar_string_test.cc
index 38455dc146..672839f3cc 100644
--- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc
@@ -314,6 +314,7 @@ TYPED_TEST(TestBinaryKernels, NonUtf8Regex) {
this->MakeArray({"\xfc\x40", "this \xfc\x40 that
\xfc\x40"}),
this->MakeArray({"bazz", "this bazz that \xfc\x40"}),
&options);
}
+ // TODO the following test is broken (GH-45735)
{
ExtractRegexOptions options("(?P<letter>[\\xfc])(?P<digit>\\d)");
auto null_bitmap = std::make_shared<Buffer>("0");
@@ -370,6 +371,7 @@ TYPED_TEST(TestBinaryKernels, NonUtf8WithNullRegex) {
this->template MakeArray<std::string>({{"\x00\x40", 2}}),
this->type(), R"(["bazz"])", &options);
}
+ // TODO the following test is broken (GH-45735)
{
ExtractRegexOptions options("(?P<null>[\\x00])(?P<digit>\\d)");
auto null_bitmap = std::make_shared<Buffer>("0");
@@ -1959,6 +1961,62 @@ TYPED_TEST(TestBaseBinaryKernels, ExtractRegex) {
&options);
}
+TYPED_TEST(TestBaseBinaryKernels, ExtractRegexSpan) {
+ ExtractRegexSpanOptions options{"(?P<letter>[ab]+)(?P<digit>\\d+)"};
+ auto type_fixe_size_list = is_binary_like(this->type()->id()) ? int32() :
int64();
+ auto out_type = struct_({field("letter",
fixed_size_list(type_fixe_size_list, 2)),
+ field("digit", fixed_size_list(type_fixe_size_list,
2))});
+ this->CheckUnary("extract_regex_span", R"([])", out_type, R"([])", &options);
+ this->CheckUnary("extract_regex_span", R"([
null,"123ab","cd123ab","cd123abef"])",
+ out_type, R"([null,null,null,null])", &options);
+ this->CheckUnary(
+ "extract_regex_span",
+ R"(["a1", "b2", "c3",
null,"123ab","abb12","abc13","cedbb15","cedaabb125efg"])",
+ out_type,
+ R"([{"letter":[0,1], "digit":[1,1]},
+ {"letter":[0,1], "digit":[1,1]},
+ null,
+ null,
+ null,
+ {"letter":[0,3], "digit":[3,2]},
+ null,
+ {"letter":[3,2], "digit":[5,2]},
+ {"letter":[3,4], "digit":[7,3]}])",
+ &options);
+ this->CheckUnary("extract_regex_span", R"([
"a3","b2","cdaa123","cdab123ef"])",
+ out_type,
+ R"([{"letter":[0,1], "digit":[1,1]},
+ {"letter":[0,1], "digit":[1,1]},
+ {"letter":[2,2], "digit":[4,3]},
+ {"letter":[2,2], "digit":[4,3]}])",
+ &options);
+}
+
+TYPED_TEST(TestBaseBinaryKernels, ExtractRegexSpanCaptureOption) {
+ ExtractRegexSpanOptions options{"(?P<foo>foo)?(?P<digit>\\d+)?"};
+ auto type_fixe_size_list = is_binary_like(this->type()->id()) ? int32() :
int64();
+ auto out_type = struct_({field("foo", fixed_size_list(type_fixe_size_list,
2)),
+ field("digit", fixed_size_list(type_fixe_size_list,
2))});
+ this->CheckUnary("extract_regex_span", R"([])", out_type, R"([])", &options);
+ this->CheckUnary("extract_regex_span",
R"(["foo","foo123","abcfoo123","abc",null])",
+ out_type,
+ R"([{"foo":[0,3],"digit":null},
+ {"foo":[0,3],"digit":[3,3]},
+ {"foo":null,"digit":null},
+ {"foo":null,"digit":null},
+ null])",
+ &options);
+ options = ExtractRegexSpanOptions{"(?P<foo>foo)(?P<digit>\\d+)?"};
+ this->CheckUnary("extract_regex_span",
R"(["foo123","foo","123","abc","abcfoo"])",
+ out_type,
+ R"([{"foo":[0,3],"digit":[3,3]},
+ {"foo":[0,3],"digit":null},
+ null,
+ null,
+ {"foo":[3,3],"digit":null}])",
+ &options);
+}
+
TYPED_TEST(TestBaseBinaryKernels, ExtractRegexNoCapture) {
// XXX Should we accept this or is it a user error?
ExtractRegexOptions options{"foo"};
@@ -1967,11 +2025,24 @@ TYPED_TEST(TestBaseBinaryKernels,
ExtractRegexNoCapture) {
R"([{}, null, null])", &options);
}
+TYPED_TEST(TestBaseBinaryKernels, ExtractRegexSpanNoCapture) {
+ // XXX Should we accept this or is it a user error?
+ ExtractRegexSpanOptions options{"foo"};
+ auto type = struct_({});
+ this->CheckUnary("extract_regex_span", R"(["oofoo", "bar", null])", type,
+ R"([{}, null, null])", &options);
+}
+
TYPED_TEST(TestBaseBinaryKernels, ExtractRegexNoOptions) {
Datum input = ArrayFromJSON(this->type(), "[]");
ASSERT_RAISES(Invalid, CallFunction("extract_regex", {input}));
}
+TYPED_TEST(TestBaseBinaryKernels, ExtractRegexSpanNoOptions) {
+ Datum input = ArrayFromJSON(this->type(), "[]");
+ ASSERT_RAISES(Invalid, CallFunction("extract_regex_span", {input}));
+}
+
TYPED_TEST(TestBaseBinaryKernels, ExtractRegexInvalid) {
Datum input = ArrayFromJSON(this->type(), "[]");
ExtractRegexOptions options{"invalid["};
@@ -1985,6 +2056,18 @@ TYPED_TEST(TestBaseBinaryKernels, ExtractRegexInvalid) {
CallFunction("extract_regex", {input}, &options));
}
+TYPED_TEST(TestBaseBinaryKernels, ExtractRegexSpanInvalid) {
+ Datum input = ArrayFromJSON(this->type(), "[]");
+ ExtractRegexSpanOptions options{"invalid["};
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, ::testing::HasSubstr("Invalid regular expression: missing ]"),
+ CallFunction("extract_regex_span", {input}, &options));
+ options = ExtractRegexSpanOptions{"(.)"};
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, ::testing::HasSubstr("Regular expression contains unnamed
groups"),
+ CallFunction("extract_regex_span", {input}, &options));
+}
+
#endif
TYPED_TEST(TestStringKernels, Strptime) {
diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst
index 8825ffebf2..57673dfe1f 100644
--- a/docs/source/cpp/compute.rst
+++ b/docs/source/cpp/compute.rst
@@ -1128,17 +1128,26 @@ when a positive ``max_splits`` is given.
String component extraction
~~~~~~~~~~~~~~~~~~~~~~~~~~~
-+---------------+-------+------------------------+-------------+-------------------------------+-------+
-| Function name | Arity | Input types | Output type | Options class
| Notes |
-+===============+=======+========================+=============+===============================+=======+
-| extract_regex | Unary | Binary- or String-like | Struct |
:struct:`ExtractRegexOptions` | \(1) |
-+---------------+-------+------------------------+-------------+-------------------------------+-------+
++--------------------+-------+------------------------+-------------+-----------------------------------+-------+
+| Function name | Arity | Input types | Output type | Options
class | Notes |
++====================+=======+========================+=============+===================================+=======+
+| extract_regex | Unary | Binary- or String-like | Struct |
:struct:`ExtractRegexOptions` | \(1) |
++--------------------+-------+------------------------+-------------+-----------------------------------+-------+
+| extract_regex_span | Unary | Binary- or String-like | Struct |
:struct:`ExtractRegexSpanOptions` | \(2) |
++--------------------+-------+------------------------+-------------+-----------------------------------+-------+
* \(1) Extract substrings defined by a regular expression using the Google RE2
library. The output struct field names refer to the named capture groups,
e.g. 'letter' and 'digit' for the regular expression
``(?P<letter>[ab])(?P<digit>\\d)``.
+* \(2) Extract the offset and length of substrings defined by a regular
expression
+ using the Google RE2 library. The output struct field names refer to the
named
+ capture groups, e.g. 'letter' and 'digit' for the regular expression
+ ``(?P<letter>[ab])(?P<digit>\\d)``. Each output struct field is a fixed size
list
+ of two integers: the index to the start of the captured group and the length
+ of the captured group, respectively.
+
String joining
~~~~~~~~~~~~~~
diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx
index 63370c938b..db6cf5b45d 100644
--- a/python/pyarrow/_compute.pyx
+++ b/python/pyarrow/_compute.pyx
@@ -1222,6 +1222,25 @@ class ExtractRegexOptions(_ExtractRegexOptions):
self._set_options(pattern)
+cdef class _ExtractRegexSpanOptions(FunctionOptions):
+ def _set_options(self, pattern):
+ self.wrapped.reset(new CExtractRegexSpanOptions(tobytes(pattern)))
+
+
+class ExtractRegexSpanOptions(_ExtractRegexSpanOptions):
+ """
+ Options for the `extract_regex_span` function.
+
+ Parameters
+ ----------
+ pattern : str
+ Regular expression with named capture fields.
+ """
+
+ def __init__(self, pattern):
+ self._set_options(pattern)
+
+
cdef class _SliceOptions(FunctionOptions):
def _set_options(self, start, stop, step):
self.wrapped.reset(new CSliceOptions(start, stop, step))
diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py
index 8040cf9ff0..1809c74afc 100644
--- a/python/pyarrow/compute.py
+++ b/python/pyarrow/compute.py
@@ -40,6 +40,7 @@ from pyarrow._compute import ( # noqa
RunEndEncodeOptions,
ElementWiseAggregateOptions,
ExtractRegexOptions,
+ ExtractRegexSpanOptions,
FilterOptions,
IndexOptions,
JoinOptions,
diff --git a/python/pyarrow/includes/libarrow.pxd
b/python/pyarrow/includes/libarrow.pxd
index c3ddaba88f..f9fa091171 100644
--- a/python/pyarrow/includes/libarrow.pxd
+++ b/python/pyarrow/includes/libarrow.pxd
@@ -2500,6 +2500,11 @@ cdef extern from "arrow/compute/api.h" namespace
"arrow::compute" nogil:
CExtractRegexOptions(c_string pattern)
c_string pattern
+ cdef cppclass CExtractRegexSpanOptions \
+ "arrow::compute::ExtractRegexSpanOptions"(CFunctionOptions):
+ CExtractRegexSpanOptions(c_string pattern)
+ c_string pattern
+
cdef cppclass CCastOptions" arrow::compute::CastOptions"(CFunctionOptions):
CCastOptions()
CCastOptions(c_bool safe)
diff --git a/python/pyarrow/tests/test_compute.py
b/python/pyarrow/tests/test_compute.py
index 8a756a262b..73506fedfc 100644
--- a/python/pyarrow/tests/test_compute.py
+++ b/python/pyarrow/tests/test_compute.py
@@ -152,6 +152,7 @@ def test_option_class_equality(request):
pc.RunEndEncodeOptions(),
pc.ElementWiseAggregateOptions(skip_nulls=True),
pc.ExtractRegexOptions("pattern"),
+ pc.ExtractRegexSpanOptions("pattern"),
pc.FilterOptions(),
pc.IndexOptions(pa.scalar(1)),
pc.JoinOptions(),
@@ -1092,6 +1093,16 @@ def test_extract_regex():
assert struct.tolist() == expected
+def test_extract_regex_span():
+ ar = pa.array(['a1', 'zb234z'])
+ expected = [{'letter': [0, 1], 'digit': [1, 1]},
+ {'letter': [1, 1], 'digit': [2, 3]}]
+ struct = pc.extract_regex_span(ar,
pattern=r'(?P<letter>[ab])(?P<digit>\d+)')
+ assert struct.tolist() == expected
+ struct = pc.extract_regex_span(ar, r'(?P<letter>[ab])(?P<digit>\d+)')
+ assert struct.tolist() == expected
+
+
def test_binary_join():
ar_list = pa.array([['foo', 'bar'], None, []])
expected = pa.array(['foo-bar', None, ''])