This is an automated email from the ASF dual-hosted git repository.
maplefu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 4f91c8f144 GH-43759: [C++] Acero: Minor code enhancement for Join
(#43760)
4f91c8f144 is described below
commit 4f91c8f144125bd147c25cb49ac0071c8d28765c
Author: mwish <[email protected]>
AuthorDate: Thu Aug 29 23:38:41 2024 +0800
GH-43759: [C++] Acero: Minor code enhancement for Join (#43760)
### Rationale for this change
Minor style enhancement for join
### What changes are included in this PR?
Minor style enhancement for join
### Are these changes tested?
Covered by existing
### Are there any user-facing changes?
no
* GitHub Issue: #43759
Authored-by: mwish <[email protected]>
Signed-off-by: mwish <[email protected]>
---
cpp/src/arrow/acero/hash_join_dict.cc | 9 ++--
cpp/src/arrow/acero/hash_join_node.cc | 16 +++----
cpp/src/arrow/acero/hash_join_node.h | 6 +--
cpp/src/arrow/acero/swiss_join.cc | 7 +--
cpp/src/arrow/compute/light_array_internal.cc | 68 +++++++++++++--------------
cpp/src/arrow/compute/light_array_internal.h | 6 ++-
cpp/src/arrow/compute/light_array_test.cc | 4 +-
7 files changed, 57 insertions(+), 59 deletions(-)
diff --git a/cpp/src/arrow/acero/hash_join_dict.cc
b/cpp/src/arrow/acero/hash_join_dict.cc
index 3aef08e6e9..8db9dddb2c 100644
--- a/cpp/src/arrow/acero/hash_join_dict.cc
+++ b/cpp/src/arrow/acero/hash_join_dict.cc
@@ -225,21 +225,20 @@ Status HashJoinDictBuild::Init(ExecContext* ctx,
std::shared_ptr<Array> dictiona
return Status::OK();
}
- dictionary_ = dictionary;
+ dictionary_ = std::move(dictionary);
// Initialize encoder
RowEncoder encoder;
- std::vector<TypeHolder> encoder_types;
- encoder_types.emplace_back(value_type_);
+ std::vector<TypeHolder> encoder_types{value_type_};
encoder.Init(encoder_types, ctx);
// Encode all dictionary values
- int64_t length = dictionary->data()->length;
+ int64_t length = dictionary_->data()->length;
if (length >= std::numeric_limits<int32_t>::max()) {
return Status::Invalid(
"Dictionary length in hash join must fit into signed 32-bit integer.");
}
- RETURN_NOT_OK(encoder.EncodeAndAppend(ExecSpan({*dictionary->data()},
length)));
+ RETURN_NOT_OK(encoder.EncodeAndAppend(ExecSpan({*dictionary_->data()},
length)));
std::vector<int32_t> entries_to_take;
diff --git a/cpp/src/arrow/acero/hash_join_node.cc
b/cpp/src/arrow/acero/hash_join_node.cc
index 67f902e64b..80dd163ced 100644
--- a/cpp/src/arrow/acero/hash_join_node.cc
+++ b/cpp/src/arrow/acero/hash_join_node.cc
@@ -61,30 +61,30 @@ Result<std::vector<FieldRef>>
HashJoinSchema::ComputePayload(
const std::vector<FieldRef>& filter, const std::vector<FieldRef>& keys) {
// payload = (output + filter) - keys, with no duplicates
std::unordered_set<int> payload_fields;
- for (auto ref : output) {
+ for (const auto& ref : output) {
ARROW_ASSIGN_OR_RAISE(auto match, ref.FindOne(schema));
payload_fields.insert(match[0]);
}
- for (auto ref : filter) {
+ for (const auto& ref : filter) {
ARROW_ASSIGN_OR_RAISE(auto match, ref.FindOne(schema));
payload_fields.insert(match[0]);
}
- for (auto ref : keys) {
+ for (const auto& ref : keys) {
ARROW_ASSIGN_OR_RAISE(auto match, ref.FindOne(schema));
payload_fields.erase(match[0]);
}
std::vector<FieldRef> payload_refs;
- for (auto ref : output) {
+ for (const auto& ref : output) {
ARROW_ASSIGN_OR_RAISE(auto match, ref.FindOne(schema));
if (payload_fields.find(match[0]) != payload_fields.end()) {
payload_refs.push_back(ref);
payload_fields.erase(match[0]);
}
}
- for (auto ref : filter) {
+ for (const auto& ref : filter) {
ARROW_ASSIGN_OR_RAISE(auto match, ref.FindOne(schema));
if (payload_fields.find(match[0]) != payload_fields.end()) {
payload_refs.push_back(ref);
@@ -198,7 +198,7 @@ Status HashJoinSchema::ValidateSchemas(JoinType join_type,
const Schema& left_sc
return Status::Invalid("Different number of key fields on left (",
left_keys.size(),
") and right (", right_keys.size(), ") side of the
join");
}
- if (left_keys.size() < 1) {
+ if (left_keys.empty()) {
return Status::Invalid("Join key cannot be empty");
}
for (size_t i = 0; i < left_keys.size() + right_keys.size(); ++i) {
@@ -432,7 +432,7 @@ Status
HashJoinSchema::CollectFilterColumns(std::vector<FieldRef>& left_filter,
indices[0] -= left_schema.num_fields();
FieldPath corrected_path(std::move(indices));
if (right_seen_paths.find(*path) == right_seen_paths.end()) {
- right_filter.push_back(corrected_path);
+ right_filter.emplace_back(corrected_path);
right_seen_paths.emplace(std::move(corrected_path));
}
} else if (left_seen_paths.find(*path) == left_seen_paths.end()) {
@@ -698,7 +698,7 @@ class HashJoinNode : public ExecNode, public TracedNode {
std::shared_ptr<Schema> output_schema,
std::unique_ptr<HashJoinSchema> schema_mgr, Expression filter,
std::unique_ptr<HashJoinImpl> impl)
- : ExecNode(plan, inputs, {"left", "right"},
+ : ExecNode(plan, std::move(inputs), {"left", "right"},
/*output_schema=*/std::move(output_schema)),
TracedNode(this),
join_type_(join_options.join_type),
diff --git a/cpp/src/arrow/acero/hash_join_node.h
b/cpp/src/arrow/acero/hash_join_node.h
index ad60019cea..19745b8675 100644
--- a/cpp/src/arrow/acero/hash_join_node.h
+++ b/cpp/src/arrow/acero/hash_join_node.h
@@ -65,9 +65,9 @@ class ARROW_ACERO_EXPORT HashJoinSchema {
std::shared_ptr<Schema> MakeOutputSchema(const std::string&
left_field_name_suffix,
const std::string&
right_field_name_suffix);
- bool LeftPayloadIsEmpty() { return PayloadIsEmpty(0); }
+ bool LeftPayloadIsEmpty() const { return PayloadIsEmpty(0); }
- bool RightPayloadIsEmpty() { return PayloadIsEmpty(1); }
+ bool RightPayloadIsEmpty() const { return PayloadIsEmpty(1); }
static int kMissingField() {
return SchemaProjectionMaps<HashJoinProjection>::kMissingField;
@@ -88,7 +88,7 @@ class ARROW_ACERO_EXPORT HashJoinSchema {
const SchemaProjectionMap&
right_to_filter,
const Expression& filter);
- bool PayloadIsEmpty(int side) {
+ bool PayloadIsEmpty(int side) const {
assert(side == 0 || side == 1);
return proj_maps[side].num_cols(HashJoinProjection::PAYLOAD) == 0;
}
diff --git a/cpp/src/arrow/acero/swiss_join.cc
b/cpp/src/arrow/acero/swiss_join.cc
index 4d0c8187ac..6c783110af 100644
--- a/cpp/src/arrow/acero/swiss_join.cc
+++ b/cpp/src/arrow/acero/swiss_join.cc
@@ -1667,7 +1667,7 @@ Result<std::shared_ptr<ArrayData>>
JoinResultMaterialize::FlushBuildColumn(
const std::shared_ptr<DataType>& data_type, const RowArray* row_array, int
column_id,
uint32_t* row_ids) {
ResizableArrayData output;
- output.Init(data_type, pool_, bit_util::Log2(num_rows_));
+ RETURN_NOT_OK(output.Init(data_type, pool_, bit_util::Log2(num_rows_)));
for (size_t i = 0; i <= null_ranges_.size(); ++i) {
int row_id_begin =
@@ -2247,8 +2247,9 @@ Result<ExecBatch>
JoinResidualFilter::MaterializeFilterInput(
build_schemas_->map(HashJoinProjection::FILTER,
HashJoinProjection::PAYLOAD);
for (int i = 0; i < num_build_cols; ++i) {
ResizableArrayData column_data;
- column_data.Init(build_schemas_->data_type(HashJoinProjection::FILTER,
i), pool_,
- bit_util::Log2(num_batch_rows));
+ RETURN_NOT_OK(
+
column_data.Init(build_schemas_->data_type(HashJoinProjection::FILTER, i),
+ pool_, bit_util::Log2(num_batch_rows)));
if (auto idx = to_key.get(i); idx != SchemaProjectionMap::kMissingField)
{
RETURN_NOT_OK(build_keys_->DecodeSelected(&column_data, idx,
num_batch_rows,
key_ids_maybe_null, pool_));
diff --git a/cpp/src/arrow/compute/light_array_internal.cc
b/cpp/src/arrow/compute/light_array_internal.cc
index 4f235925d0..e4b1f1b8cd 100644
--- a/cpp/src/arrow/compute/light_array_internal.cc
+++ b/cpp/src/arrow/compute/light_array_internal.cc
@@ -118,10 +118,9 @@ Result<KeyColumnMetadata> ColumnMetadataFromDataType(
const std::shared_ptr<DataType>& type) {
const bool is_extension = type->id() == Type::EXTENSION;
const std::shared_ptr<DataType>& typ =
- is_extension
- ?
arrow::internal::checked_pointer_cast<ExtensionType>(type->GetSharedPtr())
- ->storage_type()
- : type;
+ is_extension ? arrow::internal::checked_cast<const
ExtensionType*>(type.get())
+ ->storage_type()
+ : type;
if (typ->id() == Type::DICTIONARY) {
auto bit_width =
@@ -205,22 +204,25 @@ Status ColumnArraysFromExecBatch(const ExecBatch& batch,
column_arrays);
}
-void ResizableArrayData::Init(const std::shared_ptr<DataType>& data_type,
- MemoryPool* pool, int log_num_rows_min) {
+Status ResizableArrayData::Init(const std::shared_ptr<DataType>& data_type,
+ MemoryPool* pool, int log_num_rows_min) {
#ifndef NDEBUG
if (num_rows_allocated_ > 0) {
- ARROW_DCHECK(data_type_ != NULLPTR);
- KeyColumnMetadata metadata_before =
- ColumnMetadataFromDataType(data_type_).ValueOrDie();
- KeyColumnMetadata metadata_after =
ColumnMetadataFromDataType(data_type).ValueOrDie();
+ ARROW_DCHECK(data_type_ != nullptr);
+ const KeyColumnMetadata& metadata_before = column_metadata_;
+ ARROW_ASSIGN_OR_RAISE(KeyColumnMetadata metadata_after,
+ ColumnMetadataFromDataType(data_type));
ARROW_DCHECK(metadata_before.is_fixed_length ==
metadata_after.is_fixed_length &&
metadata_before.fixed_length == metadata_after.fixed_length);
}
#endif
+ ARROW_DCHECK(data_type != nullptr);
+ ARROW_ASSIGN_OR_RAISE(column_metadata_,
ColumnMetadataFromDataType(data_type));
Clear(/*release_buffers=*/false);
log_num_rows_min_ = log_num_rows_min;
data_type_ = data_type;
pool_ = pool;
+ return Status::OK();
}
void ResizableArrayData::Clear(bool release_buffers) {
@@ -246,8 +248,6 @@ Status ResizableArrayData::ResizeFixedLengthBuffers(int
num_rows_new) {
num_rows_allocated_new *= 2;
}
- KeyColumnMetadata column_metadata =
ColumnMetadataFromDataType(data_type_).ValueOrDie();
-
if (buffers_[kFixedLengthBuffer] == NULLPTR) {
ARROW_DCHECK(buffers_[kValidityBuffer] == NULLPTR &&
buffers_[kVariableLengthBuffer] == NULLPTR);
@@ -258,8 +258,8 @@ Status ResizableArrayData::ResizeFixedLengthBuffers(int
num_rows_new) {
bit_util::BytesForBits(num_rows_allocated_new) + kNumPaddingBytes,
pool_));
memset(mutable_data(kValidityBuffer), 0,
bit_util::BytesForBits(num_rows_allocated_new) + kNumPaddingBytes);
- if (column_metadata.is_fixed_length) {
- if (column_metadata.fixed_length == 0) {
+ if (column_metadata_.is_fixed_length) {
+ if (column_metadata_.fixed_length == 0) {
ARROW_ASSIGN_OR_RAISE(
buffers_[kFixedLengthBuffer],
AllocateResizableBuffer(
@@ -271,7 +271,7 @@ Status ResizableArrayData::ResizeFixedLengthBuffers(int
num_rows_new) {
ARROW_ASSIGN_OR_RAISE(
buffers_[kFixedLengthBuffer],
AllocateResizableBuffer(
- num_rows_allocated_new * column_metadata.fixed_length +
kNumPaddingBytes,
+ num_rows_allocated_new * column_metadata_.fixed_length +
kNumPaddingBytes,
pool_));
}
} else {
@@ -300,15 +300,15 @@ Status ResizableArrayData::ResizeFixedLengthBuffers(int
num_rows_new) {
memset(mutable_data(kValidityBuffer) + bytes_for_bits_before, 0,
bytes_for_bits_after - bytes_for_bits_before);
- if (column_metadata.is_fixed_length) {
- if (column_metadata.fixed_length == 0) {
+ if (column_metadata_.is_fixed_length) {
+ if (column_metadata_.fixed_length == 0) {
RETURN_NOT_OK(buffers_[kFixedLengthBuffer]->Resize(
bit_util::BytesForBits(num_rows_allocated_new) +
kNumPaddingBytes));
memset(mutable_data(kFixedLengthBuffer) + bytes_for_bits_before, 0,
bytes_for_bits_after - bytes_for_bits_before);
} else {
RETURN_NOT_OK(buffers_[kFixedLengthBuffer]->Resize(
- num_rows_allocated_new * column_metadata.fixed_length +
kNumPaddingBytes));
+ num_rows_allocated_new * column_metadata_.fixed_length +
kNumPaddingBytes));
}
} else {
RETURN_NOT_OK(buffers_[kFixedLengthBuffer]->Resize(
@@ -323,10 +323,7 @@ Status ResizableArrayData::ResizeFixedLengthBuffers(int
num_rows_new) {
}
Status ResizableArrayData::ResizeVaryingLengthBuffer() {
- KeyColumnMetadata column_metadata;
- column_metadata = ColumnMetadataFromDataType(data_type_).ValueOrDie();
-
- if (!column_metadata.is_fixed_length) {
+ if (!column_metadata_.is_fixed_length) {
int64_t min_new_size =
buffers_[kFixedLengthBuffer]->data_as<int32_t>()[num_rows_];
ARROW_DCHECK(var_len_buf_size_ > 0);
if (var_len_buf_size_ < min_new_size) {
@@ -343,23 +340,19 @@ Status ResizableArrayData::ResizeVaryingLengthBuffer() {
}
KeyColumnArray ResizableArrayData::column_array() const {
- KeyColumnMetadata column_metadata;
- column_metadata = ColumnMetadataFromDataType(data_type_).ValueOrDie();
- return KeyColumnArray(column_metadata, num_rows_,
+ return KeyColumnArray(column_metadata_, num_rows_,
buffers_[kValidityBuffer]->mutable_data(),
buffers_[kFixedLengthBuffer]->mutable_data(),
buffers_[kVariableLengthBuffer]->mutable_data());
}
std::shared_ptr<ArrayData> ResizableArrayData::array_data() const {
- KeyColumnMetadata column_metadata;
- column_metadata = ColumnMetadataFromDataType(data_type_).ValueOrDie();
-
- auto valid_count = arrow::internal::CountSetBits(
- buffers_[kValidityBuffer]->data(), /*offset=*/0,
static_cast<int64_t>(num_rows_));
+ auto valid_count =
+ arrow::internal::CountSetBits(buffers_[kValidityBuffer]->data(),
/*bit_offset=*/0,
+ static_cast<int64_t>(num_rows_));
int null_count = static_cast<int>(num_rows_) - static_cast<int>(valid_count);
- if (column_metadata.is_fixed_length) {
+ if (column_metadata_.is_fixed_length) {
return ArrayData::Make(data_type_, num_rows_,
{buffers_[kValidityBuffer],
buffers_[kFixedLengthBuffer]},
null_count);
@@ -493,10 +486,12 @@ Status ExecBatchBuilder::AppendSelected(const
std::shared_ptr<ArrayData>& source
ARROW_DCHECK(num_rows_before >= 0);
int num_rows_after = num_rows_before + num_rows_to_append;
if (target->num_rows() == 0) {
- target->Init(source->type, pool, kLogNumRows);
+ RETURN_NOT_OK(target->Init(source->type, pool, kLogNumRows));
}
RETURN_NOT_OK(target->ResizeFixedLengthBuffers(num_rows_after));
+ // Since target->Init is called before, we can assume that the ColumnMetadata
+ // would never fail to be created
KeyColumnMetadata column_metadata =
ColumnMetadataFromDataType(source->type).ValueOrDie();
@@ -647,11 +642,12 @@ Status ExecBatchBuilder::AppendNulls(const
std::shared_ptr<DataType>& type,
int num_rows_before = target.num_rows();
int num_rows_after = num_rows_before + num_rows_to_append;
if (target.num_rows() == 0) {
- target.Init(type, pool, kLogNumRows);
+ RETURN_NOT_OK(target.Init(type, pool, kLogNumRows));
}
RETURN_NOT_OK(target.ResizeFixedLengthBuffers(num_rows_after));
- KeyColumnMetadata column_metadata =
ColumnMetadataFromDataType(type).ValueOrDie();
+ ARROW_ASSIGN_OR_RAISE(KeyColumnMetadata column_metadata,
+ ColumnMetadataFromDataType(type));
// Process fixed length buffer
//
@@ -708,7 +704,7 @@ Status ExecBatchBuilder::AppendSelected(MemoryPool* pool,
const ExecBatch& batch
const Datum& data = batch.values[col_ids ? col_ids[i] : i];
ARROW_DCHECK(data.is_array());
const std::shared_ptr<ArrayData>& array_data = data.array();
- values_[i].Init(array_data->type, pool, kLogNumRows);
+ RETURN_NOT_OK(values_[i].Init(array_data->type, pool, kLogNumRows));
}
}
@@ -739,7 +735,7 @@ Status ExecBatchBuilder::AppendNulls(MemoryPool* pool,
if (values_.empty()) {
values_.resize(types.size());
for (size_t i = 0; i < types.size(); ++i) {
- values_[i].Init(types[i], pool, kLogNumRows);
+ RETURN_NOT_OK(values_[i].Init(types[i], pool, kLogNumRows));
}
}
diff --git a/cpp/src/arrow/compute/light_array_internal.h
b/cpp/src/arrow/compute/light_array_internal.h
index 995c421199..b8e48f096b 100644
--- a/cpp/src/arrow/compute/light_array_internal.h
+++ b/cpp/src/arrow/compute/light_array_internal.h
@@ -295,8 +295,8 @@ class ARROW_EXPORT ResizableArrayData {
/// \param pool The pool to make allocations on
/// \param log_num_rows_min All resize operations will allocate at least
enough
/// space for (1 << log_num_rows_min) rows
- void Init(const std::shared_ptr<DataType>& data_type, MemoryPool* pool,
- int log_num_rows_min);
+ Status Init(const std::shared_ptr<DataType>& data_type, MemoryPool* pool,
+ int log_num_rows_min);
/// \brief Resets the array back to an empty state
/// \param release_buffers If true then allocated memory is released and the
@@ -351,6 +351,8 @@ class ARROW_EXPORT ResizableArrayData {
static constexpr int64_t kNumPaddingBytes = 64;
int log_num_rows_min_;
std::shared_ptr<DataType> data_type_;
+ // Would be valid if data_type_ != NULLPTR.
+ KeyColumnMetadata column_metadata_{};
MemoryPool* pool_;
int num_rows_;
int num_rows_allocated_;
diff --git a/cpp/src/arrow/compute/light_array_test.cc
b/cpp/src/arrow/compute/light_array_test.cc
index cc02d489d1..98a1ab8b7a 100644
--- a/cpp/src/arrow/compute/light_array_test.cc
+++ b/cpp/src/arrow/compute/light_array_test.cc
@@ -295,7 +295,7 @@ TEST(ResizableArrayData, Basic) {
arrow::internal::checked_pointer_cast<FixedWidthType>(type)->bit_width() / 8;
{
ResizableArrayData array;
- array.Init(type, pool.get(), /*log_num_rows_min=*/16);
+ ASSERT_OK(array.Init(type, pool.get(), /*log_num_rows_min=*/16));
ASSERT_EQ(0, array.num_rows());
ASSERT_OK(array.ResizeFixedLengthBuffers(2));
ASSERT_EQ(2, array.num_rows());
@@ -330,7 +330,7 @@ TEST(ResizableArrayData, Binary) {
ARROW_SCOPED_TRACE("Type: ", type->ToString());
{
ResizableArrayData array;
- array.Init(type, pool.get(), /*log_num_rows_min=*/4);
+ ASSERT_OK(array.Init(type, pool.get(), /*log_num_rows_min=*/4));
ASSERT_EQ(0, array.num_rows());
ASSERT_OK(array.ResizeFixedLengthBuffers(2));
ASSERT_EQ(2, array.num_rows());