This is an automated email from the ASF dual-hosted git repository.
yiguolei pushed a commit to branch branch-1.2-lts
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/branch-1.2-lts by this push:
new ed4d5bf951f [fix](hash join) fix stack overflow caused by evaluate
case expr on huge build block (#28871)
ed4d5bf951f is described below
commit ed4d5bf951f8c827fb0e778010775a2a20bce7c1
Author: TengJianPing <[email protected]>
AuthorDate: Fri Dec 22 17:13:38 2023 +0800
[fix](hash join) fix stack overflow caused by evaluate case expr on huge
build block (#28871)
---
be/src/vec/columns/column_vector.cpp | 3 ++-
be/src/vec/exec/join/vhash_join_node.cpp | 11 ++++++-----
be/src/vec/exec/join/vhash_join_node.h | 1 +
be/src/vec/functions/function_case.h | 20 ++++++++++----------
be/src/vec/functions/function_string.cpp | 28 +++++++++++++++++++++++++---
be/src/vec/functions/multiply.cpp | 3 ++-
6 files changed, 46 insertions(+), 20 deletions(-)
diff --git a/be/src/vec/columns/column_vector.cpp
b/be/src/vec/columns/column_vector.cpp
index e656b97f2e7..93b59bce7ad 100644
--- a/be/src/vec/columns/column_vector.cpp
+++ b/be/src/vec/columns/column_vector.cpp
@@ -479,7 +479,8 @@ ColumnPtr ColumnVector<T>::replicate(const
IColumn::Offsets& offsets) const {
res_data.reserve(offsets.back());
// vectorized this code to speed up
- IColumn::Offset counts[size];
+ auto counts_uptr = std::unique_ptr<IColumn::Offset[]>(new
IColumn::Offset[size]);
+ IColumn::Offset* counts = counts_uptr.get();
for (ssize_t i = 0; i < size; ++i) {
counts[i] = offsets[i] - offsets[i - 1];
}
diff --git a/be/src/vec/exec/join/vhash_join_node.cpp
b/be/src/vec/exec/join/vhash_join_node.cpp
index 3efea5d8c6a..b3c7794965e 100644
--- a/be/src/vec/exec/join/vhash_join_node.cpp
+++ b/be/src/vec/exec/join/vhash_join_node.cpp
@@ -656,7 +656,7 @@ Status HashJoinNode::_materialize_build_side(RuntimeState*
state) {
RETURN_IF_ERROR(child(1)->open(state));
SCOPED_TIMER(_build_timer);
- MutableBlock mutable_block(child(1)->row_desc().tuple_descriptors());
+ MutableBlock mutable_block;
uint8_t index = 0;
int64_t last_mem_used = 0;
@@ -669,6 +669,7 @@ Status HashJoinNode::_materialize_build_side(RuntimeState*
state) {
Block block;
// If eos or have already met a null value using short-circuit
strategy, we do not need to pull
// data from data.
+ _build_col_ids.resize(_build_expr_ctxs.size());
while (!eos && !_short_circuit_for_null_in_probe_side) {
block.clear_column_data();
RETURN_IF_CANCELLED(state);
@@ -679,6 +680,8 @@ Status HashJoinNode::_materialize_build_side(RuntimeState*
state) {
_mem_used += block.allocated_bytes();
if (block.rows() != 0) {
+ RETURN_IF_ERROR(_do_evaluate(block, _build_expr_ctxs,
*_build_expr_call_timer,
+ _build_col_ids));
SCOPED_TIMER(_build_side_merge_block_timer);
RETURN_IF_CATCH_BAD_ALLOC(mutable_block.merge(block));
}
@@ -889,8 +892,6 @@ Status HashJoinNode::_process_build_block(RuntimeState*
state, Block& block, uin
ColumnRawPtrs raw_ptrs(_build_expr_ctxs.size());
ColumnUInt8::MutablePtr null_map_val;
- std::vector<int> res_col_ids(_build_expr_ctxs.size());
- RETURN_IF_ERROR(_do_evaluate(block, _build_expr_ctxs,
*_build_expr_call_timer, res_col_ids));
if (_join_op == TJoinOp::LEFT_OUTER_JOIN || _join_op ==
TJoinOp::FULL_OUTER_JOIN) {
_convert_block_to_null(block);
}
@@ -898,7 +899,7 @@ Status HashJoinNode::_process_build_block(RuntimeState*
state, Block& block, uin
// so we have to initialize this flag by the first build block.
if (!_has_set_need_null_map_for_build) {
_has_set_need_null_map_for_build = true;
- _set_build_ignore_flag(block, res_col_ids);
+ _set_build_ignore_flag(block, _build_col_ids);
}
if (_short_circuit_for_null_in_build_side || _build_side_ignore_null) {
null_map_val = ColumnUInt8::create();
@@ -906,7 +907,7 @@ Status HashJoinNode::_process_build_block(RuntimeState*
state, Block& block, uin
}
// Get the key column that needs to be built
- Status st = _extract_join_column<true>(block, null_map_val, raw_ptrs,
res_col_ids);
+ Status st = _extract_join_column<true>(block, null_map_val, raw_ptrs,
_build_col_ids);
st = std::visit(
Overload {
diff --git a/be/src/vec/exec/join/vhash_join_node.h
b/be/src/vec/exec/join/vhash_join_node.h
index 97af3818b8d..376ed8f377c 100644
--- a/be/src/vec/exec/join/vhash_join_node.h
+++ b/be/src/vec/exec/join/vhash_join_node.h
@@ -341,6 +341,7 @@ private:
std::unordered_map<const Block*, std::vector<int>> _inserted_rows;
std::vector<IRuntimeFilter*> _runtime_filters;
+ std::vector<int> _build_col_ids;
};
} // namespace vectorized
} // namespace doris
diff --git a/be/src/vec/functions/function_case.h
b/be/src/vec/functions/function_case.h
index 6b628339d6c..487e1c99503 100644
--- a/be/src/vec/functions/function_case.h
+++ b/be/src/vec/functions/function_case.h
@@ -139,9 +139,9 @@ public:
int rows_count = column_holder.rows_count;
// `then` data index corresponding to each row of results, 0
represents `else`.
- int then_idx[rows_count];
- int* __restrict then_idx_ptr = then_idx;
- memset(then_idx_ptr, 0, sizeof(then_idx));
+ auto then_idx_uptr = std::unique_ptr<int[]>(new int[rows_count]);
+ int* __restrict then_idx_ptr = then_idx_uptr.get();
+ memset(then_idx_ptr, 0, rows_count * sizeof(int));
for (int row_idx = 0; row_idx < column_holder.rows_count; row_idx++) {
for (int i = 1; i < column_holder.pair_count; i++) {
@@ -169,7 +169,7 @@ public:
}
auto result_column_ptr = data_type->create_column();
- update_result_normal(result_column_ptr, then_idx, column_holder);
+ update_result_normal(result_column_ptr, then_idx_ptr, column_holder);
block.replace_by_position(result, std::move(result_column_ptr));
return Status::OK();
}
@@ -185,9 +185,9 @@ public:
int rows_count = column_holder.rows_count;
// `then` data index corresponding to each row of results, 0
represents `else`.
- uint8_t then_idx[rows_count];
- uint8_t* __restrict then_idx_ptr = then_idx;
- memset(then_idx_ptr, 0, sizeof(then_idx));
+ auto then_idx_uptr = std::unique_ptr<uint8_t[]>(new
uint8_t[rows_count]);
+ uint8_t* __restrict then_idx_ptr = then_idx_uptr.get();
+ memset(then_idx_ptr, 0, rows_count);
auto case_column_ptr = column_holder.when_ptrs[0].value_or(nullptr);
@@ -224,13 +224,13 @@ public:
}
}
- return execute_update_result<ColumnType, then_null>(data_type, result,
block, then_idx,
+ return execute_update_result<ColumnType, then_null>(data_type, result,
block, then_idx_ptr,
column_holder);
}
template <typename ColumnType, bool then_null>
Status execute_update_result(const DataTypePtr& data_type, size_t result,
Block& block,
- uint8* then_idx, CaseWhenColumnHolder&
column_holder) {
+ const uint8* then_idx, CaseWhenColumnHolder&
column_holder) {
auto result_column_ptr = data_type->create_column();
if constexpr (std::is_same_v<ColumnType, ColumnString> ||
@@ -253,7 +253,7 @@ public:
}
template <typename IndexType>
- void update_result_normal(MutableColumnPtr& result_column_ptr, IndexType*
then_idx,
+ void update_result_normal(MutableColumnPtr& result_column_ptr, const
IndexType* __restrict then_idx,
CaseWhenColumnHolder& column_holder) {
for (int row_idx = 0; row_idx < column_holder.rows_count; row_idx++) {
if constexpr (!has_else) {
diff --git a/be/src/vec/functions/function_string.cpp
b/be/src/vec/functions/function_string.cpp
index 52066e38130..5bc98c742e9 100644
--- a/be/src/vec/functions/function_string.cpp
+++ b/be/src/vec/functions/function_string.cpp
@@ -438,6 +438,7 @@ public:
}
};
+static constexpr int MAX_STACK_CIPHER_LEN = 1024 * 64;
struct UnHexImpl {
static constexpr auto name = "unhex";
using ReturnType = DataTypeString;
@@ -509,9 +510,16 @@ struct UnHexImpl {
StringOP::push_empty_string(i, dst_data, dst_offsets);
continue;
}
+ char dst_array[MAX_STACK_CIPHER_LEN];
+ char* dst = dst_array;
int cipher_len = srclen / 2;
- char dst[cipher_len];
+ std::unique_ptr<char[]> dst_uptr;
+ if (cipher_len > MAX_STACK_CIPHER_LEN) {
+ dst_uptr.reset(new char[cipher_len]);
+ dst = dst_uptr.get();
+ }
+
int outlen = hex_decode(source, srclen, dst);
if (outlen < 0) {
@@ -581,8 +589,15 @@ struct ToBase64Impl {
continue;
}
+ char dst_array[MAX_STACK_CIPHER_LEN];
+ char* dst = dst_array;
+
int cipher_len = (int)(4.0 * ceil((double)srclen / 3.0));
- char dst[cipher_len];
+ std::unique_ptr<char[]> dst_uptr;
+ if (cipher_len > MAX_STACK_CIPHER_LEN) {
+ dst_uptr.reset(new char[cipher_len]);
+ dst = dst_uptr.get();
+ }
int outlen = base64_encode((const unsigned char*)source, srclen,
(unsigned char*)dst);
if (outlen < 0) {
@@ -621,8 +636,15 @@ struct FromBase64Impl {
continue;
}
+ char dst_array[MAX_STACK_CIPHER_LEN];
+ char* dst = dst_array;
+
int cipher_len = srclen;
- char dst[cipher_len];
+ std::unique_ptr<char[]> dst_uptr;
+ if (cipher_len > MAX_STACK_CIPHER_LEN) {
+ dst_uptr.reset(new char[cipher_len]);
+ dst = dst_uptr.get();
+ }
int outlen = base64_decode(source, srclen, dst);
if (outlen < 0) {
diff --git a/be/src/vec/functions/multiply.cpp
b/be/src/vec/functions/multiply.cpp
index e55a8549e56..6b72cba49e6 100644
--- a/be/src/vec/functions/multiply.cpp
+++ b/be/src/vec/functions/multiply.cpp
@@ -45,7 +45,8 @@ struct MultiplyImpl {
const ColumnDecimal128::Container& b,
ColumnDecimal128::Container& c) {
size_t size = c.size();
- int8 sgn[size];
+ auto sng_uptr = std::unique_ptr<int8[]>(new int8[size]);
+ int8* sgn = sng_uptr.get();
for (int i = 0; i < size; i++) {
sgn[i] = ((DecimalV2Value(a[i]).value() > 0) &&
(DecimalV2Value(b[i]).value() > 0)) ||
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]